diff --git a/src/ut_pex.cpp b/src/ut_pex.cpp index d1248637c..adf09a6fe 100644 --- a/src/ut_pex.cpp +++ b/src/ut_pex.cpp @@ -113,7 +113,7 @@ namespace libtorrent { namespace for (torrent::peer_iterator i = m_torrent.begin() , end(m_torrent.end()); i != end; ++i) { - peer_connection* peer = *i; + peer_connection* peer = *i; if (!send_peer(*peer)) continue; tcp::endpoint const& remote = peer->remote(); @@ -198,18 +198,15 @@ namespace libtorrent { namespace virtual bool on_extension_handshake(entry const& h) { - entry const& messages = h["m"]; + m_message_index = 0; + entry const* messages = h.find_key("m"); + if (!messages || messages->type() != entry::dictionary_t) return false; - if (entry const* index = messages.find_key(extension_name)) - { - m_message_index = index->integer(); - return true; - } - else - { - m_message_index = 0; - return false; - } + entry const* index = messages->find_key(extension_name); + if (!index || index->type() != entry::int_t) return false; + + m_message_index = index->integer(); + return true; } virtual bool on_extended(int length, int msg, buffer::const_interval body) @@ -222,11 +219,15 @@ namespace libtorrent { namespace if (body.left() < length) return true; - try + entry pex_msg = bdecode(body.begin, body.end); + + entry const* p = pex_msg.find_key("added"); + entry const* pf = pex_msg.find_key("added.f"); + + if (p != 0 && pf != 0 && p->type() == entry::string_t && pf->type() == entry::string_t) { - entry pex_msg = bdecode(body.begin, body.end); - std::string const& peers = pex_msg["added"].string(); - std::string const& peer_flags = pex_msg["added.f"].string(); + std::string const& peers = p->string(); + std::string const& peer_flags = pf->string(); int num_peers = peers.length() / 6; char const* in = peers.c_str(); @@ -243,32 +244,30 @@ namespace libtorrent { namespace char flags = detail::read_uint8(fin); p.peer_from_tracker(adr, pid, peer_info::pex, flags); } - - if (entry const* p6 = pex_msg.find_key("added6")) - { - std::string const& peers6 = p6->string(); - std::string const& peer6_flags = pex_msg["added6.f"].string(); - - int num_peers = peers6.length() / 18; - char const* in = peers6.c_str(); - char const* fin = peer6_flags.c_str(); - - if (int(peer6_flags.size()) != num_peers) - return true; - - peer_id pid(0); - policy& p = m_torrent.get_policy(); - for (int i = 0; i < num_peers; ++i) - { - tcp::endpoint adr = detail::read_v6_endpoint(in); - char flags = detail::read_uint8(fin); - p.peer_from_tracker(adr, pid, peer_info::pex, flags); - } - } } - catch (std::exception&) + + entry const* p6 = pex_msg.find_key("added6"); + entry const* p6f = pex_msg.find_key("added6.f"); + if (p6 && p6f && p6->type() == entry::string_t && p6f->type() == entry::string_t) { - throw protocol_error("invalid uT peer exchange message"); + std::string const& peers6 = p6->string(); + std::string const& peer6_flags = p6f->string(); + + int num_peers = peers6.length() / 18; + char const* in = peers6.c_str(); + char const* fin = peer6_flags.c_str(); + + if (int(peer6_flags.size()) != num_peers) + return true; + + peer_id pid(0); + policy& p = m_torrent.get_policy(); + for (int i = 0; i < num_peers; ++i) + { + tcp::endpoint adr = detail::read_v6_endpoint(in); + char flags = detail::read_uint8(fin); + p.peer_from_tracker(adr, pid, peer_info::pex, flags); + } } return true; }