From 06c53a41d10ed96da48e9f52be2b9619d2308d6f Mon Sep 17 00:00:00 2001 From: Steven Siloti Date: Sun, 20 Nov 2016 17:06:52 -0800 Subject: [PATCH] revisions based on code review Moved the logic to determine which sockets to keep to a separate function and add unit tests for it. Store the port which was originally specified for a listen socket so that we can match sockets reliably even with port retries. --- include/libtorrent/aux_/session_impl.hpp | 25 +++ src/session_impl.cpp | 90 +++++----- test/Jamfile | 1 + test/Makefile.am | 1 + test/test_listen_socket.cpp | 201 +++++++++++++++++++++++ 5 files changed, 267 insertions(+), 51 deletions(-) create mode 100644 test/test_listen_socket.cpp diff --git a/include/libtorrent/aux_/session_impl.hpp b/include/libtorrent/aux_/session_impl.hpp index d42a0b5a7..9ce28870c 100644 --- a/include/libtorrent/aux_/session_impl.hpp +++ b/include/libtorrent/aux_/session_impl.hpp @@ -142,6 +142,11 @@ namespace libtorrent // if the socket is not bound to a device std::string device; + // this is the port that was originally specified to listen on + // it may be different from local_endpoint.port() if we could + // had to retry binding with a higher port + int original_port; + // this is typically set to the same as the local // listen port. In case a NAT port forward was // successfully opened, this will be set to the @@ -188,6 +193,26 @@ namespace libtorrent TORRENT_EXTRA_EXPORT entry save_dht_settings(dht_settings const& settings); #endif + struct TORRENT_EXTRA_EXPORT listen_endpoint_t + { + listen_endpoint_t(address adr, int p, std::string dev, bool s) + : addr(adr), port(p), device(dev), ssl(s) {} + + address addr; + int port; + std::string device; + bool ssl; + }; + + // partitions sockets based on whether they match one of the given endpoints + // all matched sockets are ordered before unmatched sockets + // matched endpoints are removed from the vector + // returns an iterator to the first unmatched socket + TORRENT_EXTRA_EXPORT std::list::iterator + partition_listen_sockets( + std::vector& eps + , std::list& sockets); + // this is the link between the main thread and the // thread started to run the main downloader loop struct TORRENT_EXTRA_EXPORT session_impl final diff --git a/src/session_impl.cpp b/src/session_impl.cpp index 9f5d0a59c..310de80e8 100644 --- a/src/session_impl.cpp +++ b/src/session_impl.cpp @@ -270,6 +270,37 @@ namespace aux { } #endif // TORRENT_DISABLE_DHT + std::list::iterator partition_listen_sockets( + std::vector& eps + , std::list& sockets) + { + return std::partition(sockets.begin(), sockets.end() + , [&eps](listen_socket_t const& sock) + { + auto match = std::find_if(eps.begin(), eps.end() + , [&sock](listen_endpoint_t const& ep) + { + return ep.ssl == sock.ssl + && ep.port == sock.original_port + && ep.device == sock.device + && ep.addr == sock.local_endpoint.address(); + }); + + if (match != eps.end()) + { + // remove the matched endpoint so that another socket can't match it + // this also signals to the caller that it doesn't need to create a + // socket for the endpoint + eps.erase(match); + return true; + } + else + { + return false; + } + }); + } + void session_impl::init_peer_class_filter(bool unlimited_local) { // set the default peer_class_filter to use the local peer class @@ -1384,6 +1415,7 @@ namespace aux { listen_socket_t ret; ret.ssl = (flags & open_ssl_socket) != 0; + ret.original_port = bind_ep.port(); int last_op = 0; socket_type_t const sock_type = (flags & open_ssl_socket) @@ -1702,20 +1734,6 @@ namespace aux { reopen_listen_sockets(); } - namespace - { - struct listen_endpoint_t - { - listen_endpoint_t(address adr, int p, std::string dev, bool s) - : addr(adr), port(p), device(dev), ssl(s) {} - - address addr; - int port; - std::string device; - bool ssl; - }; - } - void session_impl::reopen_listen_sockets() { #ifndef TORRENT_DISABLE_LOGGING @@ -1810,51 +1828,21 @@ namespace aux { } } - int const port_retries = m_settings.get_int(settings_pack::max_retry_port_bind); + auto remove_iter = partition_listen_sockets(eps, m_listen_sockets); - // sockets we are keeping get moved to this list to prevent a socket from matching - // multiple endpoints - std::list keep; - - // remove any sockets which are no longer in the set of endpoints - // to listen on - // warning: O(n^2) operation! - // hopefully the system doesn't have too many interfaces - for (auto sock = m_listen_sockets.begin() - ; sock != m_listen_sockets.end();) + while (remove_iter != m_listen_sockets.end()) { - auto match = std::find_if(eps.begin(), eps.end() - , [sock, port_retries](listen_endpoint_t const& ep) - { return ep.ssl == sock->ssl - && (ep.port == 0 || (sock->local_endpoint.port() >= ep.port - && sock->local_endpoint.port() - ep.port < port_retries)) - && ep.device == sock->device - && ep.addr == sock->local_endpoint.address(); }); - - if (match != eps.end()) - { - // we don't need to create a new listen socket for this endpoint - // so remove it from the list - eps.erase(match); - keep.splice(keep.end(), m_listen_sockets, sock++); - continue; - } - - // this socket's local_endpoint is not on the list of endpoints to listen on - // it has got to go // TODO notify interested parties of this socket's demise #ifndef TORRENT_DISABLE_LOGGING session_log("Closing listen socket for %s on device \"%s\"" - , print_endpoint(sock->local_endpoint).c_str(), sock->device.c_str()); + , print_endpoint(remove_iter->local_endpoint).c_str() + , remove_iter->device.c_str()); #endif - if (sock->sock) sock->sock->close(ec); - if (sock->udp_sock) sock->udp_sock->close(); - sock = m_listen_sockets.erase(sock); + if (remove_iter->sock) remove_iter->sock->close(ec); + if (remove_iter->udp_sock) remove_iter->udp_sock->close(); + remove_iter = m_listen_sockets.erase(remove_iter); } - TORRENT_ASSERT(m_listen_sockets.empty()); - m_listen_sockets.swap(keep); - // open new sockets on any endpoints that didn't match with // an existing socket for (auto const& ep : eps) diff --git a/test/Jamfile b/test/Jamfile index 0c560f887..b4e5d1506 100644 --- a/test/Jamfile +++ b/test/Jamfile @@ -146,6 +146,7 @@ test-suite libtorrent : test_enum_net.cpp test_linked_list.cpp test_stack_allocator.cpp + test_listen_socket.cpp test_file_progress.cpp ] [ run test_piece_picker.cpp ] diff --git a/test/Makefile.am b/test/Makefile.am index 2f41e1617..b46fd80b9 100644 --- a/test/Makefile.am +++ b/test/Makefile.am @@ -155,6 +155,7 @@ test_primitives_SOURCES = \ test_resolve_links.cpp \ test_crc32.cpp \ test_heterogeneous_queue.cpp \ + test_listen_socket.cpp \ test_ip_voter.cpp \ test_sliding_average.cpp \ test_socket_io.cpp \ diff --git a/test/test_listen_socket.cpp b/test/test_listen_socket.cpp new file mode 100644 index 000000000..93e07695f --- /dev/null +++ b/test/test_listen_socket.cpp @@ -0,0 +1,201 @@ +/* + +Copyright (c) 2016, Steven Siloti +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +* Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the distribution. +* Neither the name of the author nor the names of its +contributors may be used to endorse or promote products derived +from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "test.hpp" +#include "libtorrent/aux_/session_impl.hpp" + +using namespace libtorrent; + +namespace +{ + void test_equal(listen_socket_t const& s, address addr, int port, std::string dev, bool ssl) + { + TEST_EQUAL(s.ssl, ssl); + TEST_EQUAL(s.local_endpoint.address(), addr); + TEST_EQUAL(s.original_port, port); + TEST_EQUAL(s.device, dev); + } + + void test_equal(aux::listen_endpoint_t const& e1, address addr, int port, std::string dev, bool ssl) + { + TEST_EQUAL(e1.ssl, ssl); + TEST_EQUAL(e1.port, port); + TEST_EQUAL(e1.addr, addr); + TEST_EQUAL(e1.device, dev); + } +} + +TORRENT_TEST(partition_listen_sockets) +{ + { + std::list sockets; + listen_socket_t s; + s.local_endpoint = tcp::endpoint(tcp::v4(), 6881); + s.original_port = 6881; + sockets.push_back(s); + s.local_endpoint = tcp::endpoint(tcp::v6(), 6881); + sockets.push_back(s); + + // remove the wildcard v6 socket and replace it with a specific global IP + std::vector eps; + eps.emplace_back(address_v4(), 6881, "", false); + eps.emplace_back(address_v6::from_string("2001::1"), 6881, "", false); + auto remove_iter = aux::partition_listen_sockets(eps, sockets); + TEST_EQUAL(eps.size(), 1); + TEST_EQUAL(std::distance(sockets.begin(), remove_iter), 1); + TEST_EQUAL(std::distance(remove_iter, sockets.end()), 1); + test_equal(sockets.front(), address_v4(), 6881, "", false); + test_equal(sockets.back(), address_v6(), 6881, "", false); + test_equal(eps.front(), address_v6::from_string("2001::1"), 6881, "", false); + } + + { + std::list sockets; + listen_socket_t s; + s.local_endpoint = tcp::endpoint(tcp::v4(), 6881); + s.original_port = 6881; + sockets.push_back(s); + s.local_endpoint = tcp::endpoint(tcp::v6(), 6881); + sockets.push_back(s); + + // change the ports + std::vector eps; + eps.emplace_back(address_v4(), 6882, "", false); + eps.emplace_back(address_v6(), 6882, "", false); + auto remove_iter = aux::partition_listen_sockets(eps, sockets); + TEST_CHECK(sockets.begin() == remove_iter); + TEST_EQUAL(eps.size(), 2); + } + + { + std::list sockets; + listen_socket_t s; + s.local_endpoint = tcp::endpoint(address_v6::from_string("2001::1"), 6881); + s.original_port = 6881; + sockets.push_back(s); + s.local_endpoint = tcp::endpoint(tcp::v4(), 6881); + sockets.push_back(s); + + + // replace the IPv6 socket with a pair of device bound sockets + std::vector eps; + eps.emplace_back(address_v4(), 6881, "", false); + eps.emplace_back(address_v6::from_string("2001::1"), 6881, "eth1", false); + eps.emplace_back(address_v6::from_string("2001::2"), 6881, "eth1", false); + auto remove_iter = aux::partition_listen_sockets(eps, sockets); + TEST_EQUAL(std::distance(sockets.begin(), remove_iter), 1); + TEST_EQUAL(std::distance(remove_iter, sockets.end()), 1); + test_equal(sockets.front(), address_v4(), 6881, "", false); + test_equal(sockets.back(), address_v6::from_string("2001::1"), 6881, "", false); + TEST_EQUAL(eps.size(), 2); + } + + { + std::list sockets; + listen_socket_t s; + s.local_endpoint = tcp::endpoint(address_v6::from_string("fe80::d250:99ff:fe0c:9b74"), 6881); + s.device = "enp3s0"; + s.original_port = 6881; + sockets.push_back(s); + s.local_endpoint = tcp::endpoint(address_v6::from_string("2001::1"), 6881); + sockets.push_back(s); + + // change the global IP of a device bound socket + std::vector eps; + eps.emplace_back(address_v6::from_string("fe80::d250:99ff:fe0c:9b74"), 6881, "enp3s0", false); + eps.emplace_back(address_v6::from_string("2001::2"), 6881, "enp3s0", false); + auto remove_iter = aux::partition_listen_sockets(eps, sockets); + TEST_EQUAL(std::distance(sockets.begin(), remove_iter), 1); + TEST_EQUAL(std::distance(remove_iter, sockets.end()), 1); + test_equal(sockets.front(), address_v6::from_string("fe80::d250:99ff:fe0c:9b74"), 6881, "enp3s0", false); + test_equal(sockets.back(), address_v6::from_string("2001::1"), 6881, "enp3s0", false); + TEST_EQUAL(eps.size(), 1); + test_equal(eps.front(), address_v6::from_string("2001::2"), 6881, "enp3s0", false); + } + + { + std::list sockets; + listen_socket_t s; + s.local_endpoint = tcp::endpoint(tcp::v4(), 6883); + s.original_port = 6881; + sockets.push_back(s); + s.local_endpoint = tcp::endpoint(tcp::v6(), 6883); + sockets.push_back(s); + + // make sure all sockets are kept when the actual port is different from the original + std::vector eps; + eps.emplace_back(address_v4(), 6881, "", false); + eps.emplace_back(address_v6(), 6881, "", false); + auto remove_iter = aux::partition_listen_sockets(eps, sockets); + TEST_CHECK(remove_iter == sockets.end()); + TEST_CHECK(eps.empty()); + } + + { + std::list sockets; + listen_socket_t s; + s.local_endpoint = tcp::endpoint(tcp::v4(), 6881); + s.original_port = 6881; + sockets.push_back(s); + s.local_endpoint = tcp::endpoint(tcp::v6(), 6881); + sockets.push_back(s); + + // add ssl sockets + std::vector eps; + eps.emplace_back(address_v4(), 6881, "", false); + eps.emplace_back(address_v6(), 6881, "", false); + eps.emplace_back(address_v4(), 6881, "", true); + eps.emplace_back(address_v6(), 6881, "", true); + auto remove_iter = aux::partition_listen_sockets(eps, sockets); + TEST_CHECK(remove_iter == sockets.end()); + TEST_EQUAL(eps.size(), 2); + } + + { + std::list sockets; + listen_socket_t s; + s.local_endpoint = tcp::endpoint(tcp::v4(), 6881); + s.original_port = 0; + sockets.push_back(s); + s.local_endpoint = tcp::endpoint(tcp::v6(), 6881); + sockets.push_back(s); + + // replace OS assigned ports with explicit ports + std::vector eps; + eps.emplace_back(address_v4(), 6882, "", false); + eps.emplace_back(address_v6(), 6882, "", false); + auto remove_iter = aux::partition_listen_sockets(eps, sockets); + TEST_CHECK(remove_iter == sockets.begin()); + TEST_EQUAL(eps.size(), 2); + } + +}