diff --git a/CMakeLists.txt b/CMakeLists.txt index a0af8b741..210d6c42d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -78,6 +78,7 @@ set(sources session_handle session_impl session_settings + session_udp_sockets proxy_settings session_stats settings_pack diff --git a/Jamfile b/Jamfile index 3352c6735..c6ab85598 100644 --- a/Jamfile +++ b/Jamfile @@ -619,6 +619,7 @@ SOURCES = session_handle session_impl session_call + session_udp_sockets settings_pack sha1_hash socket_io diff --git a/include/libtorrent/alert_types.hpp b/include/libtorrent/alert_types.hpp index ebfcecef5..00eb1fcd9 100644 --- a/include/libtorrent/alert_types.hpp +++ b/include/libtorrent/alert_types.hpp @@ -1188,8 +1188,8 @@ namespace libtorrent { virtual std::string message() const override; }; - // This alert is posted when there is an error on the UDP socket. The - // UDP socket is used for all uTP, DHT and UDP tracker traffic. It's + // This alert is posted when there is an error on a UDP socket. The + // UDP sockets are used for all uTP, DHT and UDP tracker traffic. They are // global to the session. struct TORRENT_EXPORT udp_error_alert final : alert { diff --git a/include/libtorrent/aux_/session_impl.hpp b/include/libtorrent/aux_/session_impl.hpp index 9a080e908..df7d09fbb 100644 --- a/include/libtorrent/aux_/session_impl.hpp +++ b/include/libtorrent/aux_/session_impl.hpp @@ -281,6 +281,7 @@ namespace aux { void on_ip_change(error_code const& ec); void reopen_listen_sockets(); + void reopen_outgoing_sockets(); torrent_peer_allocator_interface* get_peer_allocator() override { return &m_peer_allocator; } @@ -913,6 +914,8 @@ namespace aux { // we might need more than one listen socket std::list m_listen_sockets; + outgoing_sockets m_outgoing_sockets; + #if TORRENT_USE_I2P i2p_connection m_i2p_conn; std::shared_ptr m_i2p_listen_socket; @@ -1059,7 +1062,14 @@ namespace aux { int m_outstanding_router_lookups = 0; #endif - void send_udp_packet_hostname(char const* hostname + void send_udp_packet_hostname_deprecated(char const* hostname + , int port + , span p + , error_code& ec + , int flags); + + void send_udp_packet_hostname(std::weak_ptr sock + , char const* hostname , int port , span p , error_code& ec @@ -1070,9 +1080,19 @@ namespace aux { , int port , span p , error_code& ec + , int flags) + { + listen_socket_t* s = static_cast(sock); + send_udp_packet_hostname(s->udp_sock, hostname, port, p, ec, flags); + } + + void send_udp_packet_deprecated(bool ssl + , udp::endpoint const& ep + , span p + , error_code& ec , int flags); - void send_udp_packet(bool ssl + void send_udp_packet(std::weak_ptr sock , udp::endpoint const& ep , span p , error_code& ec @@ -1082,7 +1102,11 @@ namespace aux { , udp::endpoint const& ep , span p , error_code& ec - , int flags); + , int flags) + { + listen_socket_t* s = static_cast(sock); + send_udp_packet(s->udp_sock, ep, p, ec, flags); + } void on_udp_writeable(std::weak_ptr s, error_code const& ec); diff --git a/include/libtorrent/aux_/session_udp_sockets.hpp b/include/libtorrent/aux_/session_udp_sockets.hpp index 43a7c0567..2fe2c3272 100644 --- a/include/libtorrent/aux_/session_udp_sockets.hpp +++ b/include/libtorrent/aux_/session_udp_sockets.hpp @@ -33,19 +33,22 @@ POSSIBILITY OF SUCH DAMAGE. #ifndef TORRENT_SESSION_UDP_SOCKETS_HPP_INCLUDED #define TORRENT_SESSION_UDP_SOCKETS_HPP_INCLUDED -#include "libtorrent/udp_socket.hpp" +#include "libtorrent/utp_socket_manager.hpp" #include "libtorrent/config.hpp" #include #include namespace libtorrent { namespace aux { - struct session_udp_socket + struct listen_endpoint_t; + struct proxy_settings; + + struct session_udp_socket : utp_socket_interface { explicit session_udp_socket(io_service& ios) : sock(ios) {} - udp::endpoint local_endpoint() { return sock.local_endpoint(); } + udp::endpoint local_endpoint() override { return sock.local_endpoint(); } udp_socket sock; @@ -56,6 +59,43 @@ namespace libtorrent { namespace aux { bool write_blocked = false; }; + struct outgoing_udp_socket final : session_udp_socket + { + outgoing_udp_socket(io_service& ios, std::string const& dev, bool ssl_) + : session_udp_socket(ios), device(dev), ssl(ssl_) {} + + // the name of the device the socket is bound to, may be empty + // if the socket is not bound to a device + std::string const device; + + // set to true if this is an SSL socket + bool const ssl; + }; + + // sockets used for outoing utp connections + struct TORRENT_EXTRA_EXPORT outgoing_sockets + { + // partitions sockets based on whether they match one of the given endpoints + // all matched sockets are ordered before unmatched sockets + // matched endpoints are removed from the vector + // returns an iterator to the first unmatched socket + std::vector>::iterator + partition_outgoing_sockets(std::vector& eps); + + tcp::endpoint bind(socket_type& s, address const& remote_address) const; + + void update_proxy(proxy_settings const& settings); + + // close all sockets + void close(); + + std::vector> sockets; + private: + // round-robin index into sockets + // one dimention for IPv4/IPv6 and a second for SSL/non-SSL + mutable std::uint8_t index[2][2] = { {0, 0}, {0, 0} }; + }; + } } #endif diff --git a/include/libtorrent/utp_socket_manager.hpp b/include/libtorrent/utp_socket_manager.hpp index 882238f25..5296e09b7 100644 --- a/include/libtorrent/utp_socket_manager.hpp +++ b/include/libtorrent/utp_socket_manager.hpp @@ -49,9 +49,18 @@ namespace libtorrent { struct utp_socket_impl; struct counters; - struct utp_socket_manager final + // interface/handle to the underlying udp socket + struct TORRENT_EXTRA_EXPORT utp_socket_interface { - typedef std::function + , udp::endpoint const& , span , error_code&, int)> send_fun_t; @@ -67,7 +76,8 @@ namespace libtorrent { ~utp_socket_manager(); // return false if this is not a uTP packet - bool incoming_packet(udp::endpoint const& ep, span p); + bool incoming_packet(std::weak_ptr socket + , udp::endpoint const& ep, span p); // if the UDP socket failed with an EAGAIN or EWOULDBLOCK, this will be // called once the socket is writeable again @@ -82,10 +92,13 @@ namespace libtorrent { // flags for send_packet enum { dont_fragment = 1 }; - void send_packet(udp::endpoint const& ep, char const* p, int len + void send_packet(std::weak_ptr sock, udp::endpoint const& ep + , char const* p, int len , error_code& ec, int flags = 0); void subscribe_writable(utp_socket_impl* s); + void remove_udp_socket(std::weak_ptr sock); + // internal, used by utp_stream void remove_socket(std::uint16_t id); diff --git a/include/libtorrent/utp_stream.hpp b/include/libtorrent/utp_stream.hpp index bccdad2f4..413dc6996 100644 --- a/include/libtorrent/utp_stream.hpp +++ b/include/libtorrent/utp_stream.hpp @@ -148,15 +148,19 @@ struct utp_header }; struct utp_socket_impl; +struct utp_socket_interface; utp_socket_impl* construct_utp_impl(std::uint16_t recv_id , std::uint16_t send_id, void* userdata , utp_socket_manager& sm); void detach_utp_impl(utp_socket_impl* s); void delete_utp_impl(utp_socket_impl* s); +void utp_abort(utp_socket_impl* s); bool should_delete(utp_socket_impl* s); +bool bound_to_udp_socket(utp_socket_impl* s, std::weak_ptr sock); void tick_utp_impl(utp_socket_impl* s, time_point now); void utp_init_mtu(utp_socket_impl* s, int link_mtu, int utp_mtu); +void utp_init_socket(utp_socket_impl* s, std::weak_ptr sock); bool utp_incoming_packet(utp_socket_impl* s, span p , udp::endpoint const& ep, time_point receive_time); bool utp_match(utp_socket_impl* s, udp::endpoint const& ep, std::uint16_t id); diff --git a/src/Makefile.am b/src/Makefile.am index f1a7ed145..2ad890dab 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -123,6 +123,7 @@ libtorrent_rasterbar_la_SOURCES = \ session_handle.cpp \ session_impl.cpp \ session_settings.cpp \ + session_udp_sockets.cpp \ proxy_settings.cpp \ settings_pack.cpp \ sha1_hash.cpp \ diff --git a/src/bt_peer_connection.cpp b/src/bt_peer_connection.cpp index a084d6541..9c58507e5 100644 --- a/src/bt_peer_connection.cpp +++ b/src/bt_peer_connection.cpp @@ -1808,7 +1808,8 @@ namespace libtorrent { { address_v4::bytes_type bytes; std::copy(myip.begin(), myip.end(), bytes.begin()); - m_ses.set_external_address(address_v4(bytes) + m_ses.set_external_address(local_endpoint() + , address_v4(bytes) , aux::session_interface::source_peer, remote().address()); } #if TORRENT_USE_IPV6 @@ -1818,10 +1819,12 @@ namespace libtorrent { std::copy(myip.begin(), myip.end(), bytes.begin()); address_v6 ipv6_address(bytes); if (ipv6_address.is_v4_mapped()) - m_ses.set_external_address(ipv6_address.to_v4() + m_ses.set_external_address(local_endpoint() + , ipv6_address.to_v4() , aux::session_interface::source_peer, remote().address()); else - m_ses.set_external_address(ipv6_address + m_ses.set_external_address(local_endpoint() + , ipv6_address , aux::session_interface::source_peer, remote().address()); } #endif diff --git a/src/session_impl.cpp b/src/session_impl.cpp index 3f4e4f332..bd2d32c0c 100644 --- a/src/session_impl.cpp +++ b/src/session_impl.cpp @@ -423,8 +423,8 @@ namespace aux { , m_upload_rate(peer_connection::upload_channel) , m_host_resolver(m_io_service) , m_tracker_manager( - std::bind(&session_impl::send_udp_packet, this, false, _1, _2, _3, _4) - , std::bind(&session_impl::send_udp_packet_hostname, this, _1, _2, _3, _4, _5) + std::bind(&session_impl::send_udp_packet_deprecated, this, false, _1, _2, _3, _4) + , std::bind(&session_impl::send_udp_packet_hostname_deprecated, this, _1, _2, _3, _4, _5) , m_stats_counters , m_host_resolver , m_settings @@ -446,13 +446,13 @@ namespace aux { , m_dht_announce_timer(m_io_service) #endif , m_utp_socket_manager( - std::bind(&session_impl::send_udp_packet, this, false, _1, _2, _3, _4) + std::bind(&session_impl::send_udp_packet, this, _1, _2, _3, _4, _5) , std::bind(&session_impl::incoming_connection, this, _1) , m_io_service , m_settings, m_stats_counters, nullptr) #ifdef TORRENT_USE_OPENSSL , m_ssl_utp_socket_manager( - std::bind(&session_impl::send_udp_packet, this, true, _1, _2, _3, _4) + std::bind(&session_impl::send_udp_packet, this, _1, _2, _3, _4, _5) , std::bind(&session_impl::on_incoming_utp_ssl, this, _1) , m_io_service , m_settings, m_stats_counters @@ -943,6 +943,8 @@ namespace aux { } } + m_outgoing_sockets.close(); + #if TORRENT_USE_I2P if (m_i2p_listen_socket && m_i2p_listen_socket->is_open()) { @@ -1351,6 +1353,11 @@ namespace { && pack.get_str(settings_pack::listen_interfaces) != m_settings.get_str(settings_pack::listen_interfaces)); + bool const reopen_outgoing_port = + (pack.has_val(settings_pack::outgoing_interfaces) + && pack.get_str(settings_pack::outgoing_interfaces) + != m_settings.get_str(settings_pack::outgoing_interfaces)); + #ifndef TORRENT_DISABLE_LOGGING session_log("applying settings pack, init=%s, reopen_listen_port=%s" , init ? "true" : "false", reopen_listen_port ? "true" : "false"); @@ -1370,6 +1377,9 @@ namespace { { reopen_listen_sockets(); } + + if (init || reopen_outgoing_port) + reopen_outgoing_sockets(); } // TODO: 3 try to remove these functions. They are misleading and not very @@ -1761,6 +1771,7 @@ namespace { m_ip_notifier->async_wait([this] (error_code const& e) { this->wrap(&session_impl::on_ip_change, e); }); reopen_listen_sockets(); + reopen_outgoing_sockets(); } void session_impl::interface_to_endpoints(std::string const& device, int const port @@ -1973,6 +1984,143 @@ namespace { #endif } + void session_impl::reopen_outgoing_sockets() + { + // first build a list of endpoints we should be listening on + // we need to remove any unneeded sockets first to avoid the possibility + // of a new socket failing to bind due to a conflict with a stale socket + std::vector eps; + + for (auto const& iface : m_outgoing_interfaces) + { + interface_to_endpoints(iface, 0, false, eps); +#ifdef TORRENT_USE_OPENSSL + interface_to_endpoints(iface, 0, true, eps); +#endif + } + + // if no outgoing interfaces are specified, create sockets to use + // any interface + if (eps.empty()) + { + eps.emplace_back(address_v4(), 0, "", false); +#if TORRENT_USE_IPV6 + eps.emplace_back(address_v6(), 0, "", false); +#endif +#ifdef TORRENT_USE_OPENSSL + eps.emplace_back(address_v4(), 0, "", true); +#if TORRENT_USE_IPV6 + eps.emplace_back(address_v6(), 0, "", true); +#endif +#endif + } + + auto remove_iter = m_outgoing_sockets.partition_outgoing_sockets(eps); + + for (auto i = remove_iter; i != m_outgoing_sockets.sockets.end(); ++i) + { + auto& remove_sock = *i; + m_utp_socket_manager.remove_udp_socket(remove_sock); + +#ifndef TORRENT_DISABLE_LOGGING + if (should_log()) + { + session_log("Closing outgoing UDP socket for %s on device \"%s\"" + , print_endpoint(remove_sock->local_endpoint()).c_str() + , remove_sock->device.c_str()); + } +#endif + remove_sock->sock.close(); + } + + m_outgoing_sockets.sockets.erase(remove_iter, m_outgoing_sockets.sockets.end()); + + // open new sockets on any endpoints that didn't match with + // an existing socket + for (auto const& ep : eps) + { + error_code ec; + udp::endpoint const udp_bind_ep(ep.addr, 0); + + auto udp_sock = std::make_shared(m_io_service, ep.device, ep.ssl); + udp_sock->sock.open(udp_bind_ep.protocol(), ec); + if (ec) + { +#ifndef TORRENT_DISABLE_LOGGING + if (should_log()) + { + session_log("failed to open UDP socket: %s: %s" + , ep.device.c_str(), ec.message().c_str()); + } +#endif + if (m_alerts.should_post()) + m_alerts.emplace_alert(udp_bind_ep, ec); + continue; + } + +#if TORRENT_HAS_BINDTODEVICE + if (!ep.device.empty()) + { + udp_sock->sock.set_option(bind_to_device(ep.device.c_str()), ec); + if (ec) + { +#ifndef TORRENT_DISABLE_LOGGING + if (should_log()) + { + session_log("bind to device failed (device: %s): %s" + , ep.device.c_str(), ec.message().c_str()); + } +#endif // TORRENT_DISABLE_LOGGING + + if (m_alerts.should_post()) + m_alerts.emplace_alert(udp_bind_ep, ec); + continue; + } + } +#endif + udp_sock->sock.bind(udp_bind_ep, ec); + + if (ec) + { +#ifndef TORRENT_DISABLE_LOGGING + if (should_log()) + { + session_log("failed to bind UDP socket: %s: %s" + , ep.device.c_str(), ec.message().c_str()); + } +#endif + if (m_alerts.should_post()) + m_alerts.emplace_alert(udp_bind_ep, ec); + continue; + } + + error_code err; + set_socket_buffer_size(udp_sock->sock, m_settings, err); + if (err) + { + if (m_alerts.should_post()) + m_alerts.emplace_alert(udp_sock->sock.local_endpoint(ec), err); + } + + udp_sock->sock.set_force_proxy(m_settings.get_bool(settings_pack::force_proxy)); + // this call is necessary here because, unless the settings actually + // change after the session is up and listening, at no other point + // set_proxy_settings is called with the correct proxy configuration, + // internally, this method handle the SOCKS5's connection logic + udp_sock->sock.set_proxy_settings(proxy()); + + // TODO: 2 use a handler allocator here + ADD_OUTSTANDING_ASYNC("session_impl::on_udp_packet"); + udp_sock->sock.async_read(std::bind(&session_impl::on_udp_packet + , this, udp_sock, ep.ssl, _1)); + + if (!ec && udp_sock) + { + m_outgoing_sockets.sockets.push_back(udp_sock); + } + } + } + namespace { template void map_port(MapProtocol& m, ProtoType protocol, EndpointType const& ep @@ -2101,7 +2249,7 @@ namespace { } #endif - void session_impl::send_udp_packet_hostname(char const* hostname + void session_impl::send_udp_packet_hostname_deprecated(char const* hostname , int const port , span p , error_code& ec @@ -2114,20 +2262,27 @@ namespace { if (!i.udp_sock) continue; if (i.ssl) continue; - send_udp_packet_hostname_listen(&i, hostname, port, p, ec, flags); + send_udp_packet_hostname(i.udp_sock, hostname, port, p, ec, flags); return; } ec = boost::asio::error::operation_not_supported; } - void session_impl::send_udp_packet_hostname_listen(aux::session_listen_socket* sock + void session_impl::send_udp_packet_hostname(std::weak_ptr sock , char const* hostname , int const port , span p , error_code& ec , int const flags) { - auto s = static_cast(sock)->udp_sock; + auto si = sock.lock(); + if (!si) + { + ec = boost::asio::error::bad_descriptor; + return; + } + + auto s = std::static_pointer_cast(si); s->sock.send_hostname(hostname, port, p, ec, flags); @@ -2141,7 +2296,7 @@ namespace { } } - void session_impl::send_udp_packet(bool const ssl + void session_impl::send_udp_packet_deprecated(bool const ssl , udp::endpoint const& ep , span p , error_code& ec @@ -2157,27 +2312,32 @@ namespace { if (i.local_endpoint.address().is_v4() != ep.address().is_v4()) continue; - send_udp_packet_listen(&i, ep, p, ec, flags); + send_udp_packet(i.udp_sock, ep, p, ec, flags); return; } ec = boost::asio::error::operation_not_supported; } - void session_impl::send_udp_packet_listen(aux::session_listen_socket* sock + void session_impl::send_udp_packet(std::weak_ptr sock , udp::endpoint const& ep , span p , error_code& ec , int const flags) { - auto s = static_cast(sock)->udp_sock; + auto si = sock.lock(); + if (!si) + { + ec = boost::asio::error::bad_descriptor; + return; + } + + auto s = std::static_pointer_cast(si); TORRENT_ASSERT(s->sock.local_endpoint().protocol() == ep.protocol()); s->sock.send(ep, p, ec, flags); - if ((ec == error::would_block - || ec == error::try_again) - && !s->write_blocked) + if ((ec == error::would_block || ec == error::try_again) && !s->write_blocked) { s->write_blocked = true; ADD_OUTSTANDING_ASYNC("session_impl::on_udp_writeable"); @@ -2280,7 +2440,7 @@ namespace { // give the uTP socket manager first dis on the packet. Presumably // the majority of packets are uTP packets. - if (!mgr.incoming_packet(packet.from, buf)) + if (!mgr.incoming_packet(socket, packet.from, buf)) { // if it wasn't a uTP packet, try the other users of the UDP // socket @@ -4835,6 +4995,13 @@ namespace { bind_ep.port(std::uint16_t(next_port())); } + if (is_utp(s)) + { + auto ep = m_outgoing_sockets.bind(s, remote_address); + if (ep.port() != 0) + return ep; + } + if (!m_outgoing_interfaces.empty()) { if (m_interface_index >= m_outgoing_interfaces.size()) m_interface_index = 0; @@ -5097,6 +5264,7 @@ namespace { { for (auto& i : m_listen_sockets) i.udp_sock->sock.set_proxy_settings(proxy()); + m_outgoing_sockets.update_proxy(proxy()); } void session_impl::update_upnp() diff --git a/src/session_udp_sockets.cpp b/src/session_udp_sockets.cpp new file mode 100644 index 000000000..94a6b7b11 --- /dev/null +++ b/src/session_udp_sockets.cpp @@ -0,0 +1,117 @@ +/* + +Copyright (c) 2017, Arvid Norberg, Steven Siloti +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +* Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the distribution. +* Neither the name of the author nor the names of its +contributors may be used to endorse or promote products derived +from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "libtorrent/aux_/session_udp_sockets.hpp" +#include "libtorrent/aux_/session_impl.hpp" + +namespace libtorrent { namespace aux { + + std::vector>::iterator + outgoing_sockets::partition_outgoing_sockets(std::vector& eps) + { + return std::partition(sockets.begin(), sockets.end() + , [&eps](std::shared_ptr const& sock) + { + auto match = std::find_if(eps.begin(), eps.end() + , [&sock](listen_endpoint_t const& ep) + { + return ep.device == sock->device + && ep.addr == sock->sock.local_endpoint().address() + && ep.ssl == sock->ssl; + }); + + if (match != eps.end()) + { + // remove the matched endpoint to signal the caller that it + // doesn't need to create a socket for the endpoint + eps.erase(match); + return true; + } + else + { + return false; + } + }); + } + + tcp::endpoint outgoing_sockets::bind(socket_type& s, address const& remote_address) const + { + TORRENT_ASSERT(!sockets.empty()); + + utp_socket_impl* impl = nullptr; + bool ssl = false; +#ifdef TORRENT_USE_OPENSSL + if (s.get>() != nullptr) + { + impl = s.get>()->next_layer().get_impl(); + ssl = true; + } + else +#endif + impl = s.get()->get_impl(); + + auto& idx = index[remote_address.is_v4() ? 0 : 1][ssl ? 1 : 0]; + auto const index_begin = idx; + + for (;;) + { + if (++idx >= sockets.size()) + idx = 0; + + if (sockets[idx]->local_endpoint().address().is_v4() != remote_address.is_v4() + || sockets[idx]->ssl != ssl) + { + if (idx == index_begin) break; + continue; + } + + utp_init_socket(impl, sockets[idx]); + auto udp_ep = sockets[idx]->local_endpoint(); + return tcp::endpoint(udp_ep.address(), udp_ep.port()); + } + + return tcp::endpoint(); + } + + void outgoing_sockets::update_proxy(proxy_settings const& settings) + { + for (auto const& i : sockets) + i->sock.set_proxy_settings(settings); + } + + void outgoing_sockets::close() + { + for (auto const& l : sockets) + l->sock.close(); + } + +} } diff --git a/src/utp_socket_manager.cpp b/src/utp_socket_manager.cpp index d03cd77b9..400d62d91 100644 --- a/src/utp_socket_manager.cpp +++ b/src/utp_socket_manager.cpp @@ -138,8 +138,9 @@ namespace libtorrent { utp_mtu = std::min(mtu, restrict_mtu()); } - void utp_socket_manager::send_packet(udp::endpoint const& ep, char const* p - , int const len, error_code& ec, int const flags) + void utp_socket_manager::send_packet(std::weak_ptr sock + , udp::endpoint const& ep, char const* p + , int const len, error_code& ec, int flags) { #if !defined TORRENT_HAS_DONT_FRAGMENT && !defined TORRENT_DEBUG_MTU TORRENT_UNUSED(flags); @@ -150,13 +151,13 @@ namespace libtorrent { if ((flags & dont_fragment) && len > TORRENT_DEBUG_MTU) return; #endif - m_send_fun(ep, {p, std::size_t(len)}, ec + m_send_fun(sock, ep, {p, std::size_t(len)}, ec , ((flags & dont_fragment) ? udp_socket::dont_fragment : 0) | udp_socket::peer_connection); } - bool utp_socket_manager::incoming_packet(udp::endpoint const& ep - , span p) + bool utp_socket_manager::incoming_packet(std::weak_ptr socket + , udp::endpoint const& ep, span p) { // UTP_LOGV("incoming packet size:%d\n", size); @@ -230,6 +231,7 @@ namespace libtorrent { int link_mtu, utp_mtu; mtu_for_dest(ep.address(), link_mtu, utp_mtu); utp_init_mtu(str->get_impl(), link_mtu, utp_mtu); + utp_init_socket(str->get_impl(), socket); bool ret = utp_incoming_packet(str->get_impl(), p, ep, receive_time); if (!ret) return false; m_cb(c); @@ -304,6 +306,17 @@ namespace libtorrent { m_drained_event.push_back(s); } + void utp_socket_manager::remove_udp_socket(std::weak_ptr sock) + { + for (auto& s : m_utp_sockets) + { + if (!bound_to_udp_socket(s.second, sock)) + continue; + + utp_abort(s.second); + } + } + void utp_socket_manager::remove_socket(std::uint16_t id) { socket_map_t::iterator i = m_utp_sockets.find(id); diff --git a/src/utp_stream.cpp b/src/utp_stream.cpp index 244a582d6..4b84717f5 100644 --- a/src/utp_stream.cpp +++ b/src/utp_stream.cpp @@ -322,6 +322,7 @@ public: #endif utp_socket_manager& m_sm; + std::weak_ptr m_sock; // userdata pointer passed along // with any callback. This is initialized to 0 @@ -658,11 +659,23 @@ void delete_utp_impl(utp_socket_impl* s) delete s; } +void utp_abort(utp_socket_impl* s) +{ + s->m_error = boost::asio::error::connection_aborted; + s->set_state(utp_socket_impl::UTP_STATE_ERROR_WAIT); + s->test_socket_state(); +} + bool should_delete(utp_socket_impl* s) { return s->should_delete(); } +bool bound_to_udp_socket(utp_socket_impl* s, std::weak_ptr sock) +{ + return s->m_sock.lock() == sock.lock(); +} + void tick_utp_impl(utp_socket_impl* s, time_point now) { s->tick(now); @@ -673,6 +686,11 @@ void utp_init_mtu(utp_socket_impl* s, int link_mtu, int utp_mtu) s->init_mtu(link_mtu, utp_mtu); } +void utp_init_socket(utp_socket_impl* s, std::weak_ptr sock) +{ + s->m_sock = sock; +} + bool utp_incoming_packet(utp_socket_impl* s , span p , udp::endpoint const& ep, time_point const receive_time) @@ -812,8 +830,18 @@ utp_stream::endpoint_type utp_stream::local_endpoint(error_code& ec) const if (m_impl == nullptr) { ec = boost::asio::error::not_connected; + return endpoint_type(); } - return endpoint_type(); + + auto s = m_impl->m_sock.lock(); + if (!s) + { + ec = boost::asio::error::not_connected; + return endpoint_type(); + } + + udp::endpoint ep = s->local_endpoint(); + return endpoint_type(ep.address(), ep.port()); } utp_stream::~utp_stream() @@ -1307,7 +1335,7 @@ void utp_socket_impl::send_syn() #endif error_code ec; - m_sm.send_packet(udp::endpoint(m_remote_address, m_port) + m_sm.send_packet(m_sock, udp::endpoint(m_remote_address, m_port) , reinterpret_cast(h) , sizeof(utp_header), ec); if (ec == error::would_block || ec == error::try_again) @@ -1398,7 +1426,7 @@ void utp_socket_impl::send_reset(utp_header const* ph) // ignore errors here error_code ec; - m_sm.send_packet(udp::endpoint(m_remote_address, m_port) + m_sm.send_packet(m_sock, udp::endpoint(m_remote_address, m_port) , reinterpret_cast(&h), sizeof(h), ec); if (ec) { @@ -1934,7 +1962,7 @@ bool utp_socket_impl::send_pkt(int const flags) #endif error_code ec; - m_sm.send_packet(udp::endpoint(m_remote_address, m_port) + m_sm.send_packet(m_sock, udp::endpoint(m_remote_address, m_port) , reinterpret_cast(h), p->size, ec , p->mtu_probe ? utp_socket_manager::dont_fragment : 0); @@ -1963,7 +1991,7 @@ bool utp_socket_impl::send_pkt(int const flags) #if TORRENT_UTP_LOG UTP_LOGV("%8p: re-sending\n", static_cast(this)); #endif - m_sm.send_packet(udp::endpoint(m_remote_address, m_port) + m_sm.send_packet(m_sock, udp::endpoint(m_remote_address, m_port) , reinterpret_cast(h), p->size, ec, 0); } @@ -2135,7 +2163,7 @@ bool utp_socket_impl::resend_packet(packet* p, bool fast_resend) h->ack_nr = m_ack_nr; error_code ec; - m_sm.send_packet(udp::endpoint(m_remote_address, m_port) + m_sm.send_packet(m_sock, udp::endpoint(m_remote_address, m_port) , reinterpret_cast(p->buf), p->size, ec); ++m_out_packets; m_sm.inc_stats_counter(counters::utp_packets_out);