/* <!-- copyright */ /* * aria2 - The high speed download utility * * Copyright (C) 2006 Tatsuhiro Tsujikawa * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA * * In addition, as a special exception, the copyright holders give * permission to link the code of portions of this program with the * OpenSSL library under certain conditions as described in each * individual source file, and distribute linked combinations * including the two. * You must obey the GNU General Public License in all respects * for all of the code used other than OpenSSL. If you modify * file(s) with this exception, you may extend this exception to your * version of the file(s), but you are not obligated to do so. If you * do not wish to do so, delete this exception statement from your * version. If you delete this exception statement from all source * files in the program, then also delete it here. */ /* copyright --> */ #include "DHTMessageFactoryImpl.h" #include <cstring> #include <utility> #include "LogFactory.h" #include "DlAbortEx.h" #include "DHTNode.h" #include "DHTRoutingTable.h" #include "DHTPingMessage.h" #include "DHTPingReplyMessage.h" #include "DHTFindNodeMessage.h" #include "DHTFindNodeReplyMessage.h" #include "DHTGetPeersMessage.h" #include "DHTGetPeersReplyMessage.h" #include "DHTAnnouncePeerMessage.h" #include "DHTAnnouncePeerReplyMessage.h" #include "DHTUnknownMessage.h" #include "DHTConnection.h" #include "DHTMessageDispatcher.h" #include "DHTPeerAnnounceStorage.h" #include "DHTTokenTracker.h" #include "DHTMessageCallback.h" #include "PeerMessageUtil.h" #include "BtRuntime.h" #include "Util.h" #include "Peer.h" #include "Logger.h" #include "StringFormat.h" #include "bencode.h" namespace aria2 { DHTMessageFactoryImpl::DHTMessageFactoryImpl(): _logger(LogFactory::getInstance()) {} DHTMessageFactoryImpl::~DHTMessageFactoryImpl() {} SharedHandle<DHTNode> DHTMessageFactoryImpl::getRemoteNode(const unsigned char* id, const std::string& ipaddr, uint16_t port) const { SharedHandle<DHTNode> node = _routingTable->getNode(id, ipaddr, port); if(node.isNull()) { node.reset(new DHTNode(id)); node->setIPAddress(ipaddr); node->setPort(port); } return node; } static const BDE& getDictionary(const BDE& dict, const std::string& key) { const BDE& d = dict[key]; if(d.isDict()) { return d; } else { throw DlAbortEx (StringFormat("Malformed DHT message. Missing %s", key.c_str()).str()); } } static const BDE& getString(const BDE& dict, const std::string& key) { const BDE& c = dict[key]; if(c.isString()) { return c; } else { throw DlAbortEx (StringFormat("Malformed DHT message. Missing %s", key.c_str()).str()); } } static const BDE& getInteger(const BDE& dict, const std::string& key) { const BDE& c = dict[key]; if(c.isInteger()) { return c; } else { throw DlAbortEx (StringFormat("Malformed DHT message. Missing %s", key.c_str()).str()); } } static const BDE& getString(const BDE& list, size_t index) { const BDE& c = list[index]; if(c.isString()) { return c; } else { throw DlAbortEx (StringFormat("Malformed DHT message. element[%u] is not String.", index).str()); } } static const BDE& getInteger(const BDE& list, size_t index) { const BDE& c = list[index]; if(c.isInteger()) { return c; } else { throw DlAbortEx (StringFormat("Malformed DHT message. element[%u] is not Integer.", index).str()); } } static const BDE& getList(const BDE& dict, const std::string& key) { const BDE& l = dict[key]; if(l.isList()) { return l; } else { throw DlAbortEx (StringFormat("Malformed DHT message. Missing %s", key.c_str()).str()); } } void DHTMessageFactoryImpl::validateID(const BDE& id) const { if(id.s().size() != DHT_ID_LENGTH) { throw DlAbortEx (StringFormat("Malformed DHT message. Invalid ID length." " Expected:%d, Actual:%d", DHT_ID_LENGTH, id.s().size()).str()); } } void DHTMessageFactoryImpl::validatePort(const BDE& i) const { BDE::Integer port = i.i(); if(!(0 < port && port < UINT16_MAX)) { throw DlAbortEx (StringFormat("Malformed DHT message. Invalid port=%s", Util::itos(port).c_str()).str()); } } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createQueryMessage (const BDE& dict, const std::string& ipaddr, uint16_t port) { const BDE& messageType = getString(dict, DHTQueryMessage::Q); const BDE& transactionID = getString(dict, DHTMessage::T); const BDE& y = getString(dict, DHTMessage::Y); const BDE& aDict = getDictionary(dict, DHTQueryMessage::A); if(y.s() != DHTQueryMessage::Q) { throw DlAbortEx("Malformed DHT message. y != q"); } const BDE& id = getString(aDict, DHTMessage::ID); validateID(id); SharedHandle<DHTNode> remoteNode = getRemoteNode(id.uc(), ipaddr, port); if(messageType.s() == DHTPingMessage::PING) { return createPingMessage(remoteNode, transactionID.s()); } else if(messageType.s() == DHTFindNodeMessage::FIND_NODE) { const BDE& targetNodeID = getString(aDict, DHTFindNodeMessage::TARGET_NODE); validateID(targetNodeID); return createFindNodeMessage(remoteNode, targetNodeID.uc(), transactionID.s()); } else if(messageType.s() == DHTGetPeersMessage::GET_PEERS) { const BDE& infoHash = getString(aDict, DHTGetPeersMessage::INFO_HASH); validateID(infoHash); return createGetPeersMessage(remoteNode, infoHash.uc(), transactionID.s()); } else if(messageType.s() == DHTAnnouncePeerMessage::ANNOUNCE_PEER) { const BDE& infoHash = getString(aDict, DHTAnnouncePeerMessage::INFO_HASH); validateID(infoHash); const BDE& port = getInteger(aDict, DHTAnnouncePeerMessage::PORT); validatePort(port); const BDE& token = getString(aDict, DHTAnnouncePeerMessage::TOKEN); return createAnnouncePeerMessage(remoteNode, infoHash.uc(), static_cast<uint16_t>(port.i()), token.s(), transactionID.s()); } else { throw DlAbortEx (StringFormat("Unsupported message type: %s", messageType.s().c_str()).str()); } } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createResponseMessage(const std::string& messageType, const BDE& dict, const std::string& ipaddr, uint16_t port) { const BDE& transactionID = getString(dict, DHTMessage::T); const BDE& y = getString(dict, DHTMessage::Y); if(y.s() == DHTUnknownMessage::E) { // for now, just report error message arrived and throw exception. const BDE& e = getList(dict, DHTUnknownMessage::E); if(e.size() == 2) { _logger->info("Received Error DHT message. code=%s, msg=%s", Util::itos(getInteger(e, 0).i()).c_str(), Util::urlencode(getString(e, 1).s()).c_str()); } else { _logger->debug("e doesn't have 2 elements."); } throw DlAbortEx("Received Error DHT message."); } else if(y.s() != DHTResponseMessage::R) { throw DlAbortEx (StringFormat("Malformed DHT message. y != r: y=%s", Util::urlencode(y.s()).c_str()).str()); } const BDE& rDict = getDictionary(dict, DHTResponseMessage::R); const BDE& id = getString(rDict, DHTMessage::ID); validateID(id); SharedHandle<DHTNode> remoteNode = getRemoteNode(id.uc(), ipaddr, port); if(messageType == DHTPingReplyMessage::PING) { return createPingReplyMessage(remoteNode, id.uc(), transactionID.s()); } else if(messageType == DHTFindNodeReplyMessage::FIND_NODE) { return createFindNodeReplyMessage(remoteNode, dict, transactionID.s()); } else if(messageType == DHTGetPeersReplyMessage::GET_PEERS) { const BDE& valuesList = rDict[DHTGetPeersReplyMessage::VALUES]; if(valuesList.isList()) { return createGetPeersReplyMessageWithValues(remoteNode, dict, transactionID.s()); } else { const BDE& nodes = rDict[DHTGetPeersReplyMessage::NODES]; if(nodes.isString()) { return createGetPeersReplyMessageWithNodes(remoteNode, dict, transactionID.s()); } else { throw DlAbortEx("Malformed DHT message: missing nodes/values"); } } } else if(messageType == DHTAnnouncePeerReplyMessage::ANNOUNCE_PEER) { return createAnnouncePeerReplyMessage(remoteNode, transactionID.s()); } else { throw DlAbortEx (StringFormat("Unsupported message type: %s", messageType.c_str()).str()); } } void DHTMessageFactoryImpl::setCommonProperty(const SharedHandle<DHTAbstractMessage>& m) { m->setConnection(_connection); m->setMessageDispatcher(_dispatcher); m->setRoutingTable(_routingTable); WeakHandle<DHTMessageFactory> factory(this); m->setMessageFactory(factory); } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createPingMessage(const SharedHandle<DHTNode>& remoteNode, const std::string& transactionID) { SharedHandle<DHTPingMessage> m(new DHTPingMessage(_localNode, remoteNode, transactionID)); setCommonProperty(m); return m; } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createPingReplyMessage(const SharedHandle<DHTNode>& remoteNode, const unsigned char* id, const std::string& transactionID) { SharedHandle<DHTPingReplyMessage> m(new DHTPingReplyMessage(_localNode, remoteNode, id, transactionID)); setCommonProperty(m); return m; } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createFindNodeMessage(const SharedHandle<DHTNode>& remoteNode, const unsigned char* targetNodeID, const std::string& transactionID) { SharedHandle<DHTFindNodeMessage> m(new DHTFindNodeMessage(_localNode, remoteNode, targetNodeID, transactionID)); setCommonProperty(m); return m; } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createFindNodeReplyMessage(const SharedHandle<DHTNode>& remoteNode, const std::deque<SharedHandle<DHTNode> >& closestKNodes, const std::string& transactionID) { SharedHandle<DHTFindNodeReplyMessage> m(new DHTFindNodeReplyMessage(_localNode, remoteNode, transactionID)); m->setClosestKNodes(closestKNodes); setCommonProperty(m); return m; } std::deque<SharedHandle<DHTNode> > DHTMessageFactoryImpl::extractNodes(const unsigned char* src, size_t length) { if(length%26 != 0) { throw DlAbortEx("Nodes length is not multiple of 26"); } std::deque<SharedHandle<DHTNode> > nodes; for(size_t offset = 0; offset < length; offset += 26) { SharedHandle<DHTNode> node(new DHTNode(src+offset)); std::pair<std::string, uint16_t> addr = PeerMessageUtil::unpackcompact(src+offset+DHT_ID_LENGTH); if(addr.first.empty()) { continue; } node->setIPAddress(addr.first); node->setPort(addr.second); nodes.push_back(node); } return nodes; } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createFindNodeReplyMessage (const SharedHandle<DHTNode>& remoteNode, const BDE& dict, const std::string& transactionID) { const BDE& nodesData = getString(getDictionary(dict, DHTResponseMessage::R), DHTFindNodeReplyMessage::NODES); std::deque<SharedHandle<DHTNode> > nodes = extractNodes(nodesData.uc(), nodesData.s().size()); return createFindNodeReplyMessage(remoteNode, nodes, transactionID); } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createGetPeersMessage(const SharedHandle<DHTNode>& remoteNode, const unsigned char* infoHash, const std::string& transactionID) { SharedHandle<DHTGetPeersMessage> m(new DHTGetPeersMessage(_localNode, remoteNode, infoHash, transactionID)); m->setPeerAnnounceStorage(_peerAnnounceStorage); m->setTokenTracker(_tokenTracker); setCommonProperty(m); return m; } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createGetPeersReplyMessageWithNodes (const SharedHandle<DHTNode>& remoteNode, const BDE& dict, const std::string& transactionID) { const BDE& rDict = getDictionary(dict, DHTResponseMessage::R); const BDE& nodesData = getString(rDict, DHTGetPeersReplyMessage::NODES); std::deque<SharedHandle<DHTNode> > nodes = extractNodes(nodesData.uc(), nodesData.s().size()); const BDE& token = getString(rDict, DHTGetPeersReplyMessage::TOKEN); return createGetPeersReplyMessage(remoteNode, nodes, token.s(), transactionID); } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createGetPeersReplyMessage(const SharedHandle<DHTNode>& remoteNode, const std::deque<SharedHandle<DHTNode> >& closestKNodes, const std::string& token, const std::string& transactionID) { SharedHandle<DHTGetPeersReplyMessage> m (new DHTGetPeersReplyMessage(_localNode, remoteNode, token, transactionID)); m->setClosestKNodes(closestKNodes); setCommonProperty(m); return m; } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createGetPeersReplyMessageWithValues (const SharedHandle<DHTNode>& remoteNode, const BDE& dict, const std::string& transactionID) { const BDE& rDict = getDictionary(dict, DHTResponseMessage::R); const BDE& valuesList = getList(rDict, DHTGetPeersReplyMessage::VALUES); std::deque<SharedHandle<Peer> > peers; for(BDE::List::const_iterator i = valuesList.listBegin(); i != valuesList.listEnd(); ++i) { const BDE& data = *i; if(data.isString() && data.s().size() == 6) { std::pair<std::string, uint16_t> addr = PeerMessageUtil::unpackcompact(data.uc()); PeerHandle peer(new Peer(addr.first, addr.second)); peers.push_back(peer); } } const BDE& token = getString(rDict, DHTGetPeersReplyMessage::TOKEN); return createGetPeersReplyMessage(remoteNode, peers, token.s(), transactionID); } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createGetPeersReplyMessage(const SharedHandle<DHTNode>& remoteNode, const std::deque<SharedHandle<Peer> >& values, const std::string& token, const std::string& transactionID) { SharedHandle<DHTGetPeersReplyMessage> m(new DHTGetPeersReplyMessage(_localNode, remoteNode, token, transactionID)); m->setValues(values); setCommonProperty(m); return m; } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createAnnouncePeerMessage(const SharedHandle<DHTNode>& remoteNode, const unsigned char* infoHash, uint16_t tcpPort, const std::string& token, const std::string& transactionID) { SharedHandle<DHTAnnouncePeerMessage> m (new DHTAnnouncePeerMessage(_localNode, remoteNode, infoHash, tcpPort, token, transactionID)); m->setPeerAnnounceStorage(_peerAnnounceStorage); m->setTokenTracker(_tokenTracker); setCommonProperty(m); return m; } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createAnnouncePeerReplyMessage(const SharedHandle<DHTNode>& remoteNode, const std::string& transactionID) { SharedHandle<DHTAnnouncePeerReplyMessage> m (new DHTAnnouncePeerReplyMessage(_localNode, remoteNode, transactionID)); setCommonProperty(m); return m; } SharedHandle<DHTMessage> DHTMessageFactoryImpl::createUnknownMessage(const unsigned char* data, size_t length, const std::string& ipaddr, uint16_t port) { SharedHandle<DHTUnknownMessage> m (new DHTUnknownMessage(_localNode, data, length, ipaddr, port)); return m; } void DHTMessageFactoryImpl::setRoutingTable(const WeakHandle<DHTRoutingTable>& routingTable) { _routingTable = routingTable; } void DHTMessageFactoryImpl::setConnection(const WeakHandle<DHTConnection>& connection) { _connection = connection; } void DHTMessageFactoryImpl::setMessageDispatcher(const WeakHandle<DHTMessageDispatcher>& dispatcher) { _dispatcher = dispatcher; } void DHTMessageFactoryImpl::setPeerAnnounceStorage(const WeakHandle<DHTPeerAnnounceStorage>& storage) { _peerAnnounceStorage = storage; } void DHTMessageFactoryImpl::setTokenTracker(const WeakHandle<DHTTokenTracker>& tokenTracker) { _tokenTracker = tokenTracker; } void DHTMessageFactoryImpl::setLocalNode(const SharedHandle<DHTNode>& localNode) { _localNode = localNode; } } // namespace aria2