diff --git a/src/kademlia/node.cpp b/src/kademlia/node.cpp index b364caa47..d315cb803 100644 --- a/src/kademlia/node.cpp +++ b/src/kademlia/node.cpp @@ -651,6 +651,57 @@ namespace } } } + + struct key_desc_t + { + char const* name; + int type; + int size; + int flags; + + enum { optional = 1}; + }; + + // verifies that a message has all the required + // entries and returns them in ret + bool verify_message(lazy_entry const* msg, key_desc_t const desc[], lazy_entry const* ret[] + , int size , char* error, int error_size) + { + for (int i = 0; i < size; ++i) ret[i] = 0; + + if (msg->type() != lazy_entry::dict_t) + { + snprintf(error, error_size, "not a dictionary"); + return false; + } + for (int i = 0; i < size; ++i) + { + key_desc_t const& k = desc[i]; + ret[i] = msg->dict_find(k.name); + if (ret[i] && ret[i]->type() != k.type) ret[i] = 0; + if (ret[i] == 0 && (k.flags & key_desc_t::optional) == 0) + { + // the key was not found, and it's not an optiona key + snprintf(error, error_size, "missing '%s' key", k.name); + return false; + } + + if (k.size > 0 + && ret[i] + && k.type == lazy_entry::string_t + && ret[i]->string_length() != k.size) + { + // the string was not of the required size + ret[i] = 0; + if ((k.flags & key_desc_t::optional) == 0) + { + snprintf(error, error_size, "invalid value for '%s'", k.name); + return false; + } + } + } + return true; + } } void incoming_error(entry& e, char const* msg) @@ -668,21 +719,22 @@ void node_impl::incoming_request(msg const& m, entry& e) e["y"] = "r"; e["t"] = m.message.dict_find_string_value("t"); - lazy_entry const* query_ent = m.message.dict_find_string("q"); - if (query_ent == 0) + key_desc_t top_desc[] = { + {"q", lazy_entry::string_t, 0, 0}, + {"a", lazy_entry::dict_t, 0, 0}, + }; + + lazy_entry const* top_level[2]; + char error_string[200]; + if (!verify_message(&m.message, top_desc, top_level, 2, error_string, sizeof(error_string))) { - incoming_error(e, "missing 'q' key"); + incoming_error(e, error_string); return; } - char const* query = query_ent->string_cstr(); + char const* query = top_level[0]->string_cstr(); - lazy_entry const* arg_ent = m.message.dict_find_dict("a"); - if (arg_ent == 0) - { - incoming_error(e, "missing 'a' key"); - return; - } + lazy_entry const* arg_ent = top_level[1]; lazy_entry const* node_id_ent = arg_ent->dict_find_string("id"); if (node_id_ent == 0 || node_id_ent->string_length() != 20) @@ -705,16 +757,20 @@ void node_impl::incoming_request(msg const& m, entry& e) } else if (strcmp(query, "get_peers") == 0) { - lazy_entry const* info_hash_ent = arg_ent->dict_find_string("info_hash"); - if (info_hash_ent == 0 || info_hash_ent->string_length() != 20) + key_desc_t msg_desc[] = { + {"info_hash", lazy_entry::string_t, 20, 0}, + }; + + lazy_entry const* msg_keys[1]; + if (!verify_message(arg_ent, msg_desc, msg_keys, 1, error_string, sizeof(error_string))) { - incoming_error(e, "missing 'info-hash' key"); + incoming_error(e, error_string); return; } - reply["token"] = generate_token(m.addr, info_hash_ent->string_ptr()); + reply["token"] = generate_token(m.addr, msg_keys[0]->string_ptr()); - sha1_hash info_hash(info_hash_ent->string_ptr()); + sha1_hash info_hash(msg_keys[0]->string_ptr()); nodes_t n; // always return nodes as well as peers m_table.find_node(info_hash, n, 0); @@ -727,48 +783,53 @@ void node_impl::incoming_request(msg const& m, entry& e) } else if (strcmp(query, "find_node") == 0) { - lazy_entry const* target_ent = arg_ent->dict_find_string("target"); - if (target_ent == 0 || target_ent->string_length() != 20) + key_desc_t msg_desc[] = { + {"target", lazy_entry::string_t, 20, 0}, + }; + + lazy_entry const* msg_keys[1]; + if (!verify_message(arg_ent, msg_desc, msg_keys, 1, error_string, sizeof(error_string))) { - incoming_error(e, "missing 'target' key"); + incoming_error(e, error_string); return; } - sha1_hash target(target_ent->string_ptr()); + sha1_hash target(msg_keys[0]->string_ptr()); + + // TODO: find_node should write directly to the response entry nodes_t n; m_table.find_node(target, n, 0); write_nodes_entry(reply, n); } else if (strcmp(query, "announce_peer") == 0) { - lazy_entry const* info_hash_ent = arg_ent->dict_find_string("info_hash"); - if (info_hash_ent == 0 || info_hash_ent->string_length() != 20) + key_desc_t msg_desc[] = { + {"info_hash", lazy_entry::string_t, 20, 0}, + {"port", lazy_entry::int_t, 0, 0}, + {"token", lazy_entry::string_t, 0, 0}, + }; + + lazy_entry const* msg_keys[3]; + if (!verify_message(arg_ent, msg_desc, msg_keys, 3, error_string, sizeof(error_string))) { - incoming_error(e, "missing 'info-hash' key"); + incoming_error(e, error_string); return; } - int port = arg_ent->dict_find_int_value("port", -1); + int port = msg_keys[1]->int_value(); if (port < 0 || port >= 65536) { incoming_error(e, "invalid 'port' in announce"); return; } - sha1_hash info_hash(info_hash_ent->string_ptr()); + sha1_hash info_hash(msg_keys[0]->string_ptr()); if (m_ses.m_alerts.should_post()) m_ses.m_alerts.post_alert(dht_announce_alert( m.addr.address(), port, info_hash)); - lazy_entry const* token = arg_ent->dict_find_string("token"); - if (!token) - { - incoming_error(e, "missing 'token' key in announce"); - return; - } - - if (!verify_token(token->string_value(), info_hash_ent->string_ptr(), m.addr)) + if (!verify_token(msg_keys[2]->string_value(), msg_keys[0]->string_ptr(), m.addr)) { incoming_error(e, "invalid token in announce"); return; @@ -789,57 +850,42 @@ void node_impl::incoming_request(msg const& m, entry& e) } else if (strcmp(query, "find_torrent") == 0) { - lazy_entry const* target_ent = arg_ent->dict_find_string("target"); - if (target_ent == 0 || target_ent->string_length() != 20) + key_desc_t msg_desc[] = { + {"target", lazy_entry::string_t, 20, 0}, + {"tags", lazy_entry::string_t, 0, 0}, + }; + + lazy_entry const* msg_keys[2]; + if (!verify_message(arg_ent, msg_desc, msg_keys, 2, error_string, sizeof(error_string))) { - incoming_error(e, "missing 'target' key"); + incoming_error(e, error_string); return; } - lazy_entry const* tags_ent = arg_ent->dict_find_string("tags"); - if (tags_ent == 0) - { - incoming_error(e, "missing 'tags' key"); - return; - } + reply["token"] = generate_token(m.addr, msg_keys[0]->string_ptr()); - reply["token"] = generate_token(m.addr, target_ent->string_ptr()); - - sha1_hash target(target_ent->string_ptr()); + sha1_hash target(msg_keys[0]->string_ptr()); nodes_t n; // always return nodes as well as torrents m_table.find_node(target, n, 0); write_nodes_entry(reply, n); - lookup_torrents(target, reply, (char*)tags_ent->string_cstr()); + lookup_torrents(target, reply, (char*)msg_keys[1]->string_cstr()); } else if (strcmp(query, "announce_torrent") == 0) { - lazy_entry const* target_ent = arg_ent->dict_find_string("target"); - if (target_ent == 0 || target_ent->string_length() != 20) - { - incoming_error(e, "missing 'target' key"); - return; - } + key_desc_t msg_desc[] = { + {"target", lazy_entry::string_t, 20, 0}, + {"info_hash", lazy_entry::string_t, 20, 0}, + {"name", lazy_entry::string_t, 0, 0}, + {"tags", lazy_entry::string_t, 0, 0}, + {"token", lazy_entry::string_t, 0, 0}, + }; - lazy_entry const* info_hash_ent = arg_ent->dict_find_string("info_hash"); - if (info_hash_ent == 0 || info_hash_ent->string_length() != 20) + lazy_entry const* msg_keys[4]; + if (!verify_message(arg_ent, msg_desc, msg_keys, 4, error_string, sizeof(error_string))) { - incoming_error(e, "missing 'target' key"); - return; - } - - lazy_entry const* name_ent = arg_ent->dict_find_string("name"); - if (name_ent == 0) - { - incoming_error(e, "missing 'name' key"); - return; - } - - lazy_entry const* tags_ent = arg_ent->dict_find_string("tags"); - if (tags_ent == 0) - { - incoming_error(e, "missing 'tags' key"); + incoming_error(e, error_string); return; } @@ -847,21 +893,14 @@ void node_impl::incoming_request(msg const& m, entry& e) // m_ses.m_alerts.post_alert(dht_announce_torrent_alert( // m.addr.address(), name, tags, info_hash)); - lazy_entry const* token = arg_ent->dict_find_string("token"); - if (!token) - { - incoming_error(e, "missing 'token' key in announce"); - return; - } - - if (!verify_token(token->string_value(), target_ent->string_ptr(), m.addr)) + if (!verify_token(msg_keys[4]->string_value(), msg_keys[0]->string_ptr(), m.addr)) { incoming_error(e, "invalid token in announce"); return; } - sha1_hash target(target_ent->string_ptr()); - sha1_hash info_hash(info_hash_ent->string_ptr()); + sha1_hash target(msg_keys[0]->string_ptr()); + sha1_hash info_hash(msg_keys[1]->string_ptr()); // the token was correct. That means this // node is not spoofing its address. So, let @@ -878,9 +917,9 @@ void node_impl::incoming_request(msg const& m, entry& e) char const* in_tags[20]; int num_tags = 0; - num_tags = split_string(in_tags, 20, (char*)tags_ent->string_cstr()); + num_tags = split_string(in_tags, 20, (char*)msg_keys[3]->string_cstr()); - i->second.publish(name_ent->string_value(), in_tags, num_tags); + i->second.publish(msg_keys[2]->string_value(), in_tags, num_tags); } else {