aria2/src/DHTMessageFactoryImpl.cc

543 lines
17 KiB
C++

/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2006 Tatsuhiro Tsujikawa
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#include "DHTMessageFactoryImpl.h"
#include <cstring>
#include <utility>
#include "LogFactory.h"
#include "DlAbortEx.h"
#include "DHTNode.h"
#include "DHTRoutingTable.h"
#include "DHTPingMessage.h"
#include "DHTPingReplyMessage.h"
#include "DHTFindNodeMessage.h"
#include "DHTFindNodeReplyMessage.h"
#include "DHTGetPeersMessage.h"
#include "DHTGetPeersReplyMessage.h"
#include "DHTAnnouncePeerMessage.h"
#include "DHTAnnouncePeerReplyMessage.h"
#include "DHTUnknownMessage.h"
#include "DHTConnection.h"
#include "DHTMessageDispatcher.h"
#include "DHTPeerAnnounceStorage.h"
#include "DHTTokenTracker.h"
#include "DHTMessageCallback.h"
#include "bittorrent_helper.h"
#include "BtRuntime.h"
#include "util.h"
#include "Peer.h"
#include "Logger.h"
#include "fmt.h"
namespace aria2 {
DHTMessageFactoryImpl::DHTMessageFactoryImpl(int family)
: family_(family),
connection_(0),
dispatcher_(0),
routingTable_(0),
peerAnnounceStorage_(0),
tokenTracker_(0)
{}
DHTMessageFactoryImpl::~DHTMessageFactoryImpl() {}
SharedHandle<DHTNode>
DHTMessageFactoryImpl::getRemoteNode
(const unsigned char* id, const std::string& ipaddr, uint16_t port) const
{
SharedHandle<DHTNode> node = routingTable_->getNode(id, ipaddr, port);
if(!node) {
node.reset(new DHTNode(id));
node->setIPAddress(ipaddr);
node->setPort(port);
}
return node;
}
namespace {
const Dict* getDictionary(const Dict* dict, const std::string& key)
{
const Dict* d = downcast<Dict>(dict->get(key));
if(d) {
return d;
} else {
throw DL_ABORT_EX
(fmt("Malformed DHT message. Missing %s", key.c_str()));
}
}
} // namespace
namespace {
const String* getString(const Dict* dict, const std::string& key)
{
const String* c = downcast<String>(dict->get(key));
if(c) {
return c;
} else {
throw DL_ABORT_EX
(fmt("Malformed DHT message. Missing %s", key.c_str()));
}
}
} // namespace
namespace {
const Integer* getInteger(const Dict* dict, const std::string& key)
{
const Integer* c = downcast<Integer>(dict->get(key));
if(c) {
return c;
} else {
throw DL_ABORT_EX
(fmt("Malformed DHT message. Missing %s", key.c_str()));
}
}
} // namespace
namespace {
const String* getString(const List* list, size_t index)
{
const String* c = downcast<String>(list->get(index));
if(c) {
return c;
} else {
throw DL_ABORT_EX
(fmt("Malformed DHT message. element[%lu] is not String.",
static_cast<unsigned long>(index)));
}
}
} // namespace
namespace {
const Integer* getInteger(const List* list, size_t index)
{
const Integer* c = downcast<Integer>(list->get(index));
if(c) {
return c;
} else {
throw DL_ABORT_EX
(fmt("Malformed DHT message. element[%lu] is not Integer.",
static_cast<unsigned long>(index)));
}
}
} // namespace
namespace {
const List* getList(const Dict* dict, const std::string& key)
{
const List* l = downcast<List>(dict->get(key));
if(l) {
return l;
} else {
throw DL_ABORT_EX
(fmt("Malformed DHT message. Missing %s", key.c_str()));
}
}
} // namespace
void DHTMessageFactoryImpl::validateID(const String* id) const
{
if(id->s().size() != DHT_ID_LENGTH) {
throw DL_ABORT_EX
(fmt("Malformed DHT message. Invalid ID length."
" Expected:%lu, Actual:%lu",
static_cast<unsigned long>(DHT_ID_LENGTH),
static_cast<unsigned long>(id->s().size())));
}
}
void DHTMessageFactoryImpl::validatePort(const Integer* port) const
{
if(!(0 < port->i() && port->i() < UINT16_MAX)) {
throw DL_ABORT_EX
(fmt("Malformed DHT message. Invalid port=%lld",
static_cast<long long int>(port->i())));
}
}
namespace {
void setVersion(const SharedHandle<DHTMessage>& msg, const Dict* dict)
{
const String* v = downcast<String>(dict->get(DHTMessage::V));
if(v) {
msg->setVersion(v->s());
} else {
msg->setVersion(A2STR::NIL);
}
}
} // namespace
SharedHandle<DHTQueryMessage> DHTMessageFactoryImpl::createQueryMessage
(const Dict* dict, const std::string& ipaddr, uint16_t port)
{
const String* messageType = getString(dict, DHTQueryMessage::Q);
const String* transactionID = getString(dict, DHTMessage::T);
const String* y = getString(dict, DHTMessage::Y);
const Dict* aDict = getDictionary(dict, DHTQueryMessage::A);
if(y->s() != DHTQueryMessage::Q) {
throw DL_ABORT_EX("Malformed DHT message. y != q");
}
const String* id = getString(aDict, DHTMessage::ID);
validateID(id);
SharedHandle<DHTNode> remoteNode = getRemoteNode(id->uc(), ipaddr, port);
SharedHandle<DHTQueryMessage> msg;
if(messageType->s() == DHTPingMessage::PING) {
msg = createPingMessage(remoteNode, transactionID->s());
} else if(messageType->s() == DHTFindNodeMessage::FIND_NODE) {
const String* targetNodeID =
getString(aDict, DHTFindNodeMessage::TARGET_NODE);
validateID(targetNodeID);
msg = createFindNodeMessage(remoteNode, targetNodeID->uc(),
transactionID->s());
} else if(messageType->s() == DHTGetPeersMessage::GET_PEERS) {
const String* infoHash = getString(aDict, DHTGetPeersMessage::INFO_HASH);
validateID(infoHash);
msg = createGetPeersMessage(remoteNode, infoHash->uc(), transactionID->s());
} else if(messageType->s() == DHTAnnouncePeerMessage::ANNOUNCE_PEER) {
const String* infoHash = getString(aDict,DHTAnnouncePeerMessage::INFO_HASH);
validateID(infoHash);
const Integer* port = getInteger(aDict, DHTAnnouncePeerMessage::PORT);
validatePort(port);
const String* token = getString(aDict, DHTAnnouncePeerMessage::TOKEN);
msg = createAnnouncePeerMessage(remoteNode, infoHash->uc(),
static_cast<uint16_t>(port->i()),
token->s(), transactionID->s());
} else {
throw DL_ABORT_EX(fmt("Unsupported message type: %s",
messageType->s().c_str()));
}
setVersion(msg, dict);
return msg;
}
SharedHandle<DHTResponseMessage>
DHTMessageFactoryImpl::createResponseMessage
(const std::string& messageType,
const Dict* dict,
const std::string& ipaddr,
uint16_t port)
{
const String* transactionID = getString(dict, DHTMessage::T);
const String* y = getString(dict, DHTMessage::Y);
if(y->s() == DHTUnknownMessage::E) {
// for now, just report error message arrived and throw exception.
const List* e = getList(dict, DHTUnknownMessage::E);
if(e->size() == 2) {
A2_LOG_INFO(fmt("Received Error DHT message. code=%lld, msg=%s",
static_cast<long long int>(getInteger(e, 0)->i()),
util::percentEncode(getString(e, 1)->s()).c_str()));
} else {
A2_LOG_DEBUG("e doesn't have 2 elements.");
}
throw DL_ABORT_EX("Received Error DHT message.");
} else if(y->s() != DHTResponseMessage::R) {
throw DL_ABORT_EX
(fmt("Malformed DHT message. y != r: y=%s",
util::percentEncode(y->s()).c_str()));
}
const Dict* rDict = getDictionary(dict, DHTResponseMessage::R);
const String* id = getString(rDict, DHTMessage::ID);
validateID(id);
SharedHandle<DHTNode> remoteNode = getRemoteNode(id->uc(), ipaddr, port);
SharedHandle<DHTResponseMessage> msg;
if(messageType == DHTPingReplyMessage::PING) {
msg = createPingReplyMessage(remoteNode, id->uc(), transactionID->s());
} else if(messageType == DHTFindNodeReplyMessage::FIND_NODE) {
msg = createFindNodeReplyMessage(remoteNode, dict, transactionID->s());
} else if(messageType == DHTGetPeersReplyMessage::GET_PEERS) {
msg = createGetPeersReplyMessage(remoteNode, dict, transactionID->s());
} else if(messageType == DHTAnnouncePeerReplyMessage::ANNOUNCE_PEER) {
msg = createAnnouncePeerReplyMessage(remoteNode, transactionID->s());
} else {
throw DL_ABORT_EX
(fmt("Unsupported message type: %s", messageType.c_str()));
}
setVersion(msg, dict);
return msg;
}
namespace {
const std::string& getDefaultVersion()
{
static std::string version;
if(version.empty()) {
uint16_t vnum16 = htons(DHT_VERSION);
unsigned char buf[] = { 'A' , '2', 0, 0 };
char* vnump = reinterpret_cast<char*>(&vnum16);
memcpy(buf+2, vnump, 2);
version.assign(&buf[0], &buf[4]);
}
return version;
}
} // namespace
void DHTMessageFactoryImpl::setCommonProperty
(const SharedHandle<DHTAbstractMessage>& m)
{
m->setConnection(connection_);
m->setMessageDispatcher(dispatcher_);
m->setRoutingTable(routingTable_);
m->setMessageFactory(this);
m->setVersion(getDefaultVersion());
}
SharedHandle<DHTQueryMessage> DHTMessageFactoryImpl::createPingMessage
(const SharedHandle<DHTNode>& remoteNode, const std::string& transactionID)
{
SharedHandle<DHTPingMessage> m
(new DHTPingMessage(localNode_, remoteNode, transactionID));
setCommonProperty(m);
return m;
}
SharedHandle<DHTResponseMessage> DHTMessageFactoryImpl::createPingReplyMessage
(const SharedHandle<DHTNode>& remoteNode,
const unsigned char* id,
const std::string& transactionID)
{
SharedHandle<DHTPingReplyMessage> m
(new DHTPingReplyMessage(localNode_, remoteNode, id, transactionID));
setCommonProperty(m);
return m;
}
SharedHandle<DHTQueryMessage> DHTMessageFactoryImpl::createFindNodeMessage
(const SharedHandle<DHTNode>& remoteNode,
const unsigned char* targetNodeID,
const std::string& transactionID)
{
SharedHandle<DHTFindNodeMessage> m
(new DHTFindNodeMessage
(localNode_, remoteNode, targetNodeID, transactionID));
setCommonProperty(m);
return m;
}
SharedHandle<DHTResponseMessage>
DHTMessageFactoryImpl::createFindNodeReplyMessage
(const SharedHandle<DHTNode>& remoteNode,
const std::vector<SharedHandle<DHTNode> >& closestKNodes,
const std::string& transactionID)
{
SharedHandle<DHTFindNodeReplyMessage> m
(new DHTFindNodeReplyMessage
(family_, localNode_, remoteNode, transactionID));
m->setClosestKNodes(closestKNodes);
setCommonProperty(m);
return m;
}
void DHTMessageFactoryImpl::extractNodes
(std::vector<SharedHandle<DHTNode> >& nodes,
const unsigned char* src, size_t length)
{
int unit = bittorrent::getCompactLength(family_)+20;
if(length%unit != 0) {
throw DL_ABORT_EX
(fmt("Nodes length is not multiple of %d", unit));
}
for(size_t offset = 0; offset < length; offset += unit) {
SharedHandle<DHTNode> node(new DHTNode(src+offset));
std::pair<std::string, uint16_t> addr =
bittorrent::unpackcompact(src+offset+DHT_ID_LENGTH, family_);
if(addr.first.empty()) {
continue;
}
node->setIPAddress(addr.first);
node->setPort(addr.second);
nodes.push_back(node);
}
}
SharedHandle<DHTResponseMessage>
DHTMessageFactoryImpl::createFindNodeReplyMessage
(const SharedHandle<DHTNode>& remoteNode,
const Dict* dict,
const std::string& transactionID)
{
const String* nodesData =
downcast<String>(getDictionary(dict, DHTResponseMessage::R)->
get(family_ == AF_INET?DHTFindNodeReplyMessage::NODES:
DHTFindNodeReplyMessage::NODES6));
std::vector<SharedHandle<DHTNode> > nodes;
if(nodesData) {
extractNodes(nodes, nodesData->uc(), nodesData->s().size());
}
return createFindNodeReplyMessage(remoteNode, nodes, transactionID);
}
SharedHandle<DHTQueryMessage>
DHTMessageFactoryImpl::createGetPeersMessage
(const SharedHandle<DHTNode>& remoteNode,
const unsigned char* infoHash,
const std::string& transactionID)
{
SharedHandle<DHTGetPeersMessage> m
(new DHTGetPeersMessage(localNode_, remoteNode, infoHash, transactionID));
m->setPeerAnnounceStorage(peerAnnounceStorage_);
m->setTokenTracker(tokenTracker_);
setCommonProperty(m);
return m;
}
SharedHandle<DHTResponseMessage>
DHTMessageFactoryImpl::createGetPeersReplyMessage
(const SharedHandle<DHTNode>& remoteNode,
const Dict* dict,
const std::string& transactionID)
{
const Dict* rDict = getDictionary(dict, DHTResponseMessage::R);
const String* nodesData =
downcast<String>(rDict->get(family_ == AF_INET?DHTGetPeersReplyMessage::NODES:
DHTGetPeersReplyMessage::NODES6));
std::vector<SharedHandle<DHTNode> > nodes;
if(nodesData) {
extractNodes(nodes, nodesData->uc(), nodesData->s().size());
}
const List* valuesList =
downcast<List>(rDict->get(DHTGetPeersReplyMessage::VALUES));
std::vector<SharedHandle<Peer> > peers;
size_t clen = bittorrent::getCompactLength(family_);
if(valuesList) {
for(List::ValueType::const_iterator i = valuesList->begin(),
eoi = valuesList->end(); i != eoi; ++i) {
const String* data = downcast<String>(*i);
if(data && data->s().size() == clen) {
std::pair<std::string, uint16_t> addr =
bittorrent::unpackcompact(data->uc(), family_);
if(addr.first.empty()) {
continue;
}
SharedHandle<Peer> peer(new Peer(addr.first, addr.second));
peers.push_back(peer);
}
}
}
const String* token = getString(rDict, DHTGetPeersReplyMessage::TOKEN);
return createGetPeersReplyMessage
(remoteNode, nodes, peers, token->s(), transactionID);
}
SharedHandle<DHTResponseMessage>
DHTMessageFactoryImpl::createGetPeersReplyMessage
(const SharedHandle<DHTNode>& remoteNode,
const std::vector<SharedHandle<DHTNode> >& closestKNodes,
const std::vector<SharedHandle<Peer> >& values,
const std::string& token,
const std::string& transactionID)
{
SharedHandle<DHTGetPeersReplyMessage> m
(new DHTGetPeersReplyMessage
(family_, localNode_, remoteNode, token, transactionID));
m->setClosestKNodes(closestKNodes);
m->setValues(values);
setCommonProperty(m);
return m;
}
SharedHandle<DHTQueryMessage>
DHTMessageFactoryImpl::createAnnouncePeerMessage
(const SharedHandle<DHTNode>& remoteNode,
const unsigned char* infoHash,
uint16_t tcpPort,
const std::string& token,
const std::string& transactionID)
{
SharedHandle<DHTAnnouncePeerMessage> m
(new DHTAnnouncePeerMessage
(localNode_, remoteNode, infoHash, tcpPort, token, transactionID));
m->setPeerAnnounceStorage(peerAnnounceStorage_);
m->setTokenTracker(tokenTracker_);
setCommonProperty(m);
return m;
}
SharedHandle<DHTResponseMessage>
DHTMessageFactoryImpl::createAnnouncePeerReplyMessage
(const SharedHandle<DHTNode>& remoteNode, const std::string& transactionID)
{
SharedHandle<DHTAnnouncePeerReplyMessage> m
(new DHTAnnouncePeerReplyMessage(localNode_, remoteNode, transactionID));
setCommonProperty(m);
return m;
}
SharedHandle<DHTMessage>
DHTMessageFactoryImpl::createUnknownMessage
(const unsigned char* data, size_t length,
const std::string& ipaddr, uint16_t port)
{
SharedHandle<DHTUnknownMessage> m
(new DHTUnknownMessage(localNode_, data, length, ipaddr, port));
return m;
}
void DHTMessageFactoryImpl::setRoutingTable(DHTRoutingTable* routingTable)
{
routingTable_ = routingTable;
}
void DHTMessageFactoryImpl::setConnection(DHTConnection* connection)
{
connection_ = connection;
}
void DHTMessageFactoryImpl::setMessageDispatcher
(DHTMessageDispatcher* dispatcher)
{
dispatcher_ = dispatcher;
}
void DHTMessageFactoryImpl::setPeerAnnounceStorage
(DHTPeerAnnounceStorage* storage)
{
peerAnnounceStorage_ = storage;
}
void DHTMessageFactoryImpl::setTokenTracker(DHTTokenTracker* tokenTracker)
{
tokenTracker_ = tokenTracker;
}
void DHTMessageFactoryImpl::setLocalNode
(const SharedHandle<DHTNode>& localNode)
{
localNode_ = localNode;
}
} // namespace aria2