use signed counters in traversa_algorithm, and add some more asserts (#1208)

use signed counters in traversal_algorithm, and add some more asserts
This commit is contained in:
Arvid Norberg 2016-10-14 08:03:09 -04:00 committed by GitHub
parent e0c8ad738d
commit 540b5b28d7
7 changed files with 43 additions and 45 deletions

View File

@ -63,8 +63,6 @@ struct find_data : traversal_algorithm
virtual char const* name() const; virtual char const* name() const;
node_id const& target() const { return m_target; }
protected: protected:
virtual void done(); virtual void done();

View File

@ -73,8 +73,8 @@ struct TORRENT_EXTRA_EXPORT traversal_algorithm : boost::noncopyable
void add_entry(node_id const& id, udp::endpoint const& addr, unsigned char flags); void add_entry(node_id const& id, udp::endpoint const& addr, unsigned char flags);
traversal_algorithm(node& dht_node, node_id const& target); traversal_algorithm(node& dht_node, node_id const& target);
int invoke_count() const { return m_invoke_count; } int invoke_count() const { TORRENT_ASSERT(m_invoke_count >= 0); return m_invoke_count; }
int branch_factor() const { return m_branch_factor; } int branch_factor() const { TORRENT_ASSERT(m_branch_factor >= 0); return m_branch_factor; }
node& get_node() const { return m_node; } node& get_node() const { return m_node; }
@ -97,13 +97,19 @@ protected:
virtual bool invoke(observer_ptr) { return false; } virtual bool invoke(observer_ptr) { return false; }
int num_responses() const { return m_responses; }
int num_timeouts() const { return m_timeouts; }
node& m_node; node& m_node;
std::vector<observer_ptr> m_results; std::vector<observer_ptr> m_results;
private:
node_id const m_target; node_id const m_target;
std::uint16_t m_invoke_count; std::int16_t m_invoke_count = 0;
std::uint16_t m_branch_factor; std::int16_t m_branch_factor = 3;
std::uint16_t m_responses; std::int16_t m_responses = 0;
std::uint16_t m_timeouts; std::int16_t m_timeouts = 0;
// the IP addresses of the nodes in m_results // the IP addresses of the nodes in m_results
std::set<std::uint32_t> m_peer4_prefixes; std::set<std::uint32_t> m_peer4_prefixes;

View File

@ -95,7 +95,7 @@ void find_data::start()
if (m_results.empty()) if (m_results.empty())
{ {
std::vector<node_entry> nodes; std::vector<node_entry> nodes;
m_node.m_table.find_node(m_target, nodes, routing_table::include_failed); m_node.m_table.find_node(target(), nodes, routing_table::include_failed);
for (auto const& n : nodes) for (auto const& n : nodes)
{ {

View File

@ -61,7 +61,7 @@ void get_item::got_data(bdecode_node const& v,
if (!m_data.empty()) return; if (!m_data.empty()) return;
sha1_hash incoming_target = item_target_id(v.data_section()); sha1_hash incoming_target = item_target_id(v.data_section());
if (incoming_target != m_target) return; if (incoming_target != target()) return;
m_data.assign(v); m_data.assign(v);
@ -79,7 +79,7 @@ void get_item::got_data(bdecode_node const& v,
std::string const salt_copy(m_data.salt()); std::string const salt_copy(m_data.salt());
sha1_hash const incoming_target = item_target_id(salt_copy, pk); sha1_hash const incoming_target = item_target_id(salt_copy, pk);
if (incoming_target != m_target) return; if (incoming_target != target()) return;
// this is mutable data. If it passes the signature // this is mutable data. If it passes the signature
// check, remember it. Just keep the version with // check, remember it. Just keep the version with
@ -137,18 +137,14 @@ observer_ptr get_item::new_observer(udp::endpoint const& ep
bool get_item::invoke(observer_ptr o) bool get_item::invoke(observer_ptr o)
{ {
if (m_done) if (m_done) return false;
{
m_invoke_count = -1;
return false;
}
entry e; entry e;
e["y"] = "q"; e["y"] = "q";
entry& a = e["a"]; entry& a = e["a"];
e["q"] = "get"; e["q"] = "get";
a["target"] = m_target.to_string(); a["target"] = target().to_string();
return m_node.m_rpc.invoke(e, o->target_ep(), o); return m_node.m_rpc.invoke(e, o->target_ep(), o);
} }
@ -168,7 +164,7 @@ void get_item::done()
#if TORRENT_USE_ASSERTS #if TORRENT_USE_ASSERTS
if (m_data.is_mutable()) if (m_data.is_mutable())
{ {
TORRENT_ASSERT(m_target == item_target_id(m_data.salt(), m_data.pk())); TORRENT_ASSERT(target() == item_target_id(m_data.salt(), m_data.pk()));
} }
#endif #endif
} }

View File

@ -141,23 +141,19 @@ char const* get_peers::name() const { return "get_peers"; }
bool get_peers::invoke(observer_ptr o) bool get_peers::invoke(observer_ptr o)
{ {
if (m_done) if (m_done) return false;
{
m_invoke_count = -1;
return false;
}
entry e; entry e;
e["y"] = "q"; e["y"] = "q";
entry& a = e["a"]; entry& a = e["a"];
e["q"] = "get_peers"; e["q"] = "get_peers";
a["info_hash"] = m_target.to_string(); a["info_hash"] = target().to_string();
if (m_noseeds) a["noseed"] = 1; if (m_noseeds) a["noseed"] = 1;
if (m_node.observer() != nullptr) if (m_node.observer() != nullptr)
{ {
m_node.observer()->outgoing_get_peers(m_target, m_target, o->target_ep()); m_node.observer()->outgoing_get_peers(target(), target(), o->target_ep());
} }
m_node.stats_counters().inc_stats_counter(counters::dht_get_peers_out); m_node.stats_counters().inc_stats_counter(counters::dht_get_peers_out);
@ -217,7 +213,7 @@ bool obfuscated_get_peers::invoke(observer_ptr o)
if (!m_obfuscated) return get_peers::invoke(o); if (!m_obfuscated) return get_peers::invoke(o);
node_id const& id = o->id(); node_id const& id = o->id();
int const shared_prefix = 160 - distance_exp(id, m_target); int const shared_prefix = 160 - distance_exp(id, target());
// when we get close to the target zone in the DHT // when we get close to the target zone in the DHT
// start using the correct info-hash, in order to // start using the correct info-hash, in order to
@ -254,12 +250,12 @@ bool obfuscated_get_peers::invoke(observer_ptr o)
// now, obfuscate the bits past shared_prefix + 3 // now, obfuscate the bits past shared_prefix + 3
node_id mask = generate_prefix_mask(shared_prefix + 3); node_id mask = generate_prefix_mask(shared_prefix + 3);
node_id obfuscated_target = generate_random_id() & ~mask; node_id obfuscated_target = generate_random_id() & ~mask;
obfuscated_target |= m_target & mask; obfuscated_target |= target() & mask;
a["info_hash"] = obfuscated_target.to_string(); a["info_hash"] = obfuscated_target.to_string();
if (m_node.observer() != nullptr) if (m_node.observer() != nullptr)
{ {
m_node.observer()->outgoing_get_peers(m_target, obfuscated_target m_node.observer()->outgoing_get_peers(target(), obfuscated_target
, o->target_ep()); , o->target_ep());
} }
@ -275,7 +271,7 @@ void obfuscated_get_peers::done()
// oops, we failed to switch over to the non-obfuscated // oops, we failed to switch over to the non-obfuscated
// mode early enough. do it now // mode early enough. do it now
auto ta = std::make_shared<get_peers>(m_node, m_target auto ta = std::make_shared<get_peers>(m_node, target()
, m_data_callback, m_nodes_callback, m_noseeds); , m_data_callback, m_nodes_callback, m_noseeds);
// don't call these when the obfuscated_get_peers // don't call these when the obfuscated_get_peers

View File

@ -74,20 +74,16 @@ void put_data::done()
#ifndef TORRENT_DISABLE_LOGGING #ifndef TORRENT_DISABLE_LOGGING
get_node().observer()->log(dht_logger::traversal, "[%p] %s DONE, response %d, timeout %d" get_node().observer()->log(dht_logger::traversal, "[%p] %s DONE, response %d, timeout %d"
, static_cast<void*>(this), name(), m_responses, m_timeouts); , static_cast<void*>(this), name(), num_responses(), num_timeouts());
#endif #endif
m_put_callback(m_data, m_responses); m_put_callback(m_data, num_responses());
traversal_algorithm::done(); traversal_algorithm::done();
} }
bool put_data::invoke(observer_ptr o) bool put_data::invoke(observer_ptr o)
{ {
if (m_done) if (m_done) return false;
{
m_invoke_count = -1;
return false;
}
// TODO: what if o is not an isntance of put_data_observer? This need to be // TODO: what if o is not an isntance of put_data_observer? This need to be
// redesigned for better type saftey. // redesigned for better type saftey.

View File

@ -84,10 +84,6 @@ traversal_algorithm::traversal_algorithm(
, node_id const& target) , node_id const& target)
: m_node(dht_node) : m_node(dht_node)
, m_target(target) , m_target(target)
, m_invoke_count(0)
, m_branch_factor(3)
, m_responses(0)
, m_timeouts(0)
{ {
#ifndef TORRENT_DISABLE_LOGGING #ifndef TORRENT_DISABLE_LOGGING
dht_observer* logger = get_node().observer(); dht_observer* logger = get_node().observer();
@ -301,6 +297,8 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
if (m_results.empty()) return; if (m_results.empty()) return;
bool decrement_branch_factor = false;
TORRENT_ASSERT(o->flags & observer::flag_queried); TORRENT_ASSERT(o->flags & observer::flag_queried);
if (flags & short_timeout) if (flags & short_timeout)
{ {
@ -311,7 +309,10 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
// around for some more, but open up the slot // around for some more, but open up the slot
// by increasing the branch factor // by increasing the branch factor
if ((o->flags & observer::flag_short_timeout) == 0) if ((o->flags & observer::flag_short_timeout) == 0)
{
TORRENT_ASSERT(m_branch_factor < (std::numeric_limits<std::int16_t>::max)());
++m_branch_factor; ++m_branch_factor;
}
o->flags |= observer::flag_short_timeout; o->flags |= observer::flag_short_timeout;
#ifndef TORRENT_DISABLE_LOGGING #ifndef TORRENT_DISABLE_LOGGING
dht_observer* logger = get_node().observer(); dht_observer* logger = get_node().observer();
@ -333,8 +334,7 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
o->flags |= observer::flag_failed; o->flags |= observer::flag_failed;
// if this flag is set, it means we increased the // if this flag is set, it means we increased the
// branch factor for it, and we should restore it // branch factor for it, and we should restore it
if (o->flags & observer::flag_short_timeout) decrement_branch_factor = (o->flags & observer::flag_short_timeout) != 0;
--m_branch_factor;
#ifndef TORRENT_DISABLE_LOGGING #ifndef TORRENT_DISABLE_LOGGING
dht_observer* logger = get_node().observer(); dht_observer* logger = get_node().observer();
@ -356,12 +356,18 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
--m_invoke_count; --m_invoke_count;
} }
if (flags & prevent_request) // this is another reason to decrement the branch factor, to prevent another
// request from filling this slot. Only ever decrement once per response though
decrement_branch_factor |= (flags & prevent_request);
if (decrement_branch_factor)
{ {
TORRENT_ASSERT(m_branch_factor > 0);
--m_branch_factor; --m_branch_factor;
if (m_branch_factor <= 0) m_branch_factor = 1; if (m_branch_factor <= 0) m_branch_factor = 1;
} }
bool is_done = add_requests();
bool const is_done = add_requests();
if (is_done) done(); if (is_done) done();
} }
@ -484,7 +490,7 @@ bool traversal_algorithm::add_requests()
o->flags |= observer::flag_queried; o->flags |= observer::flag_queried;
if (invoke(*i)) if (invoke(*i))
{ {
TORRENT_ASSERT(m_invoke_count < (std::numeric_limits<std::uint16_t>::max)()); TORRENT_ASSERT(m_invoke_count < (std::numeric_limits<std::int16_t>::max)());
++m_invoke_count; ++m_invoke_count;
++outstanding; ++outstanding;
} }