diff --git a/include/libtorrent/upnp.hpp b/include/libtorrent/upnp.hpp index 47a3b7961..c62aceb6f 100644 --- a/include/libtorrent/upnp.hpp +++ b/include/libtorrent/upnp.hpp @@ -101,8 +101,7 @@ namespace libtorrent struct parse_state { - parse_state(): in_service(false) {} - bool in_service; + bool in_service = false; std::list tag_stack; std::string control_url; std::string service_type; @@ -120,9 +119,28 @@ struct parse_state } }; +struct error_code_parse_state +{ + bool in_error_code = false; + bool exit = false; + int error_code = -1; +}; + +struct ip_address_parse_state: error_code_parse_state +{ + bool in_ip_address = false; + std::string ip_address; +}; + TORRENT_EXTRA_EXPORT void find_control_url(int type, char const* string , int str_len, parse_state& state); +TORRENT_EXTRA_EXPORT void find_error_code(int type, char const* string + , int str_len, error_code_parse_state& state); + +TORRENT_EXTRA_EXPORT void find_ip_address(int type, char const* string + , int str_len, ip_address_parse_state& state); + // TODO: support using the windows API for UPnP operations as well struct TORRENT_EXTRA_EXPORT upnp final : std::enable_shared_from_this diff --git a/src/session_impl.cpp b/src/session_impl.cpp index f84130f1b..9e40f86ad 100644 --- a/src/session_impl.cpp +++ b/src/session_impl.cpp @@ -1959,7 +1959,7 @@ namespace aux { map_handle = -1; // only update this mapping if we actually have a socket listening - if (ep.address() != address()) + if (ep != EndpointType()) map_handle = m.add_mapping(protocol, ep.port(), ep.port()); } } diff --git a/src/upnp.cpp b/src/upnp.cpp index 6f2f07ad1..b66eaf042 100644 --- a/src/upnp.cpp +++ b/src/upnp.cpp @@ -881,7 +881,7 @@ namespace } } -TORRENT_EXTRA_EXPORT void find_control_url(int type, char const* string +void find_control_url(int type, char const* string , int str_len, parse_state& state) { if (type == xml_start_tag) @@ -1115,53 +1115,39 @@ void upnp::disable(error_code const& ec) m_socket.close(); } +void find_error_code(int type, char const* string, int str_len, error_code_parse_state& state) +{ + if (state.exit) return; + if (type == xml_start_tag && !std::strncmp("errorCode", string, size_t(str_len))) + { + state.in_error_code = true; + } + else if (type == xml_string && state.in_error_code) + { + std::string error_code_str(string, str_len); + state.error_code = std::atoi(error_code_str.c_str()); + state.exit = true; + } +} + +void find_ip_address(int type, char const* string, int str_len, ip_address_parse_state& state) +{ + find_error_code(type, string, str_len, state); + if (state.exit) return; + + if (type == xml_start_tag && !std::strncmp("NewExternalIPAddress", string, size_t(str_len))) + { + state.in_ip_address = true; + } + else if (type == xml_string && state.in_ip_address) + { + state.ip_address.assign(string, str_len); + state.exit = true; + } +} + namespace { - struct error_code_parse_state - { - error_code_parse_state(): in_error_code(false), exit(false), error_code(-1) {} - bool in_error_code; - bool exit; - int error_code; - }; - - void find_error_code(int type, char const* string, error_code_parse_state& state) - { - if (state.exit) return; - if (type == xml_start_tag && !std::strcmp("errorCode", string)) - { - state.in_error_code = true; - } - else if (type == xml_string && state.in_error_code) - { - state.error_code = std::atoi(string); - state.exit = true; - } - } - - struct ip_address_parse_state: public error_code_parse_state - { - ip_address_parse_state(): in_ip_address(false) {} - bool in_ip_address; - std::string ip_address; - }; - - void find_ip_address(int type, char const* string, ip_address_parse_state& state) - { - find_error_code(type, string, state); - if (state.exit) return; - - if (type == xml_start_tag && !std::strcmp("NewExternalIPAddress", string)) - { - state.in_ip_address = true; - } - else if (type == xml_string && state.in_ip_address) - { - state.ip_address = string; - state.exit = true; - } - } - struct error_code_t { int code; @@ -1295,7 +1281,7 @@ void upnp::on_upnp_get_ip_address_response(error_code const& e #endif ip_address_parse_state s; - xml_parse(body, std::bind(&find_ip_address, _1, _2, std::ref(s))); + xml_parse(body, std::bind(&find_ip_address, _1, _2, _3, std::ref(s))); #ifndef TORRENT_DISABLE_LOGGING if (s.error_code != -1) { @@ -1396,7 +1382,7 @@ void upnp::on_upnp_map_response(error_code const& e error_code_parse_state s; span body = p.get_body(); - xml_parse(body, std::bind(&find_error_code, _1, _2, std::ref(s))); + xml_parse(body, std::bind(&find_error_code, _1, _2, _3, std::ref(s))); if (s.error_code != -1) { @@ -1550,7 +1536,7 @@ void upnp::on_upnp_unmap_response(error_code const& e error_code_parse_state s; if (p.header_finished()) { - xml_parse(p.get_body(), std::bind(&find_error_code, _1, _2, std::ref(s))); + xml_parse(p.get_body(), std::bind(&find_error_code, _1, _2, _3, std::ref(s))); } portmap_protocol const proto = m_mappings[mapping].protocol; diff --git a/test/test_xml.cpp b/test/test_xml.cpp index d60506502..cccad6541 100644 --- a/test/test_xml.cpp +++ b/test/test_xml.cpp @@ -233,6 +233,36 @@ char upnp_xml2[] = "" ""; +char upnp_xml3[] = +"" +"" +"" +"s:Client" +"UPnPError" +"" +"" +"402" +"Invalid Args" +"" +"" +"" +"" +""; + +char upnp_xml4[] = +"" +"" +"" +"" +"123.10.20.30" +"" +"" +""; + using namespace libtorrent; using namespace std::placeholders; @@ -307,6 +337,26 @@ TORRENT_TEST(upnp_parser2) TEST_EQUAL(xml_s.model, "Wireless-G ADSL Home Gateway"); } +TORRENT_TEST(upnp_parser3) +{ + error_code_parse_state xml_s; + xml_parse(upnp_xml3, std::bind(&find_error_code, _1, _2, _3, std::ref(xml_s))); + + std::cout << "error_code " << xml_s.error_code << std::endl; + TEST_EQUAL(xml_s.error_code, 402); +} + +TORRENT_TEST(upnp_parser4) +{ + ip_address_parse_state xml_s; + xml_parse(upnp_xml4, std::bind(&find_ip_address, _1, _2, _3, std::ref(xml_s))); + + std::cout << "error_code " << xml_s.error_code << std::endl; + std::cout << "ip_address " << xml_s.ip_address << std::endl; + TEST_EQUAL(xml_s.error_code, -1); + TEST_EQUAL(xml_s.ip_address, "123.10.20.30"); +} + TORRENT_TEST(tags) { test_parse("foobar", "BaSfooEbSbarFa");