diff --git a/src/kademlia/rpc_manager.cpp b/src/kademlia/rpc_manager.cpp index 69a900f77..613619d44 100644 --- a/src/kademlia/rpc_manager.cpp +++ b/src/kademlia/rpc_manager.cpp @@ -312,7 +312,7 @@ bool rpc_manager::incoming(msg const& m, node_id* id) { // It's an error. #ifndef TORRENT_DISABLE_LOGGING - bdecode_node err = m.message.dict_find("e"); + bdecode_node err = m.message.dict_find_list("e"); if (err && err.list_size() >= 2 && err.list_at(0).type() == bdecode_node::int_t && err.list_at(1).type() == bdecode_node::string_t) diff --git a/src/web_peer_connection.cpp b/src/web_peer_connection.cpp index 39017ea61..022da0b29 100644 --- a/src/web_peer_connection.cpp +++ b/src/web_peer_connection.cpp @@ -98,7 +98,7 @@ web_peer_connection::web_peer_connection(peer_connection_args const& pack if (!web.supports_keepalive) preferred_size *= 4; prefer_contiguous_blocks((std::max)(preferred_size / tor->block_size(), 1)); - + // we want large blocks as well, so // we can request more bytes at once // this setting will merge adjacent requests @@ -196,7 +196,7 @@ void web_peer_connection::disconnect(error_code const& ec peer_connection::disconnect(ec, op, error); if (t) t->disconnect_web_seed(this); } - + boost::optional web_peer_connection::downloading_piece_progress() const { @@ -948,7 +948,7 @@ void web_peer_connection::on_receive(error_code const& error // block from the http receive buffer and then // (if it completed) call incoming_piece() with // m_piece as buffer. - + int piece_size = int(m_piece.size()); int copy_size = (std::min)((std::min)(front_request.length - piece_size , recv_buffer.left()), int(range_end - range_start - m_received_body)); @@ -1061,7 +1061,7 @@ void web_peer_connection::on_receive(error_code const& error m_received_body = 0; m_chunk_pos = 0; m_partial_chunk_header = 0; - + if (!t->need_loaded()) { disconnect(errors::torrent_aborted, op_bittorrent); diff --git a/test/test_dht.cpp b/test/test_dht.cpp index 8eb7b0e4e..aa6fcc389 100644 --- a/test/test_dht.cpp +++ b/test/test_dht.cpp @@ -2397,14 +2397,75 @@ TORRENT_TEST(invalid_error_msg) bool found = false; for (int i = 0; i < int(observer.m_log.size()); ++i) { - if (observer.m_log[i].find("INCOMING ERROR") - && observer.m_log[i].find("(malformed)")) + if (observer.m_log[i].find("INCOMING ERROR") != std::string::npos + && observer.m_log[i].find("(malformed)") != std::string::npos) found = true; printf("%s\n", observer.m_log[i].c_str()); } - TEST_EQUAL(found, false); + TEST_EQUAL(found, true); +} + +TORRENT_TEST(rpc_invalid_error_msg) +{ + dht_settings sett = test_settings(); + mock_socket s; + obs observer; + counters cnt; + + dht::routing_table table(node_id(), 8, sett, &observer); + dht::rpc_manager rpc(node_id(), sett, table, &s, &observer); + dht::node node(&s, sett, node_id(0), &observer, cnt); + + udp::endpoint source(address::from_string("10.0.0.1"), 20); + + // we need this to create an entry for this transaction ID, otherwise the + // incoming message will just be dropped + entry req; + req["y"] = "q"; + req["q"] = "bogus_query"; + req["t"] = "\0\0\0\0"; + + g_sent_packets.clear(); + boost::intrusive_ptr algo(new dht::traversal_algorithm( + node, node_id())); + + observer_ptr o(new (rpc.allocate_observer()) null_observer(algo, source, node_id())); +#if defined TORRENT_DEBUG || defined TORRENT_RELEASE_ASSERTS + o->m_in_constructor = false; +#endif + rpc.invoke(req, source, o); + + // here's the incoming (malformed) error message + entry err; + err["y"] = "e"; + err["e"].string() = "Malformed Error"; + err["t"] = g_sent_packets.begin()->second["t"].string(); + char msg_buf[1500]; + int size = bencode(msg_buf, err); + + bdecode_node decoded; + error_code ec; + bdecode(msg_buf, msg_buf + size, decoded, ec); + if (ec) fprintf(stderr, "bdecode failed: %s\n", ec.message().c_str()); + + dht::msg m(decoded, source); + node_id nid; + rpc.incoming(m, &nid); + + bool found = false; + for (int i = 0; i < int(observer.m_log.size()); ++i) + { + if (observer.m_log[i].find("reply with") != std::string::npos + && observer.m_log[i].find("(malformed)") != std::string::npos + && observer.m_log[i].find("error") != std::string::npos) + found = true; + + printf("%s\n", observer.m_log[i].c_str()); + } + + TEST_EQUAL(found, true); } #endif