revisions based on code review

Moved the logic to determine which sockets to keep to a separate function and
add unit tests for it.

Store the port which was originally specified for a listen socket so that we
can match sockets reliably even with port retries.
This commit is contained in:
Steven Siloti 2016-11-20 17:06:52 -08:00 committed by Arvid Norberg
parent babb93fb1e
commit 06c53a41d1
5 changed files with 267 additions and 51 deletions

View File

@ -142,6 +142,11 @@ namespace libtorrent
// if the socket is not bound to a device
std::string device;
// this is the port that was originally specified to listen on
// it may be different from local_endpoint.port() if we could
// had to retry binding with a higher port
int original_port;
// this is typically set to the same as the local
// listen port. In case a NAT port forward was
// successfully opened, this will be set to the
@ -188,6 +193,26 @@ namespace libtorrent
TORRENT_EXTRA_EXPORT entry save_dht_settings(dht_settings const& settings);
#endif
struct TORRENT_EXTRA_EXPORT listen_endpoint_t
{
listen_endpoint_t(address adr, int p, std::string dev, bool s)
: addr(adr), port(p), device(dev), ssl(s) {}
address addr;
int port;
std::string device;
bool ssl;
};
// partitions sockets based on whether they match one of the given endpoints
// all matched sockets are ordered before unmatched sockets
// matched endpoints are removed from the vector
// returns an iterator to the first unmatched socket
TORRENT_EXTRA_EXPORT std::list<listen_socket_t>::iterator
partition_listen_sockets(
std::vector<listen_endpoint_t>& eps
, std::list<listen_socket_t>& sockets);
// this is the link between the main thread and the
// thread started to run the main downloader loop
struct TORRENT_EXTRA_EXPORT session_impl final

View File

@ -270,6 +270,37 @@ namespace aux {
}
#endif // TORRENT_DISABLE_DHT
std::list<listen_socket_t>::iterator partition_listen_sockets(
std::vector<listen_endpoint_t>& eps
, std::list<listen_socket_t>& sockets)
{
return std::partition(sockets.begin(), sockets.end()
, [&eps](listen_socket_t const& sock)
{
auto match = std::find_if(eps.begin(), eps.end()
, [&sock](listen_endpoint_t const& ep)
{
return ep.ssl == sock.ssl
&& ep.port == sock.original_port
&& ep.device == sock.device
&& ep.addr == sock.local_endpoint.address();
});
if (match != eps.end())
{
// remove the matched endpoint so that another socket can't match it
// this also signals to the caller that it doesn't need to create a
// socket for the endpoint
eps.erase(match);
return true;
}
else
{
return false;
}
});
}
void session_impl::init_peer_class_filter(bool unlimited_local)
{
// set the default peer_class_filter to use the local peer class
@ -1384,6 +1415,7 @@ namespace aux {
listen_socket_t ret;
ret.ssl = (flags & open_ssl_socket) != 0;
ret.original_port = bind_ep.port();
int last_op = 0;
socket_type_t const sock_type
= (flags & open_ssl_socket)
@ -1702,20 +1734,6 @@ namespace aux {
reopen_listen_sockets();
}
namespace
{
struct listen_endpoint_t
{
listen_endpoint_t(address adr, int p, std::string dev, bool s)
: addr(adr), port(p), device(dev), ssl(s) {}
address addr;
int port;
std::string device;
bool ssl;
};
}
void session_impl::reopen_listen_sockets()
{
#ifndef TORRENT_DISABLE_LOGGING
@ -1810,51 +1828,21 @@ namespace aux {
}
}
int const port_retries = m_settings.get_int(settings_pack::max_retry_port_bind);
auto remove_iter = partition_listen_sockets(eps, m_listen_sockets);
// sockets we are keeping get moved to this list to prevent a socket from matching
// multiple endpoints
std::list<listen_socket_t> keep;
// remove any sockets which are no longer in the set of endpoints
// to listen on
// warning: O(n^2) operation!
// hopefully the system doesn't have too many interfaces
for (auto sock = m_listen_sockets.begin()
; sock != m_listen_sockets.end();)
while (remove_iter != m_listen_sockets.end())
{
auto match = std::find_if(eps.begin(), eps.end()
, [sock, port_retries](listen_endpoint_t const& ep)
{ return ep.ssl == sock->ssl
&& (ep.port == 0 || (sock->local_endpoint.port() >= ep.port
&& sock->local_endpoint.port() - ep.port < port_retries))
&& ep.device == sock->device
&& ep.addr == sock->local_endpoint.address(); });
if (match != eps.end())
{
// we don't need to create a new listen socket for this endpoint
// so remove it from the list
eps.erase(match);
keep.splice(keep.end(), m_listen_sockets, sock++);
continue;
}
// this socket's local_endpoint is not on the list of endpoints to listen on
// it has got to go
// TODO notify interested parties of this socket's demise
#ifndef TORRENT_DISABLE_LOGGING
session_log("Closing listen socket for %s on device \"%s\""
, print_endpoint(sock->local_endpoint).c_str(), sock->device.c_str());
, print_endpoint(remove_iter->local_endpoint).c_str()
, remove_iter->device.c_str());
#endif
if (sock->sock) sock->sock->close(ec);
if (sock->udp_sock) sock->udp_sock->close();
sock = m_listen_sockets.erase(sock);
if (remove_iter->sock) remove_iter->sock->close(ec);
if (remove_iter->udp_sock) remove_iter->udp_sock->close();
remove_iter = m_listen_sockets.erase(remove_iter);
}
TORRENT_ASSERT(m_listen_sockets.empty());
m_listen_sockets.swap(keep);
// open new sockets on any endpoints that didn't match with
// an existing socket
for (auto const& ep : eps)

View File

@ -146,6 +146,7 @@ test-suite libtorrent :
test_enum_net.cpp
test_linked_list.cpp
test_stack_allocator.cpp
test_listen_socket.cpp
test_file_progress.cpp ]
[ run test_piece_picker.cpp ]

