refactor to use std::shared_ptr with dht observer (#1057)

refactor to use std::shared_ptr with dht observer
This commit is contained in:
Alden Torres 2016-09-02 21:05:11 -04:00 committed by Arvid Norberg
parent fd4c09d7b2
commit 29a4075555
17 changed files with 79 additions and 111 deletions

View File

@ -71,7 +71,7 @@ struct find_data : traversal_algorithm
protected: protected:
virtual void done(); virtual void done();
virtual observer_ptr new_observer(void* ptr, udp::endpoint const& ep virtual observer_ptr new_observer(udp::endpoint const& ep
, node_id const& id); , node_id const& id);
nodes_callback m_nodes_callback; nodes_callback m_nodes_callback;

View File

@ -66,7 +66,8 @@ public:
virtual char const* name() const; virtual char const* name() const;
protected: protected:
virtual observer_ptr new_observer(void* ptr, udp::endpoint const& ep, node_id const& id); virtual observer_ptr new_observer(udp::endpoint const& ep
, node_id const& id);
virtual bool invoke(observer_ptr o); virtual bool invoke(observer_ptr o);
virtual void done(); virtual void done();

View File

@ -53,7 +53,8 @@ struct get_peers : find_data
protected: protected:
virtual bool invoke(observer_ptr o); virtual bool invoke(observer_ptr o);
virtual observer_ptr new_observer(void* ptr, udp::endpoint const& ep, node_id const& id); virtual observer_ptr new_observer(udp::endpoint const& ep
, node_id const& id);
data_callback m_data_callback; data_callback m_data_callback;
bool m_noseeds; bool m_noseeds;
@ -72,7 +73,7 @@ struct obfuscated_get_peers : get_peers
protected: protected:
virtual observer_ptr new_observer(void* ptr, udp::endpoint const& ep, virtual observer_ptr new_observer(udp::endpoint const& ep,
node_id const& id); node_id const& id);
virtual bool invoke(observer_ptr o); virtual bool invoke(observer_ptr o);
virtual void done(); virtual void done();

View File

@ -34,14 +34,11 @@ POSSIBILITY OF SUCH DAMAGE.
#define OBSERVER_HPP #define OBSERVER_HPP
#include <cstdint> #include <cstdint>
#include <memory>
#include <libtorrent/time.hpp> #include <libtorrent/time.hpp>
#include <libtorrent/address.hpp> #include <libtorrent/address.hpp>
#include "libtorrent/aux_/disable_warnings_push.hpp"
#include <boost/intrusive_ptr.hpp>
#include "libtorrent/aux_/disable_warnings_pop.hpp"
namespace libtorrent { namespace libtorrent {
namespace dht { namespace dht {
@ -50,21 +47,14 @@ struct observer;
struct msg; struct msg;
struct traversal_algorithm; struct traversal_algorithm;
// defined in rpc_manager.cpp
TORRENT_EXTRA_EXPORT void intrusive_ptr_add_ref(observer const*);
TORRENT_EXTRA_EXPORT void intrusive_ptr_release(observer const*);
struct TORRENT_EXTRA_EXPORT observer : boost::noncopyable struct TORRENT_EXTRA_EXPORT observer : boost::noncopyable
, std::enable_shared_from_this<observer>
{ {
friend TORRENT_EXTRA_EXPORT void intrusive_ptr_add_ref(observer const*);
friend TORRENT_EXTRA_EXPORT void intrusive_ptr_release(observer const*);
observer(std::shared_ptr<traversal_algorithm> const& a observer(std::shared_ptr<traversal_algorithm> const& a
, udp::endpoint const& ep, node_id const& id) , udp::endpoint const& ep, node_id const& id)
: m_sent() : m_sent()
, m_algorithm(a) , m_algorithm(a)
, m_id(id) , m_id(id)
, m_refs(0)
, m_port(0) , m_port(0)
, m_transaction_id() , m_transaction_id()
, flags(0) , flags(0)
@ -137,6 +127,9 @@ protected:
private: private:
std::shared_ptr<observer> self()
{ return shared_from_this(); }
time_point m_sent; time_point m_sent;
const std::shared_ptr<traversal_algorithm> m_algorithm; const std::shared_ptr<traversal_algorithm> m_algorithm;
@ -151,9 +144,6 @@ private:
address_v4::bytes_type v4; address_v4::bytes_type v4;
} m_addr; } m_addr;
// reference counter for intrusive_ptr
mutable std::uint16_t m_refs;
std::uint16_t m_port; std::uint16_t m_port;
// the transaction ID for this call // the transaction ID for this call
@ -169,7 +159,7 @@ public:
#endif #endif
}; };
typedef boost::intrusive_ptr<observer> observer_ptr; typedef std::shared_ptr<observer> observer_ptr;
} } } }

