diff --git a/include/libtorrent/broadcast_socket.hpp b/include/libtorrent/broadcast_socket.hpp index 0225b7887..283b42038 100644 --- a/include/libtorrent/broadcast_socket.hpp +++ b/include/libtorrent/broadcast_socket.hpp @@ -59,9 +59,6 @@ namespace libtorrent // determines if the operating system supports IPv6 TORRENT_EXTRA_EXPORT bool supports_ipv6(); - TORRENT_EXTRA_EXPORT int common_bits(unsigned char const* b1 - , unsigned char const* b2, int n); - typedef boost::function receive_handler_t; diff --git a/include/libtorrent/kademlia/routing_table.hpp b/include/libtorrent/kademlia/routing_table.hpp index 0c6da4e53..05057b718 100644 --- a/include/libtorrent/kademlia/routing_table.hpp +++ b/include/libtorrent/kademlia/routing_table.hpp @@ -124,6 +124,8 @@ namespace impl } } +TORRENT_EXTRA_EXPORT bool compare_ip_cidr(address const& lhs, address const& rhs); + class TORRENT_EXTRA_EXPORT routing_table : boost::noncopyable { public: diff --git a/src/broadcast_socket.cpp b/src/broadcast_socket.cpp index c679564c4..a62ffdcd3 100644 --- a/src/broadcast_socket.cpp +++ b/src/broadcast_socket.cpp @@ -149,49 +149,6 @@ namespace libtorrent #endif } - // count the length of the common bit prefix - int common_bits(unsigned char const* b1 - , unsigned char const* b2, int n) - { - for (int i = 0; i < n; ++i, ++b1, ++b2) - { - unsigned char a = *b1 ^ *b2; - if (a == 0) continue; - int ret = i * 8 + 8; - for (; a > 0; a >>= 1) --ret; - return ret; - } - return n * 8; - } - - // returns the number of bits in that differ from the right - // between the addresses. The larger number, the further apart - // the IPs are - int cidr_distance(address const& a1, address const& a2) - { -#if TORRENT_USE_IPV6 - if (a1.is_v4() && a2.is_v4()) - { -#endif - // both are v4 - address_v4::bytes_type b1 = a1.to_v4().to_bytes(); - address_v4::bytes_type b2 = a2.to_v4().to_bytes(); - return int(address_v4::bytes_type().size()) * 8 - - common_bits(b1.data(), b2.data(), int(b1.size())); -#if TORRENT_USE_IPV6 - } - - address_v6::bytes_type b1; - address_v6::bytes_type b2; - if (a1.is_v4()) b1 = address_v6::v4_mapped(a1.to_v4()).to_bytes(); - else b1 = a1.to_v6().to_bytes(); - if (a2.is_v4()) b2 = address_v6::v4_mapped(a2.to_v4()).to_bytes(); - else b2 = a2.to_v6().to_bytes(); - return int(address_v6::bytes_type().size()) * 8 - - common_bits(b1.data(), b2.data(), int(b1.size())); -#endif - } - broadcast_socket::broadcast_socket( udp::endpoint const& multicast_endpoint) : m_multicast_endpoint(multicast_endpoint) diff --git a/src/kademlia/routing_table.cpp b/src/kademlia/routing_table.cpp index 9437cae75..ad1cb1bd2 100644 --- a/src/kademlia/routing_table.cpp +++ b/src/kademlia/routing_table.cpp @@ -469,10 +469,9 @@ out: void routing_table::replacement_cache(bucket_t& nodes) const { - for (table_t::const_iterator i = m_buckets.begin() - , end(m_buckets.end()); i != end; ++i) + for (auto const& b : m_buckets) { - std::copy(i->replacements.begin(), i->replacements.end() + std::copy(b.replacements.begin(), b.replacements.end() , std::back_inserter(nodes)); } } @@ -497,21 +496,37 @@ routing_table::table_t::iterator routing_table::find_bucket(node_id const& id) return i; } -namespace { - -bool compare_ip_cidr(node_entry const& lhs, node_entry const& rhs) +// returns true if the two IPs are "too close" to each other to be allowed in +// the same DHT lookup. If they are, the last one to be found will be ignored +bool compare_ip_cidr(address const& lhs, address const& rhs) { - TORRENT_ASSERT(lhs.addr().is_v4() == rhs.addr().is_v4()); - // the number of bits in the IPs that may match. If - // more bits that this matches, something suspicious is - // going on and we shouldn't add the second one to our - // routing table - int cutoff = rhs.addr().is_v4() ? 8 : 64; - int dist = cidr_distance(lhs.addr(), rhs.addr()); - return dist <= cutoff; -} + TORRENT_ASSERT(lhs.is_v4() == rhs.is_v4()); -} // anonymous namespace +#if TORRENT_USE_IPV6 + if (lhs.is_v6()) + { + // if IPv6 addresses is in the same /64, they're too close and we won't + // trust the second one + boost::uint64_t lhs_ip; + memcpy(&lhs_ip, lhs.to_v6().to_bytes().data(), 8); + boost::uint64_t rhs_ip; + memcpy(&rhs_ip, rhs.to_v6().to_bytes().data(), 8); + + // since the condition we're looking for is all the first bits being + // zero, there's no need to byte-swap into host byte order here. + boost::uint64_t const mask = lhs_ip ^ rhs_ip; + return mask == 0; + } + else +#endif + { + // if IPv4 addresses is in the same /24, they're too close and we won't + // trust the second one + boost::uint32_t const mask + = lhs.to_v4().to_ulong() ^ rhs.to_v4().to_ulong(); + return mask <= 0x000000ff; + } +} node_entry* routing_table::find_node(udp::endpoint const& ep , routing_table::table_t::iterator* bucket) @@ -735,12 +750,11 @@ routing_table::add_node_status_t routing_table::add_node_impl(node_entry e) if (m_settings.restrict_routing_ips) { // don't allow multiple entries from IPs very close to each other - // TODO: 3 the call to compare_ip_cidr here is expensive. peel off some - // layers of abstraction here to make it quicker. Look at xoring and using _builtin_ctz() - j = std::find_if(b.begin(), b.end(), boost::bind(&compare_ip_cidr, _1, e)); + address const cmp = e.addr(); + j = std::find_if(b.begin(), b.end(), [&](node_entry const& a) { return compare_ip_cidr(a.addr(), cmp); }); if (j == b.end()) { - j = std::find_if(rb.begin(), rb.end(), boost::bind(&compare_ip_cidr, _1, e)); + j = std::find_if(rb.begin(), rb.end(), [&](node_entry const& a) { return compare_ip_cidr(a.addr(), cmp); }); if (j == rb.end()) goto ip_ok; } diff --git a/test/test_dht.cpp b/test/test_dht.cpp index 075ac8021..7e0189307 100644 --- a/test/test_dht.cpp +++ b/test/test_dht.cpp @@ -2825,5 +2825,53 @@ TORRENT_TEST(distance_exp) ), std::get<2>(t)); } } + +TORRENT_TEST(compare_ip_cidr) +{ + using tst = std::tuple; + std::vector const v4tests = { + tst{"10.255.255.0", "10.255.255.255", true}, + tst{"11.0.0.0", "10.255.255.255", false}, + tst{"0.0.0.0", "128.255.255.255", false}, + tst{"0.0.0.0", "127.255.255.255", false}, + tst{"255.255.255.0", "255.255.255.255", true}, + tst{"255.254.255.0", "255.255.255.255", false}, + tst{"0.0.0.0", "0.0.0.0", true}, + tst{"255.255.255.255", "255.255.255.255", true}, + }; + + for (auto const& t : v4tests) + { + fprintf(stderr, "%s %s\n", std::get<0>(t), std::get<1>(t)); + TEST_EQUAL(compare_ip_cidr( + address_v4::from_string(std::get<0>(t)) + , address_v4::from_string(std::get<1>(t))) + , std::get<2>(t)); + } + +#if TORRENT_USE_IPV6 + std::vector const v6tests = { + tst{"::1", "::ffff:ffff:ffff:ffff", true}, + tst{"::2:0000:0000:0000:0000", "::1:ffff:ffff:ffff:ffff", false}, + tst{"::ff:0000:0000:0000:0000", "::ffff:ffff:ffff:ffff", false}, + tst{"::caca:0000:0000:0000:0000", "::ffff:ffff:ffff:ffff:ffff", false}, + tst{"::a:0000:0000:0000:0000", "::b:ffff:ffff:ffff:ffff", false}, + tst{"::7f:0000:0000:0000:0000", "::ffff:ffff:ffff:ffff", false}, + tst{"7f::", "ff::", false}, + tst{"ff::", "ff::", true}, + tst{"::", "::", true}, + tst{"ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", true}, + }; + + for (auto const& t : v6tests) + { + TEST_EQUAL(compare_ip_cidr( + address_v6::from_string(std::get<0>(t)) + , address_v6::from_string(std::get<1>(t))) + , std::get<2>(t)); + } +#endif +} + #endif diff --git a/test/test_primitives.cpp b/test/test_primitives.cpp index 7a1f91c0f..148d3bbdd 100644 --- a/test/test_primitives.cpp +++ b/test/test_primitives.cpp @@ -118,17 +118,6 @@ TORRENT_TEST(primitives) // test network functions - // CIDR distance test - sha1_hash h1 = to_hash("0123456789abcdef01232456789abcdef0123456"); - sha1_hash h2 = to_hash("0123456789abcdef01232456789abcdef0123456"); - TEST_CHECK(common_bits(&h1[0], &h2[0], 20) == 160); - h2 = to_hash("0120456789abcdef01232456789abcdef0123456"); - TEST_CHECK(common_bits(&h1[0], &h2[0], 20) == 14); - h2 = to_hash("012f456789abcdef01232456789abcdef0123456"); - TEST_CHECK(common_bits(&h1[0], &h2[0], 20) == 12); - h2 = to_hash("0123456789abcdef11232456789abcdef0123456"); - TEST_CHECK(common_bits(&h1[0], &h2[0], 20) == 16 * 4 + 3); - // test print_endpoint, parse_endpoint and print_address TEST_EQUAL(print_endpoint(ep("127.0.0.1", 23)), "127.0.0.1:23"); #if TORRENT_USE_IPV6