Merge pull request #252 from ssiloti/traversal_multiple_dones

fix traversal_algorithm::done() being invoked more than once
This commit is contained in:
Arvid Norberg 2015-11-10 01:44:37 -05:00
commit 4ebca897d9
5 changed files with 103 additions and 29 deletions

View File

@ -67,7 +67,6 @@ struct traversal_algorithm : boost::noncopyable
void failed(observer_ptr o, int flags = 0);
virtual ~traversal_algorithm();
void status(dht_lookup& l);
void abort();
void* allocate_observer();
void free_observer(void* ptr);

View File

@ -135,8 +135,6 @@ char const* find_data::name() const { return "find_data"; }
void find_data::done()
{
if (m_invoke_count != 0) return;
m_done = true;
#ifndef TORRENT_DISABLE_LOGGING

View File

@ -111,7 +111,7 @@ void get_item::got_data(bdecode_node const& v,
// There can only be one true immutable item with a given id
// Now that we've got it and the user doesn't want to do a put
// there's no point in continuing to query other nodes
abort();
done();
}
}
}

View File

@ -374,14 +374,22 @@ void traversal_algorithm::done()
#ifndef TORRENT_DISABLE_LOGGING
int results_target = m_node.m_table.bucket_size();
int closest_target = 160;
#endif
// TODO: 3 it would be nice to not have to perform this loop if
// logging is disabled
for (std::vector<observer_ptr>::iterator i = m_results.begin()
, end(m_results.end()); i != end && results_target > 0; ++i)
, end(m_results.end()); i != end; ++i)
{
boost::intrusive_ptr<observer> o = *i;
if ((o->flags & observer::flag_alive) && get_node().observer())
if (o->flags & observer::flag_queried)
{
// set the done flag on any outstanding queries to prevent them from
// calling finished() or failed() after we've already declared the traversal
// done
o->flags |= observer::flag_done;
}
#ifndef TORRENT_DISABLE_LOGGING
if (results_target > 0 && (o->flags & observer::flag_alive) && get_node().observer())
{
TORRENT_ASSERT(o->flags & observer::flag_queried);
char hex_id[41];
@ -395,8 +403,10 @@ void traversal_algorithm::done()
int dist = distance_exp(m_target, o->id());
if (dist < closest_target) closest_target = dist;
}
#endif
}
#ifndef TORRENT_DISABLE_LOGGING
if (get_node().observer())
{
get_node().observer()->log(dht_logger::traversal
@ -404,9 +414,11 @@ void traversal_algorithm::done()
, static_cast<void*>(this), closest_target, name());
}
#endif
// delete all our references to the observer objects so
// they will in turn release the traversal algorithm
m_results.clear();
m_invoke_count = 0;
}
bool traversal_algorithm::add_requests()
@ -610,26 +622,5 @@ void traversal_observer::reply(msg const& m)
set_id(node_id(id.string_ptr()));
}
void traversal_algorithm::abort()
{
for (std::vector<observer_ptr>::iterator i = m_results.begin()
, end(m_results.end()); i != end; ++i)
{
observer& o = **i;
if (o.flags & observer::flag_queried)
o.flags |= observer::flag_done;
}
#ifndef TORRENT_DISABLE_LOGGING
if (get_node().observer())
{
get_node().observer()->log(dht_logger::traversal, "[%p] ABORTED type: %s"
, static_cast<void*>(this), name());
}
#endif
done();
}
} } // namespace libtorrent::dht

View File

@ -1920,6 +1920,92 @@ TORRENT_TEST(dht)
g_put_count = 0;
} while (false);
// verify that done() is only invoked once
// See PR 252
g_sent_packets.clear();
do
{
// set the branching factor to k to make this a little easier
int old_branching = sett.search_branching;
sett.search_branching = 8;
dht::node node(&s, sett, (node_id::min)(), &observer, cnt);
sha1_hash target = hasher(public_key, item_pk_len).final();
enum { num_test_nodes = 9 }; // we need K + 1 nodes to create the failing sequence
node_entry nodes[num_test_nodes] =
{ node_entry(target, udp::endpoint(address_v4::from_string("1.1.1.1"), 1231))
, node_entry(target, udp::endpoint(address_v4::from_string("2.2.2.2"), 1232))
, node_entry(target, udp::endpoint(address_v4::from_string("3.3.3.3"), 1233))
, node_entry(target, udp::endpoint(address_v4::from_string("4.4.4.4"), 1234))
, node_entry(target, udp::endpoint(address_v4::from_string("5.5.5.5"), 1235))
, node_entry(target, udp::endpoint(address_v4::from_string("6.6.6.6"), 1236))
, node_entry(target, udp::endpoint(address_v4::from_string("7.7.7.7"), 1237))
, node_entry(target, udp::endpoint(address_v4::from_string("8.8.8.8"), 1238))
, node_entry(target, udp::endpoint(address_v4::from_string("9.9.9.9"), 1239)) };
// invert the ith most significant byte so that the test nodes are
// progressivly closer to the target item
for (int i = 0; i < num_test_nodes; ++i)
nodes[i].id[i] = ~nodes[i].id[i];
// add the first k nodes to the subject's routing table
for (int i = 0; i < 8; ++i)
node.m_table.add_node(nodes[i]);
// kick off a mutable get request
g_put_item.assign(items[0].ent, empty_salt, seq, public_key, private_key);
node.get_item(target, get_item_cb);
TEST_EQUAL(g_sent_packets.size(), 8);
if (g_sent_packets.size() != 8) break;
// first send responses for the k closest nodes
for (int i = 1;; ++i)
{
// once the k closest nodes have responded, send the final response
// from the farthest node, this shouldn't trigger a second call to
// get_item_cb
if (i == num_test_nodes) i = 0;
std::list<std::pair<udp::endpoint, entry> >::iterator packet = find_packet(nodes[i].ep());
TEST_CHECK(packet != g_sent_packets.end());
if (packet == g_sent_packets.end()) continue;
lazy_from_entry(packet->second, response);
ret = verify_message(response, get_item_desc, parsed, 6, error_string
, sizeof(error_string));
if (!ret)
{
fprintf(stderr, " invalid get request: %s\n", print_entry(response).c_str());
TEST_ERROR(error_string);
continue;
}
char t[10];
snprintf(t, sizeof(t), "%02d", i);
msg_args args;
args.token(t).port(1234).nid(nodes[i].id);
// add the address of the closest node to the first response
if (i == 1)
args.nodes(nodes_t(1, nodes[8]));
send_dht_response(node, response, nodes[i].ep(), args);
g_sent_packets.erase(packet);
// once we've sent the response from the farthest node, we're done
if (i == 0) break;
}
TEST_EQUAL(g_put_count, 1);
// k nodes should now have outstanding put requests
TEST_EQUAL(g_sent_packets.size(), 8);
g_sent_packets.clear();
g_put_item.clear();
g_put_count = 0;
sett.search_branching = old_branching;
} while (false);
}
void get_test_keypair(char* public_key, char* private_key)