View File

@ -155,6 +155,7 @@ test_primitives_SOURCES = \
test_resolve_links.cpp \
test_crc32.cpp \
test_heterogeneous_queue.cpp \
test_listen_socket.cpp \
test_ip_voter.cpp \
test_sliding_average.cpp \
test_socket_io.cpp \

201
test/test_listen_socket.cpp Normal file
View File

@ -0,0 +1,201 @@
/*
Copyright (c) 2016, Steven Siloti
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the distribution.
* Neither the name of the author nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
*/
#include "test.hpp"
#include "libtorrent/aux_/session_impl.hpp"
using namespace libtorrent;
namespace
{
void test_equal(listen_socket_t const& s, address addr, int port, std::string dev, bool ssl)
{
TEST_EQUAL(s.ssl, ssl);
TEST_EQUAL(s.local_endpoint.address(), addr);
TEST_EQUAL(s.original_port, port);
TEST_EQUAL(s.device, dev);
}
void test_equal(aux::listen_endpoint_t const& e1, address addr, int port, std::string dev, bool ssl)
{
TEST_EQUAL(e1.ssl, ssl);
TEST_EQUAL(e1.port, port);
TEST_EQUAL(e1.addr, addr);
TEST_EQUAL(e1.device, dev);
}
}
TORRENT_TEST(partition_listen_sockets)
{
{
std::list<listen_socket_t> sockets;
listen_socket_t s;
s.local_endpoint = tcp::endpoint(tcp::v4(), 6881);
s.original_port = 6881;
sockets.push_back(s);
s.local_endpoint = tcp::endpoint(tcp::v6(), 6881);
sockets.push_back(s);
// remove the wildcard v6 socket and replace it with a specific global IP
std::vector<aux::listen_endpoint_t> eps;
eps.emplace_back(address_v4(), 6881, "", false);
eps.emplace_back(address_v6::from_string("2001::1"), 6881, "", false);
auto remove_iter = aux::partition_listen_sockets(eps, sockets);
TEST_EQUAL(eps.size(), 1);
TEST_EQUAL(std::distance(sockets.begin(), remove_iter), 1);
TEST_EQUAL(std::distance(remove_iter, sockets.end()), 1);
test_equal(sockets.front(), address_v4(), 6881, "", false);
test_equal(sockets.back(), address_v6(), 6881, "", false);
test_equal(eps.front(), address_v6::from_string("2001::1"), 6881, "", false);
}
{
std::list<listen_socket_t> sockets;
listen_socket_t s;
s.local_endpoint = tcp::endpoint(tcp::v4(), 6881);
s.original_port = 6881;
sockets.push_back(s);
s.local_endpoint = tcp::endpoint(tcp::v6(), 6881);
sockets.push_back(s);
// change the ports
std::vector<aux::listen_endpoint_t> eps;
eps.emplace_back(address_v4(), 6882, "", false);
eps.emplace_back(address_v6(), 6882, "", false);
auto remove_iter = aux::partition_listen_sockets(eps, sockets);
TEST_CHECK(sockets.begin() == remove_iter);
TEST_EQUAL(eps.size(), 2);
}
{
std::list<listen_socket_t> sockets;
listen_socket_t s;
s.local_endpoint = tcp::endpoint(address_v6::from_string("2001::1"), 6881);
s.original_port = 6881;
sockets.push_back(s);
s.local_endpoint = tcp::endpoint(tcp::v4(), 6881);
sockets.push_back(s);
// replace the IPv6 socket with a pair of device bound sockets
std::vector<aux::listen_endpoint_t> eps;
eps.emplace_back(address_v4(), 6881, "", false);
eps.emplace_back(address_v6::from_string("2001::1"), 6881, "eth1", false);
eps.emplace_back(address_v6::from_string("2001::2"), 6881, "eth1", false);
auto remove_iter = aux::partition_listen_sockets(eps, sockets);
TEST_EQUAL(std::distance(sockets.begin(), remove_iter), 1);
TEST_EQUAL(std::distance(remove_iter, sockets.end()), 1);
test_equal(sockets.front(), address_v4(), 6881, "", false);
test_equal(sockets.back(), address_v6::from_string("2001::1"), 6881, "", false);
TEST_EQUAL(eps.size(), 2);
}
{
std::list<listen_socket_t> sockets;
listen_socket_t s;
s.local_endpoint = tcp::endpoint(address_v6::from_string("fe80::d250:99ff:fe0c:9b74"), 6881);
s.device = "enp3s0";
s.original_port = 6881;
sockets.push_back(s);
s.local_endpoint = tcp::endpoint(address_v6::from_string("2001::1"), 6881);
sockets.push_back(s);
// change the global IP of a device bound socket
std::vector<aux::listen_endpoint_t> eps;
eps.emplace_back(address_v6::from_string("fe80::d250:99ff:fe0c:9b74"), 6881, "enp3s0", false);
eps.emplace_back(address_v6::from_string("2001::2"), 6881, "enp3s0", false);
auto remove_iter = aux::partition_listen_sockets(eps, sockets);
TEST_EQUAL(std::distance(sockets.begin(), remove_iter), 1);
TEST_EQUAL(std::distance(remove_iter, sockets.end()), 1);
test_equal(sockets.front(), address_v6::from_string("fe80::d250:99ff:fe0c:9b74"), 6881, "enp3s0", false);
test_equal(sockets.back(), address_v6::from_string("2001::1"), 6881, "enp3s0", false);
TEST_EQUAL(eps.size(), 1);
test_equal(eps.front(), address_v6::from_string("2001::2"), 6881, "enp3s0", false);
}
{
std::list<listen_socket_t> sockets;
listen_socket_t s;
s.local_endpoint = tcp::endpoint(tcp::v4(), 6883);
s.original_port = 6881;
sockets.push_back(s);
s.local_endpoint = tcp::endpoint(tcp::v6(), 6883);
sockets.push_back(s);
// make sure all sockets are kept when the actual port is different from the original
std::vector<aux::listen_endpoint_t> eps;
eps.emplace_back(address_v4(), 6881, "", false);
eps.emplace_back(address_v6(), 6881, "", false);
auto remove_iter = aux::partition_listen_sockets(eps, sockets);
TEST_CHECK(remove_iter == sockets.end());
TEST_CHECK(eps.empty());
}
{
std::list<listen_socket_t> sockets;
listen_socket_t s;
s.local_endpoint = tcp::endpoint(tcp::v4(), 6881);
s.original_port = 6881;
sockets.push_back(s);
s.local_endpoint = tcp::endpoint(tcp::v6(), 6881);
sockets.push_back(s);
// add ssl sockets
std::vector<aux::listen_endpoint_t> eps;
eps.emplace_back(address_v4(), 6881, "", false);
eps.emplace_back(address_v6(), 6881, "", false);
eps.emplace_back(address_v4(), 6881, "", true);
eps.emplace_back(address_v6(), 6881, "", true);
auto remove_iter = aux::partition_listen_sockets(eps, sockets);
TEST_CHECK(remove_iter == sockets.end());
TEST_EQUAL(eps.size(), 2);
}
{
std::list<listen_socket_t> sockets;
listen_socket_t s;
s.local_endpoint = tcp::endpoint(tcp::v4(), 6881);
s.original_port = 0;
sockets.push_back(s);
s.local_endpoint = tcp::endpoint(tcp::v6(), 6881);
sockets.push_back(s);
// replace OS assigned ports with explicit ports
std::vector<aux::listen_endpoint_t> eps;
eps.emplace_back(address_v4(), 6882, "", false);
eps.emplace_back(address_v6(), 6882, "", false);
auto remove_iter = aux::partition_listen_sockets(eps, sockets);
TEST_CHECK(remove_iter == sockets.begin());
TEST_EQUAL(eps.size(), 2);
}
}