forked from premiere/premiere-libtorrent
use signed counters in traversa_algorithm, and add some more asserts (#1208)
use signed counters in traversal_algorithm, and add some more asserts
This commit is contained in:
parent
e0c8ad738d
commit
540b5b28d7
|
@ -63,8 +63,6 @@ struct find_data : traversal_algorithm
|
|||
|
||||
virtual char const* name() const;
|
||||
|
||||
node_id const& target() const { return m_target; }
|
||||
|
||||
protected:
|
||||
|
||||
virtual void done();
|
||||
|
|
|
@ -73,8 +73,8 @@ struct TORRENT_EXTRA_EXPORT traversal_algorithm : boost::noncopyable
|
|||
void add_entry(node_id const& id, udp::endpoint const& addr, unsigned char flags);
|
||||
|
||||
traversal_algorithm(node& dht_node, node_id const& target);
|
||||
int invoke_count() const { return m_invoke_count; }
|
||||
int branch_factor() const { return m_branch_factor; }
|
||||
int invoke_count() const { TORRENT_ASSERT(m_invoke_count >= 0); return m_invoke_count; }
|
||||
int branch_factor() const { TORRENT_ASSERT(m_branch_factor >= 0); return m_branch_factor; }
|
||||
|
||||
node& get_node() const { return m_node; }
|
||||
|
||||
|
@ -97,13 +97,19 @@ protected:
|
|||
|
||||
virtual bool invoke(observer_ptr) { return false; }
|
||||
|
||||
int num_responses() const { return m_responses; }
|
||||
int num_timeouts() const { return m_timeouts; }
|
||||
|
||||
node& m_node;
|
||||
std::vector<observer_ptr> m_results;
|
||||
|
||||
private:
|
||||
|
||||
node_id const m_target;
|
||||
std::uint16_t m_invoke_count;
|
||||
std::uint16_t m_branch_factor;
|
||||
std::uint16_t m_responses;
|
||||
std::uint16_t m_timeouts;
|
||||
std::int16_t m_invoke_count = 0;
|
||||
std::int16_t m_branch_factor = 3;
|
||||
std::int16_t m_responses = 0;
|
||||
std::int16_t m_timeouts = 0;
|
||||
|
||||
// the IP addresses of the nodes in m_results
|
||||
std::set<std::uint32_t> m_peer4_prefixes;
|
||||
|
|
|
@ -95,7 +95,7 @@ void find_data::start()
|
|||
if (m_results.empty())
|
||||
{
|
||||
std::vector<node_entry> nodes;
|
||||
m_node.m_table.find_node(m_target, nodes, routing_table::include_failed);
|
||||
m_node.m_table.find_node(target(), nodes, routing_table::include_failed);
|
||||
|
||||
for (auto const& n : nodes)
|
||||
{
|
||||
|
|
|
@ -61,7 +61,7 @@ void get_item::got_data(bdecode_node const& v,
|
|||
if (!m_data.empty()) return;
|
||||
|
||||
sha1_hash incoming_target = item_target_id(v.data_section());
|
||||
if (incoming_target != m_target) return;
|
||||
if (incoming_target != target()) return;
|
||||
|
||||
m_data.assign(v);
|
||||
|
||||
|
@ -79,7 +79,7 @@ void get_item::got_data(bdecode_node const& v,
|
|||
|
||||
std::string const salt_copy(m_data.salt());
|
||||
sha1_hash const incoming_target = item_target_id(salt_copy, pk);
|
||||
if (incoming_target != m_target) return;
|
||||
if (incoming_target != target()) return;
|
||||
|
||||
// this is mutable data. If it passes the signature
|
||||
// check, remember it. Just keep the version with
|
||||
|
@ -137,18 +137,14 @@ observer_ptr get_item::new_observer(udp::endpoint const& ep
|
|||
|
||||
bool get_item::invoke(observer_ptr o)
|
||||
{
|
||||
if (m_done)
|
||||
{
|
||||
m_invoke_count = -1;
|
||||
return false;
|
||||
}
|
||||
if (m_done) return false;
|
||||
|
||||
entry e;
|
||||
e["y"] = "q";
|
||||
entry& a = e["a"];
|
||||
|
||||
e["q"] = "get";
|
||||
a["target"] = m_target.to_string();
|
||||
a["target"] = target().to_string();
|
||||
|
||||
return m_node.m_rpc.invoke(e, o->target_ep(), o);
|
||||
}
|
||||
|
@ -168,7 +164,7 @@ void get_item::done()
|
|||
#if TORRENT_USE_ASSERTS
|
||||
if (m_data.is_mutable())
|
||||
{
|
||||
TORRENT_ASSERT(m_target == item_target_id(m_data.salt(), m_data.pk()));
|
||||
TORRENT_ASSERT(target() == item_target_id(m_data.salt(), m_data.pk()));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -141,23 +141,19 @@ char const* get_peers::name() const { return "get_peers"; }
|
|||
|
||||
bool get_peers::invoke(observer_ptr o)
|
||||
{
|
||||
if (m_done)
|
||||
{
|
||||
m_invoke_count = -1;
|
||||
return false;
|
||||
}
|
||||
if (m_done) return false;
|
||||
|
||||
entry e;
|
||||
e["y"] = "q";
|
||||
entry& a = e["a"];
|
||||
|
||||
e["q"] = "get_peers";
|
||||
a["info_hash"] = m_target.to_string();
|
||||
a["info_hash"] = target().to_string();
|
||||
if (m_noseeds) a["noseed"] = 1;
|
||||
|
||||
if (m_node.observer() != nullptr)
|
||||
{
|
||||
m_node.observer()->outgoing_get_peers(m_target, m_target, o->target_ep());
|
||||
m_node.observer()->outgoing_get_peers(target(), target(), o->target_ep());
|
||||
}
|
||||
|
||||
m_node.stats_counters().inc_stats_counter(counters::dht_get_peers_out);
|
||||
|
@ -217,7 +213,7 @@ bool obfuscated_get_peers::invoke(observer_ptr o)
|
|||
if (!m_obfuscated) return get_peers::invoke(o);
|
||||
|
||||
node_id const& id = o->id();
|
||||
int const shared_prefix = 160 - distance_exp(id, m_target);
|
||||
int const shared_prefix = 160 - distance_exp(id, target());
|
||||
|
||||
// when we get close to the target zone in the DHT
|
||||
// start using the correct info-hash, in order to
|
||||
|
@ -254,12 +250,12 @@ bool obfuscated_get_peers::invoke(observer_ptr o)
|
|||
// now, obfuscate the bits past shared_prefix + 3
|
||||
node_id mask = generate_prefix_mask(shared_prefix + 3);
|
||||
node_id obfuscated_target = generate_random_id() & ~mask;
|
||||
obfuscated_target |= m_target & mask;
|
||||
obfuscated_target |= target() & mask;
|
||||
a["info_hash"] = obfuscated_target.to_string();
|
||||
|
||||
if (m_node.observer() != nullptr)
|
||||
{
|
||||
m_node.observer()->outgoing_get_peers(m_target, obfuscated_target
|
||||
m_node.observer()->outgoing_get_peers(target(), obfuscated_target
|
||||
, o->target_ep());
|
||||
}
|
||||
|
||||
|
@ -275,7 +271,7 @@ void obfuscated_get_peers::done()
|
|||
// oops, we failed to switch over to the non-obfuscated
|
||||
// mode early enough. do it now
|
||||
|
||||
auto ta = std::make_shared<get_peers>(m_node, m_target
|
||||
auto ta = std::make_shared<get_peers>(m_node, target()
|
||||
, m_data_callback, m_nodes_callback, m_noseeds);
|
||||
|
||||
// don't call these when the obfuscated_get_peers
|
||||
|
|
|
@ -74,20 +74,16 @@ void put_data::done()
|
|||
|
||||
#ifndef TORRENT_DISABLE_LOGGING
|
||||
get_node().observer()->log(dht_logger::traversal, "[%p] %s DONE, response %d, timeout %d"
|
||||
, static_cast<void*>(this), name(), m_responses, m_timeouts);
|
||||
, static_cast<void*>(this), name(), num_responses(), num_timeouts());
|
||||
#endif
|
||||
|
||||
m_put_callback(m_data, m_responses);
|
||||
m_put_callback(m_data, num_responses());
|
||||
traversal_algorithm::done();
|
||||
}
|
||||
|
||||
bool put_data::invoke(observer_ptr o)
|
||||
{
|
||||
if (m_done)
|
||||
{
|
||||
m_invoke_count = -1;
|
||||
return false;
|
||||
}
|
||||
if (m_done) return false;
|
||||
|
||||
// TODO: what if o is not an isntance of put_data_observer? This need to be
|
||||
// redesigned for better type saftey.
|
||||
|
|
|
@ -84,10 +84,6 @@ traversal_algorithm::traversal_algorithm(
|
|||
, node_id const& target)
|
||||
: m_node(dht_node)
|
||||
, m_target(target)
|
||||
, m_invoke_count(0)
|
||||
, m_branch_factor(3)
|
||||
, m_responses(0)
|
||||
, m_timeouts(0)
|
||||
{
|
||||
#ifndef TORRENT_DISABLE_LOGGING
|
||||
dht_observer* logger = get_node().observer();
|
||||
|
@ -301,6 +297,8 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
|
|||
|
||||
if (m_results.empty()) return;
|
||||
|
||||
bool decrement_branch_factor = false;
|
||||
|
||||
TORRENT_ASSERT(o->flags & observer::flag_queried);
|
||||
if (flags & short_timeout)
|
||||
{
|
||||
|
@ -311,7 +309,10 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
|
|||
// around for some more, but open up the slot
|
||||
// by increasing the branch factor
|
||||
if ((o->flags & observer::flag_short_timeout) == 0)
|
||||
{
|
||||
TORRENT_ASSERT(m_branch_factor < (std::numeric_limits<std::int16_t>::max)());
|
||||
++m_branch_factor;
|
||||
}
|
||||
o->flags |= observer::flag_short_timeout;
|
||||
#ifndef TORRENT_DISABLE_LOGGING
|
||||
dht_observer* logger = get_node().observer();
|
||||
|
@ -333,8 +334,7 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
|
|||
o->flags |= observer::flag_failed;
|
||||
// if this flag is set, it means we increased the
|
||||
// branch factor for it, and we should restore it
|
||||
if (o->flags & observer::flag_short_timeout)
|
||||
--m_branch_factor;
|
||||
decrement_branch_factor = (o->flags & observer::flag_short_timeout) != 0;
|
||||
|
||||
#ifndef TORRENT_DISABLE_LOGGING
|
||||
dht_observer* logger = get_node().observer();
|
||||
|
@ -356,12 +356,18 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
|
|||
--m_invoke_count;
|
||||
}
|
||||
|
||||
if (flags & prevent_request)
|
||||
// this is another reason to decrement the branch factor, to prevent another
|
||||
// request from filling this slot. Only ever decrement once per response though
|
||||
decrement_branch_factor |= (flags & prevent_request);
|
||||
|
||||
if (decrement_branch_factor)
|
||||
{
|
||||
TORRENT_ASSERT(m_branch_factor > 0);
|
||||
--m_branch_factor;
|
||||
if (m_branch_factor <= 0) m_branch_factor = 1;
|
||||
}
|
||||
bool is_done = add_requests();
|
||||
|
||||
bool const is_done = add_requests();
|
||||
if (is_done) done();
|
||||
}
|
||||
|
||||
|
@ -484,7 +490,7 @@ bool traversal_algorithm::add_requests()
|
|||
o->flags |= observer::flag_queried;
|
||||
if (invoke(*i))
|
||||
{
|
||||
TORRENT_ASSERT(m_invoke_count < (std::numeric_limits<std::uint16_t>::max)());
|
||||
TORRENT_ASSERT(m_invoke_count < (std::numeric_limits<std::int16_t>::max)());
|
||||
++m_invoke_count;
|
||||
++outstanding;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue