correctly check incoming DHT error messages

This commit is contained in:
arvidn 2015-11-30 18:02:00 -05:00
parent d0da753fb8
commit 7540a6e5cc
3 changed files with 61 additions and 16 deletions

View File

@ -283,11 +283,22 @@ void node::incoming(msg const& m)
case 'e':
{
#ifndef TORRENT_DISABLE_LOGGING
bdecode_node err = m.message.dict_find_list("e");
if (err && err.list_size() >= 2 && m_observer)
if (m_observer)
{
m_observer->log(dht_logger::node, "INCOMING ERROR: %s"
, err.list_string_value_at(1).c_str());
bdecode_node err = m.message.dict_find_list("e");
if (err && err.list_size() >= 2
&& err.list_at(0).type() == bdecode_node::int_t
&& err.list_at(1).type() == bdecode_node::string_t
&& m_observer)
{
m_observer->log(dht_logger::node, "INCOMING ERROR: (%" PRId64 ") %s"
, err.list_int_value_at(0)
, err.list_string_value_at(1).c_str());
}
else
{
m_observer->log(dht_logger::node, "INCOMING ERROR (malformed)");
}
}
#endif
node_id id;
@ -918,7 +929,7 @@ void node::incoming_request(msg const& m, entry& e)
}
reply["token"] = generate_token(m.addr, msg_keys[0].string_ptr());
m_counters.inc_stats_counter(counters::dht_get_peers_in);
sha1_hash info_hash(msg_keys[0].string_ptr());

View File

@ -308,14 +308,25 @@ bool rpc_manager::incoming(msg const& m, node_id* id)
, total_milliseconds(now - o->sent()), print_endpoint(m.addr).c_str());
#endif
if (m.message.dict_find_string_value("y") == "e")
if (m.message.dict_find_string_value("y") == "e")
{
// It's an error.
#ifndef TORRENT_DISABLE_LOGGING
bdecode_node err_ent = m.message.dict_find("e");
TORRENT_ASSERT(err_ent);
m_log->log(dht_logger::rpc_manager, "reply with error from %s: %s"
, print_endpoint(m.addr).c_str(), err_ent.list_string_value_at(1).c_str());
bdecode_node err = m.message.dict_find("e");
if (err && err.list_size() >= 2
&& err.list_at(0).type() == bdecode_node::int_t
&& err.list_at(1).type() == bdecode_node::string_t)
{
m_log->log(dht_logger::rpc_manager, "reply with error from %s: (%" PRId64 ") %s"
, print_endpoint(m.addr).c_str()
, err.list_int_value_at(0)
, err.list_string_value_at(1).c_str());
}
else
{
m_log->log(dht_logger::rpc_manager, "reply with (malformed) error from %s"
, print_endpoint(m.addr).c_str());
}
#endif
// Logically, we should call o->reply(m) since we get a reply.
// a reply could be "response" or "error", here the reply is an "error".

View File

@ -99,6 +99,8 @@ struct mock_socket : udp_socket_interface
bool has_quota() { return true; }
bool send_packet(entry& msg, udp::endpoint const& ep, int flags)
{
// TODO: ideally the mock_socket would contain this queue of packets, to
// make tests independent
g_sent_packets.push_back(std::make_pair(ep, msg));
return true;
}
@ -274,10 +276,6 @@ void send_dht_response(node& node, bdecode_node const& request, udp::endpoint co
e["r"].dict().insert(std::make_pair("id", generate_next().to_string()));
char msg_buf[1500];
int size = bencode(msg_buf, e);
#if defined TORRENT_DEBUG && TORRENT_USE_IOSTREAM
// this yields a lot of output. too much
// std::cerr << "sending: " << e << "\n";
#endif
#ifdef TORRENT_USE_VALGRIND
VALGRIND_CHECK_MEM_IS_DEFINED(msg_buf, size);
@ -762,7 +760,7 @@ TORRENT_TEST(dht)
fprintf(stderr, "msg: %s\n", print_entry(response).c_str());
fprintf(stderr, " invalid error response: %s\n", error_string);
}
// a node with invalid node-id shouldn't be added to routing table.
TEST_EQUAL(node.size().get<0>(), nodes_num);
@ -2280,7 +2278,7 @@ TORRENT_TEST(read_only_node)
mock_socket s;
obs observer;
counters cnt;
dht::node node(&s, sett, node_id(0), &observer, cnt);
udp::endpoint source(address::from_string("10.0.0.1"), 20);
bdecode_node response;
@ -2361,5 +2359,30 @@ TORRENT_TEST(read_only_node)
TEST_CHECK(!parsed[3]);
}
TORRENT_TEST(invalid_error_msg)
{
dht_settings sett = test_settings();
mock_socket s;
obs observer;
counters cnt;
dht::node node(&s, sett, node_id(0), &observer, cnt);
udp::endpoint source(address::from_string("10.0.0.1"), 20);
entry e;
e["y"] = "e";
e["e"].string() = "Malformed Error";
char msg_buf[1500];
int size = bencode(msg_buf, e);
bdecode_node decoded;
error_code ec;
bdecode(msg_buf, msg_buf + size, decoded, ec);
if (ec) fprintf(stderr, "bdecode failed: %s\n", ec.message().c_str());
dht::msg m(decoded, source);
node.incoming(m);
}
#endif