mirror of https://github.com/aria2/aria2
533 lines
17 KiB
C++
533 lines
17 KiB
C++
/* <!-- copyright */
|
|
/*
|
|
* aria2 - The high speed download utility
|
|
*
|
|
* Copyright (C) 2006 Tatsuhiro Tsujikawa
|
|
*
|
|
* This program is free software; you can redistribute it and/or modify
|
|
* it under the terms of the GNU General Public License as published by
|
|
* the Free Software Foundation; either version 2 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* This program is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU General Public License
|
|
* along with this program; if not, write to the Free Software
|
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
|
*
|
|
* In addition, as a special exception, the copyright holders give
|
|
* permission to link the code of portions of this program with the
|
|
* OpenSSL library under certain conditions as described in each
|
|
* individual source file, and distribute linked combinations
|
|
* including the two.
|
|
* You must obey the GNU General Public License in all respects
|
|
* for all of the code used other than OpenSSL. If you modify
|
|
* file(s) with this exception, you may extend this exception to your
|
|
* version of the file(s), but you are not obligated to do so. If you
|
|
* do not wish to do so, delete this exception statement from your
|
|
* version. If you delete this exception statement from all source
|
|
* files in the program, then also delete it here.
|
|
*/
|
|
/* copyright --> */
|
|
#include "DHTMessageFactoryImpl.h"
|
|
|
|
#include <cstring>
|
|
#include <utility>
|
|
|
|
#include "LogFactory.h"
|
|
#include "DlAbortEx.h"
|
|
#include "DHTNode.h"
|
|
#include "DHTRoutingTable.h"
|
|
#include "DHTPingMessage.h"
|
|
#include "DHTPingReplyMessage.h"
|
|
#include "DHTFindNodeMessage.h"
|
|
#include "DHTFindNodeReplyMessage.h"
|
|
#include "DHTGetPeersMessage.h"
|
|
#include "DHTGetPeersReplyMessage.h"
|
|
#include "DHTAnnouncePeerMessage.h"
|
|
#include "DHTAnnouncePeerReplyMessage.h"
|
|
#include "DHTUnknownMessage.h"
|
|
#include "DHTConnection.h"
|
|
#include "DHTMessageDispatcher.h"
|
|
#include "DHTPeerAnnounceStorage.h"
|
|
#include "DHTTokenTracker.h"
|
|
#include "DHTMessageCallback.h"
|
|
#include "bittorrent_helper.h"
|
|
#include "BtRuntime.h"
|
|
#include "util.h"
|
|
#include "Peer.h"
|
|
#include "Logger.h"
|
|
#include "fmt.h"
|
|
|
|
namespace aria2 {
|
|
|
|
DHTMessageFactoryImpl::DHTMessageFactoryImpl(int family)
|
|
: family_{family},
|
|
connection_{nullptr},
|
|
dispatcher_{nullptr},
|
|
routingTable_{nullptr},
|
|
peerAnnounceStorage_{nullptr},
|
|
tokenTracker_{nullptr}
|
|
{
|
|
}
|
|
|
|
std::shared_ptr<DHTNode> DHTMessageFactoryImpl::getRemoteNode(
|
|
const unsigned char* id, const std::string& ipaddr, uint16_t port) const
|
|
{
|
|
auto node = routingTable_->getNode(id, ipaddr, port);
|
|
if (!node) {
|
|
node = std::make_shared<DHTNode>(id);
|
|
node->setIPAddress(ipaddr);
|
|
node->setPort(port);
|
|
}
|
|
return node;
|
|
}
|
|
|
|
namespace {
|
|
const Dict* getDictionary(const Dict* dict, const std::string& key)
|
|
{
|
|
const Dict* d = downcast<Dict>(dict->get(key));
|
|
if (d) {
|
|
return d;
|
|
}
|
|
else {
|
|
throw DL_ABORT_EX(fmt("Malformed DHT message. Missing %s", key.c_str()));
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
namespace {
|
|
const String* getString(const Dict* dict, const std::string& key)
|
|
{
|
|
const String* c = downcast<String>(dict->get(key));
|
|
if (c) {
|
|
return c;
|
|
}
|
|
else {
|
|
throw DL_ABORT_EX(fmt("Malformed DHT message. Missing %s", key.c_str()));
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
namespace {
|
|
const Integer* getInteger(const Dict* dict, const std::string& key)
|
|
{
|
|
const Integer* c = downcast<Integer>(dict->get(key));
|
|
if (c) {
|
|
return c;
|
|
}
|
|
else {
|
|
throw DL_ABORT_EX(fmt("Malformed DHT message. Missing %s", key.c_str()));
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
namespace {
|
|
const String* getString(const List* list, size_t index)
|
|
{
|
|
const String* c = downcast<String>(list->get(index));
|
|
if (c) {
|
|
return c;
|
|
}
|
|
else {
|
|
throw DL_ABORT_EX(fmt("Malformed DHT message. element[%lu] is not String.",
|
|
static_cast<unsigned long>(index)));
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
namespace {
|
|
const Integer* getInteger(const List* list, size_t index)
|
|
{
|
|
const Integer* c = downcast<Integer>(list->get(index));
|
|
if (c) {
|
|
return c;
|
|
}
|
|
else {
|
|
throw DL_ABORT_EX(fmt("Malformed DHT message. element[%lu] is not Integer.",
|
|
static_cast<unsigned long>(index)));
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
namespace {
|
|
const List* getList(const Dict* dict, const std::string& key)
|
|
{
|
|
const List* l = downcast<List>(dict->get(key));
|
|
if (l) {
|
|
return l;
|
|
}
|
|
else {
|
|
throw DL_ABORT_EX(fmt("Malformed DHT message. Missing %s", key.c_str()));
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
void DHTMessageFactoryImpl::validateID(const String* id) const
|
|
{
|
|
if (id->s().size() != DHT_ID_LENGTH) {
|
|
throw DL_ABORT_EX(fmt("Malformed DHT message. Invalid ID length."
|
|
" Expected:%lu, Actual:%lu",
|
|
static_cast<unsigned long>(DHT_ID_LENGTH),
|
|
static_cast<unsigned long>(id->s().size())));
|
|
}
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::validatePort(const Integer* port) const
|
|
{
|
|
if (!(0 < port->i() && port->i() < UINT16_MAX)) {
|
|
throw DL_ABORT_EX(
|
|
fmt("Malformed DHT message. Invalid port=%" PRId64 "", port->i()));
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
void setVersion(DHTMessage* msg, const Dict* dict)
|
|
{
|
|
const String* v = downcast<String>(dict->get(DHTMessage::V));
|
|
if (v) {
|
|
msg->setVersion(v->s());
|
|
}
|
|
else {
|
|
msg->setVersion(A2STR::NIL);
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
std::unique_ptr<DHTQueryMessage> DHTMessageFactoryImpl::createQueryMessage(
|
|
const Dict* dict, const std::string& ipaddr, uint16_t port)
|
|
{
|
|
const String* messageType = getString(dict, DHTQueryMessage::Q);
|
|
const String* transactionID = getString(dict, DHTMessage::T);
|
|
const String* y = getString(dict, DHTMessage::Y);
|
|
const Dict* aDict = getDictionary(dict, DHTQueryMessage::A);
|
|
if (y->s() != DHTQueryMessage::Q) {
|
|
throw DL_ABORT_EX("Malformed DHT message. y != q");
|
|
}
|
|
const String* id = getString(aDict, DHTMessage::ID);
|
|
validateID(id);
|
|
auto remoteNode = getRemoteNode(id->uc(), ipaddr, port);
|
|
auto msg = std::unique_ptr<DHTQueryMessage>{};
|
|
if (messageType->s() == DHTPingMessage::PING) {
|
|
msg = createPingMessage(remoteNode, transactionID->s());
|
|
}
|
|
else if (messageType->s() == DHTFindNodeMessage::FIND_NODE) {
|
|
const String* targetNodeID =
|
|
getString(aDict, DHTFindNodeMessage::TARGET_NODE);
|
|
validateID(targetNodeID);
|
|
msg = createFindNodeMessage(remoteNode, targetNodeID->uc(),
|
|
transactionID->s());
|
|
}
|
|
else if (messageType->s() == DHTGetPeersMessage::GET_PEERS) {
|
|
const String* infoHash = getString(aDict, DHTGetPeersMessage::INFO_HASH);
|
|
validateID(infoHash);
|
|
msg = createGetPeersMessage(remoteNode, infoHash->uc(), transactionID->s());
|
|
}
|
|
else if (messageType->s() == DHTAnnouncePeerMessage::ANNOUNCE_PEER) {
|
|
const String* infoHash =
|
|
getString(aDict, DHTAnnouncePeerMessage::INFO_HASH);
|
|
validateID(infoHash);
|
|
const Integer* port = getInteger(aDict, DHTAnnouncePeerMessage::PORT);
|
|
validatePort(port);
|
|
const String* token = getString(aDict, DHTAnnouncePeerMessage::TOKEN);
|
|
msg = createAnnouncePeerMessage(remoteNode, infoHash->uc(),
|
|
static_cast<uint16_t>(port->i()),
|
|
token->s(), transactionID->s());
|
|
}
|
|
else {
|
|
throw DL_ABORT_EX(
|
|
fmt("Unsupported message type: %s", messageType->s().c_str()));
|
|
}
|
|
setVersion(msg.get(), dict);
|
|
return msg;
|
|
}
|
|
|
|
std::unique_ptr<DHTResponseMessage>
|
|
DHTMessageFactoryImpl::createResponseMessage(const std::string& messageType,
|
|
const Dict* dict,
|
|
const std::string& ipaddr,
|
|
uint16_t port)
|
|
{
|
|
const String* transactionID = getString(dict, DHTMessage::T);
|
|
const String* y = getString(dict, DHTMessage::Y);
|
|
if (y->s() == DHTUnknownMessage::E) {
|
|
// for now, just report error message arrived and throw exception.
|
|
const List* e = getList(dict, DHTUnknownMessage::E);
|
|
if (e->size() == 2) {
|
|
A2_LOG_INFO(fmt("Received Error DHT message. code=%" PRId64 ", msg=%s",
|
|
getInteger(e, 0)->i(),
|
|
util::percentEncode(getString(e, 1)->s()).c_str()));
|
|
}
|
|
else {
|
|
A2_LOG_DEBUG("e doesn't have 2 elements.");
|
|
}
|
|
throw DL_ABORT_EX("Received Error DHT message.");
|
|
}
|
|
else if (y->s() != DHTResponseMessage::R) {
|
|
throw DL_ABORT_EX(fmt("Malformed DHT message. y != r: y=%s",
|
|
util::percentEncode(y->s()).c_str()));
|
|
}
|
|
const Dict* rDict = getDictionary(dict, DHTResponseMessage::R);
|
|
const String* id = getString(rDict, DHTMessage::ID);
|
|
validateID(id);
|
|
auto remoteNode = getRemoteNode(id->uc(), ipaddr, port);
|
|
auto msg = std::unique_ptr<DHTResponseMessage>{};
|
|
if (messageType == DHTPingReplyMessage::PING) {
|
|
msg = createPingReplyMessage(remoteNode, id->uc(), transactionID->s());
|
|
}
|
|
else if (messageType == DHTFindNodeReplyMessage::FIND_NODE) {
|
|
msg = createFindNodeReplyMessage(remoteNode, dict, transactionID->s());
|
|
}
|
|
else if (messageType == DHTGetPeersReplyMessage::GET_PEERS) {
|
|
msg = createGetPeersReplyMessage(remoteNode, dict, transactionID->s());
|
|
}
|
|
else if (messageType == DHTAnnouncePeerReplyMessage::ANNOUNCE_PEER) {
|
|
msg = createAnnouncePeerReplyMessage(remoteNode, transactionID->s());
|
|
}
|
|
else {
|
|
throw DL_ABORT_EX(fmt("Unsupported message type: %s", messageType.c_str()));
|
|
}
|
|
setVersion(msg.get(), dict);
|
|
return msg;
|
|
}
|
|
|
|
namespace {
|
|
const std::string& getDefaultVersion()
|
|
{
|
|
static std::string version;
|
|
if (version.empty()) {
|
|
uint16_t vnum16 = htons(DHT_VERSION);
|
|
unsigned char buf[] = {'A', '2', 0, 0};
|
|
char* vnump = reinterpret_cast<char*>(&vnum16);
|
|
memcpy(buf + 2, vnump, 2);
|
|
version.assign(&buf[0], &buf[4]);
|
|
}
|
|
return version;
|
|
}
|
|
} // namespace
|
|
|
|
void DHTMessageFactoryImpl::setCommonProperty(DHTAbstractMessage* m)
|
|
{
|
|
m->setConnection(connection_);
|
|
m->setMessageDispatcher(dispatcher_);
|
|
m->setRoutingTable(routingTable_);
|
|
m->setMessageFactory(this);
|
|
m->setVersion(getDefaultVersion());
|
|
}
|
|
|
|
std::unique_ptr<DHTPingMessage> DHTMessageFactoryImpl::createPingMessage(
|
|
const std::shared_ptr<DHTNode>& remoteNode,
|
|
const std::string& transactionID)
|
|
{
|
|
auto m = make_unique<DHTPingMessage>(localNode_, remoteNode, transactionID);
|
|
setCommonProperty(m.get());
|
|
return m;
|
|
}
|
|
|
|
std::unique_ptr<DHTPingReplyMessage>
|
|
DHTMessageFactoryImpl::createPingReplyMessage(
|
|
const std::shared_ptr<DHTNode>& remoteNode, const unsigned char* id,
|
|
const std::string& transactionID)
|
|
{
|
|
auto m = make_unique<DHTPingReplyMessage>(localNode_, remoteNode, id,
|
|
transactionID);
|
|
setCommonProperty(m.get());
|
|
return m;
|
|
}
|
|
|
|
std::unique_ptr<DHTFindNodeMessage>
|
|
DHTMessageFactoryImpl::createFindNodeMessage(
|
|
const std::shared_ptr<DHTNode>& remoteNode,
|
|
const unsigned char* targetNodeID, const std::string& transactionID)
|
|
{
|
|
auto m = make_unique<DHTFindNodeMessage>(localNode_, remoteNode, targetNodeID,
|
|
transactionID);
|
|
setCommonProperty(m.get());
|
|
return m;
|
|
}
|
|
|
|
std::unique_ptr<DHTFindNodeReplyMessage>
|
|
DHTMessageFactoryImpl::createFindNodeReplyMessage(
|
|
const std::shared_ptr<DHTNode>& remoteNode,
|
|
std::vector<std::shared_ptr<DHTNode>> closestKNodes,
|
|
const std::string& transactionID)
|
|
{
|
|
auto m = make_unique<DHTFindNodeReplyMessage>(family_, localNode_, remoteNode,
|
|
transactionID);
|
|
m->setClosestKNodes(std::move(closestKNodes));
|
|
setCommonProperty(m.get());
|
|
return m;
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::extractNodes(
|
|
std::vector<std::shared_ptr<DHTNode>>& nodes, const unsigned char* src,
|
|
size_t length)
|
|
{
|
|
int unit = bittorrent::getCompactLength(family_) + 20;
|
|
if (length % unit != 0) {
|
|
throw DL_ABORT_EX(fmt("Nodes length is not multiple of %d", unit));
|
|
}
|
|
for (size_t offset = 0; offset < length; offset += unit) {
|
|
auto node = std::make_shared<DHTNode>(src + offset);
|
|
auto addr =
|
|
bittorrent::unpackcompact(src + offset + DHT_ID_LENGTH, family_);
|
|
if (addr.first.empty()) {
|
|
continue;
|
|
}
|
|
node->setIPAddress(addr.first);
|
|
node->setPort(addr.second);
|
|
nodes.push_back(std::move(node));
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<DHTFindNodeReplyMessage>
|
|
DHTMessageFactoryImpl::createFindNodeReplyMessage(
|
|
const std::shared_ptr<DHTNode>& remoteNode, const Dict* dict,
|
|
const std::string& transactionID)
|
|
{
|
|
const String* nodesData = downcast<String>(
|
|
getDictionary(dict, DHTResponseMessage::R)
|
|
->get(family_ == AF_INET ? DHTFindNodeReplyMessage::NODES
|
|
: DHTFindNodeReplyMessage::NODES6));
|
|
std::vector<std::shared_ptr<DHTNode>> nodes;
|
|
if (nodesData) {
|
|
extractNodes(nodes, nodesData->uc(), nodesData->s().size());
|
|
}
|
|
return createFindNodeReplyMessage(remoteNode, std::move(nodes),
|
|
transactionID);
|
|
}
|
|
|
|
std::unique_ptr<DHTGetPeersMessage>
|
|
DHTMessageFactoryImpl::createGetPeersMessage(
|
|
const std::shared_ptr<DHTNode>& remoteNode, const unsigned char* infoHash,
|
|
const std::string& transactionID)
|
|
{
|
|
auto m = make_unique<DHTGetPeersMessage>(localNode_, remoteNode, infoHash,
|
|
transactionID);
|
|
m->setPeerAnnounceStorage(peerAnnounceStorage_);
|
|
m->setTokenTracker(tokenTracker_);
|
|
setCommonProperty(m.get());
|
|
return m;
|
|
}
|
|
|
|
std::unique_ptr<DHTGetPeersReplyMessage>
|
|
DHTMessageFactoryImpl::createGetPeersReplyMessage(
|
|
const std::shared_ptr<DHTNode>& remoteNode, const Dict* dict,
|
|
const std::string& transactionID)
|
|
{
|
|
const Dict* rDict = getDictionary(dict, DHTResponseMessage::R);
|
|
const String* nodesData = downcast<String>(
|
|
rDict->get(family_ == AF_INET ? DHTGetPeersReplyMessage::NODES
|
|
: DHTGetPeersReplyMessage::NODES6));
|
|
std::vector<std::shared_ptr<DHTNode>> nodes;
|
|
if (nodesData) {
|
|
extractNodes(nodes, nodesData->uc(), nodesData->s().size());
|
|
}
|
|
const List* valuesList =
|
|
downcast<List>(rDict->get(DHTGetPeersReplyMessage::VALUES));
|
|
std::vector<std::shared_ptr<Peer>> peers;
|
|
size_t clen = bittorrent::getCompactLength(family_);
|
|
if (valuesList) {
|
|
for (auto& elem : *valuesList) {
|
|
const String* data = downcast<String>(elem);
|
|
if (data && data->s().size() == clen) {
|
|
auto addr = bittorrent::unpackcompact(data->uc(), family_);
|
|
if (addr.first.empty()) {
|
|
continue;
|
|
}
|
|
peers.push_back(std::make_shared<Peer>(addr.first, addr.second));
|
|
}
|
|
}
|
|
}
|
|
const String* token = getString(rDict, DHTGetPeersReplyMessage::TOKEN);
|
|
return createGetPeersReplyMessage(remoteNode, std::move(nodes),
|
|
std::move(peers), token->s(),
|
|
transactionID);
|
|
}
|
|
|
|
std::unique_ptr<DHTGetPeersReplyMessage>
|
|
DHTMessageFactoryImpl::createGetPeersReplyMessage(
|
|
const std::shared_ptr<DHTNode>& remoteNode,
|
|
std::vector<std::shared_ptr<DHTNode>> closestKNodes,
|
|
std::vector<std::shared_ptr<Peer>> values, const std::string& token,
|
|
const std::string& transactionID)
|
|
{
|
|
auto m = make_unique<DHTGetPeersReplyMessage>(family_, localNode_, remoteNode,
|
|
token, transactionID);
|
|
m->setClosestKNodes(std::move(closestKNodes));
|
|
m->setValues(std::move(values));
|
|
setCommonProperty(m.get());
|
|
return m;
|
|
}
|
|
|
|
std::unique_ptr<DHTAnnouncePeerMessage>
|
|
DHTMessageFactoryImpl::createAnnouncePeerMessage(
|
|
const std::shared_ptr<DHTNode>& remoteNode, const unsigned char* infoHash,
|
|
uint16_t tcpPort, const std::string& token,
|
|
const std::string& transactionID)
|
|
{
|
|
auto m = make_unique<DHTAnnouncePeerMessage>(localNode_, remoteNode, infoHash,
|
|
tcpPort, token, transactionID);
|
|
m->setPeerAnnounceStorage(peerAnnounceStorage_);
|
|
m->setTokenTracker(tokenTracker_);
|
|
setCommonProperty(m.get());
|
|
return m;
|
|
}
|
|
|
|
std::unique_ptr<DHTAnnouncePeerReplyMessage>
|
|
DHTMessageFactoryImpl::createAnnouncePeerReplyMessage(
|
|
const std::shared_ptr<DHTNode>& remoteNode,
|
|
const std::string& transactionID)
|
|
{
|
|
auto m = make_unique<DHTAnnouncePeerReplyMessage>(localNode_, remoteNode,
|
|
transactionID);
|
|
setCommonProperty(m.get());
|
|
return m;
|
|
}
|
|
|
|
std::unique_ptr<DHTUnknownMessage> DHTMessageFactoryImpl::createUnknownMessage(
|
|
const unsigned char* data, size_t length, const std::string& ipaddr,
|
|
uint16_t port)
|
|
|
|
{
|
|
return make_unique<DHTUnknownMessage>(localNode_, data, length, ipaddr, port);
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::setRoutingTable(DHTRoutingTable* routingTable)
|
|
{
|
|
routingTable_ = routingTable;
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::setConnection(DHTConnection* connection)
|
|
{
|
|
connection_ = connection;
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::setMessageDispatcher(
|
|
DHTMessageDispatcher* dispatcher)
|
|
{
|
|
dispatcher_ = dispatcher;
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::setPeerAnnounceStorage(
|
|
DHTPeerAnnounceStorage* storage)
|
|
{
|
|
peerAnnounceStorage_ = storage;
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::setTokenTracker(DHTTokenTracker* tokenTracker)
|
|
{
|
|
tokenTracker_ = tokenTracker;
|
|
}
|
|
|
|
void DHTMessageFactoryImpl::setLocalNode(
|
|
const std::shared_ptr<DHTNode>& localNode)
|
|
{
|
|
localNode_ = localNode;
|
|
}
|
|
|
|
} // namespace aria2
|