From 8648de3706efe0d9ab2504793bd0dd8a9566a1d6 Mon Sep 17 00:00:00 2001 From: Steven Siloti Date: Thu, 21 Sep 2017 21:00:38 -0700 Subject: [PATCH] DHT nodes should only handle requests on their socket (#2355) --- include/libtorrent/aux_/session_impl.hpp | 1 + include/libtorrent/kademlia/dht_tracker.hpp | 3 +- include/libtorrent/kademlia/node.hpp | 7 +++-- simulation/setup_dht.cpp | 2 +- simulation/test_dht_rate_limit.cpp | 2 +- src/kademlia/dht_tracker.cpp | 6 ++-- src/kademlia/node.cpp | 9 +++--- src/session_impl.cpp | 12 ++++---- test/test_dht.cpp | 31 +++++++++++++++++++-- 9 files changed, 52 insertions(+), 21 deletions(-) diff --git a/include/libtorrent/aux_/session_impl.hpp b/include/libtorrent/aux_/session_impl.hpp index a884c74cb..9197af5fd 100644 --- a/include/libtorrent/aux_/session_impl.hpp +++ b/include/libtorrent/aux_/session_impl.hpp @@ -1102,6 +1102,7 @@ namespace aux { void on_udp_writeable(std::weak_ptr s, error_code const& ec); void on_udp_packet(std::weak_ptr s + , std::weak_ptr ls , transport ssl, error_code const& ec); libtorrent::utp_socket_manager m_utp_socket_manager; diff --git a/include/libtorrent/kademlia/dht_tracker.hpp b/include/libtorrent/kademlia/dht_tracker.hpp index c52d7e93b..da9517ca2 100644 --- a/include/libtorrent/kademlia/dht_tracker.hpp +++ b/include/libtorrent/kademlia/dht_tracker.hpp @@ -141,7 +141,8 @@ namespace libtorrent { namespace dht { void update_stats_counters(counters& c) const; void incoming_error(error_code const& ec, udp::endpoint const& ep); - bool incoming_packet(udp::endpoint const& ep, span buf); + bool incoming_packet(aux::listen_socket_handle const& s + , udp::endpoint const& ep, span buf); std::vector> live_nodes(node_id const& nid); diff --git a/include/libtorrent/kademlia/node.hpp b/include/libtorrent/kademlia/node.hpp index 24c1e154d..0e040de93 100644 --- a/include/libtorrent/kademlia/node.hpp +++ b/include/libtorrent/kademlia/node.hpp @@ -110,7 +110,7 @@ public: void add_router_node(udp::endpoint const& router); void unreachable(udp::endpoint const& ep); - void incoming(msg const& m); + void incoming(aux::listen_socket_handle const& s, msg const& m); #ifndef TORRENT_NO_DEPRECATE int num_torrents() const { return int(m_storage.num_torrents()); } @@ -241,6 +241,7 @@ private: public: routing_table m_table; rpc_manager m_rpc; + aux::listen_socket_handle const m_sock; private: @@ -253,6 +254,8 @@ private: static protocol_descriptor const& map_protocol_to_descriptor(udp protocol); + socket_manager* m_sock_man; + get_foreign_node_t m_get_foreign_node; dht_observer* m_observer; @@ -268,8 +271,6 @@ private: // secret random numbers used to create write tokens std::uint32_t m_secret[2]; - socket_manager* m_sock_man; - aux::listen_socket_handle m_sock; counters& m_counters; dht_storage_interface& m_storage; diff --git a/simulation/setup_dht.cpp b/simulation/setup_dht.cpp index 259c4a297..5322aa73f 100644 --- a/simulation/setup_dht.cpp +++ b/simulation/setup_dht.cpp @@ -145,7 +145,7 @@ struct dht_node final : lt::dht::socket_manager if (msg.type() != bdecode_node::dict_t) return; lt::dht::msg m(msg, m_ep); - dht().incoming(m); + dht().incoming(m_ls, m); sock().async_receive_from(asio::mutable_buffers_1(m_buffer, sizeof(m_buffer)) , m_ep, [&](lt::error_code const& ec, std::size_t bytes_transferred) diff --git a/simulation/test_dht_rate_limit.cpp b/simulation/test_dht_rate_limit.cpp index ca17f7550..11d80f603 100644 --- a/simulation/test_dht_rate_limit.cpp +++ b/simulation/test_dht_rate_limit.cpp @@ -132,7 +132,7 @@ TORRENT_TEST(dht_rate_limit) udp_socket::packet p; error_code err; int const num = int(sock.read(lt::span(&p, 1), err)); - if (num) dht->incoming_packet(p.from, p.data); + if (num) dht->incoming_packet(ls, p.from, p.data); if (stop || err) return; sock.async_read(on_read); }; diff --git a/src/kademlia/dht_tracker.cpp b/src/kademlia/dht_tracker.cpp index cf4a219ca..2c610cba1 100644 --- a/src/kademlia/dht_tracker.cpp +++ b/src/kademlia/dht_tracker.cpp @@ -498,8 +498,8 @@ namespace libtorrent { namespace dht { } } - bool dht_tracker::incoming_packet(udp::endpoint const& ep - , span const buf) + bool dht_tracker::incoming_packet(aux::listen_socket_handle const& s + , udp::endpoint const& ep, span const buf) { int const buf_size = int(buf.size()); if (buf_size <= 20 @@ -564,7 +564,7 @@ namespace libtorrent { namespace dht { libtorrent::dht::msg const m(m_msg, ep); for (auto& n : m_nodes) - n.second.dht.incoming(m); + n.second.dht.incoming(s, m); return true; } diff --git a/src/kademlia/node.cpp b/src/kademlia/node.cpp index e4368f2f8..71a596702 100644 --- a/src/kademlia/node.cpp +++ b/src/kademlia/node.cpp @@ -113,13 +113,13 @@ node::node(aux::listen_socket_handle const& sock, socket_manager* sock_man , m_id(calculate_node_id(nid, sock)) , m_table(m_id, sock.get_local_endpoint().protocol() == tcp::v4() ? udp::v4() : udp::v6(), 8, settings, observer) , m_rpc(m_id, m_settings, m_table, sock, sock_man, observer) + , m_sock(sock) + , m_sock_man(sock_man) , m_get_foreign_node(get_foreign_node) , m_observer(observer) , m_protocol(map_protocol_to_descriptor(sock.get_local_endpoint().protocol() == tcp::v4() ? udp::v4() : udp::v6())) , m_last_tracker_tick(aux::time_now()) , m_last_self_refresh(min_time()) - , m_sock_man(sock_man) - , m_sock(sock) , m_counters(cnt) , m_storage(storage) { @@ -256,7 +256,7 @@ void node::unreachable(udp::endpoint const& ep) m_rpc.unreachable(ep); } -void node::incoming(msg const& m) +void node::incoming(aux::listen_socket_handle const& s, msg const& m) { // is this a reply? bdecode_node const y_ent = m.message.dict_find_string("y"); @@ -317,7 +317,8 @@ void node::incoming(msg const& m) // responds to 'query' messages that it receives. if (m_settings.read_only) break; - if (!native_address(m.addr)) break; + // only respond to requests if they're addressed to this node + if (s != m_sock) break; if (!m_sock_man->has_quota()) { diff --git a/src/session_impl.cpp b/src/session_impl.cpp index 7e53626d6..be4567fec 100644 --- a/src/session_impl.cpp +++ b/src/session_impl.cpp @@ -1686,7 +1686,7 @@ namespace { // TODO: 2 use a handler allocator here ADD_OUTSTANDING_ASYNC("session_impl::on_udp_packet"); ret->udp_sock->sock.async_read(std::bind(&session_impl::on_udp_packet - , this, ret->udp_sock, ret->ssl, _1)); + , this, ret->udp_sock, ret, ret->ssl, _1)); #ifndef TORRENT_DISABLE_LOGGING if (should_log()) @@ -2078,7 +2078,7 @@ namespace { // TODO: 2 use a handler allocator here ADD_OUTSTANDING_ASYNC("session_impl::on_udp_packet"); udp_sock->sock.async_read(std::bind(&session_impl::on_udp_packet - , this, udp_sock, ep.ssl, _1)); + , this, udp_sock, std::weak_ptr(), ep.ssl, _1)); if (!ec && udp_sock) { @@ -2328,7 +2328,7 @@ namespace { void session_impl::on_udp_packet(std::weak_ptr socket - , transport const ssl, error_code const& ec) + , std::weak_ptr ls, transport const ssl, error_code const& ec) { COMPLETE_ASYNC("session_impl::on_udp_packet"); if (ec) @@ -2402,7 +2402,9 @@ namespace { #ifndef TORRENT_DISABLE_DHT if (m_dht && buf.size() > 20 && buf.front() == 'd' && buf.back() == 'e') { - handled = m_dht->incoming_packet(packet.from, buf); + auto listen_socket = ls.lock(); + if (listen_socket) + handled = m_dht->incoming_packet(listen_socket, packet.from, buf); } #endif @@ -2470,7 +2472,7 @@ namespace { ADD_OUTSTANDING_ASYNC("session_impl::on_udp_packet"); s->sock.async_read(std::bind(&session_impl::on_udp_packet - , this, socket, ssl, _1)); + , this, std::move(socket), std::move(ls), ssl, _1)); } void session_impl::async_accept(std::shared_ptr const& listener diff --git a/test/test_dht.cpp b/test/test_dht.cpp index 2ed25a202..07af199c0 100644 --- a/test/test_dht.cpp +++ b/test/test_dht.cpp @@ -288,7 +288,7 @@ void send_dht_request(node& node, char const* msg, udp::endpoint const& ep if (ec) std::printf("bdecode failed: %s\n", ec.message().c_str()); dht::msg m(decoded, ep); - node.incoming(m); + node.incoming(node.m_sock, m); // If the request is supposed to get a response, by now the node should have // invoked the send function and put the response in g_sent_packets @@ -333,7 +333,7 @@ void send_dht_response(node& node, bdecode_node const& request, udp::endpoint co if (ec) std::printf("bdecode failed: %s\n", ec.message().c_str()); dht::msg m(decoded, ep); - node.incoming(m); + node.incoming(node.m_sock, m); } struct announce_item @@ -2765,6 +2765,31 @@ TORRENT_TEST(dht_dual_stack) } #endif +TORRENT_TEST(multi_home) +{ + // send a request with a different listen socket and make sure the node ignores it + dht_test_setup t(udp::endpoint(rand_v4(), 20)); + bdecode_node response; + + entry e; + e["q"] = "ping"; + e["t"] = "10"; + e["y"] = "q"; + e["a"].dict().insert(std::make_pair("id", generate_next().to_string())); + char msg_buf[1500]; + int size = bencode(msg_buf, e); + + bdecode_node decoded; + error_code ec; + bdecode(msg_buf, msg_buf + size, decoded, ec); + if (ec) std::printf("bdecode failed: %s\n", ec.message().c_str()); + + dht::msg m(decoded, t.source); + t.dht_node.incoming(dummy_listen_socket(udp::endpoint(rand_v4(), 21)), m); + TEST_CHECK(g_sent_packets.empty()); + g_sent_packets.clear(); +} + TORRENT_TEST(signing_test1) { // test vector 1 @@ -3254,7 +3279,7 @@ TORRENT_TEST(invalid_error_msg) if (ec) std::printf("bdecode failed: %s\n", ec.message().c_str()); dht::msg m(decoded, source); - node.incoming(m); + node.incoming(node.m_sock, m); bool found = false; for (int i = 0; i < int(observer.m_log.size()); ++i)