From f8b2b60634fcffecf2f16cd2c83ee88af0ecc535 Mon Sep 17 00:00:00 2001 From: Arvid Norberg Date: Sun, 16 Mar 2008 03:52:13 +0000 Subject: [PATCH] fixed bug in dht routing table. added unit test to trunk --- src/kademlia/routing_table.cpp | 50 +++++++++++++------- test/test_primitives.cpp | 86 ++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 17 deletions(-) diff --git a/src/kademlia/routing_table.cpp b/src/kademlia/routing_table.cpp index 02a7ffb12..ac2e09b46 100644 --- a/src/kademlia/routing_table.cpp +++ b/src/kademlia/routing_table.cpp @@ -355,6 +355,19 @@ bool routing_table::need_bootstrap() const return true; } +template +DstIter copy_if_n(SrcIter begin, SrcIter end, DstIter target, size_t n, Pred p) +{ + for (; n > 0 && begin != end; ++begin) + { + if (!p(*begin)) continue; + *target = *begin; + --n; + ++target; + } + return target; +} + // fills the vector with the k nodes from our buckets that // are nearest to the given id. void routing_table::find_node(node_id const& target @@ -369,8 +382,8 @@ void routing_table::find_node(node_id const& target // copy all nodes that hasn't failed into the target // vector. - std::remove_copy_if(b.begin(), b.end(), std::back_inserter(l) - , bind(&node_entry::fail_count, _1)); + copy_if_n(b.begin(), b.end(), std::back_inserter(l) + , (std::min)(size_t(count), b.size()), bind(&node_entry::fail_count, _1) == 0); TORRENT_ASSERT((int)l.size() <= count); if ((int)l.size() == count) @@ -386,26 +399,30 @@ void routing_table::find_node(node_id const& target // [0, bucket_index) if we are to include ourself // or [1, bucket_index) if not. bucket_t tmpb; - for (int i = include_self?0:1; i < count; ++i) + for (int i = include_self?0:1; i < bucket_index; ++i) { bucket_t& b = m_buckets[i].first; std::remove_copy_if(b.begin(), b.end(), std::back_inserter(tmpb) , bind(&node_entry::fail_count, _1)); } - std::random_shuffle(tmpb.begin(), tmpb.end()); - size_t to_copy = (std::min)(m_bucket_size - l.size() - , tmpb.size()); - std::copy(tmpb.begin(), tmpb.begin() + to_copy - , std::back_inserter(l)); + if (count - l.size() < tmpb.size()) + { + std::random_shuffle(tmpb.begin(), tmpb.end()); + size_t to_copy = count - l.size(); + std::copy(tmpb.begin(), tmpb.begin() + to_copy, std::back_inserter(l)); + } + else + { + std::copy(tmpb.begin(), tmpb.end(), std::back_inserter(l)); + } - TORRENT_ASSERT((int)l.size() <= m_bucket_size); + TORRENT_ASSERT((int)l.size() <= count); // return if we have enough nodes or if the bucket index // is the biggest index available (there are no more buckets) // to look in. - if ((int)l.size() == count - || bucket_index == (int)m_buckets.size() - 1) + if ((int)l.size() == count) { TORRENT_ASSERT(std::count_if(l.begin(), l.end() , boost::bind(&node_entry::fail_count, _1) != 0) == 0); @@ -416,18 +433,17 @@ void routing_table::find_node(node_id const& target { bucket_t& b = m_buckets[i].first; - std::remove_copy_if(b.begin(), b.end(), std::back_inserter(l) - , bind(&node_entry::fail_count, _1)); - if ((int)l.size() >= count) + size_t to_copy = (std::min)(count - l.size(), b.size()); + copy_if_n(b.begin(), b.end(), std::back_inserter(l) + , to_copy, bind(&node_entry::fail_count, _1) == 0); + TORRENT_ASSERT((int)l.size() <= count); + if ((int)l.size() == count) { - l.erase(l.begin() + count, l.end()); TORRENT_ASSERT(std::count_if(l.begin(), l.end() , boost::bind(&node_entry::fail_count, _1) != 0) == 0); return; } } - TORRENT_ASSERT((int)l.size() == count - || std::distance(l.begin(), l.end()) < m_bucket_size); TORRENT_ASSERT((int)l.size() <= count); TORRENT_ASSERT(std::count_if(l.begin(), l.end() diff --git a/test/test_primitives.cpp b/test/test_primitives.cpp index 776fb3885..b2619d9c8 100644 --- a/test/test_primitives.cpp +++ b/test/test_primitives.cpp @@ -7,6 +7,7 @@ #include "libtorrent/torrent_info.hpp" #include "libtorrent/escape_string.hpp" #include "libtorrent/kademlia/node_id.hpp" +#include "libtorrent/kademlia/routing_table.hpp" #include "libtorrent/broadcast_socket.hpp" #include @@ -63,6 +64,17 @@ void parser_callback(std::string& out, int token, char const* s, char const* val } } +void add_and_replace(libtorrent::dht::node_id& dst, libtorrent::dht::node_id const& add) +{ + bool carry = false; + for (int k = 19; k >= 0; --k) + { + int sum = dst[k] + add[k] + (carry?1:0); + dst[k] = sum & 255; + carry = sum > 255; + } +} + int test_main() { using namespace libtorrent; @@ -330,6 +342,80 @@ int test_main() } } + // test kademlia routing table + dht_settings s; + node_id id = boost::lexical_cast("6123456789abcdef01232456789abcdef0123456"); + dht::routing_table table(id, 10, s); + table.node_seen(id, udp::endpoint(address_v4::any(), rand())); + + node_id tmp; + node_id diff = boost::lexical_cast("00000f7459456a9453f8719b09547c11d5f34064"); + std::vector nodes; + for (int i = 0; i < 1000000; ++i) + { + table.node_seen(tmp, udp::endpoint(address_v4::any(), rand())); + add_and_replace(tmp, diff); + } + + std::copy(table.begin(), table.end(), std::back_inserter(nodes)); + + std::cout << "nodes: " << nodes.size() << std::endl; + + std::vector temp; + + std::generate(tmp.begin(), tmp.end(), &std::rand); + table.find_node(tmp, temp, false, nodes.size() + 1); + std::cout << "returned: " << temp.size() << std::endl; + TEST_CHECK(temp.size() == nodes.size() - 1); + + std::generate(tmp.begin(), tmp.end(), &std::rand); + table.find_node(tmp, temp, true, nodes.size() + 1); + std::cout << "returned: " << temp.size() << std::endl; + TEST_CHECK(temp.size() == nodes.size()); + + std::generate(tmp.begin(), tmp.end(), &std::rand); + table.find_node(tmp, temp, false, 7); + std::cout << "returned: " << temp.size() << std::endl; + TEST_CHECK(temp.size() == 7); + + std::sort(nodes.begin(), nodes.end(), bind(&compare_ref + , bind(&node_entry::id, _1) + , bind(&node_entry::id, _2), tmp)); + + int hits = 0; + for (std::vector::iterator i = temp.begin() + , end(temp.end()); i != end; ++i) + { + int hit = std::find_if(nodes.begin(), nodes.end() + , bind(&node_entry::id, _1) == i->id) - nodes.begin(); + std::cerr << hit << std::endl; + if (hit < int(temp.size())) ++hits; + } + TEST_CHECK(hits > int(temp.size()) / 2); + + std::generate(tmp.begin(), tmp.end(), &std::rand); + table.find_node(tmp, temp, false, 15); + std::cout << "returned: " << temp.size() << std::endl; + TEST_CHECK(temp.size() == 15); + + std::sort(nodes.begin(), nodes.end(), bind(&compare_ref + , bind(&node_entry::id, _1) + , bind(&node_entry::id, _2), tmp)); + + hits = 0; + for (std::vector::iterator i = temp.begin() + , end(temp.end()); i != end; ++i) + { + int hit = std::find_if(nodes.begin(), nodes.end() + , bind(&node_entry::id, _1) == i->id) - nodes.begin(); + std::cerr << hit << std::endl; + if (hit < int(temp.size())) ++hits; + } + TEST_CHECK(hits > int(temp.size()) / 2); + + + + // test peer_id/sha1_hash type sha1_hash h1(0);