diff --git a/include/libtorrent/kademlia/node.hpp b/include/libtorrent/kademlia/node.hpp index 895088622..5ba4e48fa 100644 --- a/include/libtorrent/kademlia/node.hpp +++ b/include/libtorrent/kademlia/node.hpp @@ -261,7 +261,7 @@ protected: // is called when a find data request is received. Should // return false if the data is not stored on this node. If // the data is stored, it should be serialized into 'data'. - bool lookup_peers(sha1_hash const& info_hash, entry& reply) const; + bool lookup_peers(sha1_hash const& info_hash, int prefix, entry& reply) const; bool lookup_torrents(sha1_hash const& target, entry& reply , char* tags) const; diff --git a/include/libtorrent/peer_id.hpp b/include/libtorrent/peer_id.hpp index 3624147ad..2a1e86a1e 100644 --- a/include/libtorrent/peer_id.hpp +++ b/include/libtorrent/peer_id.hpp @@ -118,6 +118,48 @@ namespace libtorrent return true; } + big_number& operator<<=(int n) + { + TORRENT_ASSERT(n >= 0); + if (n > number_size * 8) n = number_size; + int num_bytes = n / 8; + if (num_bytes > 0) + { + std::memmove(m_number, m_number + num_bytes, number_size - num_bytes); + std::memset(m_number + number_size - num_bytes, 0, num_bytes); + n -= num_bytes * 8; + } + if (n > 0) + { + for (int i = 0; i < number_size - 1; ++i) + { + m_number[i] <<= n; + m_number[i] |= m_number[i+1] >> (8 - n); + } + } + return *this; + } + + big_number& operator>>=(int n) + { + int num_bytes = n / 8; + if (num_bytes > 0) + { + std::memmove(m_number + num_bytes, m_number, number_size - num_bytes); + std::memset(m_number, 0, num_bytes); + n -= num_bytes * 8; + } + if (n > 0) + { + for (int i = number_size - 1; i > 0; --i) + { + m_number[i] >>= n; + m_number[i] |= m_number[i-1] << (8 - n); + } + } + return *this; + } + bool operator==(big_number const& n) const { return std::equal(n.m_number, n.m_number+number_size, m_number); @@ -146,6 +188,20 @@ namespace libtorrent return ret; } + big_number operator^ (big_number const& n) const + { + big_number ret = *this; + ret ^= n; + return ret; + } + + big_number operator& (big_number const& n) const + { + big_number ret = *this; + ret &= n; + return ret; + } + big_number& operator &= (big_number const& n) { for (int i = 0; i< number_size; ++i) diff --git a/src/kademlia/node.cpp b/src/kademlia/node.cpp index 07955060f..3e63a4083 100644 --- a/src/kademlia/node.cpp +++ b/src/kademlia/node.cpp @@ -525,13 +525,20 @@ bool node_impl::lookup_torrents(sha1_hash const& target return true; } -bool node_impl::lookup_peers(sha1_hash const& info_hash, entry& reply) const +bool node_impl::lookup_peers(sha1_hash const& info_hash, int prefix, entry& reply) const { if (m_ses.m_alerts.should_post()) m_ses.m_alerts.post_alert(dht_get_peers_alert(info_hash)); - table_t::const_iterator i = m_map.find(info_hash); + table_t::const_iterator i = m_map.lower_bound(info_hash); if (i == m_map.end()) return false; + if (i->first != info_hash && prefix == 20) return false; + if (prefix != 20) + { + sha1_hash mask = sha1_hash::max(); + mask <<= (20 - prefix) * 8; + if ((i->first & mask) != (info_hash & mask)) return false; + } torrent_entry const& v = i->second; if (v.peers.empty()) return false; @@ -704,10 +711,11 @@ void node_impl::incoming_request(msg const& m, entry& e) { key_desc_t msg_desc[] = { {"info_hash", lazy_entry::string_t, 20, 0}, + {"ifhpfxl", lazy_entry::int_t, 0, key_desc_t::optional}, }; - lazy_entry const* msg_keys[1]; - if (!verify_message(arg_ent, msg_desc, msg_keys, 1, error_string, sizeof(error_string))) + lazy_entry const* msg_keys[2]; + if (!verify_message(arg_ent, msg_desc, msg_keys, 2, error_string, sizeof(error_string))) { incoming_error(e, error_string); return; @@ -721,7 +729,11 @@ void node_impl::incoming_request(msg const& m, entry& e) m_table.find_node(info_hash, n, 0); write_nodes_entry(reply, n); - bool ret = lookup_peers(info_hash, reply); + int prefix = msg_keys[1] ? msg_keys[1]->int_value() : 20; + if (prefix > 20) prefix = 20; + else if (prefix < 4) prefix = 4; + + bool ret = lookup_peers(info_hash, prefix, reply); #ifdef TORRENT_DHT_VERBOSE_LOGGING if (ret) TORRENT_LOG(node) << " values: " << reply["values"].list().size(); #endif diff --git a/test/test_primitives.cpp b/test/test_primitives.cpp index f45fc98f5..f6e2e147c 100644 --- a/test/test_primitives.cpp +++ b/test/test_primitives.cpp @@ -1356,6 +1356,21 @@ int test_main() h2 = sha1_hash(" "); TEST_CHECK(h2 == to_hash("2020202020202020202020202020202020202020")); + + h1 = to_hash("ffffffffff0000000000ffffffffff0000000000"); +#if TORRENT_USE_IOSTREAM + std::cerr << h1 << std::endl; +#endif + h1 <<= 12; +#if TORRENT_USE_IOSTREAM + std::cerr << h1 << std::endl; +#endif + TEST_CHECK(h1 == to_hash("fffffff0000000000ffffffffff0000000000000")); + h1 >>= 12; +#if TORRENT_USE_IOSTREAM + std::cerr << h1 << std::endl; +#endif + TEST_CHECK(h1 == to_hash("000fffffff0000000000ffffffffff0000000000")); // CIDR distance test h1 = to_hash("0123456789abcdef01232456789abcdef0123456");