fixed bug in dht routing table. added unit test to trunk

This commit is contained in:
Arvid Norberg 2008-03-16 03:52:13 +00:00
parent 5b4590c6b1
commit f8b2b60634
2 changed files with 119 additions and 17 deletions

View File

@ -355,6 +355,19 @@ bool routing_table::need_bootstrap() const
return true;
}
template <class SrcIter, class DstIter, class Pred>
DstIter copy_if_n(SrcIter begin, SrcIter end, DstIter target, size_t n, Pred p)
{
for (; n > 0 && begin != end; ++begin)
{
if (!p(*begin)) continue;
*target = *begin;
--n;
++target;
}
return target;
}
// fills the vector with the k nodes from our buckets that
// are nearest to the given id.
void routing_table::find_node(node_id const& target
@ -369,8 +382,8 @@ void routing_table::find_node(node_id const& target
// copy all nodes that hasn't failed into the target
// vector.
std::remove_copy_if(b.begin(), b.end(), std::back_inserter(l)
, bind(&node_entry::fail_count, _1));
copy_if_n(b.begin(), b.end(), std::back_inserter(l)
, (std::min)(size_t(count), b.size()), bind(&node_entry::fail_count, _1) == 0);
TORRENT_ASSERT((int)l.size() <= count);
if ((int)l.size() == count)
@ -386,26 +399,30 @@ void routing_table::find_node(node_id const& target
// [0, bucket_index) if we are to include ourself
// or [1, bucket_index) if not.
bucket_t tmpb;
for (int i = include_self?0:1; i < count; ++i)
for (int i = include_self?0:1; i < bucket_index; ++i)
{
bucket_t& b = m_buckets[i].first;
std::remove_copy_if(b.begin(), b.end(), std::back_inserter(tmpb)
, bind(&node_entry::fail_count, _1));
}
std::random_shuffle(tmpb.begin(), tmpb.end());
size_t to_copy = (std::min)(m_bucket_size - l.size()
, tmpb.size());
std::copy(tmpb.begin(), tmpb.begin() + to_copy
, std::back_inserter(l));
if (count - l.size() < tmpb.size())
{
std::random_shuffle(tmpb.begin(), tmpb.end());
size_t to_copy = count - l.size();
std::copy(tmpb.begin(), tmpb.begin() + to_copy, std::back_inserter(l));
}
else
{
std::copy(tmpb.begin(), tmpb.end(), std::back_inserter(l));
}
TORRENT_ASSERT((int)l.size() <= m_bucket_size);
TORRENT_ASSERT((int)l.size() <= count);
// return if we have enough nodes or if the bucket index
// is the biggest index available (there are no more buckets)
// to look in.
if ((int)l.size() == count
|| bucket_index == (int)m_buckets.size() - 1)
if ((int)l.size() == count)
{
TORRENT_ASSERT(std::count_if(l.begin(), l.end()
, boost::bind(&node_entry::fail_count, _1) != 0) == 0);
@ -416,18 +433,17 @@ void routing_table::find_node(node_id const& target
{
bucket_t& b = m_buckets[i].first;
std::remove_copy_if(b.begin(), b.end(), std::back_inserter(l)
, bind(&node_entry::fail_count, _1));
if ((int)l.size() >= count)
size_t to_copy = (std::min)(count - l.size(), b.size());
copy_if_n(b.begin(), b.end(), std::back_inserter(l)
, to_copy, bind(&node_entry::fail_count, _1) == 0);
TORRENT_ASSERT((int)l.size() <= count);
if ((int)l.size() == count)
{
l.erase(l.begin() + count, l.end());
TORRENT_ASSERT(std::count_if(l.begin(), l.end()
, boost::bind(&node_entry::fail_count, _1) != 0) == 0);
return;
}
}
TORRENT_ASSERT((int)l.size() == count
|| std::distance(l.begin(), l.end()) < m_bucket_size);
TORRENT_ASSERT((int)l.size() <= count);
TORRENT_ASSERT(std::count_if(l.begin(), l.end()

View File

@ -7,6 +7,7 @@
#include "libtorrent/torrent_info.hpp"
#include "libtorrent/escape_string.hpp"
#include "libtorrent/kademlia/node_id.hpp"
#include "libtorrent/kademlia/routing_table.hpp"
#include "libtorrent/broadcast_socket.hpp"
#include <boost/tuple/tuple.hpp>
@ -63,6 +64,17 @@ void parser_callback(std::string& out, int token, char const* s, char const* val
}
}
void add_and_replace(libtorrent::dht::node_id& dst, libtorrent::dht::node_id const& add)
{
bool carry = false;
for (int k = 19; k >= 0; --k)
{
int sum = dst[k] + add[k] + (carry?1:0);
dst[k] = sum & 255;
carry = sum > 255;
}
}
int test_main()
{
using namespace libtorrent;
@ -330,6 +342,80 @@ int test_main()
}
}
// test kademlia routing table
dht_settings s;
node_id id = boost::lexical_cast<sha1_hash>("6123456789abcdef01232456789abcdef0123456");
dht::routing_table table(id, 10, s);
table.node_seen(id, udp::endpoint(address_v4::any(), rand()));
node_id tmp;
node_id diff = boost::lexical_cast<sha1_hash>("00000f7459456a9453f8719b09547c11d5f34064");
std::vector<node_entry> nodes;
for (int i = 0; i < 1000000; ++i)
{
table.node_seen(tmp, udp::endpoint(address_v4::any(), rand()));
add_and_replace(tmp, diff);
}
std::copy(table.begin(), table.end(), std::back_inserter(nodes));
std::cout << "nodes: " << nodes.size() << std::endl;
std::vector<node_entry> temp;
std::generate(tmp.begin(), tmp.end(), &std::rand);
table.find_node(tmp, temp, false, nodes.size() + 1);
std::cout << "returned: " << temp.size() << std::endl;
TEST_CHECK(temp.size() == nodes.size() - 1);
std::generate(tmp.begin(), tmp.end(), &std::rand);
table.find_node(tmp, temp, true, nodes.size() + 1);
std::cout << "returned: " << temp.size() << std::endl;
TEST_CHECK(temp.size() == nodes.size());
std::generate(tmp.begin(), tmp.end(), &std::rand);
table.find_node(tmp, temp, false, 7);
std::cout << "returned: " << temp.size() << std::endl;
TEST_CHECK(temp.size() == 7);
std::sort(nodes.begin(), nodes.end(), bind(&compare_ref
, bind(&node_entry::id, _1)
, bind(&node_entry::id, _2), tmp));
int hits = 0;
for (std::vector<node_entry>::iterator i = temp.begin()
, end(temp.end()); i != end; ++i)
{
int hit = std::find_if(nodes.begin(), nodes.end()
, bind(&node_entry::id, _1) == i->id) - nodes.begin();
std::cerr << hit << std::endl;
if (hit < int(temp.size())) ++hits;
}
TEST_CHECK(hits > int(temp.size()) / 2);
std::generate(tmp.begin(), tmp.end(), &std::rand);
table.find_node(tmp, temp, false, 15);
std::cout << "returned: " << temp.size() << std::endl;
TEST_CHECK(temp.size() == 15);
std::sort(nodes.begin(), nodes.end(), bind(&compare_ref
, bind(&node_entry::id, _1)
, bind(&node_entry::id, _2), tmp));
hits = 0;
for (std::vector<node_entry>::iterator i = temp.begin()
, end(temp.end()); i != end; ++i)
{
int hit = std::find_if(nodes.begin(), nodes.end()
, bind(&node_entry::id, _1) == i->id) - nodes.begin();
std::cerr << hit << std::endl;
if (hit < int(temp.size())) ++hits;
}
TEST_CHECK(hits > int(temp.size()) / 2);
// test peer_id/sha1_hash type
sha1_hash h1(0);