aria2/src/DHTSetup.cc

310 lines
12 KiB
C++

/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2006 Tatsuhiro Tsujikawa
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#include "DHTSetup.h"
#include <algorithm>
#include "LogFactory.h"
#include "Logger.h"
#include "util.h"
#include "DHTNode.h"
#include "DHTConnectionImpl.h"
#include "DHTRoutingTable.h"
#include "DHTMessageFactoryImpl.h"
#include "DHTMessageTracker.h"
#include "DHTMessageDispatcherImpl.h"
#include "DHTMessageReceiver.h"
#include "DHTTaskQueueImpl.h"
#include "DHTTaskFactoryImpl.h"
#include "DHTPeerAnnounceStorage.h"
#include "DHTTokenTracker.h"
#include "DHTInteractionCommand.h"
#include "DHTTokenUpdateCommand.h"
#include "DHTBucketRefreshCommand.h"
#include "DHTPeerAnnounceCommand.h"
#include "DHTEntryPointNameResolveCommand.h"
#include "DHTAutoSaveCommand.h"
#include "DHTTask.h"
#include "DHTRoutingTableDeserializer.h"
#include "DHTRegistry.h"
#include "DHTBucketRefreshTask.h"
#include "DHTMessageCallback.h"
#include "DHTMessageTrackerEntry.h"
#include "DHTMessageEntry.h"
#include "UDPTrackerClient.h"
#include "BtRegistry.h"
#include "prefs.h"
#include "Option.h"
#include "SocketCore.h"
#include "DlAbortEx.h"
#include "RecoverableException.h"
#include "a2functional.h"
#include "DownloadEngine.h"
#include "fmt.h"
namespace aria2 {
DHTSetup::DHTSetup() {}
DHTSetup::~DHTSetup() {}
std::pair<std::vector<std::unique_ptr<Command>>,
std::vector<std::unique_ptr<Command>>>
DHTSetup::setup(DownloadEngine* e, int family)
{
std::vector<std::unique_ptr<Command>> tempCommands;
std::vector<std::unique_ptr<Command>> tempRoutineCommands;
if ((family != AF_INET && family != AF_INET6) ||
(family == AF_INET && DHTRegistry::isInitialized()) ||
(family == AF_INET6 && DHTRegistry::isInitialized6())) {
return {};
}
try {
// load routing table and localnode id here
std::shared_ptr<DHTNode> localNode;
DHTRoutingTableDeserializer deserializer(family);
const std::string& dhtFile = e->getOption()->get(
family == AF_INET ? PREF_DHT_FILE_PATH : PREF_DHT_FILE_PATH6);
try {
deserializer.deserialize(dhtFile);
localNode = deserializer.getLocalNode();
}
catch (RecoverableException& e) {
A2_LOG_ERROR_EX(
fmt("Exception caught while loading DHT routing table from %s",
dhtFile.c_str()),
e);
}
if (!localNode) {
localNode = std::make_shared<DHTNode>();
}
uint16_t port;
auto connection = make_unique<DHTConnectionImpl>(family);
{
port = e->getBtRegistry()->getUdpPort();
const std::string& addr = e->getOption()->get(
family == AF_INET ? PREF_DHT_LISTEN_ADDR : PREF_DHT_LISTEN_ADDR6);
// If UDP port is already used, use the same port
// number. Normally IPv4 port is available, then IPv6 port is
// (especially for port >= 1024). We don't loose much by doing
// this. We did the same thing in TCP socket. See BtSetup.cc.
bool rv;
if (port == 0) {
auto sgl =
util::parseIntSegments(e->getOption()->get(PREF_DHT_LISTEN_PORT));
sgl.normalize();
rv = connection->bind(port, addr, sgl);
}
else {
rv = connection->bind(port, addr);
}
if (!rv) {
throw DL_ABORT_EX("Error occurred while binding UDP port for DHT");
}
localNode->setPort(port);
}
A2_LOG_DEBUG(fmt("Initialized local node ID=%s",
util::toHex(localNode->getID(), DHT_ID_LENGTH).c_str()));
auto tracker = std::make_shared<DHTMessageTracker>();
auto routingTable = make_unique<DHTRoutingTable>(localNode);
auto factory = make_unique<DHTMessageFactoryImpl>(family);
auto dispatcher = make_unique<DHTMessageDispatcherImpl>(tracker);
auto receiver = make_unique<DHTMessageReceiver>(tracker);
auto taskQueue = make_unique<DHTTaskQueueImpl>();
auto taskFactory = make_unique<DHTTaskFactoryImpl>();
auto peerAnnounceStorage = make_unique<DHTPeerAnnounceStorage>();
auto tokenTracker = make_unique<DHTTokenTracker>();
// For now, UDPTrackerClient was enabled along with DHT
auto udpTrackerClient = std::make_shared<UDPTrackerClient>();
const auto messageTimeout =
e->getOption()->getAsInt(PREF_DHT_MESSAGE_TIMEOUT);
// wiring up
tracker->setRoutingTable(routingTable.get());
tracker->setMessageFactory(factory.get());
dispatcher->setTimeout(std::chrono::seconds(messageTimeout));
receiver->setMessageFactory(factory.get());
receiver->setRoutingTable(routingTable.get());
taskFactory->setLocalNode(localNode);
taskFactory->setRoutingTable(routingTable.get());
taskFactory->setMessageDispatcher(dispatcher.get());
taskFactory->setMessageFactory(factory.get());
taskFactory->setTaskQueue(taskQueue.get());
taskFactory->setTimeout(std::chrono::seconds(messageTimeout));
routingTable->setTaskQueue(taskQueue.get());
routingTable->setTaskFactory(taskFactory.get());
peerAnnounceStorage->setTaskQueue(taskQueue.get());
peerAnnounceStorage->setTaskFactory(taskFactory.get());
factory->setRoutingTable(routingTable.get());
factory->setConnection(connection.get());
factory->setMessageDispatcher(dispatcher.get());
factory->setPeerAnnounceStorage(peerAnnounceStorage.get());
factory->setTokenTracker(tokenTracker.get());
factory->setLocalNode(localNode);
PrefPtr prefEntryPointHost = family == AF_INET ? PREF_DHT_ENTRY_POINT_HOST
: PREF_DHT_ENTRY_POINT_HOST6;
if (!e->getOption()->get(prefEntryPointHost).empty()) {
{
PrefPtr prefEntryPointPort = family == AF_INET
? PREF_DHT_ENTRY_POINT_PORT
: PREF_DHT_ENTRY_POINT_PORT6;
std::pair<std::string, uint16_t> addr(
e->getOption()->get(prefEntryPointHost),
e->getOption()->getAsInt(prefEntryPointPort));
std::vector<std::pair<std::string, uint16_t>> entryPoints;
entryPoints.push_back(addr);
auto command = make_unique<DHTEntryPointNameResolveCommand>(
e->newCUID(), e, entryPoints);
command->setBootstrapEnabled(true);
command->setTaskQueue(taskQueue.get());
command->setTaskFactory(taskFactory.get());
command->setRoutingTable(routingTable.get());
command->setLocalNode(localNode);
tempCommands.push_back(std::move(command));
}
}
else {
A2_LOG_INFO("No DHT entry point specified.");
}
{
auto command = make_unique<DHTInteractionCommand>(e->newCUID(), e);
command->setMessageDispatcher(dispatcher.get());
command->setMessageReceiver(receiver.get());
command->setTaskQueue(taskQueue.get());
command->setReadCheckSocket(connection->getSocket());
command->setConnection(std::move(connection));
command->setUDPTrackerClient(udpTrackerClient);
tempRoutineCommands.push_back(std::move(command));
}
{
auto command = make_unique<DHTTokenUpdateCommand>(
e->newCUID(), e, DHT_TOKEN_UPDATE_INTERVAL);
command->setTokenTracker(tokenTracker.get());
tempCommands.push_back(std::move(command));
}
{
auto command = make_unique<DHTBucketRefreshCommand>(
e->newCUID(), e, DHT_BUCKET_REFRESH_CHECK_INTERVAL);
command->setTaskQueue(taskQueue.get());
command->setRoutingTable(routingTable.get());
command->setTaskFactory(taskFactory.get());
tempCommands.push_back(std::move(command));
}
{
auto command = make_unique<DHTPeerAnnounceCommand>(
e->newCUID(), e, DHT_PEER_ANNOUNCE_CHECK_INTERVAL);
command->setPeerAnnounceStorage(peerAnnounceStorage.get());
tempCommands.push_back(std::move(command));
}
{
auto command =
make_unique<DHTAutoSaveCommand>(e->newCUID(), e, family, 30_min);
command->setLocalNode(localNode);
command->setRoutingTable(routingTable.get());
tempCommands.push_back(std::move(command));
}
// add deserialized nodes to routing table
auto& desnodes = deserializer.getNodes();
for (auto& node : desnodes) {
routingTable->addNode(node);
}
if (!desnodes.empty()) {
auto task = std::static_pointer_cast<DHTBucketRefreshTask>(
taskFactory->createBucketRefreshTask());
task->setForceRefresh(true);
taskQueue->addPeriodicTask1(task);
}
// assign them into DHTRegistry
if (family == AF_INET) {
DHTRegistry::getMutableData().localNode = localNode;
DHTRegistry::getMutableData().routingTable = std::move(routingTable);
DHTRegistry::getMutableData().taskQueue = std::move(taskQueue);
DHTRegistry::getMutableData().taskFactory = std::move(taskFactory);
DHTRegistry::getMutableData().peerAnnounceStorage =
std::move(peerAnnounceStorage);
DHTRegistry::getMutableData().tokenTracker = std::move(tokenTracker);
DHTRegistry::getMutableData().messageDispatcher = std::move(dispatcher);
DHTRegistry::getMutableData().messageReceiver = std::move(receiver);
DHTRegistry::getMutableData().messageFactory = std::move(factory);
e->getBtRegistry()->setUDPTrackerClient(udpTrackerClient);
DHTRegistry::setInitialized(true);
}
else {
DHTRegistry::getMutableData6().localNode = localNode;
DHTRegistry::getMutableData6().routingTable = std::move(routingTable);
DHTRegistry::getMutableData6().taskQueue = std::move(taskQueue);
DHTRegistry::getMutableData6().taskFactory = std::move(taskFactory);
DHTRegistry::getMutableData6().peerAnnounceStorage =
std::move(peerAnnounceStorage);
DHTRegistry::getMutableData6().tokenTracker = std::move(tokenTracker);
DHTRegistry::getMutableData6().messageDispatcher = std::move(dispatcher);
DHTRegistry::getMutableData6().messageReceiver = std::move(receiver);
DHTRegistry::getMutableData6().messageFactory = std::move(factory);
DHTRegistry::setInitialized6(true);
}
if (e->getBtRegistry()->getUdpPort() == 0) {
// We assign port last so that no exception gets in the way
e->getBtRegistry()->setUdpPort(port);
}
}
catch (RecoverableException& ex) {
A2_LOG_ERROR_EX(fmt("Exception caught while initializing DHT functionality."
" DHT is disabled."),
ex);
tempCommands.clear();
tempRoutineCommands.clear();
if (family == AF_INET) {
DHTRegistry::clearData();
e->getBtRegistry()->setUDPTrackerClient(
std::shared_ptr<UDPTrackerClient>{});
}
else {
DHTRegistry::clearData6();
}
}
return std::make_pair(std::move(tempCommands),
std::move(tempRoutineCommands));
}
} // namespace aria2