use signed counters in traversa_algorithm, and add some more asserts (#1208)

use signed counters in traversal_algorithm, and add some more asserts
This commit is contained in:
Arvid Norberg 2016-10-14 08:03:09 -04:00 committed by GitHub
parent e0c8ad738d
commit 540b5b28d7
7 changed files with 43 additions and 45 deletions

View File

@ -63,8 +63,6 @@ struct find_data : traversal_algorithm
virtual char const* name() const;
node_id const& target() const { return m_target; }
protected:
virtual void done();

View File

@ -73,8 +73,8 @@ struct TORRENT_EXTRA_EXPORT traversal_algorithm : boost::noncopyable
void add_entry(node_id const& id, udp::endpoint const& addr, unsigned char flags);
traversal_algorithm(node& dht_node, node_id const& target);
int invoke_count() const { return m_invoke_count; }
int branch_factor() const { return m_branch_factor; }
int invoke_count() const { TORRENT_ASSERT(m_invoke_count >= 0); return m_invoke_count; }
int branch_factor() const { TORRENT_ASSERT(m_branch_factor >= 0); return m_branch_factor; }
node& get_node() const { return m_node; }
@ -97,13 +97,19 @@ protected:
virtual bool invoke(observer_ptr) { return false; }
int num_responses() const { return m_responses; }
int num_timeouts() const { return m_timeouts; }
node& m_node;
std::vector<observer_ptr> m_results;
private:
node_id const m_target;
std::uint16_t m_invoke_count;
std::uint16_t m_branch_factor;
std::uint16_t m_responses;
std::uint16_t m_timeouts;
std::int16_t m_invoke_count = 0;
std::int16_t m_branch_factor = 3;
std::int16_t m_responses = 0;
std::int16_t m_timeouts = 0;
// the IP addresses of the nodes in m_results
std::set<std::uint32_t> m_peer4_prefixes;

View File

@ -95,7 +95,7 @@ void find_data::start()
if (m_results.empty())
{
std::vector<node_entry> nodes;
m_node.m_table.find_node(m_target, nodes, routing_table::include_failed);
m_node.m_table.find_node(target(), nodes, routing_table::include_failed);
for (auto const& n : nodes)
{

View File

@ -61,7 +61,7 @@ void get_item::got_data(bdecode_node const& v,
if (!m_data.empty()) return;
sha1_hash incoming_target = item_target_id(v.data_section());
if (incoming_target != m_target) return;
if (incoming_target != target()) return;
m_data.assign(v);
@ -79,7 +79,7 @@ void get_item::got_data(bdecode_node const& v,
std::string const salt_copy(m_data.salt());
sha1_hash const incoming_target = item_target_id(salt_copy, pk);
if (incoming_target != m_target) return;
if (incoming_target != target()) return;
// this is mutable data. If it passes the signature
// check, remember it. Just keep the version with
@ -137,18 +137,14 @@ observer_ptr get_item::new_observer(udp::endpoint const& ep
bool get_item::invoke(observer_ptr o)
{
if (m_done)
{
m_invoke_count = -1;
return false;
}
if (m_done) return false;
entry e;
e["y"] = "q";
entry& a = e["a"];
e["q"] = "get";
a["target"] = m_target.to_string();
a["target"] = target().to_string();
return m_node.m_rpc.invoke(e, o->target_ep(), o);
}
@ -168,7 +164,7 @@ void get_item::done()
#if TORRENT_USE_ASSERTS
if (m_data.is_mutable())
{
TORRENT_ASSERT(m_target == item_target_id(m_data.salt(), m_data.pk()));
TORRENT_ASSERT(target() == item_target_id(m_data.salt(), m_data.pk()));
}
#endif
}

View File

@ -141,23 +141,19 @@ char const* get_peers::name() const { return "get_peers"; }
bool get_peers::invoke(observer_ptr o)
{
if (m_done)
{
m_invoke_count = -1;
return false;
}
if (m_done) return false;
entry e;
e["y"] = "q";
entry& a = e["a"];
e["q"] = "get_peers";
a["info_hash"] = m_target.to_string();
a["info_hash"] = target().to_string();
if (m_noseeds) a["noseed"] = 1;
if (m_node.observer() != nullptr)
{
m_node.observer()->outgoing_get_peers(m_target, m_target, o->target_ep());
m_node.observer()->outgoing_get_peers(target(), target(), o->target_ep());
}
m_node.stats_counters().inc_stats_counter(counters::dht_get_peers_out);
@ -217,7 +213,7 @@ bool obfuscated_get_peers::invoke(observer_ptr o)
if (!m_obfuscated) return get_peers::invoke(o);
node_id const& id = o->id();
int const shared_prefix = 160 - distance_exp(id, m_target);
int const shared_prefix = 160 - distance_exp(id, target());
// when we get close to the target zone in the DHT
// start using the correct info-hash, in order to
@ -254,12 +250,12 @@ bool obfuscated_get_peers::invoke(observer_ptr o)
// now, obfuscate the bits past shared_prefix + 3
node_id mask = generate_prefix_mask(shared_prefix + 3);
node_id obfuscated_target = generate_random_id() & ~mask;
obfuscated_target |= m_target & mask;
obfuscated_target |= target() & mask;
a["info_hash"] = obfuscated_target.to_string();
if (m_node.observer() != nullptr)
{
m_node.observer()->outgoing_get_peers(m_target, obfuscated_target
m_node.observer()->outgoing_get_peers(target(), obfuscated_target
, o->target_ep());
}
@ -275,7 +271,7 @@ void obfuscated_get_peers::done()
// oops, we failed to switch over to the non-obfuscated
// mode early enough. do it now
auto ta = std::make_shared<get_peers>(m_node, m_target
auto ta = std::make_shared<get_peers>(m_node, target()
, m_data_callback, m_nodes_callback, m_noseeds);
// don't call these when the obfuscated_get_peers

View File

@ -74,20 +74,16 @@ void put_data::done()
#ifndef TORRENT_DISABLE_LOGGING
get_node().observer()->log(dht_logger::traversal, "[%p] %s DONE, response %d, timeout %d"
, static_cast<void*>(this), name(), m_responses, m_timeouts);
, static_cast<void*>(this), name(), num_responses(), num_timeouts());
#endif
m_put_callback(m_data, m_responses);
m_put_callback(m_data, num_responses());
traversal_algorithm::done();
}
bool put_data::invoke(observer_ptr o)
{
if (m_done)
{
m_invoke_count = -1;
return false;
}
if (m_done) return false;
// TODO: what if o is not an isntance of put_data_observer? This need to be
// redesigned for better type saftey.

View File

@ -84,10 +84,6 @@ traversal_algorithm::traversal_algorithm(
, node_id const& target)
: m_node(dht_node)
, m_target(target)
, m_invoke_count(0)
, m_branch_factor(3)
, m_responses(0)
, m_timeouts(0)
{
#ifndef TORRENT_DISABLE_LOGGING
dht_observer* logger = get_node().observer();
@ -301,6 +297,8 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
if (m_results.empty()) return;
bool decrement_branch_factor = false;
TORRENT_ASSERT(o->flags & observer::flag_queried);
if (flags & short_timeout)
{
@ -311,7 +309,10 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
// around for some more, but open up the slot
// by increasing the branch factor
if ((o->flags & observer::flag_short_timeout) == 0)
{
TORRENT_ASSERT(m_branch_factor < (std::numeric_limits<std::int16_t>::max)());
++m_branch_factor;
}
o->flags |= observer::flag_short_timeout;
#ifndef TORRENT_DISABLE_LOGGING
dht_observer* logger = get_node().observer();
@ -333,8 +334,7 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
o->flags |= observer::flag_failed;
// if this flag is set, it means we increased the
// branch factor for it, and we should restore it
if (o->flags & observer::flag_short_timeout)
--m_branch_factor;
decrement_branch_factor = (o->flags & observer::flag_short_timeout) != 0;
#ifndef TORRENT_DISABLE_LOGGING
dht_observer* logger = get_node().observer();
@ -356,12 +356,18 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
--m_invoke_count;
}
if (flags & prevent_request)
// this is another reason to decrement the branch factor, to prevent another
// request from filling this slot. Only ever decrement once per response though
decrement_branch_factor |= (flags & prevent_request);
if (decrement_branch_factor)
{
TORRENT_ASSERT(m_branch_factor > 0);
--m_branch_factor;
if (m_branch_factor <= 0) m_branch_factor = 1;
}
bool is_done = add_requests();
bool const is_done = add_requests();
if (is_done) done();
}
@ -484,7 +490,7 @@ bool traversal_algorithm::add_requests()
o->flags |= observer::flag_queried;
if (invoke(*i))
{
TORRENT_ASSERT(m_invoke_count < (std::numeric_limits<std::uint16_t>::max)());
TORRENT_ASSERT(m_invoke_count < (std::numeric_limits<std::int16_t>::max)());
++m_invoke_count;
++outstanding;
}