improve type-safety of observer_flags and traversal_flags

This commit is contained in:
arvidn 2017-08-05 15:45:28 +02:00 committed by Arvid Norberg
parent f2d1a283bc
commit 7b98af8145
6 changed files with 74 additions and 46 deletions

View File

@ -38,15 +38,21 @@ POSSIBILITY OF SUCH DAMAGE.
#include <libtorrent/time.hpp>
#include <libtorrent/address.hpp>
#include <libtorrent/flags.hpp>
namespace libtorrent {
namespace dht {
namespace libtorrent { namespace dht {
struct dht_observer;
struct observer;
struct msg;
struct traversal_algorithm;
struct TORRENT_EXTRA_EXPORT observer : boost::noncopyable
, std::enable_shared_from_this<observer>
struct observer_flags_tag;
using observer_flags_t = libtorrent::flags::bitfield_flag<std::uint8_t, observer_flags_tag>;
struct TORRENT_EXTRA_EXPORT observer
: std::enable_shared_from_this<observer>
{
observer(std::shared_ptr<traversal_algorithm> const& a
, udp::endpoint const& ep, node_id const& id)
@ -55,7 +61,7 @@ struct TORRENT_EXTRA_EXPORT observer : boost::noncopyable
, m_id(id)
, m_port(0)
, m_transaction_id()
, flags(0)
, flags{}
{
TORRENT_ASSERT(a);
#if TORRENT_USE_ASSERTS
@ -67,6 +73,9 @@ struct TORRENT_EXTRA_EXPORT observer : boost::noncopyable
set_target(ep);
}
observer(observer const&) = delete;
observer& operator=(observer const&) = delete;
// defined in rpc_manager.cpp
virtual ~observer();
@ -77,7 +86,7 @@ struct TORRENT_EXTRA_EXPORT observer : boost::noncopyable
// a few seconds, before the request has timed out
void short_timeout();
bool has_short_timeout() const { return (flags & flag_short_timeout) != 0; }
bool has_short_timeout() const { return bool(flags & flag_short_timeout); }
// this is called when no reply has been received within
// some timeout, or a reply with incorrect format.
@ -108,16 +117,14 @@ struct TORRENT_EXTRA_EXPORT observer : boost::noncopyable
std::uint16_t transaction_id() const
{ return m_transaction_id; }
enum {
flag_queried = 1,
flag_initial = 2,
flag_no_id = 4,
flag_short_timeout = 8,
flag_failed = 16,
flag_ipv6_address = 32,
flag_alive = 64,
flag_done = 128
};
static constexpr observer_flags_t flag_queried = 0_bit;
static constexpr observer_flags_t flag_initial = 1_bit;
static constexpr observer_flags_t flag_no_id = 2_bit;
static constexpr observer_flags_t flag_short_timeout = 3_bit;
static constexpr observer_flags_t flag_failed = 4_bit;
static constexpr observer_flags_t flag_ipv6_address = 5_bit;
static constexpr observer_flags_t flag_alive = 6_bit;
static constexpr observer_flags_t flag_done = 7_bit;
protected:
@ -130,7 +137,7 @@ private:
time_point m_sent;
const std::shared_ptr<traversal_algorithm> m_algorithm;
std::shared_ptr<traversal_algorithm> const m_algorithm;
node_id m_id;
@ -147,7 +154,7 @@ private:
// the transaction ID for this call
std::uint16_t m_transaction_id;
public:
std::uint8_t flags;
observer_flags_t flags;
#if TORRENT_USE_ASSERTS
bool m_in_constructor:1;
@ -157,8 +164,9 @@ public:
#endif
};
typedef std::shared_ptr<observer> observer_ptr;
using observer_ptr = std::shared_ptr<observer>;
} }
}
}
#endif

View File

