make DHT bootstrapping more robust by not throwing away nodes

This commit is contained in:
arvidn 2017-03-19 12:09:32 -04:00 committed by Arvid Norberg
parent 4eb5155263
commit a9a12e873f
4 changed files with 186 additions and 76 deletions

View File

@ -68,7 +68,7 @@ struct TORRENT_EXTRA_EXPORT traversal_algorithm : boost::noncopyable
node_id const& target() const { return m_target; }
void resort_results();
void resort_result(observer*);
void add_entry(node_id const& id, udp::endpoint const& addr, unsigned char flags);
traversal_algorithm(node& dht_node, node_id const& target);
@ -104,15 +104,27 @@ protected:
int num_timeouts() const { return m_timeouts; }
node& m_node;
// this vector is sorted by node-id distance from our node id. Closer nodes
// are earlier in the vector. However, not the entire vector is necessarily
// sorted, the tail of the vector may contain nodes out-of-order. This is
// used when bootstrapping. The ``m_sorted_results`` member indicates how
// many of the first elements are sorted.
std::vector<observer_ptr> m_results;
int num_sorted_results() const { return m_sorted_results; }
private:
node_id const m_target;
std::int16_t m_invoke_count = 0;
std::int16_t m_branch_factor = 3;
std::int8_t m_invoke_count = 0;
std::int8_t m_branch_factor = 3;
// the number of elements at the beginning of m_results that are sorted by
// node_id.
std::int8_t m_sorted_results = 0;
std::int16_t m_responses = 0;
std::int16_t m_timeouts = 0;
#ifndef TORRENT_DISABLE_LOGGING
// this is a unique ID for this specific traversal_algorithm instance,
// just used for logging

View File

@ -132,7 +132,7 @@ void observer::set_id(node_id const& id)
{
if (m_id == id) return;
m_id = id;
if (m_algorithm) m_algorithm->resort_results();
if (m_algorithm) m_algorithm->resort_result(this);
}
using observer_storage = aux::aligned_union<1

View File

