From 540b5b28d702669cab09c49a823633072da87420 Mon Sep 17 00:00:00 2001 From: Arvid Norberg Date: Fri, 14 Oct 2016 08:03:09 -0400 Subject: [PATCH] use signed counters in traversa_algorithm, and add some more asserts (#1208) use signed counters in traversal_algorithm, and add some more asserts --- include/libtorrent/kademlia/find_data.hpp | 2 -- .../kademlia/traversal_algorithm.hpp | 18 +++++++++----- src/kademlia/find_data.cpp | 2 +- src/kademlia/get_item.cpp | 14 ++++------- src/kademlia/get_peers.cpp | 18 ++++++-------- src/kademlia/put_data.cpp | 10 +++----- src/kademlia/traversal_algorithm.cpp | 24 ++++++++++++------- 7 files changed, 43 insertions(+), 45 deletions(-) diff --git a/include/libtorrent/kademlia/find_data.hpp b/include/libtorrent/kademlia/find_data.hpp index 6955e3f10..02a188143 100644 --- a/include/libtorrent/kademlia/find_data.hpp +++ b/include/libtorrent/kademlia/find_data.hpp @@ -63,8 +63,6 @@ struct find_data : traversal_algorithm virtual char const* name() const; - node_id const& target() const { return m_target; } - protected: virtual void done(); diff --git a/include/libtorrent/kademlia/traversal_algorithm.hpp b/include/libtorrent/kademlia/traversal_algorithm.hpp index 9c7640e29..4ab0117eb 100644 --- a/include/libtorrent/kademlia/traversal_algorithm.hpp +++ b/include/libtorrent/kademlia/traversal_algorithm.hpp @@ -73,8 +73,8 @@ struct TORRENT_EXTRA_EXPORT traversal_algorithm : boost::noncopyable void add_entry(node_id const& id, udp::endpoint const& addr, unsigned char flags); traversal_algorithm(node& dht_node, node_id const& target); - int invoke_count() const { return m_invoke_count; } - int branch_factor() const { return m_branch_factor; } + int invoke_count() const { TORRENT_ASSERT(m_invoke_count >= 0); return m_invoke_count; } + int branch_factor() const { TORRENT_ASSERT(m_branch_factor >= 0); return m_branch_factor; } node& get_node() const { return m_node; } @@ -97,13 +97,19 @@ protected: virtual bool invoke(observer_ptr) { return false; } + int num_responses() const { return m_responses; } + int num_timeouts() const { return m_timeouts; } + node& m_node; std::vector m_results; + +private: + node_id const m_target; - std::uint16_t m_invoke_count; - std::uint16_t m_branch_factor; - std::uint16_t m_responses; - std::uint16_t m_timeouts; + std::int16_t m_invoke_count = 0; + std::int16_t m_branch_factor = 3; + std::int16_t m_responses = 0; + std::int16_t m_timeouts = 0; // the IP addresses of the nodes in m_results std::set m_peer4_prefixes; diff --git a/src/kademlia/find_data.cpp b/src/kademlia/find_data.cpp index bafd47e14..0e802b7b3 100644 --- a/src/kademlia/find_data.cpp +++ b/src/kademlia/find_data.cpp @@ -95,7 +95,7 @@ void find_data::start() if (m_results.empty()) { std::vector nodes; - m_node.m_table.find_node(m_target, nodes, routing_table::include_failed); + m_node.m_table.find_node(target(), nodes, routing_table::include_failed); for (auto const& n : nodes) { diff --git a/src/kademlia/get_item.cpp b/src/kademlia/get_item.cpp index f0393e7c1..7047f9ece 100644 --- a/src/kademlia/get_item.cpp +++ b/src/kademlia/get_item.cpp @@ -61,7 +61,7 @@ void get_item::got_data(bdecode_node const& v, if (!m_data.empty()) return; sha1_hash incoming_target = item_target_id(v.data_section()); - if (incoming_target != m_target) return; + if (incoming_target != target()) return; m_data.assign(v); @@ -79,7 +79,7 @@ void get_item::got_data(bdecode_node const& v, std::string const salt_copy(m_data.salt()); sha1_hash const incoming_target = item_target_id(salt_copy, pk); - if (incoming_target != m_target) return; + if (incoming_target != target()) return; // this is mutable data. If it passes the signature // check, remember it. Just keep the version with @@ -137,18 +137,14 @@ observer_ptr get_item::new_observer(udp::endpoint const& ep bool get_item::invoke(observer_ptr o) { - if (m_done) - { - m_invoke_count = -1; - return false; - } + if (m_done) return false; entry e; e["y"] = "q"; entry& a = e["a"]; e["q"] = "get"; - a["target"] = m_target.to_string(); + a["target"] = target().to_string(); return m_node.m_rpc.invoke(e, o->target_ep(), o); } @@ -168,7 +164,7 @@ void get_item::done() #if TORRENT_USE_ASSERTS if (m_data.is_mutable()) { - TORRENT_ASSERT(m_target == item_target_id(m_data.salt(), m_data.pk())); + TORRENT_ASSERT(target() == item_target_id(m_data.salt(), m_data.pk())); } #endif } diff --git a/src/kademlia/get_peers.cpp b/src/kademlia/get_peers.cpp index 892edc32f..1dd0cacbb 100644 --- a/src/kademlia/get_peers.cpp +++ b/src/kademlia/get_peers.cpp @@ -141,23 +141,19 @@ char const* get_peers::name() const { return "get_peers"; } bool get_peers::invoke(observer_ptr o) { - if (m_done) - { - m_invoke_count = -1; - return false; - } + if (m_done) return false; entry e; e["y"] = "q"; entry& a = e["a"]; e["q"] = "get_peers"; - a["info_hash"] = m_target.to_string(); + a["info_hash"] = target().to_string(); if (m_noseeds) a["noseed"] = 1; if (m_node.observer() != nullptr) { - m_node.observer()->outgoing_get_peers(m_target, m_target, o->target_ep()); + m_node.observer()->outgoing_get_peers(target(), target(), o->target_ep()); } m_node.stats_counters().inc_stats_counter(counters::dht_get_peers_out); @@ -217,7 +213,7 @@ bool obfuscated_get_peers::invoke(observer_ptr o) if (!m_obfuscated) return get_peers::invoke(o); node_id const& id = o->id(); - int const shared_prefix = 160 - distance_exp(id, m_target); + int const shared_prefix = 160 - distance_exp(id, target()); // when we get close to the target zone in the DHT // start using the correct info-hash, in order to @@ -254,12 +250,12 @@ bool obfuscated_get_peers::invoke(observer_ptr o) // now, obfuscate the bits past shared_prefix + 3 node_id mask = generate_prefix_mask(shared_prefix + 3); node_id obfuscated_target = generate_random_id() & ~mask; - obfuscated_target |= m_target & mask; + obfuscated_target |= target() & mask; a["info_hash"] = obfuscated_target.to_string(); if (m_node.observer() != nullptr) { - m_node.observer()->outgoing_get_peers(m_target, obfuscated_target + m_node.observer()->outgoing_get_peers(target(), obfuscated_target , o->target_ep()); } @@ -275,7 +271,7 @@ void obfuscated_get_peers::done() // oops, we failed to switch over to the non-obfuscated // mode early enough. do it now - auto ta = std::make_shared(m_node, m_target + auto ta = std::make_shared(m_node, target() , m_data_callback, m_nodes_callback, m_noseeds); // don't call these when the obfuscated_get_peers diff --git a/src/kademlia/put_data.cpp b/src/kademlia/put_data.cpp index a0c175496..031dbb7d1 100644 --- a/src/kademlia/put_data.cpp +++ b/src/kademlia/put_data.cpp @@ -74,20 +74,16 @@ void put_data::done() #ifndef TORRENT_DISABLE_LOGGING get_node().observer()->log(dht_logger::traversal, "[%p] %s DONE, response %d, timeout %d" - , static_cast(this), name(), m_responses, m_timeouts); + , static_cast(this), name(), num_responses(), num_timeouts()); #endif - m_put_callback(m_data, m_responses); + m_put_callback(m_data, num_responses()); traversal_algorithm::done(); } bool put_data::invoke(observer_ptr o) { - if (m_done) - { - m_invoke_count = -1; - return false; - } + if (m_done) return false; // TODO: what if o is not an isntance of put_data_observer? This need to be // redesigned for better type saftey. diff --git a/src/kademlia/traversal_algorithm.cpp b/src/kademlia/traversal_algorithm.cpp index 5be48ed2d..5e71627d9 100644 --- a/src/kademlia/traversal_algorithm.cpp +++ b/src/kademlia/traversal_algorithm.cpp @@ -84,10 +84,6 @@ traversal_algorithm::traversal_algorithm( , node_id const& target) : m_node(dht_node) , m_target(target) - , m_invoke_count(0) - , m_branch_factor(3) - , m_responses(0) - , m_timeouts(0) { #ifndef TORRENT_DISABLE_LOGGING dht_observer* logger = get_node().observer(); @@ -301,6 +297,8 @@ void traversal_algorithm::failed(observer_ptr o, int const flags) if (m_results.empty()) return; + bool decrement_branch_factor = false; + TORRENT_ASSERT(o->flags & observer::flag_queried); if (flags & short_timeout) { @@ -311,7 +309,10 @@ void traversal_algorithm::failed(observer_ptr o, int const flags) // around for some more, but open up the slot // by increasing the branch factor if ((o->flags & observer::flag_short_timeout) == 0) + { + TORRENT_ASSERT(m_branch_factor < (std::numeric_limits::max)()); ++m_branch_factor; + } o->flags |= observer::flag_short_timeout; #ifndef TORRENT_DISABLE_LOGGING dht_observer* logger = get_node().observer(); @@ -333,8 +334,7 @@ void traversal_algorithm::failed(observer_ptr o, int const flags) o->flags |= observer::flag_failed; // if this flag is set, it means we increased the // branch factor for it, and we should restore it - if (o->flags & observer::flag_short_timeout) - --m_branch_factor; + decrement_branch_factor = (o->flags & observer::flag_short_timeout) != 0; #ifndef TORRENT_DISABLE_LOGGING dht_observer* logger = get_node().observer(); @@ -356,12 +356,18 @@ void traversal_algorithm::failed(observer_ptr o, int const flags) --m_invoke_count; } - if (flags & prevent_request) + // this is another reason to decrement the branch factor, to prevent another + // request from filling this slot. Only ever decrement once per response though + decrement_branch_factor |= (flags & prevent_request); + + if (decrement_branch_factor) { + TORRENT_ASSERT(m_branch_factor > 0); --m_branch_factor; if (m_branch_factor <= 0) m_branch_factor = 1; } - bool is_done = add_requests(); + + bool const is_done = add_requests(); if (is_done) done(); } @@ -484,7 +490,7 @@ bool traversal_algorithm::add_requests() o->flags |= observer::flag_queried; if (invoke(*i)) { - TORRENT_ASSERT(m_invoke_count < (std::numeric_limits::max)()); + TORRENT_ASSERT(m_invoke_count < (std::numeric_limits::max)()); ++m_invoke_count; ++outstanding; }