View File

@ -52,7 +52,7 @@ public:
, done_callback const& callback); , done_callback const& callback);
virtual char const* name() const; virtual char const* name() const;
observer_ptr new_observer(void* ptr, udp::endpoint const& ep observer_ptr new_observer(udp::endpoint const& ep
, node_id const& id); , node_id const& id);
void trim_seed_nodes(); void trim_seed_nodes();

View File

@ -94,8 +94,19 @@ public:
void check_invariant() const; void check_invariant() const;
#endif #endif
void* allocate_observer(); template <typename T, typename... Args>
void free_observer(void* ptr); std::shared_ptr<T> allocate_observer(Args&&... args)
{
void* ptr = allocate_observer();
if (ptr == nullptr) return std::shared_ptr<T>();
auto deleter = [this](observer* o)
{
o->~observer();
free_observer(o);
};
return std::shared_ptr<T>(new (ptr) T(std::forward<Args>(args)...), deleter);
}
int num_allocated_observers() const { return m_allocated_observers; } int num_allocated_observers() const { return m_allocated_observers; }
@ -103,6 +114,9 @@ public:
private: private:
void* allocate_observer();
void free_observer(void* ptr);
std::uint32_t calc_connection_id(udp::endpoint addr); std::uint32_t calc_connection_id(udp::endpoint addr);
mutable boost::pool<> m_pool_allocator; mutable boost::pool<> m_pool_allocator;

View File

@ -35,6 +35,7 @@ POSSIBILITY OF SUCH DAMAGE.
#include <vector> #include <vector>
#include <set> #include <set>
#include <memory>
#include <libtorrent/kademlia/node_id.hpp> #include <libtorrent/kademlia/node_id.hpp>
#include <libtorrent/kademlia/routing_table.hpp> #include <libtorrent/kademlia/routing_table.hpp>
@ -64,9 +65,6 @@ struct TORRENT_EXTRA_EXPORT traversal_algorithm : boost::noncopyable
virtual ~traversal_algorithm(); virtual ~traversal_algorithm();
void status(dht_lookup& l); void status(dht_lookup& l);
void* allocate_observer();
void free_observer(void* ptr);
virtual char const* name() const; virtual char const* name() const;
virtual void start(); virtual void start();
@ -95,8 +93,8 @@ protected:
virtual void done(); virtual void done();
// should construct an algorithm dependent // should construct an algorithm dependent
// observer in ptr. // observer in ptr.
virtual observer_ptr new_observer(void* ptr virtual observer_ptr new_observer(udp::endpoint const& ep
, udp::endpoint const& ep, node_id const& id); , node_id const& id);
virtual bool invoke(observer_ptr) { return false; } virtual bool invoke(observer_ptr) { return false; }

View File

@ -118,12 +118,12 @@ void find_data::got_write_token(node_id const& n, std::string write_token)
m_write_tokens[n] = std::move(write_token); m_write_tokens[n] = std::move(write_token);
} }
observer_ptr find_data::new_observer(void* ptr observer_ptr find_data::new_observer(udp::endpoint const& ep
, udp::endpoint const& ep, node_id const& id) , node_id const& id)
{ {
observer_ptr o(new (ptr) find_data_observer(self(), ep, id)); auto o = m_node.m_rpc.allocate_observer<find_data_observer>(self(), ep, id);
#if TORRENT_USE_ASSERTS #if TORRENT_USE_ASSERTS
o->m_in_constructor = false; if (o) o->m_in_constructor = false;
#endif #endif
return o; return o;
} }

View File