@ -41,26 +41,31 @@ POSSIBILITY OF SUCH DAMAGE.
#include <libtorrent/kademlia/routing_table.hpp>
#include <libtorrent/kademlia/observer.hpp>
#include <libtorrent/address.hpp>
#include <libtorrent/flags.hpp>
#include "libtorrent/aux_/disable_warnings_push.hpp"
#include <boost/noncopyable.hpp>
#include "libtorrent/aux_/disable_warnings_pop.hpp"
namespace libtorrent {
namespace libtorrent { struct dht_lookup; }
namespace libtorrent { namespace dht {
struct dht_lookup;
namespace dht {
class node;
struct node_endpoint;
struct traversal_flags_tag;
using traversal_flags_t = libtorrent::flags::bitfield_flag<std::uint8_t, traversal_flags_tag>;
// this class may not be instantiated as a stack object
struct TORRENT_EXTRA_EXPORT traversal_algorithm : boost::noncopyable
, std::enable_shared_from_this<traversal_algorithm>
struct TORRENT_EXTRA_EXPORT traversal_algorithm
: std::enable_shared_from_this<traversal_algorithm>
{
void traverse(node_id const& id, udp::endpoint const& addr);
void finished(observer_ptr o);
enum flags_t { prevent_request = 1, short_timeout = 2 };
void failed(observer_ptr o, int flags = 0);
static constexpr traversal_flags_t prevent_request = 0_bit;
static constexpr traversal_flags_t short_timeout = 1_bit;
void failed(observer_ptr o, traversal_flags_t flags = {});
virtual ~traversal_algorithm();
void status(dht_lookup& l);
@ -70,9 +75,11 @@ struct TORRENT_EXTRA_EXPORT traversal_algorithm : boost::noncopyable
node_id const& target() const { return m_target; }
void resort_result(observer*);
void add_entry(node_id const& id, udp::endpoint const& addr, unsigned char flags);
void add_entry(node_id const& id, udp::endpoint const& addr, observer_flags_t flags);
traversal_algorithm(node& dht_node, node_id const& target);
traversal_algorithm(traversal_algorithm const&) = delete;
traversal_algorithm& operator=(traversal_algorithm const&) = delete;
int invoke_count() const { TORRENT_ASSERT(m_invoke_count >= 0); return m_invoke_count; }
int branch_factor() const { TORRENT_ASSERT(m_branch_factor >= 0); return m_branch_factor; }
@ -157,6 +164,7 @@ struct traversal_observer : observer
virtual void reply(msg const&);
};
} } // namespace libtorrent::dht
} // namespace dht
} // namespace libtorrent
#endif // TRAVERSAL_ALGORITHM_050324_HPP

View File