@ -93,11 +93,33 @@ traversal_algorithm::traversal_algorithm(
#endif
}
void traversal_algorithm::resort_results()
void traversal_algorithm::resort_result(observer* o)
{
std::sort(m_results.begin(), m_results.end()
// find the given observer, remove it and insert it in its sorted location
auto it = std::find_if(m_results.begin(), m_results.end()
, [=](observer_ptr const& ptr) { return ptr.get() == o; });
if (it == m_results.end()) return;
if (it - m_results.begin() < m_sorted_results)
--m_sorted_results;
observer_ptr ptr = std::move(*it);
m_results.erase(it);
TORRENT_ASSERT(std::size_t(m_sorted_results) <= m_results.size());
auto end = m_results.begin() + m_sorted_results;
TORRENT_ASSERT(libtorrent::dht::is_sorted(m_results.begin(), end
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
{ return compare_ref(lhs->id(), rhs->id(), m_target); }));
auto iter = std::lower_bound(m_results.begin(), end, ptr
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
{ return compare_ref(lhs->id(), rhs->id(), m_target); });
m_results.insert(iter, ptr);
++m_sorted_results;
}
void traversal_algorithm::add_entry(node_id const& id
@ -117,85 +139,110 @@ void traversal_algorithm::add_entry(node_id const& id
done();
return;
}
o->flags |= flags;
if (id.is_all_zeros())
{
o->set_id(generate_random_id());
o->flags |= observer::flag_no_id;
}
o->flags |= flags;
TORRENT_ASSERT(libtorrent::dht::is_sorted(m_results.begin(), m_results.end()
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
{ return compare_ref(lhs->id(), rhs->id(), m_target); }));
auto iter = std::lower_bound(m_results.begin(), m_results.end(), o
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
{ return compare_ref(lhs->id(), rhs->id(), m_target); });
if (iter == m_results.end() || (*iter)->id() != id)
{
if (m_node.settings().restrict_search_ips
&& !(flags & observer::flag_initial))
{
#if TORRENT_USE_IPV6
if (o->target_addr().is_v6())
{
address_v6::bytes_type addr_bytes = o->target_addr().to_v6().to_bytes();
address_v6::bytes_type::const_iterator prefix_it = addr_bytes.begin();
std::uint64_t const prefix6 = detail::read_uint64(prefix_it);
if (m_peer6_prefixes.insert(prefix6).second)
goto add_result;
}
else
#endif
{
// mask the lower octet
std::uint32_t const prefix4
= o->target_addr().to_v4().to_ulong() & 0xffffff00;
if (m_peer4_prefixes.insert(prefix4).second)
goto add_result;
}
// we already have a node in this search with an IP very
// close to this one. We know that it's not the same, because
// it claims a different node-ID. Ignore this to avoid attacks
#ifndef TORRENT_DISABLE_LOGGING
dht_observer* logger = get_node().observer();
if (logger != nullptr && logger->should_log(dht_logger::traversal))
{
logger->log(dht_logger::traversal
, "[%u] traversal DUPLICATE node. id: %s addr: %s type: %s"
, m_id, aux::to_hex(o->id()).c_str(), print_address(o->target_addr()).c_str(), name());
}
#endif
return;
}
add_result:
TORRENT_ASSERT((o->flags & observer::flag_no_id)
|| std::none_of(m_results.begin(), m_results.end()
, [&id](observer_ptr const& ob) { return ob->id() == id; }));
m_results.push_back(o);
#ifndef TORRENT_DISABLE_LOGGING
dht_observer* logger = get_node().observer();
if (logger != nullptr && logger->should_log(dht_logger::traversal))
{
logger->log(dht_logger::traversal
, "[%u] ADD id: %s addr: %s distance: %d invoke-count: %d type: %s"
, "[%u] ADD (no-id) id: %s addr: %s distance: %d invoke-count: %d type: %s"
, m_id, aux::to_hex(id).c_str(), print_endpoint(addr).c_str()
, distance_exp(m_target, id), m_invoke_count, name());
}
#endif
iter = m_results.insert(iter, o);
TORRENT_ASSERT(libtorrent::dht::is_sorted(m_results.begin(), m_results.end()
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
{ return compare_ref(lhs->id(), rhs->id(), m_target); }));
}
else
{
TORRENT_ASSERT(std::size_t(m_sorted_results) <= m_results.size());
auto end = m_results.begin() + m_sorted_results;
TORRENT_ASSERT(libtorrent::dht::is_sorted(m_results.begin(), end
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
{ return compare_ref(lhs->id(), rhs->id(), m_target); }));
auto iter = std::lower_bound(m_results.begin(), end, o
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
{ return compare_ref(lhs->id(), rhs->id(), m_target); });
if (iter == end || (*iter)->id() != id)
{
// this IP restriction does not apply to the nodes we loaded from out
// node cache
if (m_node.settings().restrict_search_ips
&& !(flags & observer::flag_initial))
{
#if TORRENT_USE_IPV6
if (o->target_addr().is_v6())
{
address_v6::bytes_type addr_bytes = o->target_addr().to_v6().to_bytes();
address_v6::bytes_type::const_iterator prefix_it = addr_bytes.begin();
std::uint64_t const prefix6 = detail::read_uint64(prefix_it);
if (m_peer6_prefixes.insert(prefix6).second)
goto add_result;
}
else
#endif
{
// mask the lower octet
std::uint32_t const prefix4
= o->target_addr().to_v4().to_ulong() & 0xffffff00;
if (m_peer4_prefixes.insert(prefix4).second)
goto add_result;
}
// we already have a node in this search with an IP very
// close to this one. We know that it's not the same, because
// it claims a different node-ID. Ignore this to avoid attacks
#ifndef TORRENT_DISABLE_LOGGING
dht_observer* logger = get_node().observer();
if (logger != nullptr && logger->should_log(dht_logger::traversal))
{
logger->log(dht_logger::traversal
, "[%u] traversal DUPLICATE node. id: %s addr: %s type: %s"
, m_id, aux::to_hex(o->id()).c_str(), print_address(o->target_addr()).c_str(), name());
}
#endif
return;
}
add_result:
TORRENT_ASSERT((o->flags & observer::flag_no_id)
|| std::none_of(m_results.begin(), end
, [&id](observer_ptr const& ob) { return ob->id() == id; }));
#ifndef TORRENT_DISABLE_LOGGING
dht_observer* logger = get_node().observer();
if (logger != nullptr && logger->should_log(dht_logger::traversal))
{
logger->log(dht_logger::traversal
, "[%u] ADD id: %s addr: %s distance: %d invoke-count: %d type: %s"
, m_id, aux::to_hex(id).c_str(), print_endpoint(addr).c_str()
, distance_exp(m_target, id), m_invoke_count, name());
}
#endif
m_results.insert(iter, o);
++m_sorted_results;
}
}
TORRENT_ASSERT(std::size_t(m_sorted_results) <= m_results.size());
auto end = m_results.begin() + m_sorted_results;
TORRENT_ASSERT(libtorrent::dht::is_sorted(m_results.begin(), end
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
{ return compare_ref(lhs->id(), rhs->id(), m_target); }));
if (m_results.size() > 100)
{
@ -217,6 +264,7 @@ void traversal_algorithm::add_entry(node_id const& id
#endif
});
m_results.resize(100);
m_sorted_results = std::min(std::int8_t(100), m_sorted_results);
}
}
@ -226,7 +274,7 @@ void traversal_algorithm::start()
// router nodes in the table
if (m_results.size() < 3) add_router_entries();
init();
bool is_done = add_requests();
bool const is_done = add_requests();
if (is_done) done();
}
@ -274,7 +322,7 @@ void traversal_algorithm::finished(observer_ptr o)
++m_responses;
TORRENT_ASSERT(m_invoke_count > 0);
--m_invoke_count;
bool is_done = add_requests();
bool const is_done = add_requests();
if (is_done) done();
}
@ -303,7 +351,7 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
// by increasing the branch factor
if ((o->flags & observer::flag_short_timeout) == 0)
{
TORRENT_ASSERT(m_branch_factor < (std::numeric_limits<std::int16_t>::max)());
TORRENT_ASSERT(m_branch_factor < (std::numeric_limits<std::int8_t>::max)());
++m_branch_factor;
}
o->flags |= observer::flag_short_timeout;
@ -388,7 +436,7 @@ void traversal_algorithm::done()
, print_endpoint(o->target_ep()).c_str());
--results_target;
int dist = distance_exp(m_target, o->id());
int const dist = distance_exp(m_target, o->id());
if (dist < closest_target) closest_target = dist;
}
#endif
@ -406,6 +454,7 @@ void traversal_algorithm::done()
// delete all our references to the observer objects so
// they will in turn release the traversal algorithm
m_results.clear();
m_sorted_results = 0;
m_invoke_count = 0;
}
@ -423,7 +472,7 @@ bool traversal_algorithm::add_requests()
// if we're doing aggressive lookups, we keep branch-factor
// outstanding requests _at the tops_ of the result list. Otherwise
// we just keep any branch-factor outstanding requests
bool agg = m_node.settings().aggressive_lookups;
bool const agg = m_node.settings().aggressive_lookups;
// Find the first node that hasn't already been queried.
// and make sure that the 'm_branch_factor' top nodes
@ -474,7 +523,7 @@ bool traversal_algorithm::add_requests()
o->flags |= observer::flag_queried;
if (invoke(*i))
{
TORRENT_ASSERT(m_invoke_count < (std::numeric_limits<std::int16_t>::max)());
TORRENT_ASSERT(m_invoke_count < (std::numeric_limits<std::int8_t>::max)());
++m_invoke_count;
++outstanding;
}
@ -509,7 +558,7 @@ void traversal_algorithm::add_router_entries()
void traversal_algorithm::init()
{
m_branch_factor = aux::numeric_cast<std::int16_t>(m_node.branch_factor());
m_branch_factor = aux::numeric_cast<std::int8_t>(m_node.branch_factor());
m_node.add_traversal_algorithm(this);
}