@ -125,12 +125,12 @@ get_item::get_item(
char const* get_item::name() const { return "get"; } char const* get_item::name() const { return "get"; }
observer_ptr get_item::new_observer(void* ptr observer_ptr get_item::new_observer(udp::endpoint const& ep
, udp::endpoint const& ep, node_id const& id) , node_id const& id)
{ {
observer_ptr o(new (ptr) get_item_observer(self(), ep, id)); auto o = m_node.m_rpc.allocate_observer<get_item_observer>(self(), ep, id);
#if TORRENT_USE_ASSERTS #if TORRENT_USE_ASSERTS
o->m_in_constructor = false; if (o) o->m_in_constructor = false;
#endif #endif
return o; return o;
} }

View File

@ -160,12 +160,12 @@ bool get_peers::invoke(observer_ptr o)
return m_node.m_rpc.invoke(e, o->target_ep(), o); return m_node.m_rpc.invoke(e, o->target_ep(), o);
} }
observer_ptr get_peers::new_observer(void* ptr observer_ptr get_peers::new_observer(udp::endpoint const& ep
, udp::endpoint const& ep, node_id const& id) , node_id const& id)
{ {
observer_ptr o(new (ptr) get_peers_observer(self(), ep, id)); auto o = m_node.m_rpc.allocate_observer<get_peers_observer>(self(), ep, id);
#if TORRENT_USE_ASSERTS #if TORRENT_USE_ASSERTS
o->m_in_constructor = false; if (o) o->m_in_constructor = false;
#endif #endif
return o; return o;
} }
@ -184,22 +184,24 @@ obfuscated_get_peers::obfuscated_get_peers(
char const* obfuscated_get_peers::name() const char const* obfuscated_get_peers::name() const
{ return !m_obfuscated ? get_peers::name() : "get_peers [obfuscated]"; } { return !m_obfuscated ? get_peers::name() : "get_peers [obfuscated]"; }
observer_ptr obfuscated_get_peers::new_observer(void* ptr observer_ptr obfuscated_get_peers::new_observer(udp::endpoint const& ep
, udp::endpoint const& ep, node_id const& id) , node_id const& id)
{ {
if (m_obfuscated) if (m_obfuscated)
{ {
observer_ptr o(new (ptr) obfuscated_get_peers_observer(self(), ep, id)); auto o = m_node.m_rpc.allocate_observer<obfuscated_get_peers_observer>(self()
, ep, id);
#if TORRENT_USE_ASSERTS #if TORRENT_USE_ASSERTS
o->m_in_constructor = false; if (o) o->m_in_constructor = false;
#endif #endif
return o; return o;
} }
else else
{ {
observer_ptr o(new (ptr) get_peers_observer(self(), ep, id)); auto o = m_node.m_rpc.allocate_observer<get_peers_observer>(self()
, ep, id);
#if TORRENT_USE_ASSERTS #if TORRENT_USE_ASSERTS
o->m_in_constructor = false; if (o) o->m_in_constructor = false;
#endif #endif
return o; return o;
} }

View File

@ -373,9 +373,9 @@ namespace
} }
#endif #endif
void* ptr = node.m_rpc.allocate_observer(); auto o = node.m_rpc.allocate_observer<announce_observer>(algo
if (ptr == nullptr) return; , p.first.ep(), p.first.id);
observer_ptr o(new (ptr) announce_observer(algo, p.first.ep(), p.first.id)); if (!o) return;
#if TORRENT_USE_ASSERTS #if TORRENT_USE_ASSERTS
o->m_in_constructor = false; o->m_in_constructor = false;
#endif #endif
@ -459,9 +459,8 @@ void node::direct_request(udp::endpoint ep, entry& e
// not really a traversal // not really a traversal
auto algo = std::make_shared<direct_traversal>(*this, node_id(), f); auto algo = std::make_shared<direct_traversal>(*this, node_id(), f);
void* ptr = m_rpc.allocate_observer(); auto o = m_rpc.allocate_observer<direct_observer>(algo, ep, node_id());
if (ptr == nullptr) return; if (!o) return;
observer_ptr o(new (ptr) direct_observer(algo, ep, node_id()));
#if TORRENT_USE_ASSERTS #if TORRENT_USE_ASSERTS
o->m_in_constructor = false; o->m_in_constructor = false;
#endif #endif
@ -658,9 +657,6 @@ void node::send_single_refresh(udp::endpoint const& ep, int bucket
, node_id const& id) , node_id const& id)
{ {
TORRENT_ASSERT(id != m_id); TORRENT_ASSERT(id != m_id);
void* ptr = m_rpc.allocate_observer();
if (ptr == nullptr) return;
TORRENT_ASSERT(bucket >= 0); TORRENT_ASSERT(bucket >= 0);
TORRENT_ASSERT(bucket <= 159); TORRENT_ASSERT(bucket <= 159);
@ -675,7 +671,8 @@ void node::send_single_refresh(udp::endpoint const& ep, int bucket
// this is unfortunately necessary for the observer // this is unfortunately necessary for the observer
// to free itself from the pool when it's being released // to free itself from the pool when it's being released
auto algo = std::make_shared<traversal_algorithm>(*this, node_id()); auto algo = std::make_shared<traversal_algorithm>(*this, node_id());
observer_ptr o(new (ptr) ping_observer(algo, ep, id)); auto o = m_rpc.allocate_observer<ping_observer>(algo, ep, id);
if (!o) return;
#if TORRENT_USE_ASSERTS #if TORRENT_USE_ASSERTS
o->m_in_constructor = false; o->m_in_constructor = false;
#endif #endif

View File

@ -55,16 +55,13 @@ void put_data::start()
if (is_done) done(); if (is_done) done();
} }
void put_data::set_targets(std::vector<std::pair<node_entry, std::string> > const& targets) void put_data::set_targets(std::vector<std::pair<node_entry, std::string>> const& targets)
{ {
for (std::vector<std::pair<node_entry, std::string> >::const_iterator i = targets.begin() for (auto const& p : targets)
, end(targets.end()); i != end; ++i)
{ {
void* ptr = m_node.m_rpc.allocate_observer(); auto o = m_node.m_rpc.allocate_observer<put_data_observer>(self(), p.first.ep()
if (ptr == nullptr) return; , p.first.id, p.second);
if (!o) return;
observer_ptr o(new (ptr) put_data_observer(self(), i->first.ep()
, i->first.id, i->second));
#if TORRENT_USE_ASSERTS #if TORRENT_USE_ASSERTS
o->m_in_constructor = false; o->m_in_constructor = false;

View File

@ -41,12 +41,12 @@ POSSIBILITY OF SUCH DAMAGE.
namespace libtorrent { namespace dht namespace libtorrent { namespace dht
{ {
observer_ptr bootstrap::new_observer(void* ptr observer_ptr bootstrap::new_observer(udp::endpoint const& ep
, udp::endpoint const& ep, node_id const& id) , node_id const& id)
{ {
observer_ptr o(new (ptr) get_peers_observer(self(), ep, id)); auto o = m_node.m_rpc.allocate_observer<get_peers_observer>(self(), ep, id);
#if TORRENT_USE_ASSERTS #if TORRENT_USE_ASSERTS
o->m_in_constructor = false; if (o) o->m_in_constructor = false;
#endif #endif
return o; return o;
} }

View File

@ -69,25 +69,6 @@ namespace libtorrent { namespace dht
namespace io = libtorrent::detail; namespace io = libtorrent::detail;
void intrusive_ptr_add_ref(observer const* o)
{
TORRENT_ASSERT(o != nullptr);
TORRENT_ASSERT(o->m_refs < 0xffff);
++o->m_refs;
}
void intrusive_ptr_release(observer const* o)
{
TORRENT_ASSERT(o != nullptr);
TORRENT_ASSERT(o->m_refs > 0);
if (--o->m_refs == 0)
{
auto ta = o->algorithm()->shared_from_this();
(const_cast<observer*>(o))->~observer();
ta->free_observer(const_cast<observer*>(o));
}
}
// TODO: 3 move this into it's own .cpp file // TODO: 3 move this into it's own .cpp file
dht_observer* observer::get_observer() const dht_observer* observer::get_observer() const
{ {
@ -132,27 +113,27 @@ void observer::abort()
{ {
if (flags & flag_done) return; if (flags & flag_done) return;
flags |= flag_done; flags |= flag_done;
m_algorithm->failed(observer_ptr(this), traversal_algorithm::prevent_request); m_algorithm->failed(self(), traversal_algorithm::prevent_request);
} }
void observer::done() void observer::done()
{ {
if (flags & flag_done) return; if (flags & flag_done) return;
flags |= flag_done; flags |= flag_done;
m_algorithm->finished(observer_ptr(this)); m_algorithm->finished(self());
} }
void observer::short_timeout() void observer::short_timeout()
{ {
if (flags & flag_short_timeout) return; if (flags & flag_short_timeout) return;
m_algorithm->failed(observer_ptr(this), traversal_algorithm::short_timeout); m_algorithm->failed(self(), traversal_algorithm::short_timeout);
} }
void observer::timeout() void observer::timeout()
{ {
if (flags & flag_done) return; if (flags & flag_done) return;
flags |= flag_done; flags |= flag_done;
m_algorithm->failed(observer_ptr(this)); m_algorithm->failed(self());
} }
void observer::set_id(node_id const& id) void observer::set_id(node_id const& id)

View File

@ -76,12 +76,12 @@ bool is_sorted(It b, It e, Cmp cmp)
} }
#endif #endif
observer_ptr traversal_algorithm::new_observer(void* ptr observer_ptr traversal_algorithm::new_observer(udp::endpoint const& ep
, udp::endpoint const& ep, node_id const& id) , node_id const& id)
{ {
observer_ptr o(new (ptr) null_observer(self(), ep, id)); auto o = m_node.m_rpc.allocate_observer<null_observer>(self(), ep, id);
#if TORRENT_USE_ASSERTS #if TORRENT_USE_ASSERTS
o->m_in_constructor = false; if (o) o->m_in_constructor = false;
#endif #endif
return o; return o;
} }
@ -117,8 +117,8 @@ void traversal_algorithm::resort_results()
void traversal_algorithm::add_entry(node_id const& id, udp::endpoint addr, unsigned char flags) void traversal_algorithm::add_entry(node_id const& id, udp::endpoint addr, unsigned char flags)
{ {
TORRENT_ASSERT(m_node.m_rpc.allocation_size() >= sizeof(find_data_observer)); TORRENT_ASSERT(m_node.m_rpc.allocation_size() >= sizeof(find_data_observer));
void* ptr = m_node.m_rpc.allocate_observer(); auto o = new_observer(addr, id);
if (ptr == nullptr) if (!o)
{ {
#ifndef TORRENT_DISABLE_LOGGING #ifndef TORRENT_DISABLE_LOGGING
if (get_node().observer()) if (get_node().observer())
@ -130,7 +130,6 @@ void traversal_algorithm::add_entry(node_id const& id, udp::endpoint addr, unsig
done(); done();
return; return;
} }
observer_ptr o = new_observer(ptr, addr, id);
if (id.is_all_zeros()) if (id.is_all_zeros())
{ {
o->set_id(generate_random_id()); o->set_id(generate_random_id());
@ -246,16 +245,6 @@ void traversal_algorithm::start()
if (is_done) done(); if (is_done) done();
} }
void* traversal_algorithm::allocate_observer()
{
return m_node.m_rpc.allocate_observer();
}
void traversal_algorithm::free_observer(void* ptr)
{
m_node.m_rpc.free_observer(ptr);
}
char const* traversal_algorithm::name() const char const* traversal_algorithm::name() const
{ {
return "traversal_algorithm"; return "traversal_algorithm";

View File

@ -41,8 +41,6 @@ POSSIBILITY OF SUCH DAMAGE.
#include "libtorrent/aux_/disable_warnings_push.hpp" #include "libtorrent/aux_/disable_warnings_push.hpp"
#include <boost/version.hpp>
#if defined(__APPLE__) #if defined(__APPLE__)
// for getattrlist() // for getattrlist()
#include <sys/attr.h> #include <sys/attr.h>

View File

@ -2968,7 +2968,7 @@ TORRENT_TEST(rpc_invalid_error_msg)
g_sent_packets.clear(); g_sent_packets.clear();
auto algo = std::make_shared<dht::traversal_algorithm>(node, node_id()); auto algo = std::make_shared<dht::traversal_algorithm>(node, node_id());
observer_ptr o(new (rpc.allocate_observer()) null_observer(algo, source, node_id())); auto o = rpc.allocate_observer<null_observer>(algo, source, node_id());
#if TORRENT_USE_ASSERTS #if TORRENT_USE_ASSERTS
o->m_in_constructor = false; o->m_in_constructor = false;
#endif #endif