/* */ #include "DHTMessageFactoryImpl.h" #include #include #include "LogFactory.h" #include "DlAbortEx.h" #include "DHTNode.h" #include "DHTRoutingTable.h" #include "DHTPingMessage.h" #include "DHTPingReplyMessage.h" #include "DHTFindNodeMessage.h" #include "DHTFindNodeReplyMessage.h" #include "DHTGetPeersMessage.h" #include "DHTGetPeersReplyMessage.h" #include "DHTAnnouncePeerMessage.h" #include "DHTAnnouncePeerReplyMessage.h" #include "DHTUnknownMessage.h" #include "DHTConnection.h" #include "DHTMessageDispatcher.h" #include "DHTPeerAnnounceStorage.h" #include "DHTTokenTracker.h" #include "DHTMessageCallback.h" #include "bittorrent_helper.h" #include "BtRuntime.h" #include "util.h" #include "Peer.h" #include "Logger.h" #include "fmt.h" namespace aria2 { DHTMessageFactoryImpl::DHTMessageFactoryImpl(int family) : family_{family}, connection_{nullptr}, dispatcher_{nullptr}, routingTable_{nullptr}, peerAnnounceStorage_{nullptr}, tokenTracker_{nullptr} { } std::shared_ptr DHTMessageFactoryImpl::getRemoteNode( const unsigned char* id, const std::string& ipaddr, uint16_t port) const { auto node = routingTable_->getNode(id, ipaddr, port); if (!node) { node = std::make_shared(id); node->setIPAddress(ipaddr); node->setPort(port); } return node; } namespace { const Dict* getDictionary(const Dict* dict, const std::string& key) { const Dict* d = downcast(dict->get(key)); if (d) { return d; } else { throw DL_ABORT_EX(fmt("Malformed DHT message. Missing %s", key.c_str())); } } } // namespace namespace { const String* getString(const Dict* dict, const std::string& key) { const String* c = downcast(dict->get(key)); if (c) { return c; } else { throw DL_ABORT_EX(fmt("Malformed DHT message. Missing %s", key.c_str())); } } } // namespace namespace { const Integer* getInteger(const Dict* dict, const std::string& key) { const Integer* c = downcast(dict->get(key)); if (c) { return c; } else { throw DL_ABORT_EX(fmt("Malformed DHT message. Missing %s", key.c_str())); } } } // namespace namespace { const String* getString(const List* list, size_t index) { const String* c = downcast(list->get(index)); if (c) { return c; } else { throw DL_ABORT_EX(fmt("Malformed DHT message. element[%lu] is not String.", static_cast(index))); } } } // namespace namespace { const Integer* getInteger(const List* list, size_t index) { const Integer* c = downcast(list->get(index)); if (c) { return c; } else { throw DL_ABORT_EX(fmt("Malformed DHT message. element[%lu] is not Integer.", static_cast(index))); } } } // namespace namespace { const List* getList(const Dict* dict, const std::string& key) { const List* l = downcast(dict->get(key)); if (l) { return l; } else { throw DL_ABORT_EX(fmt("Malformed DHT message. Missing %s", key.c_str())); } } } // namespace void DHTMessageFactoryImpl::validateID(const String* id) const { if (id->s().size() != DHT_ID_LENGTH) { throw DL_ABORT_EX(fmt("Malformed DHT message. Invalid ID length." " Expected:%lu, Actual:%lu", static_cast(DHT_ID_LENGTH), static_cast(id->s().size()))); } } void DHTMessageFactoryImpl::validatePort(const Integer* port) const { if (!(0 < port->i() && port->i() < UINT16_MAX)) { throw DL_ABORT_EX( fmt("Malformed DHT message. Invalid port=%" PRId64 "", port->i())); } } namespace { void setVersion(DHTMessage* msg, const Dict* dict) { const String* v = downcast(dict->get(DHTMessage::V)); if (v) { msg->setVersion(v->s()); } else { msg->setVersion(A2STR::NIL); } } } // namespace std::unique_ptr DHTMessageFactoryImpl::createQueryMessage( const Dict* dict, const std::string& ipaddr, uint16_t port) { const String* messageType = getString(dict, DHTQueryMessage::Q); const String* transactionID = getString(dict, DHTMessage::T); const String* y = getString(dict, DHTMessage::Y); const Dict* aDict = getDictionary(dict, DHTQueryMessage::A); if (y->s() != DHTQueryMessage::Q) { throw DL_ABORT_EX("Malformed DHT message. y != q"); } const String* id = getString(aDict, DHTMessage::ID); validateID(id); auto remoteNode = getRemoteNode(id->uc(), ipaddr, port); auto msg = std::unique_ptr{}; if (messageType->s() == DHTPingMessage::PING) { msg = createPingMessage(remoteNode, transactionID->s()); } else if (messageType->s() == DHTFindNodeMessage::FIND_NODE) { const String* targetNodeID = getString(aDict, DHTFindNodeMessage::TARGET_NODE); validateID(targetNodeID); msg = createFindNodeMessage(remoteNode, targetNodeID->uc(), transactionID->s()); } else if (messageType->s() == DHTGetPeersMessage::GET_PEERS) { const String* infoHash = getString(aDict, DHTGetPeersMessage::INFO_HASH); validateID(infoHash); msg = createGetPeersMessage(remoteNode, infoHash->uc(), transactionID->s()); } else if (messageType->s() == DHTAnnouncePeerMessage::ANNOUNCE_PEER) { const String* infoHash = getString(aDict, DHTAnnouncePeerMessage::INFO_HASH); validateID(infoHash); const Integer* port = getInteger(aDict, DHTAnnouncePeerMessage::PORT); validatePort(port); const String* token = getString(aDict, DHTAnnouncePeerMessage::TOKEN); msg = createAnnouncePeerMessage(remoteNode, infoHash->uc(), static_cast(port->i()), token->s(), transactionID->s()); } else { throw DL_ABORT_EX( fmt("Unsupported message type: %s", messageType->s().c_str())); } setVersion(msg.get(), dict); return msg; } std::unique_ptr DHTMessageFactoryImpl::createResponseMessage(const std::string& messageType, const Dict* dict, const std::string& ipaddr, uint16_t port) { const String* transactionID = getString(dict, DHTMessage::T); const String* y = getString(dict, DHTMessage::Y); if (y->s() == DHTUnknownMessage::E) { // for now, just report error message arrived and throw exception. const List* e = getList(dict, DHTUnknownMessage::E); if (e->size() == 2) { A2_LOG_INFO(fmt("Received Error DHT message. code=%" PRId64 ", msg=%s", getInteger(e, 0)->i(), util::percentEncode(getString(e, 1)->s()).c_str())); } else { A2_LOG_DEBUG("e doesn't have 2 elements."); } throw DL_ABORT_EX("Received Error DHT message."); } else if (y->s() != DHTResponseMessage::R) { throw DL_ABORT_EX(fmt("Malformed DHT message. y != r: y=%s", util::percentEncode(y->s()).c_str())); } const Dict* rDict = getDictionary(dict, DHTResponseMessage::R); const String* id = getString(rDict, DHTMessage::ID); validateID(id); auto remoteNode = getRemoteNode(id->uc(), ipaddr, port); auto msg = std::unique_ptr{}; if (messageType == DHTPingReplyMessage::PING) { msg = createPingReplyMessage(remoteNode, id->uc(), transactionID->s()); } else if (messageType == DHTFindNodeReplyMessage::FIND_NODE) { msg = createFindNodeReplyMessage(remoteNode, dict, transactionID->s()); } else if (messageType == DHTGetPeersReplyMessage::GET_PEERS) { msg = createGetPeersReplyMessage(remoteNode, dict, transactionID->s()); } else if (messageType == DHTAnnouncePeerReplyMessage::ANNOUNCE_PEER) { msg = createAnnouncePeerReplyMessage(remoteNode, transactionID->s()); } else { throw DL_ABORT_EX(fmt("Unsupported message type: %s", messageType.c_str())); } setVersion(msg.get(), dict); return msg; } namespace { const std::string& getDefaultVersion() { static std::string version; if (version.empty()) { uint16_t vnum16 = htons(DHT_VERSION); unsigned char buf[] = {'A', '2', 0, 0}; char* vnump = reinterpret_cast(&vnum16); memcpy(buf + 2, vnump, 2); version.assign(&buf[0], &buf[4]); } return version; } } // namespace void DHTMessageFactoryImpl::setCommonProperty(DHTAbstractMessage* m) { m->setConnection(connection_); m->setMessageDispatcher(dispatcher_); m->setRoutingTable(routingTable_); m->setMessageFactory(this); m->setVersion(getDefaultVersion()); } std::unique_ptr DHTMessageFactoryImpl::createPingMessage( const std::shared_ptr& remoteNode, const std::string& transactionID) { auto m = make_unique(localNode_, remoteNode, transactionID); setCommonProperty(m.get()); return m; } std::unique_ptr DHTMessageFactoryImpl::createPingReplyMessage( const std::shared_ptr& remoteNode, const unsigned char* id, const std::string& transactionID) { auto m = make_unique(localNode_, remoteNode, id, transactionID); setCommonProperty(m.get()); return m; } std::unique_ptr DHTMessageFactoryImpl::createFindNodeMessage( const std::shared_ptr& remoteNode, const unsigned char* targetNodeID, const std::string& transactionID) { auto m = make_unique(localNode_, remoteNode, targetNodeID, transactionID); setCommonProperty(m.get()); return m; } std::unique_ptr DHTMessageFactoryImpl::createFindNodeReplyMessage( const std::shared_ptr& remoteNode, std::vector> closestKNodes, const std::string& transactionID) { auto m = make_unique(family_, localNode_, remoteNode, transactionID); m->setClosestKNodes(std::move(closestKNodes)); setCommonProperty(m.get()); return m; } void DHTMessageFactoryImpl::extractNodes( std::vector>& nodes, const unsigned char* src, size_t length) { int unit = bittorrent::getCompactLength(family_) + 20; if (length % unit != 0) { throw DL_ABORT_EX(fmt("Nodes length is not multiple of %d", unit)); } for (size_t offset = 0; offset < length; offset += unit) { auto node = std::make_shared(src + offset); auto addr = bittorrent::unpackcompact(src + offset + DHT_ID_LENGTH, family_); if (addr.first.empty()) { continue; } node->setIPAddress(addr.first); node->setPort(addr.second); nodes.push_back(std::move(node)); } } std::unique_ptr DHTMessageFactoryImpl::createFindNodeReplyMessage( const std::shared_ptr& remoteNode, const Dict* dict, const std::string& transactionID) { const String* nodesData = downcast( getDictionary(dict, DHTResponseMessage::R) ->get(family_ == AF_INET ? DHTFindNodeReplyMessage::NODES : DHTFindNodeReplyMessage::NODES6)); std::vector> nodes; if (nodesData) { extractNodes(nodes, nodesData->uc(), nodesData->s().size()); } return createFindNodeReplyMessage(remoteNode, std::move(nodes), transactionID); } std::unique_ptr DHTMessageFactoryImpl::createGetPeersMessage( const std::shared_ptr& remoteNode, const unsigned char* infoHash, const std::string& transactionID) { auto m = make_unique(localNode_, remoteNode, infoHash, transactionID); m->setPeerAnnounceStorage(peerAnnounceStorage_); m->setTokenTracker(tokenTracker_); setCommonProperty(m.get()); return m; } std::unique_ptr DHTMessageFactoryImpl::createGetPeersReplyMessage( const std::shared_ptr& remoteNode, const Dict* dict, const std::string& transactionID) { const Dict* rDict = getDictionary(dict, DHTResponseMessage::R); const String* nodesData = downcast( rDict->get(family_ == AF_INET ? DHTGetPeersReplyMessage::NODES : DHTGetPeersReplyMessage::NODES6)); std::vector> nodes; if (nodesData) { extractNodes(nodes, nodesData->uc(), nodesData->s().size()); } const List* valuesList = downcast(rDict->get(DHTGetPeersReplyMessage::VALUES)); std::vector> peers; size_t clen = bittorrent::getCompactLength(family_); if (valuesList) { for (auto& elem : *valuesList) { const String* data = downcast(elem); if (data && data->s().size() == clen) { auto addr = bittorrent::unpackcompact(data->uc(), family_); if (addr.first.empty()) { continue; } peers.push_back(std::make_shared(addr.first, addr.second)); } } } const String* token = getString(rDict, DHTGetPeersReplyMessage::TOKEN); return createGetPeersReplyMessage(remoteNode, std::move(nodes), std::move(peers), token->s(), transactionID); } std::unique_ptr DHTMessageFactoryImpl::createGetPeersReplyMessage( const std::shared_ptr& remoteNode, std::vector> closestKNodes, std::vector> values, const std::string& token, const std::string& transactionID) { auto m = make_unique(family_, localNode_, remoteNode, token, transactionID); m->setClosestKNodes(std::move(closestKNodes)); m->setValues(std::move(values)); setCommonProperty(m.get()); return m; } std::unique_ptr DHTMessageFactoryImpl::createAnnouncePeerMessage( const std::shared_ptr& remoteNode, const unsigned char* infoHash, uint16_t tcpPort, const std::string& token, const std::string& transactionID) { auto m = make_unique(localNode_, remoteNode, infoHash, tcpPort, token, transactionID); m->setPeerAnnounceStorage(peerAnnounceStorage_); m->setTokenTracker(tokenTracker_); setCommonProperty(m.get()); return m; } std::unique_ptr DHTMessageFactoryImpl::createAnnouncePeerReplyMessage( const std::shared_ptr& remoteNode, const std::string& transactionID) { auto m = make_unique(localNode_, remoteNode, transactionID); setCommonProperty(m.get()); return m; } std::unique_ptr DHTMessageFactoryImpl::createUnknownMessage( const unsigned char* data, size_t length, const std::string& ipaddr, uint16_t port) { return make_unique(localNode_, data, length, ipaddr, port); } void DHTMessageFactoryImpl::setRoutingTable(DHTRoutingTable* routingTable) { routingTable_ = routingTable; } void DHTMessageFactoryImpl::setConnection(DHTConnection* connection) { connection_ = connection; } void DHTMessageFactoryImpl::setMessageDispatcher( DHTMessageDispatcher* dispatcher) { dispatcher_ = dispatcher; } void DHTMessageFactoryImpl::setPeerAnnounceStorage( DHTPeerAnnounceStorage* storage) { peerAnnounceStorage_ = storage; } void DHTMessageFactoryImpl::setTokenTracker(DHTTokenTracker* tokenTracker) { tokenTracker_ = tokenTracker; } void DHTMessageFactoryImpl::setLocalNode( const std::shared_ptr& localNode) { localNode_ = localNode; } } // namespace aria2