View File

@ -107,6 +107,8 @@ void node_push_back(std::vector<node_entry>* nv, node_entry const& n)
void nop_node() {}
// TODO: 3 make the mock_socket hold a reference to the list of where to record
// packets instead of having a global variable
std::list<std::pair<udp::endpoint, entry>> g_sent_packets;
struct mock_socket final : socket_manager
@ -3237,6 +3239,53 @@ TORRENT_TEST(invalid_error_msg)
TEST_EQUAL(found, true);
}
struct test_algo : dht::traversal_algorithm
{
test_algo(node& dht_node, node_id const& target)
: traversal_algorithm(dht_node, target)
{}
std::vector<observer_ptr> const& results() const { return m_results; };
using traversal_algorithm::num_sorted_results;
};
TORRENT_TEST(unsorted_traversal_results)
{
// make sure the handling of an unsorted tail of nodes is correct in the
// traversal algorithm. Initial nodes (that we bootstrap from) remain
// unsorted, since we don't know their node IDs
dht_test_setup t(udp::endpoint(rand_v4(), 20));
node_id const our_id = t.dht_node.nid();
auto algo = std::make_shared<test_algo>(t.dht_node, our_id);
std::vector<udp::endpoint> eps;
for (int i = 0; i < 10; ++i)
{
eps.push_back(rand_udp_ep(rand_v4));
algo->add_entry(node_id(), eps.back(), observer::flag_initial);
}
// we should have 10 unsorted nodes now
TEST_CHECK(algo->num_sorted_results() == 0);
auto results = algo->results();
TEST_CHECK(results.size() == eps.size());
for (int i = 0; i < int(eps.size()); ++i)
TEST_CHECK(eps[i] == results[i]->target_ep());
// setting the node ID, regardless of what we set it to, should cause this
// observer to become sorted. i.e. be moved to the beginning of the restult
// list.
results[5]->set_id(node_id("abababababababababab"));
TEST_CHECK(algo->num_sorted_results() == 1);
results = algo->results();
TEST_CHECK(results.size() == eps.size());
TEST_CHECK(eps[5] == results[0]->target_ep());
}
TORRENT_TEST(rpc_invalid_error_msg)
{
// TODO: 3 use dht_test_setup class to simplify the node setup