diff --git a/test/test_fast_extension.cpp b/test/test_fast_extension.cpp index 17081d8e9..91256fa0c 100644 --- a/test/test_fast_extension.cpp +++ b/test/test_fast_extension.cpp @@ -82,27 +82,27 @@ void print_session_log(lt::session& ses) print_alerts(ses, "ses", true); } -int read_message(tcp::socket& s, char* buffer, int max_size) +int read_message(tcp::socket& s, span buffer) { using namespace lt::detail; error_code ec; - boost::asio::read(s, boost::asio::buffer(buffer, 4) + boost::asio::read(s, boost::asio::buffer(buffer.data(), 4) , boost::asio::transfer_all(), ec); if (ec) { TEST_ERROR(ec.message()); return -1; } - char* ptr = buffer; - int length = read_int32(ptr); - if (length > max_size) + char const* ptr = buffer.data(); + int const length = read_int32(ptr); + if (length > buffer.size()) { log("message size: %d", length); TEST_ERROR("message size exceeds max limit"); return -1; } - boost::asio::read(s, boost::asio::buffer(buffer, std::size_t(length)) + boost::asio::read(s, boost::asio::buffer(buffer.data(), std::size_t(length)) , boost::asio::transfer_all(), ec); if (ec) { @@ -112,7 +112,7 @@ int read_message(tcp::socket& s, char* buffer, int max_size) return length; } -void print_message(char const* buffer, int len) +void print_message(span buffer) { char const* message_name[] = {"choke", "unchoke", "interested", "not_interested" , "have", "bitfield", "request", "piece", "cancel", "dht_port", "", "", "" @@ -121,13 +121,13 @@ void print_message(char const* buffer, int len) std::stringstream message; char extra[300]; extra[0] = 0; - if (len == 0) + if (buffer.empty()) { message << "keepalive"; } else { - int msg = buffer[0]; + int const msg = buffer[0]; if (msg >= 0 && msg < int(sizeof(message_name)/sizeof(message_name[0]))) message << message_name[msg]; else if (msg == 20) @@ -135,26 +135,26 @@ void print_message(char const* buffer, int len) else message << "unknown[" << msg << "]"; - if (msg == 0x6 && len == 13) + if (msg == 0x6 && buffer.size() == 13) { peer_request r; - const char* ptr = buffer + 1; + const char* ptr = buffer.data() + 1; r.piece = piece_index_t(detail::read_int32(ptr)); r.start = detail::read_int32(ptr); r.length = detail::read_int32(ptr); std::snprintf(extra, sizeof(extra), "p: %d s: %d l: %d" , static_cast(r.piece), r.start, r.length); } - else if (msg == 0x11 && len == 5) + else if (msg == 0x11 && buffer.size() == 5) { - const char* ptr = buffer + 1; + const char* ptr = buffer.data() + 1; int index = detail::read_int32(ptr); std::snprintf(extra, sizeof(extra), "p: %d", index); } - else if (msg == 20 && len > 4 && buffer[1] == 0 ) + else if (msg == 20 && buffer.size() > 4 && buffer[1] == 0 ) { std::snprintf(extra, sizeof(extra), "%s" - , print_entry(bdecode({buffer + 2, len - 2})).c_str()); + , print_entry(bdecode(buffer.subspan(2))).c_str()); } } @@ -338,17 +338,18 @@ void send_request(tcp::socket& s, peer_request req) if (ec) TEST_ERROR(ec.message()); } -entry read_extension_handshake(tcp::socket& s, char* recv_buffer, int size) +entry read_extension_handshake(tcp::socket& s, span recv_buffer) { for (;;) { - int len = read_message(s, recv_buffer, size); + int const len = read_message(s, recv_buffer); if (len == -1) { TEST_ERROR("failed to read message"); return entry(); } - print_message(recv_buffer, len); + recv_buffer = recv_buffer.first(len); + print_message(recv_buffer); if (len < 4) continue; int msg = recv_buffer[0]; @@ -356,7 +357,7 @@ entry read_extension_handshake(tcp::socket& s, char* recv_buffer, int size) int extmsg = recv_buffer[1]; if (extmsg != 0) continue; - return bdecode({recv_buffer + 2, len - 2}); + return bdecode(recv_buffer.subspan(2)); } } @@ -391,25 +392,26 @@ void send_ut_metadata_msg(tcp::socket& s, int ut_metadata_msg, int type, int pie if (ec) TEST_ERROR(ec.message()); } -entry read_ut_metadata_msg(tcp::socket& s, char* recv_buffer, int size) +entry read_ut_metadata_msg(tcp::socket& s, span recv_buffer) { for (;;) { - int len = read_message(s, recv_buffer, size); + int const len = read_message(s, recv_buffer); if (len == -1) { TEST_ERROR("failed to read message"); return entry(); } - print_message(recv_buffer, len); + auto const buffer = recv_buffer.first(len); + print_message(buffer); if (len < 4) continue; - int msg = recv_buffer[0]; + int const msg = buffer[0]; if (msg != 20) continue; - int extmsg = recv_buffer[1]; + int const extmsg = buffer[1]; if (extmsg != 1) continue; - return bdecode({recv_buffer + 2, len - 2}); + return bdecode(buffer.subspan(2)); } } #endif // TORRENT_DISABLE_EXTENSIONS @@ -509,15 +511,16 @@ TORRENT_TEST(reject_fast) while (!allowed_fast.empty()) { print_session_log(*ses); - int len = read_message(s, recv_buffer, sizeof(recv_buffer)); + int const len = read_message(s, recv_buffer); if (len == -1) break; - print_message(recv_buffer, len); - int msg = recv_buffer[0]; + auto buffer = span(recv_buffer).first(len); + print_message(buffer); + int msg = buffer[0]; if (msg != 0x6) continue; using namespace lt::detail; - char* ptr = recv_buffer + 1; - int piece = read_int32(ptr); + char const* ptr = buffer.data() + 1; + int const piece = read_int32(ptr); std::vector::iterator i = std::find(allowed_fast.begin() , allowed_fast.end(), piece); @@ -572,17 +575,19 @@ TORRENT_TEST(invalid_suggest) std::this_thread::sleep_for(lt::milliseconds(500)); print_session_log(*ses); - int len = read_message(s, recv_buffer, sizeof(recv_buffer)); + int len = read_message(s, recv_buffer); + auto buffer = span(recv_buffer).first(len); int idx = -1; while (len > 0) { - if (recv_buffer[0] == 6) + if (buffer[0] == 6) { - char* ptr = recv_buffer + 1; + char const* ptr = buffer.data() + 1; idx = detail::read_int32(ptr); break; } - len = read_message(s, recv_buffer, sizeof(recv_buffer)); + len = read_message(s, recv_buffer); + buffer = span(recv_buffer).first(len); } TEST_CHECK(idx != -234); TEST_CHECK(idx != -1); @@ -625,15 +630,16 @@ TORRENT_TEST(reject_suggest) while (!suggested.empty() && fail_counter > 0) { print_session_log(*ses); - int len = read_message(s, recv_buffer, sizeof(recv_buffer)); + int const len = read_message(s, recv_buffer); if (len == -1) break; - print_message(recv_buffer, len); - int msg = recv_buffer[0]; + auto buffer = span(recv_buffer).first(len); + print_message(buffer); + int const msg = buffer[0]; fail_counter--; if (msg != 0x6) continue; using namespace lt::detail; - char* ptr = recv_buffer + 1; + char const* ptr = buffer.data() + 1; int const piece = read_int32(ptr); std::vector::iterator i = std::find(suggested.begin() @@ -701,17 +707,18 @@ TORRENT_TEST(suggest_order) while (!suggested.empty() && fail_counter > 0) { print_session_log(*ses); - int len = read_message(s, recv_buffer, sizeof(recv_buffer)); + int const len = read_message(s, recv_buffer); if (len == -1) break; - print_message(recv_buffer, len); - int msg = recv_buffer[0]; + auto const buffer = span(recv_buffer).first(len); + print_message({recv_buffer, len}); + int const msg = recv_buffer[0]; fail_counter--; // we're just interested in requests if (msg != 0x6) continue; using namespace lt::detail; - char* ptr = recv_buffer + 1; + char const* ptr = buffer.data() + 1; int const piece = read_int32(ptr); // make sure we receive the requests inverse order of sending the suggest @@ -831,18 +838,19 @@ TORRENT_TEST(dont_have) { print_session_log(*ses); - int const len = read_message(s, recv_buffer, sizeof(recv_buffer)); + int const len = read_message(s, recv_buffer); if (len == -1) break; - print_message(recv_buffer, len); + auto const buffer = span(recv_buffer).first(len); + print_message(buffer); if (len == 0) continue; - int msg = recv_buffer[0]; + int const msg = buffer[0]; if (msg != 20) continue; - int ext_msg = recv_buffer[1]; + int const ext_msg = buffer[1]; if (ext_msg != 0) continue; int pos = 0; ec.clear(); - bdecode_node e = bdecode({recv_buffer + 2, len - 2}, ec, &pos); + bdecode_node e = bdecode(buffer.subspan(2), ec, &pos); if (ec) { log("failed to parse extension handshake: %s at pos %d" @@ -912,7 +920,7 @@ TORRENT_TEST(extension_handshake) entry extensions; send_extension_handshake(s, extensions); - extensions = read_extension_handshake(s, recv_buffer, sizeof(recv_buffer)); + extensions = read_extension_handshake(s, recv_buffer); std::cout << extensions << '\n'; @@ -954,7 +962,7 @@ TORRENT_TEST(invalid_metadata_request) extensions["m"]["ut_metadata"] = 1; send_extension_handshake(s, extensions); - extensions = read_extension_handshake(s, recv_buffer, sizeof(recv_buffer)); + extensions = read_extension_handshake(s, recv_buffer); int ut_metadata = int(extensions["m"]["ut_metadata"].integer()); @@ -970,15 +978,13 @@ TORRENT_TEST(invalid_metadata_request) // we assume we were not disconnected because of the invalid one send_ut_metadata_msg(s, ut_metadata, 0, 0); - entry ut_metadata_msg = read_ut_metadata_msg(s, recv_buffer - , sizeof(recv_buffer)); + entry ut_metadata_msg = read_ut_metadata_msg(s, recv_buffer); // the first response should be "dont-have" TEST_EQUAL(ut_metadata_msg["msg_type"].integer(), 2); TEST_EQUAL(ut_metadata_msg["piece"].integer(), 1); - ut_metadata_msg = read_ut_metadata_msg(s, recv_buffer - , sizeof(recv_buffer)); + ut_metadata_msg = read_ut_metadata_msg(s, recv_buffer); // the second response should be the payload TEST_EQUAL(ut_metadata_msg["msg_type"].integer(), 1); @@ -1029,14 +1035,15 @@ void have_all_test(bool const incoming) // since we advertised support for FAST extensions for (;;) { - int const len = read_message(s, recv_buffer, sizeof(recv_buffer)); + int const len = read_message(s, recv_buffer); if (len == -1) { TEST_ERROR("failed to receive have-all despite advertising support for FAST"); break; } - print_message(recv_buffer, len); - int const msg = recv_buffer[0]; + auto const buffer = span(recv_buffer).first(len); + print_message(buffer); + int const msg = buffer[0]; if (msg == 0xe) // have-all { // success!