@ -151,7 +151,7 @@ void find_data::done()
, end(m_results.end()); i != end && num_results > 0; ++i)
{
observer_ptr const& o = *i;
if ((o->flags & observer::flag_alive) == 0)
if (!(o->flags & observer::flag_alive))
{
#ifndef TORRENT_DISABLE_LOGGING
if (logger != nullptr && logger->should_log(dht_logger::traversal))

View File

@ -218,7 +218,7 @@ bool obfuscated_get_peers::invoke(observer_ptr o)
// don't re-request from nodes that didn't respond
if (node->flags & observer::flag_failed) continue;
// don't interrupt with queries that are already in-flight
if ((node->flags & observer::flag_alive) == 0) continue;
if (!(node->flags & observer::flag_alive)) continue;
node->flags &= ~(observer::flag_queried | observer::flag_alive);
}
return get_peers::invoke(o);
@ -283,7 +283,7 @@ void obfuscated_get_peers::done()
// only add nodes whose node ID we know and that
// we know are alive
if (o->flags & observer::flag_no_id) continue;
if ((o->flags & observer::flag_alive) == 0) continue;
if (!(o->flags & observer::flag_alive)) continue;
ta->add_entry(o->id(), o->target_ep(), observer::flag_initial);
++num_added;

View File

@ -63,6 +63,16 @@ using namespace std::placeholders;
namespace libtorrent { namespace dht {
// TODO: 3 move this into it's own .cpp file
constexpr observer_flags_t observer::flag_queried;
constexpr observer_flags_t observer::flag_initial;
constexpr observer_flags_t observer::flag_no_id;
constexpr observer_flags_t observer::flag_short_timeout;
constexpr observer_flags_t observer::flag_failed;
constexpr observer_flags_t observer::flag_ipv6_address;
constexpr observer_flags_t observer::flag_alive;
constexpr observer_flags_t observer::flag_done;
dht_observer* observer::get_observer() const
{
return m_algorithm->get_node().observer();
@ -502,7 +512,7 @@ observer::~observer()
// reported back to the traversal_algorithm as
// well. If it wasn't sent, it cannot have been
// reported back
TORRENT_ASSERT(m_was_sent == ((flags & flag_done) != 0) || m_was_abandoned);
TORRENT_ASSERT(m_was_sent == bool(flags & flag_done) || m_was_abandoned);
TORRENT_ASSERT(!m_in_constructor);
#if TORRENT_USE_ASSERTS
TORRENT_ASSERT(m_in_use);

View File

@ -46,7 +46,11 @@ POSSIBILITY OF SUCH DAMAGE.
using namespace std::placeholders;
namespace libtorrent { namespace dht {
namespace libtorrent {
namespace dht {
constexpr traversal_flags_t traversal_algorithm::prevent_request;
constexpr traversal_flags_t traversal_algorithm::short_timeout;
#if TORRENT_USE_ASSERTS
template <class It, class Cmp>
@ -76,9 +80,7 @@ observer_ptr traversal_algorithm::new_observer(udp::endpoint const& ep
return o;
}
traversal_algorithm::traversal_algorithm(
node& dht_node
, node_id const& target)
traversal_algorithm::traversal_algorithm(node& dht_node, node_id const& target)
: m_node(dht_node)
, m_target(target)
{
@ -123,7 +125,7 @@ void traversal_algorithm::resort_result(observer* o)
}
void traversal_algorithm::add_entry(node_id const& id
, udp::endpoint const& addr, unsigned char const flags)
, udp::endpoint const& addr, observer_flags_t const flags)
{
TORRENT_ASSERT(m_node.m_rpc.allocation_size() >= sizeof(find_data_observer));
auto o = new_observer(addr, id);
@ -297,7 +299,7 @@ void traversal_algorithm::traverse(node_id const& id, udp::endpoint const& addr)
// let the routing table know this node may exist
m_node.m_table.heard_about(id, addr);
add_entry(id, addr, 0);
add_entry(id, addr, {});
}
void traversal_algorithm::finished(observer_ptr o)
@ -328,11 +330,11 @@ void traversal_algorithm::finished(observer_ptr o)
// prevent request means that the total number of requests has
// overflown. This query failed because it was the oldest one.
// So, if this is true, don't make another request
void traversal_algorithm::failed(observer_ptr o, int const flags)
void traversal_algorithm::failed(observer_ptr o, traversal_flags_t const flags)
{
// don't tell the routing table about
// node ids that we just generated ourself
if ((o->flags & observer::flag_no_id) == 0)
if (!(o->flags & observer::flag_no_id))
m_node.m_table.node_failed(o->id(), o->target_ep());
if (m_results.empty()) return;
@ -348,7 +350,7 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
// we do get a late response, keep the handler
// around for some more, but open up the slot
// by increasing the branch factor
if ((o->flags & observer::flag_short_timeout) == 0
if (!(o->flags & observer::flag_short_timeout)
&& m_branch_factor < std::numeric_limits<std::int8_t>::max())
{
++m_branch_factor;
@ -363,7 +365,7 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
o->flags |= observer::flag_failed;
// if this flag is set, it means we increased the
// branch factor for it, and we should restore it
decrement_branch_factor = (o->flags & observer::flag_short_timeout) != 0;
decrement_branch_factor = bool(o->flags & observer::flag_short_timeout);
#ifndef TORRENT_DISABLE_LOGGING
log_timeout(o,"");
@ -376,7 +378,7 @@ void traversal_algorithm::failed(observer_ptr o, int const flags)
// this is another reason to decrement the branch factor, to prevent another
// request from filling this slot. Only ever decrement once per response though
decrement_branch_factor |= (flags & prevent_request);
decrement_branch_factor |= bool(flags & prevent_request);
if (decrement_branch_factor)
{
@ -499,7 +501,7 @@ bool traversal_algorithm::add_requests()
{
// if it's queried, not alive and not failed, it
// must be currently in flight
if ((o->flags & observer::flag_failed) == 0)
if (!(o->flags & observer::flag_failed))
++outstanding;
continue;