mirror of https://github.com/aria2/aria2
543 lines
17 KiB
C++
543 lines
17 KiB
C++
/* <!-- copyright */
|
|
/*
|
|
* aria2 - The high speed download utility
|
|
*
|
|
* Copyright (C) 2006 Tatsuhiro Tsujikawa
|
|
*
|
|
* This program is free software; you can redistribute it and/or modify
|
|
* it under the terms of the GNU General Public License as published by
|
|
* the Free Software Foundation; either version 2 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* This program is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU General Public License
|
|
* along with this program; if not, write to the Free Software
|
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
|
*
|
|
* In addition, as a special exception, the copyright holders give
|
|
* permission to link the code of portions of this program with the
|
|
* OpenSSL library under certain conditions as described in each
|
|
* individual source file, and distribute linked combinations
|
|
* including the two.
|
|
* You must obey the GNU General Public License in all respects
|
|
* for all of the code used other than OpenSSL. If you modify
|
|
* file(s) with this exception, you may extend this exception to your
|
|
* version of the file(s), but you are not obligated to do so. If you
|
|
* do not wish to do so, delete this exception statement from your
|
|
* version. If you delete this exception statement from all source
|
|
* files in the program, then also delete it here.
|
|
*/
|
|
/* copyright --> */
|
|
#include "DHTMessageFactoryImpl.h"
|
|
|
|
#include <cstring>
|
|
#include <utility>
|
|
|
|
#include "LogFactory.h"
|
|
#include "DlAbortEx.h"
|
|
#include "DHTNode.h"
|
|
#include "DHTRoutingTable.h"
|
|
#include "DHTPingMessage.h"
|
|
#include "DHTPingReplyMessage.h"
|
|
#include "DHTFindNodeMessage.h"
|
|
#include "DHTFindNodeReplyMessage.h"
|
|
#include "DHTGetPeersMessage.h"
|
|
#include "DHTGetPeersReplyMessage.h"
|
|
#include "DHTAnnouncePeerMessage.h"
|
|
#include "DHTAnnouncePeerReplyMessage.h"
|
|
#include "DHTUnknownMessage.h"
|
|
#include "DHTConnection.h"
|
|
#include "DHTMessageDispatcher.h"
|
|
#include "DHTPeerAnnounceStorage.h"
|
|
#include "DHTTokenTracker.h"
|
|
#include "DHTMessageCallback.h"
|
|
#include "bittorrent_helper.h"
|
|
#include "BtRuntime.h"
|
|
#include "util.h"
|
|
#include "Peer.h"
|
|
#include "Logger.h"
|
|
#include "fmt.h"
|
|
|
|
namespace aria2 {
|
|
|
|
DHTMessageFactoryImpl::DHTMessageFactoryImpl(int family)
|
|
: family_(family),
|
|
connection_(0),
|
|
dispatcher_(0),
|
|
routingTable_(0),
|
|
peerAnnounceStorage_(0),
|
|
tokenTracker_(0)
|
|
{}
|
|
|
|
DHTMessageFactoryImpl::~DHTMessageFactoryImpl() {}
|
|
|
|
SharedHandle<DHTNode>
|
|
DHTMessageFactoryImpl::getRemoteNode
|
|
(const unsigned char* id, const std::string& ipaddr, uint16_t port) const
|
|
{
|
|
SharedHandle<DHTNode> node = routingTable_->getNode(id, ipaddr, port);
|
|
if(!node) {
|
|
node.reset(new DHTNode(id));
|
|
node->setIPAddress(ipaddr);
|
|
node->setPort(port);
|
|
}
|
|
return node;
|
|
}
|
|
|
|
namespace {
|
|
const Dict* getDictionary(const Dict* dict, const std::string& key)
|
|
{
|
|
const Dict* d = downcast<Dict>(dict->get(key));
|
|
if(d) {
|
|
return d;
|
|
} else {
|
|
throw DL_ABORT_EX
|
|
(fmt("Malformed DHT message. Missing %s", key.c_str()));
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
namespace {
|
|
const String* getString(const Dict* dict, const std::string& key)
|
|
{
|
|
const String* c = downcast<String>(dict->get(key));
|
|
if(c) {
|
|
return c;
|
|
} else {
|
|
throw DL_ABORT_EX
|
|
(fmt("Malformed DHT message. Missing %s", key.c_str()));
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
namespace {
|
|
const Integer* getInteger(const Dict* dict, const std::string& key)
|
|
{
|
|
const Integer* c = downcast<Integer>(dict->get(key));
|
|
if(c) {
|
|
return c;
|
|
} else {
|
|
throw DL_ABORT_EX
|
|
(fmt("Malformed DHT message. Missing %s", key.c_str()));
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
namespace {
|
|
const String* getString(const List* list, size_t index)
|
|
{
|
|
const String* c = downcast<String>(list->get(index));
|
|
if(c) {
|
|
return c;
|
|
} else {
|
|
throw DL_ABORT_EX
|
|
(fmt("Malformed DHT message. element[%lu] is not String.",
|
|
static_cast<unsigned long>(index)));
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
namespace {
|
|
const Integer* getInteger(const List* list, size_t index)
|
|
{
|
|
const Integer* c = downcast<Integer>(list->get(index));
|
|
if(c) {
|
|
return c;
|
|
} else {
|
|
throw DL_ABORT_EX
|
|
(fmt("Malformed DHT message. element[%lu] is not Integer.",
|
|
static_cast<unsigned long>(index)));
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
namespace {
|
|
const List* getList(const Dict* dict, const std::string& key)
|
|
{
|
|
const List* l = downcast<List>(dict->get(key));
|
|
if(l) {
|
|
return l;
|
|
} else {
|
|
throw DL_ABORT_EX
|
|
(fmt("Malformed DHT message. Missing %s", key.c_str()));
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
void DHTMessageFactoryImpl::validateID(const String* id) const
|
|
{
|
|
if(id->s().size() != DHT_ID_LENGTH) {
|
|
throw DL_ABORT_EX
|
|
(fmt("Malformed DHT message. Invalid ID length."
|
|
" Expected:%lu, Actual:%lu",
|
|
static_cast<unsigned long>(DHT_ID_LENGTH),
|
|
static_cast<unsigned long>(id->s().size())));
|
|
}
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::validatePort(const Integer* port) const
|
|
{
|
|
if(!(0 < port->i() && port->i() < UINT16_MAX)) {
|
|
throw DL_ABORT_EX
|
|
(fmt("Malformed DHT message. Invalid port=%lld",
|
|
static_cast<long long int>(port->i())));
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
void setVersion(const SharedHandle<DHTMessage>& msg, const Dict* dict)
|
|
{
|
|
const String* v = downcast<String>(dict->get(DHTMessage::V));
|
|
if(v) {
|
|
msg->setVersion(v->s());
|
|
} else {
|
|
msg->setVersion(A2STR::NIL);
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
SharedHandle<DHTQueryMessage> DHTMessageFactoryImpl::createQueryMessage
|
|
(const Dict* dict, const std::string& ipaddr, uint16_t port)
|
|
{
|
|
const String* messageType = getString(dict, DHTQueryMessage::Q);
|
|
const String* transactionID = getString(dict, DHTMessage::T);
|
|
const String* y = getString(dict, DHTMessage::Y);
|
|
const Dict* aDict = getDictionary(dict, DHTQueryMessage::A);
|
|
if(y->s() != DHTQueryMessage::Q) {
|
|
throw DL_ABORT_EX("Malformed DHT message. y != q");
|
|
}
|
|
const String* id = getString(aDict, DHTMessage::ID);
|
|
validateID(id);
|
|
SharedHandle<DHTNode> remoteNode = getRemoteNode(id->uc(), ipaddr, port);
|
|
SharedHandle<DHTQueryMessage> msg;
|
|
if(messageType->s() == DHTPingMessage::PING) {
|
|
msg = createPingMessage(remoteNode, transactionID->s());
|
|
} else if(messageType->s() == DHTFindNodeMessage::FIND_NODE) {
|
|
const String* targetNodeID =
|
|
getString(aDict, DHTFindNodeMessage::TARGET_NODE);
|
|
validateID(targetNodeID);
|
|
msg = createFindNodeMessage(remoteNode, targetNodeID->uc(),
|
|
transactionID->s());
|
|
} else if(messageType->s() == DHTGetPeersMessage::GET_PEERS) {
|
|
const String* infoHash = getString(aDict, DHTGetPeersMessage::INFO_HASH);
|
|
validateID(infoHash);
|
|
msg = createGetPeersMessage(remoteNode, infoHash->uc(), transactionID->s());
|
|
} else if(messageType->s() == DHTAnnouncePeerMessage::ANNOUNCE_PEER) {
|
|
const String* infoHash = getString(aDict,DHTAnnouncePeerMessage::INFO_HASH);
|
|
validateID(infoHash);
|
|
const Integer* port = getInteger(aDict, DHTAnnouncePeerMessage::PORT);
|
|
validatePort(port);
|
|
const String* token = getString(aDict, DHTAnnouncePeerMessage::TOKEN);
|
|
msg = createAnnouncePeerMessage(remoteNode, infoHash->uc(),
|
|
static_cast<uint16_t>(port->i()),
|
|
token->s(), transactionID->s());
|
|
} else {
|
|
throw DL_ABORT_EX(fmt("Unsupported message type: %s",
|
|
messageType->s().c_str()));
|
|
}
|
|
setVersion(msg, dict);
|
|
return msg;
|
|
}
|
|
|
|
SharedHandle<DHTResponseMessage>
|
|
DHTMessageFactoryImpl::createResponseMessage
|
|
(const std::string& messageType,
|
|
const Dict* dict,
|
|
const std::string& ipaddr,
|
|
uint16_t port)
|
|
{
|
|
const String* transactionID = getString(dict, DHTMessage::T);
|
|
const String* y = getString(dict, DHTMessage::Y);
|
|
if(y->s() == DHTUnknownMessage::E) {
|
|
// for now, just report error message arrived and throw exception.
|
|
const List* e = getList(dict, DHTUnknownMessage::E);
|
|
if(e->size() == 2) {
|
|
A2_LOG_INFO(fmt("Received Error DHT message. code=%lld, msg=%s",
|
|
static_cast<long long int>(getInteger(e, 0)->i()),
|
|
util::percentEncode(getString(e, 1)->s()).c_str()));
|
|
} else {
|
|
A2_LOG_DEBUG("e doesn't have 2 elements.");
|
|
}
|
|
throw DL_ABORT_EX("Received Error DHT message.");
|
|
} else if(y->s() != DHTResponseMessage::R) {
|
|
throw DL_ABORT_EX
|
|
(fmt("Malformed DHT message. y != r: y=%s",
|
|
util::percentEncode(y->s()).c_str()));
|
|
}
|
|
const Dict* rDict = getDictionary(dict, DHTResponseMessage::R);
|
|
const String* id = getString(rDict, DHTMessage::ID);
|
|
validateID(id);
|
|
SharedHandle<DHTNode> remoteNode = getRemoteNode(id->uc(), ipaddr, port);
|
|
SharedHandle<DHTResponseMessage> msg;
|
|
if(messageType == DHTPingReplyMessage::PING) {
|
|
msg = createPingReplyMessage(remoteNode, id->uc(), transactionID->s());
|
|
} else if(messageType == DHTFindNodeReplyMessage::FIND_NODE) {
|
|
msg = createFindNodeReplyMessage(remoteNode, dict, transactionID->s());
|
|
} else if(messageType == DHTGetPeersReplyMessage::GET_PEERS) {
|
|
msg = createGetPeersReplyMessage(remoteNode, dict, transactionID->s());
|
|
} else if(messageType == DHTAnnouncePeerReplyMessage::ANNOUNCE_PEER) {
|
|
msg = createAnnouncePeerReplyMessage(remoteNode, transactionID->s());
|
|
} else {
|
|
throw DL_ABORT_EX
|
|
(fmt("Unsupported message type: %s", messageType.c_str()));
|
|
}
|
|
setVersion(msg, dict);
|
|
return msg;
|
|
}
|
|
|
|
namespace {
|
|
const std::string& getDefaultVersion()
|
|
{
|
|
static std::string version;
|
|
if(version.empty()) {
|
|
uint16_t vnum16 = htons(DHT_VERSION);
|
|
unsigned char buf[] = { 'A' , '2', 0, 0 };
|
|
char* vnump = reinterpret_cast<char*>(&vnum16);
|
|
memcpy(buf+2, vnump, 2);
|
|
version.assign(&buf[0], &buf[4]);
|
|
}
|
|
return version;
|
|
}
|
|
} // namespace
|
|
|
|
void DHTMessageFactoryImpl::setCommonProperty
|
|
(const SharedHandle<DHTAbstractMessage>& m)
|
|
{
|
|
m->setConnection(connection_);
|
|
m->setMessageDispatcher(dispatcher_);
|
|
m->setRoutingTable(routingTable_);
|
|
m->setMessageFactory(this);
|
|
m->setVersion(getDefaultVersion());
|
|
}
|
|
|
|
SharedHandle<DHTQueryMessage> DHTMessageFactoryImpl::createPingMessage
|
|
(const SharedHandle<DHTNode>& remoteNode, const std::string& transactionID)
|
|
{
|
|
SharedHandle<DHTPingMessage> m
|
|
(new DHTPingMessage(localNode_, remoteNode, transactionID));
|
|
setCommonProperty(m);
|
|
return m;
|
|
}
|
|
|
|
SharedHandle<DHTResponseMessage> DHTMessageFactoryImpl::createPingReplyMessage
|
|
(const SharedHandle<DHTNode>& remoteNode,
|
|
const unsigned char* id,
|
|
const std::string& transactionID)
|
|
{
|
|
SharedHandle<DHTPingReplyMessage> m
|
|
(new DHTPingReplyMessage(localNode_, remoteNode, id, transactionID));
|
|
setCommonProperty(m);
|
|
return m;
|
|
}
|
|
|
|
SharedHandle<DHTQueryMessage> DHTMessageFactoryImpl::createFindNodeMessage
|
|
(const SharedHandle<DHTNode>& remoteNode,
|
|
const unsigned char* targetNodeID,
|
|
const std::string& transactionID)
|
|
{
|
|
SharedHandle<DHTFindNodeMessage> m
|
|
(new DHTFindNodeMessage
|
|
(localNode_, remoteNode, targetNodeID, transactionID));
|
|
setCommonProperty(m);
|
|
return m;
|
|
}
|
|
|
|
SharedHandle<DHTResponseMessage>
|
|
DHTMessageFactoryImpl::createFindNodeReplyMessage
|
|
(const SharedHandle<DHTNode>& remoteNode,
|
|
const std::vector<SharedHandle<DHTNode> >& closestKNodes,
|
|
const std::string& transactionID)
|
|
{
|
|
SharedHandle<DHTFindNodeReplyMessage> m
|
|
(new DHTFindNodeReplyMessage
|
|
(family_, localNode_, remoteNode, transactionID));
|
|
m->setClosestKNodes(closestKNodes);
|
|
setCommonProperty(m);
|
|
return m;
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::extractNodes
|
|
(std::vector<SharedHandle<DHTNode> >& nodes,
|
|
const unsigned char* src, size_t length)
|
|
{
|
|
int unit = bittorrent::getCompactLength(family_)+20;
|
|
if(length%unit != 0) {
|
|
throw DL_ABORT_EX
|
|
(fmt("Nodes length is not multiple of %d", unit));
|
|
}
|
|
for(size_t offset = 0; offset < length; offset += unit) {
|
|
SharedHandle<DHTNode> node(new DHTNode(src+offset));
|
|
std::pair<std::string, uint16_t> addr =
|
|
bittorrent::unpackcompact(src+offset+DHT_ID_LENGTH, family_);
|
|
if(addr.first.empty()) {
|
|
continue;
|
|
}
|
|
node->setIPAddress(addr.first);
|
|
node->setPort(addr.second);
|
|
nodes.push_back(node);
|
|
}
|
|
}
|
|
|
|
SharedHandle<DHTResponseMessage>
|
|
DHTMessageFactoryImpl::createFindNodeReplyMessage
|
|
(const SharedHandle<DHTNode>& remoteNode,
|
|
const Dict* dict,
|
|
const std::string& transactionID)
|
|
{
|
|
const String* nodesData =
|
|
downcast<String>(getDictionary(dict, DHTResponseMessage::R)->
|
|
get(family_ == AF_INET?DHTFindNodeReplyMessage::NODES:
|
|
DHTFindNodeReplyMessage::NODES6));
|
|
std::vector<SharedHandle<DHTNode> > nodes;
|
|
if(nodesData) {
|
|
extractNodes(nodes, nodesData->uc(), nodesData->s().size());
|
|
}
|
|
return createFindNodeReplyMessage(remoteNode, nodes, transactionID);
|
|
}
|
|
|
|
SharedHandle<DHTQueryMessage>
|
|
DHTMessageFactoryImpl::createGetPeersMessage
|
|
(const SharedHandle<DHTNode>& remoteNode,
|
|
const unsigned char* infoHash,
|
|
const std::string& transactionID)
|
|
{
|
|
SharedHandle<DHTGetPeersMessage> m
|
|
(new DHTGetPeersMessage(localNode_, remoteNode, infoHash, transactionID));
|
|
m->setPeerAnnounceStorage(peerAnnounceStorage_);
|
|
m->setTokenTracker(tokenTracker_);
|
|
setCommonProperty(m);
|
|
return m;
|
|
}
|
|
|
|
SharedHandle<DHTResponseMessage>
|
|
DHTMessageFactoryImpl::createGetPeersReplyMessage
|
|
(const SharedHandle<DHTNode>& remoteNode,
|
|
const Dict* dict,
|
|
const std::string& transactionID)
|
|
{
|
|
const Dict* rDict = getDictionary(dict, DHTResponseMessage::R);
|
|
const String* nodesData =
|
|
downcast<String>(rDict->get(family_ == AF_INET?DHTGetPeersReplyMessage::NODES:
|
|
DHTGetPeersReplyMessage::NODES6));
|
|
std::vector<SharedHandle<DHTNode> > nodes;
|
|
if(nodesData) {
|
|
extractNodes(nodes, nodesData->uc(), nodesData->s().size());
|
|
}
|
|
const List* valuesList =
|
|
downcast<List>(rDict->get(DHTGetPeersReplyMessage::VALUES));
|
|
std::vector<SharedHandle<Peer> > peers;
|
|
size_t clen = bittorrent::getCompactLength(family_);
|
|
if(valuesList) {
|
|
for(List::ValueType::const_iterator i = valuesList->begin(),
|
|
eoi = valuesList->end(); i != eoi; ++i) {
|
|
const String* data = downcast<String>(*i);
|
|
if(data && data->s().size() == clen) {
|
|
std::pair<std::string, uint16_t> addr =
|
|
bittorrent::unpackcompact(data->uc(), family_);
|
|
if(addr.first.empty()) {
|
|
continue;
|
|
}
|
|
SharedHandle<Peer> peer(new Peer(addr.first, addr.second));
|
|
peers.push_back(peer);
|
|
}
|
|
}
|
|
}
|
|
const String* token = getString(rDict, DHTGetPeersReplyMessage::TOKEN);
|
|
return createGetPeersReplyMessage
|
|
(remoteNode, nodes, peers, token->s(), transactionID);
|
|
}
|
|
|
|
SharedHandle<DHTResponseMessage>
|
|
DHTMessageFactoryImpl::createGetPeersReplyMessage
|
|
(const SharedHandle<DHTNode>& remoteNode,
|
|
const std::vector<SharedHandle<DHTNode> >& closestKNodes,
|
|
const std::vector<SharedHandle<Peer> >& values,
|
|
const std::string& token,
|
|
const std::string& transactionID)
|
|
{
|
|
SharedHandle<DHTGetPeersReplyMessage> m
|
|
(new DHTGetPeersReplyMessage
|
|
(family_, localNode_, remoteNode, token, transactionID));
|
|
m->setClosestKNodes(closestKNodes);
|
|
m->setValues(values);
|
|
setCommonProperty(m);
|
|
return m;
|
|
}
|
|
|
|
SharedHandle<DHTQueryMessage>
|
|
DHTMessageFactoryImpl::createAnnouncePeerMessage
|
|
(const SharedHandle<DHTNode>& remoteNode,
|
|
const unsigned char* infoHash,
|
|
uint16_t tcpPort,
|
|
const std::string& token,
|
|
const std::string& transactionID)
|
|
{
|
|
SharedHandle<DHTAnnouncePeerMessage> m
|
|
(new DHTAnnouncePeerMessage
|
|
(localNode_, remoteNode, infoHash, tcpPort, token, transactionID));
|
|
m->setPeerAnnounceStorage(peerAnnounceStorage_);
|
|
m->setTokenTracker(tokenTracker_);
|
|
setCommonProperty(m);
|
|
return m;
|
|
}
|
|
|
|
SharedHandle<DHTResponseMessage>
|
|
DHTMessageFactoryImpl::createAnnouncePeerReplyMessage
|
|
(const SharedHandle<DHTNode>& remoteNode, const std::string& transactionID)
|
|
{
|
|
SharedHandle<DHTAnnouncePeerReplyMessage> m
|
|
(new DHTAnnouncePeerReplyMessage(localNode_, remoteNode, transactionID));
|
|
setCommonProperty(m);
|
|
return m;
|
|
}
|
|
|
|
SharedHandle<DHTMessage>
|
|
DHTMessageFactoryImpl::createUnknownMessage
|
|
(const unsigned char* data, size_t length,
|
|
const std::string& ipaddr, uint16_t port)
|
|
|
|
{
|
|
SharedHandle<DHTUnknownMessage> m
|
|
(new DHTUnknownMessage(localNode_, data, length, ipaddr, port));
|
|
return m;
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::setRoutingTable(DHTRoutingTable* routingTable)
|
|
{
|
|
routingTable_ = routingTable;
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::setConnection(DHTConnection* connection)
|
|
{
|
|
connection_ = connection;
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::setMessageDispatcher
|
|
(DHTMessageDispatcher* dispatcher)
|
|
{
|
|
dispatcher_ = dispatcher;
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::setPeerAnnounceStorage
|
|
(DHTPeerAnnounceStorage* storage)
|
|
{
|
|
peerAnnounceStorage_ = storage;
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::setTokenTracker(DHTTokenTracker* tokenTracker)
|
|
{
|
|
tokenTracker_ = tokenTracker;
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::setLocalNode
|
|
(const SharedHandle<DHTNode>& localNode)
|
|
{
|
|
localNode_ = localNode;
|
|
}
|
|
|
|
} // namespace aria2
|