forked from premiere/premiere-libtorrent
make DHT bootstrapping more robust by not throwing away nodes
This commit is contained in:
parent
4eb5155263
commit
a9a12e873f
|
@ -68,7 +68,7 @@ struct TORRENT_EXTRA_EXPORT traversal_algorithm : boost::noncopyable
|
|||
|
||||
node_id const& target() const { return m_target; }
|
||||
|
||||
void resort_results();
|
||||
void resort_result(observer*);
|
||||
void add_entry(node_id const& id, udp::endpoint const& addr, unsigned char flags);
|
||||
|
||||
traversal_algorithm(node& dht_node, node_id const& target);
|
||||
|
@ -104,15 +104,27 @@ protected:
|
|||
int num_timeouts() const { return m_timeouts; }
|
||||
|
||||
node& m_node;
|
||||
|
||||
// this vector is sorted by node-id distance from our node id. Closer nodes
|
||||
// are earlier in the vector. However, not the entire vector is necessarily
|
||||
// sorted, the tail of the vector may contain nodes out-of-order. This is
|
||||
// used when bootstrapping. The ``m_sorted_results`` member indicates how
|
||||
// many of the first elements are sorted.
|
||||
std::vector<observer_ptr> m_results;
|
||||
|
||||
int num_sorted_results() const { return m_sorted_results; }
|
||||
|
||||
private:
|
||||
|
||||
node_id const m_target;
|
||||
std::int16_t m_invoke_count = 0;
|
||||
std::int16_t m_branch_factor = 3;
|
||||
std::int8_t m_invoke_count = 0;
|
||||
std::int8_t m_branch_factor = 3;
|
||||
// the number of elements at the beginning of m_results that are sorted by
|
||||
// node_id.
|
||||
std::int8_t m_sorted_results = 0;
|
||||
std::int16_t m_responses = 0;
|
||||
std::int16_t m_timeouts = 0;
|
||||
|
||||
#ifndef TORRENT_DISABLE_LOGGING
|
||||
// this is a unique ID for this specific traversal_algorithm instance,
|
||||
// just used for logging
|
||||
|
|
|
@ -132,7 +132,7 @@ void observer::set_id(node_id const& id)
|
|||
{
|
||||
if (m_id == id) return;
|
||||
m_id = id;
|
||||
if (m_algorithm) m_algorithm->resort_results();
|
||||
if (m_algorithm) m_algorithm->resort_result(this);
|
||||
}
|
||||
|
||||
using observer_storage = aux::aligned_union<1
|
||||
|
|
|
@ -93,11 +93,33 @@ traversal_algorithm::traversal_algorithm(
|
|||
#endif
|
||||
}
|
||||
|
||||
void traversal_algorithm::resort_results()
|
||||
void traversal_algorithm::resort_result(observer* o)
|
||||
{
|
||||
std::sort(m_results.begin(), m_results.end()
|
||||
// find the given observer, remove it and insert it in its sorted location
|
||||
auto it = std::find_if(m_results.begin(), m_results.end()
|
||||
, [=](observer_ptr const& ptr) { return ptr.get() == o; });
|
||||
|
||||
if (it == m_results.end()) return;
|
||||
|
||||
if (it - m_results.begin() < m_sorted_results)
|
||||
--m_sorted_results;
|
||||
|
||||
observer_ptr ptr = std::move(*it);
|
||||
m_results.erase(it);
|
||||
|
||||
TORRENT_ASSERT(std::size_t(m_sorted_results) <= m_results.size());
|
||||
auto end = m_results.begin() + m_sorted_results;
|
||||
|
||||
TORRENT_ASSERT(libtorrent::dht::is_sorted(m_results.begin(), end
|
||||
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
|
||||
{ return compare_ref(lhs->id(), rhs->id(), m_target); }));
|
||||
|
||||
auto iter = std::lower_bound(m_results.begin(), end, ptr
|
||||
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
|
||||
{ return compare_ref(lhs->id(), rhs->id(), m_target); });
|
||||
|
||||
m_results.insert(iter, ptr);
|
||||
++m_sorted_results;
|
||||
}
|
||||
|
||||
void traversal_algorithm::add_entry(node_id const& id
|
||||
|
@ -117,85 +139,110 @@ void traversal_algorithm::add_entry(node_id const& id
|
|||
done();
|
||||
return;
|
||||
}
|
||||
|
||||
o->flags |= flags;
|
||||
|
||||
if (id.is_all_zeros())
|
||||
{
|
||||
o->set_id(generate_random_id());
|
||||
o->flags |= observer::flag_no_id;
|
||||
}
|
||||
|
||||
o->flags |= flags;
|
||||
|
||||
TORRENT_ASSERT(libtorrent::dht::is_sorted(m_results.begin(), m_results.end()
|
||||
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
|
||||
{ return compare_ref(lhs->id(), rhs->id(), m_target); }));
|
||||
|
||||
auto iter = std::lower_bound(m_results.begin(), m_results.end(), o
|
||||
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
|
||||
{ return compare_ref(lhs->id(), rhs->id(), m_target); });
|
||||
|
||||
if (iter == m_results.end() || (*iter)->id() != id)
|
||||
{
|
||||
if (m_node.settings().restrict_search_ips
|
||||
&& !(flags & observer::flag_initial))
|
||||
{
|
||||
#if TORRENT_USE_IPV6
|
||||
if (o->target_addr().is_v6())
|
||||
{
|
||||
address_v6::bytes_type addr_bytes = o->target_addr().to_v6().to_bytes();
|
||||
address_v6::bytes_type::const_iterator prefix_it = addr_bytes.begin();
|
||||
std::uint64_t const prefix6 = detail::read_uint64(prefix_it);
|
||||
|
||||
if (m_peer6_prefixes.insert(prefix6).second)
|
||||
goto add_result;
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
// mask the lower octet
|
||||
std::uint32_t const prefix4
|
||||
= o->target_addr().to_v4().to_ulong() & 0xffffff00;
|
||||
|
||||
if (m_peer4_prefixes.insert(prefix4).second)
|
||||
goto add_result;
|
||||
}
|
||||
|
||||
// we already have a node in this search with an IP very
|
||||
// close to this one. We know that it's not the same, because
|
||||
// it claims a different node-ID. Ignore this to avoid attacks
|
||||
#ifndef TORRENT_DISABLE_LOGGING
|
||||
dht_observer* logger = get_node().observer();
|
||||
if (logger != nullptr && logger->should_log(dht_logger::traversal))
|
||||
{
|
||||
logger->log(dht_logger::traversal
|
||||
, "[%u] traversal DUPLICATE node. id: %s addr: %s type: %s"
|
||||
, m_id, aux::to_hex(o->id()).c_str(), print_address(o->target_addr()).c_str(), name());
|
||||
}
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
add_result:
|
||||
|
||||
TORRENT_ASSERT((o->flags & observer::flag_no_id)
|
||||
|| std::none_of(m_results.begin(), m_results.end()
|
||||
, [&id](observer_ptr const& ob) { return ob->id() == id; }));
|
||||
m_results.push_back(o);
|
||||
|
||||
#ifndef TORRENT_DISABLE_LOGGING
|
||||
dht_observer* logger = get_node().observer();
|
||||
if (logger != nullptr && logger->should_log(dht_logger::traversal))
|
||||
{
|
||||
logger->log(dht_logger::traversal
|
||||
, "[%u] ADD id: %s addr: %s distance: %d invoke-count: %d type: %s"
|
||||
, "[%u] ADD (no-id) id: %s addr: %s distance: %d invoke-count: %d type: %s"
|
||||
, m_id, aux::to_hex(id).c_str(), print_endpoint(addr).c_str()
|
||||
, distance_exp(m_target, id), m_invoke_count, name());
|
||||
}
|
||||
#endif
|
||||
iter = m_results.insert(iter, o);
|
||||
|
||||
TORRENT_ASSERT(libtorrent::dht::is_sorted(m_results.begin(), m_results.end()
|
||||
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
|
||||
{ return compare_ref(lhs->id(), rhs->id(), m_target); }));
|
||||
}
|
||||
else
|
||||
{
|
||||
TORRENT_ASSERT(std::size_t(m_sorted_results) <= m_results.size());
|
||||
auto end = m_results.begin() + m_sorted_results;
|
||||
|
||||
TORRENT_ASSERT(libtorrent::dht::is_sorted(m_results.begin(), end
|
||||
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
|
||||
{ return compare_ref(lhs->id(), rhs->id(), m_target); }));
|
||||
|
||||
auto iter = std::lower_bound(m_results.begin(), end, o
|
||||
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
|
||||
{ return compare_ref(lhs->id(), rhs->id(), m_target); });
|
||||
|
||||
if (iter == end || (*iter)->id() != id)
|
||||
{
|
||||
// this IP restriction does not apply to the nodes we loaded from out
|
||||
// node cache
|
||||
if (m_node.settings().restrict_search_ips
|
||||
&& !(flags & observer::flag_initial))
|
||||
{
|
||||
#if TORRENT_USE_IPV6
|
||||
if (o->target_addr().is_v6())
|
||||
{
|
||||
address_v6::bytes_type addr_bytes = o->target_addr().to_v6().to_bytes();
|
||||
address_v6::bytes_type::const_iterator prefix_it = addr_bytes.begin();
|
||||
std::uint64_t const prefix6 = detail::read_uint64(prefix_it);
|
||||
|
||||
if (m_peer6_prefixes.insert(prefix6).second)
|
||||
goto add_result;
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
// mask the lower octet
|
||||
std::uint32_t const prefix4
|
||||
= o->target_addr().to_v4().to_ulong() & 0xffffff00;
|
||||
|
||||
if (m_peer4_prefixes.insert(prefix4).second)
|
||||
goto add_result;
|
||||
}
|
||||
|
||||
// we already have a node in this search with an IP very
|
||||
// close to this one. We know that it's not the same, because
|
||||
// it claims a different node-ID. Ignore this to avoid attacks
|
||||
#ifndef TORRENT_DISABLE_LOGGING
|
||||
dht_observer* logger = get_node().observer();
|
||||
if (logger != nullptr && logger->should_log(dht_logger::traversal))
|
||||
{
|
||||
logger->log(dht_logger::traversal
|
||||
, "[%u] traversal DUPLICATE node. id: %s addr: %s type: %s"
|
||||
, m_id, aux::to_hex(o->id()).c_str(), print_address(o->target_addr()).c_str(), name());
|
||||
}
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
add_result:
|
||||
|
||||
TORRENT_ASSERT((o->flags & observer::flag_no_id)
|
||||
|| std::none_of(m_results.begin(), end
|
||||
, [&id](observer_ptr const& ob) { return ob->id() == id; }));
|
||||
|
||||
#ifndef TORRENT_DISABLE_LOGGING
|
||||
dht_observer* logger = get_node().observer();
|
||||
if (logger != nullptr && logger->should_log(dht_logger::traversal))
|
||||
{
|
||||
logger->log(dht_logger::traversal
|
||||
, "[%u] ADD id: %s addr: %s distance: %d invoke-count: %d type: %s"
|
||||
, m_id, aux::to_hex(id).c_str(), print_endpoint(addr).c_str()
|
||||
, distance_exp(m_target, id), m_invoke_count, name());
|
||||
}
|
||||
#endif
|
||||
m_results.insert(iter, o);
|
||||
++m_sorted_results;
|
||||
}
|
||||
}
|
||||
|
||||
TORRENT_ASSERT(std::size_t(m_sorted_results) <= m_results.size());
|
||||
auto end = m_results.begin() + m_sorted_results;
|
||||
|
||||
TORRENT_ASSERT(libtorrent::dht::is_sorted(m_results.begin(), end
|
||||
, [this](observer_ptr const& lhs, observer_ptr const& rhs)
|
||||
{ return compare_ref(lhs->id(), rhs->id(), m_target); }));
|
||||
|
||||
if (m_results.size() > 100)
|
||||
{
|
||||
|
@ -217,6 +264,7 @@ void traversal_algorithm::add_entry(node_id const& id
|
|||
#endif
|
||||
});
|
||||
m_results.resize(100);
|
||||
m_sorted_results = std::min(std::int8_t(100), m_sorted_results);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -226,7 +274,7 @@ void traversal_algorithm::start()
|
|||
// router nodes in the table
|
||||
if (m_results.size() < 3) add_router_entries();
|
||||
init();
|
||||
bool is_done = add_requests();
|
||||
bool const is_done = add_requests();
|
||||
if (is_done) done();
|
||||
}
|
||||
|
||||
|
@ -274,7 +322,7 @@ void traversal_algorithm::finished(observer_ptr o)
|
|||
++m_responses;
|
||||
TORRENT_ASSERT(m_invoke_count > 0);
|
||||
--m_invoke_count;
|
||||
bool is_done = add_requests();
|
||||
bool const is_done = add_requests();
|
||||
if (is_done) done();
|
||||
}
|
||||
|
||||
|
@ -303,7 +351,7 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
|
|||
// by increasing the branch factor
|
||||
if ((o->flags & observer::flag_short_timeout) == 0)
|
||||
{
|
||||
TORRENT_ASSERT(m_branch_factor < (std::numeric_limits<std::int16_t>::max)());
|
||||
TORRENT_ASSERT(m_branch_factor < (std::numeric_limits<std::int8_t>::max)());
|
||||
++m_branch_factor;
|
||||
}
|
||||
o->flags |= observer::flag_short_timeout;
|
||||
|
@ -388,7 +436,7 @@ void traversal_algorithm::done()
|
|||
, print_endpoint(o->target_ep()).c_str());
|
||||
|
||||
--results_target;
|
||||
int dist = distance_exp(m_target, o->id());
|
||||
int const dist = distance_exp(m_target, o->id());
|
||||
if (dist < closest_target) closest_target = dist;
|
||||
}
|
||||
#endif
|
||||
|
@ -406,6 +454,7 @@ void traversal_algorithm::done()
|
|||
// delete all our references to the observer objects so
|
||||
// they will in turn release the traversal algorithm
|
||||
m_results.clear();
|
||||
m_sorted_results = 0;
|
||||
m_invoke_count = 0;
|
||||
}
|
||||
|
||||
|
@ -423,7 +472,7 @@ bool traversal_algorithm::add_requests()
|
|||
// if we're doing aggressive lookups, we keep branch-factor
|
||||
// outstanding requests _at the tops_ of the result list. Otherwise
|
||||
// we just keep any branch-factor outstanding requests
|
||||
bool agg = m_node.settings().aggressive_lookups;
|
||||
bool const agg = m_node.settings().aggressive_lookups;
|
||||
|
||||
// Find the first node that hasn't already been queried.
|
||||
// and make sure that the 'm_branch_factor' top nodes
|
||||
|
@ -474,7 +523,7 @@ bool traversal_algorithm::add_requests()
|
|||
o->flags |= observer::flag_queried;
|
||||
if (invoke(*i))
|
||||
{
|
||||
TORRENT_ASSERT(m_invoke_count < (std::numeric_limits<std::int16_t>::max)());
|
||||
TORRENT_ASSERT(m_invoke_count < (std::numeric_limits<std::int8_t>::max)());
|
||||
++m_invoke_count;
|
||||
++outstanding;
|
||||
}
|
||||
|
@ -509,7 +558,7 @@ void traversal_algorithm::add_router_entries()
|
|||
|
||||
void traversal_algorithm::init()
|
||||
{
|
||||
m_branch_factor = aux::numeric_cast<std::int16_t>(m_node.branch_factor());
|
||||
m_branch_factor = aux::numeric_cast<std::int8_t>(m_node.branch_factor());
|
||||
m_node.add_traversal_algorithm(this);
|
||||
}
|
||||
|
||||
|
|
|
@ -107,6 +107,8 @@ void node_push_back(std::vector<node_entry>* nv, node_entry const& n)
|
|||
|
||||
void nop_node() {}
|
||||
|
||||
// TODO: 3 make the mock_socket hold a reference to the list of where to record
|
||||
// packets instead of having a global variable
|
||||
std::list<std::pair<udp::endpoint, entry>> g_sent_packets;
|
||||
|
||||
struct mock_socket final : socket_manager
|
||||
|
@ -3237,6 +3239,53 @@ TORRENT_TEST(invalid_error_msg)
|
|||
TEST_EQUAL(found, true);
|
||||
}
|
||||
|
||||
struct test_algo : dht::traversal_algorithm
|
||||
{
|
||||
test_algo(node& dht_node, node_id const& target)
|
||||
: traversal_algorithm(dht_node, target)
|
||||
{}
|
||||
|
||||
std::vector<observer_ptr> const& results() const { return m_results; };
|
||||
|
||||
using traversal_algorithm::num_sorted_results;
|
||||
};
|
||||
|
||||
TORRENT_TEST(unsorted_traversal_results)
|
||||
{
|
||||
// make sure the handling of an unsorted tail of nodes is correct in the
|
||||
// traversal algorithm. Initial nodes (that we bootstrap from) remain
|
||||
// unsorted, since we don't know their node IDs
|
||||
|
||||
dht_test_setup t(udp::endpoint(rand_v4(), 20));
|
||||
|
||||
node_id const our_id = t.dht_node.nid();
|
||||
auto algo = std::make_shared<test_algo>(t.dht_node, our_id);
|
||||
|
||||
std::vector<udp::endpoint> eps;
|
||||
for (int i = 0; i < 10; ++i)
|
||||
{
|
||||
eps.push_back(rand_udp_ep(rand_v4));
|
||||
algo->add_entry(node_id(), eps.back(), observer::flag_initial);
|
||||
}
|
||||
|
||||
// we should have 10 unsorted nodes now
|
||||
TEST_CHECK(algo->num_sorted_results() == 0);
|
||||
auto results = algo->results();
|
||||
TEST_CHECK(results.size() == eps.size());
|
||||
for (int i = 0; i < int(eps.size()); ++i)
|
||||
TEST_CHECK(eps[i] == results[i]->target_ep());
|
||||
|
||||
// setting the node ID, regardless of what we set it to, should cause this
|
||||
// observer to become sorted. i.e. be moved to the beginning of the restult
|
||||
// list.
|
||||
results[5]->set_id(node_id("abababababababababab"));
|
||||
|
||||
TEST_CHECK(algo->num_sorted_results() == 1);
|
||||
results = algo->results();
|
||||
TEST_CHECK(results.size() == eps.size());
|
||||
TEST_CHECK(eps[5] == results[0]->target_ep());
|
||||
}
|
||||
|
||||
TORRENT_TEST(rpc_invalid_error_msg)
|
||||
{
|
||||
// TODO: 3 use dht_test_setup class to simplify the node setup
|
||||
|
|
Loading…
Reference in New Issue