Use std::unique_ptr to store DHTMessages instead of std::shared_ptr

pull/103/head
Tatsuhiro Tsujikawa 2013-07-02 22:58:20 +09:00
parent 4f7d1c395b
commit 1a5d75e819
53 changed files with 833 additions and 872 deletions

View File

@ -46,17 +46,16 @@
namespace aria2 {
DHTAbstractMessage::DHTAbstractMessage(const std::shared_ptr<DHTNode>& localNode,
const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID):
DHTMessage(localNode, remoteNode, transactionID),
connection_(0),
dispatcher_(0),
factory_(0),
routingTable_(0)
{}
DHTAbstractMessage::~DHTAbstractMessage() {}
DHTAbstractMessage::DHTAbstractMessage
(const std::shared_ptr<DHTNode>& localNode,
const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID)
: DHTMessage{localNode, remoteNode, transactionID},
connection_{nullptr},
dispatcher_{nullptr},
factory_{nullptr},
routingTable_{nullptr}
{}
std::string DHTAbstractMessage::getBencodedMessage()
{

View File

@ -60,8 +60,6 @@ public:
const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID = A2STR::NIL);
virtual ~DHTAbstractMessage();
virtual bool send();
virtual const std::string& getType() const = 0;

View File

@ -67,32 +67,29 @@ class DHTAbstractNodeLookupTask:public DHTAbstractTask {
private:
unsigned char targetID_[DHT_ID_LENGTH];
std::deque<std::shared_ptr<DHTNodeLookupEntry> > entries_;
std::deque<std::unique_ptr<DHTNodeLookupEntry>> entries_;
size_t inFlightMessage_;
template<typename Container>
void toEntries
(Container& entries, const std::vector<std::shared_ptr<DHTNode> >& nodes) const
(Container& entries,
const std::vector<std::shared_ptr<DHTNode>>& nodes) const
{
for(std::vector<std::shared_ptr<DHTNode> >::const_iterator i = nodes.begin(),
eoi = nodes.end(); i != eoi; ++i) {
std::shared_ptr<DHTNodeLookupEntry> e(new DHTNodeLookupEntry(*i));
entries.push_back(e);
for(auto& node : nodes) {
entries.push_back(make_unique<DHTNodeLookupEntry>(node));
}
}
void sendMessage()
{
for(std::deque<std::shared_ptr<DHTNodeLookupEntry> >::iterator i =
entries_.begin(), eoi = entries_.end();
for(auto i = std::begin(entries_), eoi = std::end(entries_);
i != eoi && inFlightMessage_ < ALPHA; ++i) {
if((*i)->used == false) {
++inFlightMessage_;
(*i)->used = true;
std::shared_ptr<DHTMessage> m = createMessage((*i)->node);
std::shared_ptr<DHTMessageCallback> callback(createCallback());
getMessageDispatcher()->addMessageToQueue(m, callback);
getMessageDispatcher()->addMessageToQueue
(createMessage((*i)->node), createCallback());
}
}
}
@ -122,13 +119,13 @@ protected:
return targetID_;
}
const std::deque<std::shared_ptr<DHTNodeLookupEntry> >& getEntries() const
const std::deque<std::unique_ptr<DHTNodeLookupEntry>>& getEntries() const
{
return entries_;
}
virtual void getNodesFromMessage
(std::vector<std::shared_ptr<DHTNode> >& nodes,
(std::vector<std::shared_ptr<DHTNode>>& nodes,
const ResponseMessage* message) = 0;
virtual void onReceivedInternal
@ -138,10 +135,10 @@ protected:
virtual void onFinish() {}
virtual std::shared_ptr<DHTMessage> createMessage
virtual std::unique_ptr<DHTMessage> createMessage
(const std::shared_ptr<DHTNode>& remoteNode) = 0;
virtual std::shared_ptr<DHTMessageCallback> createCallback() = 0;
virtual std::unique_ptr<DHTMessageCallback> createCallback() = 0;
public:
DHTAbstractNodeLookupTask(const unsigned char* targetID):
inFlightMessage_(0)
@ -153,7 +150,7 @@ public:
virtual void startup()
{
std::vector<std::shared_ptr<DHTNode> > nodes;
std::vector<std::shared_ptr<DHTNode>> nodes;
getRoutingTable()->getClosestKNodes(nodes, targetID_);
entries_.clear();
toEntries(entries_, nodes);
@ -174,43 +171,42 @@ public:
{
--inFlightMessage_;
// Replace old Node ID with new Node ID.
for(std::deque<std::shared_ptr<DHTNodeLookupEntry> >::iterator i =
entries_.begin(), eoi = entries_.end(); i != eoi; ++i) {
if((*i)->node->getIPAddress() == message->getRemoteNode()->getIPAddress()
&& (*i)->node->getPort() == message->getRemoteNode()->getPort()) {
(*i)->node = message->getRemoteNode();
for(auto& entry : entries_) {
if(entry->node->getIPAddress() == message->getRemoteNode()->getIPAddress()
&& entry->node->getPort() == message->getRemoteNode()->getPort()) {
entry->node = message->getRemoteNode();
}
}
onReceivedInternal(message);
std::vector<std::shared_ptr<DHTNode> > nodes;
std::vector<std::shared_ptr<DHTNode>> nodes;
getNodesFromMessage(nodes, message);
std::vector<std::shared_ptr<DHTNodeLookupEntry> > newEntries;
std::vector<std::unique_ptr<DHTNodeLookupEntry>> newEntries;
toEntries(newEntries, nodes);
size_t count = 0;
for(std::vector<std::shared_ptr<DHTNodeLookupEntry> >::const_iterator i =
newEntries.begin(), eoi = newEntries.end(); i != eoi; ++i) {
if(memcmp(getLocalNode()->getID(), (*i)->node->getID(),
for(auto& ne : newEntries) {
if(memcmp(getLocalNode()->getID(), ne->node->getID(),
DHT_ID_LENGTH) != 0) {
entries_.push_front(*i);
++count;
A2_LOG_DEBUG(fmt("Received nodes: id=%s, ip=%s",
util::toHex((*i)->node->getID(),
util::toHex(ne->node->getID(),
DHT_ID_LENGTH).c_str(),
(*i)->node->getIPAddress().c_str()));
ne->node->getIPAddress().c_str()));
entries_.push_front(std::move(ne));
++count;
}
}
A2_LOG_DEBUG(fmt("%lu node lookup entries added.",
static_cast<unsigned long>(count)));
std::stable_sort(entries_.begin(), entries_.end(), DHTIDCloser(targetID_));
std::stable_sort(std::begin(entries_), std::end(entries_),
DHTIDCloser(targetID_));
entries_.erase
(std::unique(entries_.begin(), entries_.end(),
DerefEqualTo<std::shared_ptr<DHTNodeLookupEntry> >()),
entries_.end());
(std::unique(std::begin(entries_), std::end(entries_),
DerefEqualTo<std::unique_ptr<DHTNodeLookupEntry>>{}),
std::end(entries_));
A2_LOG_DEBUG(fmt("%lu node lookup entries are unique.",
static_cast<unsigned long>(entries_.size())));
if(entries_.size() > DHTBucket::K) {
entries_.erase(entries_.begin()+DHTBucket::K, entries_.end());
entries_.erase(std::begin(entries_)+DHTBucket::K, std::end(entries_));
}
sendMessageAndCheckFinish();
}
@ -220,8 +216,8 @@ public:
A2_LOG_DEBUG(fmt("node lookup message timeout for node ID=%s",
util::toHex(node->getID(), DHT_ID_LENGTH).c_str()));
--inFlightMessage_;
for(std::deque<std::shared_ptr<DHTNodeLookupEntry> >::iterator i =
entries_.begin(), eoi = entries_.end(); i != eoi; ++i) {
for(auto i = std::begin(entries_), eoi = std::end(entries_);
i != eoi; ++i) {
if(*(*i)->node == *node) {
entries_.erase(i);
break;

View File

@ -44,6 +44,7 @@
#include "util.h"
#include "DHTPeerAnnounceStorage.h"
#include "DHTTokenTracker.h"
#include "DHTAnnouncePeerReplyMessage.h"
#include "DlAbortEx.h"
#include "BtConstants.h"
#include "fmt.h"
@ -65,32 +66,29 @@ DHTAnnouncePeerMessage::DHTAnnouncePeerMessage
const unsigned char* infoHash,
uint16_t tcpPort,
const std::string& token,
const std::string& transactionID):
DHTQueryMessage(localNode, remoteNode, transactionID),
token_(token),
tcpPort_(tcpPort),
peerAnnounceStorage_(0),
tokenTracker_(0)
const std::string& transactionID)
: DHTQueryMessage{localNode, remoteNode, transactionID},
token_{token},
tcpPort_{tcpPort},
peerAnnounceStorage_{nullptr},
tokenTracker_{nullptr}
{
memcpy(infoHash_, infoHash, DHT_ID_LENGTH);
}
DHTAnnouncePeerMessage::~DHTAnnouncePeerMessage() {}
void DHTAnnouncePeerMessage::doReceivedAction()
{
peerAnnounceStorage_->addPeerAnnounce
(infoHash_, getRemoteNode()->getIPAddress(), tcpPort_);
std::shared_ptr<DHTMessage> reply =
getMessageFactory()->createAnnouncePeerReplyMessage
(getRemoteNode(), getTransactionID());
getMessageDispatcher()->addMessageToQueue(reply);
getMessageDispatcher()->addMessageToQueue
(getMessageFactory()->createAnnouncePeerReplyMessage
(getRemoteNode(), getTransactionID()));
}
std::shared_ptr<Dict> DHTAnnouncePeerMessage::getArgument()
{
std::shared_ptr<Dict> aDict = Dict::g();
auto aDict = Dict::g();
aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH));
aDict->put(INFO_HASH, String::g(infoHash_, DHT_ID_LENGTH));
aDict->put(PORT, Integer::g(tcpPort_));

View File

@ -65,8 +65,6 @@ public:
const std::string& token,
const std::string& transactionID = A2STR::NIL);
virtual ~DHTAnnouncePeerMessage();
virtual void doReceivedAction();
virtual std::shared_ptr<Dict> getArgument();

View File

@ -41,6 +41,7 @@
#include "DHTMessageFactory.h"
#include "DHTMessageDispatcher.h"
#include "DHTMessageCallback.h"
#include "DHTFindNodeReplyMessage.h"
#include "util.h"
namespace aria2 {
@ -49,30 +50,28 @@ const std::string DHTFindNodeMessage::FIND_NODE("find_node");
const std::string DHTFindNodeMessage::TARGET_NODE("target");
DHTFindNodeMessage::DHTFindNodeMessage(const std::shared_ptr<DHTNode>& localNode,
const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* targetNodeID,
const std::string& transactionID):
DHTQueryMessage(localNode, remoteNode, transactionID)
DHTFindNodeMessage::DHTFindNodeMessage
(const std::shared_ptr<DHTNode>& localNode,
const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* targetNodeID,
const std::string& transactionID)
: DHTQueryMessage{localNode, remoteNode, transactionID}
{
memcpy(targetNodeID_, targetNodeID, DHT_ID_LENGTH);
}
DHTFindNodeMessage::~DHTFindNodeMessage() {}
void DHTFindNodeMessage::doReceivedAction()
{
std::vector<std::shared_ptr<DHTNode> > nodes;
std::vector<std::shared_ptr<DHTNode>> nodes;
getRoutingTable()->getClosestKNodes(nodes, targetNodeID_);
std::shared_ptr<DHTMessage> reply =
getMessageFactory()->createFindNodeReplyMessage
(getRemoteNode(), nodes, getTransactionID());
getMessageDispatcher()->addMessageToQueue(reply);
getMessageDispatcher()->addMessageToQueue
(getMessageFactory()->createFindNodeReplyMessage
(getRemoteNode(), std::move(nodes), getTransactionID()));
}
std::shared_ptr<Dict> DHTFindNodeMessage::getArgument()
{
std::shared_ptr<Dict> aDict = Dict::g();
auto aDict = Dict::g();
aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH));
aDict->put(TARGET_NODE, String::g(targetNodeID_, DHT_ID_LENGTH));
return aDict;

View File

@ -51,8 +51,6 @@ public:
const unsigned char* targetNodeID,
const std::string& transactionID = A2STR::NIL);
virtual ~DHTFindNodeMessage();
virtual void doReceivedAction();
virtual std::shared_ptr<Dict> getArgument();

View File

@ -57,25 +57,23 @@ DHTFindNodeReplyMessage::DHTFindNodeReplyMessage
(int family,
const std::shared_ptr<DHTNode>& localNode,
const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID):
DHTResponseMessage(localNode, remoteNode, transactionID),
family_(family) {}
DHTFindNodeReplyMessage::~DHTFindNodeReplyMessage() {}
const std::string& transactionID)
: DHTResponseMessage{localNode, remoteNode, transactionID},
family_{family}
{}
void DHTFindNodeReplyMessage::doReceivedAction()
{
for(std::vector<std::shared_ptr<DHTNode> >::iterator i = closestKNodes_.begin(),
eoi = closestKNodes_.end(); i != eoi; ++i) {
if(memcmp((*i)->getID(), getLocalNode()->getID(), DHT_ID_LENGTH) != 0) {
getRoutingTable()->addNode(*i);
for(auto& node : closestKNodes_) {
if(memcmp(node->getID(), getLocalNode()->getID(), DHT_ID_LENGTH) != 0) {
getRoutingTable()->addNode(node);
}
}
}
std::shared_ptr<Dict> DHTFindNodeReplyMessage::getResponse()
{
std::shared_ptr<Dict> aDict = Dict::g();
auto aDict = Dict::g();
aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH));
unsigned char buffer[DHTBucket::K*38];
const int clen = bittorrent::getCompactLength(family_);
@ -83,14 +81,12 @@ std::shared_ptr<Dict> DHTFindNodeReplyMessage::getResponse()
assert(unit <= 38);
size_t offset = 0;
size_t k = 0;
for(std::vector<std::shared_ptr<DHTNode> >::const_iterator i =
closestKNodes_.begin(), eoi = closestKNodes_.end();
for(auto i = std::begin(closestKNodes_), eoi = std::end(closestKNodes_);
i != eoi && k < DHTBucket::K; ++i) {
std::shared_ptr<DHTNode> node = *i;
memcpy(buffer+offset, node->getID(), DHT_ID_LENGTH);
memcpy(buffer+offset, (*i)->getID(), DHT_ID_LENGTH);
unsigned char compact[COMPACT_LEN_IPV6];
int compactlen = bittorrent::packcompact
(compact, node->getIPAddress(), node->getPort());
int compactlen = bittorrent::packcompact(compact, (*i)->getIPAddress(),
(*i)->getPort());
if(compactlen == clen) {
memcpy(buffer+20+offset, compact, compactlen);
offset += unit;
@ -112,9 +108,9 @@ void DHTFindNodeReplyMessage::accept(DHTMessageCallback* callback)
}
void DHTFindNodeReplyMessage::setClosestKNodes
(const std::vector<std::shared_ptr<DHTNode> >& closestKNodes)
(std::vector<std::shared_ptr<DHTNode>> closestKNodes)
{
closestKNodes_ = closestKNodes;
closestKNodes_ = std::move(closestKNodes);
}
std::string DHTFindNodeReplyMessage::toStringOptional() const

View File

@ -44,7 +44,7 @@ class DHTFindNodeReplyMessage:public DHTResponseMessage {
private:
int family_;
std::vector<std::shared_ptr<DHTNode> > closestKNodes_;
std::vector<std::shared_ptr<DHTNode>> closestKNodes_;
protected:
virtual std::string toStringOptional() const;
public:
@ -53,8 +53,6 @@ public:
const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID);
virtual ~DHTFindNodeReplyMessage();
virtual void doReceivedAction();
virtual std::shared_ptr<Dict> getResponse();
@ -63,13 +61,12 @@ public:
virtual void accept(DHTMessageCallback* callback);
const std::vector<std::shared_ptr<DHTNode> >& getClosestKNodes() const
const std::vector<std::shared_ptr<DHTNode>>& getClosestKNodes() const
{
return closestKNodes_;
}
void setClosestKNodes
(const std::vector<std::shared_ptr<DHTNode> >& closestKNodes);
void setClosestKNodes(std::vector<std::shared_ptr<DHTNode>> closestKNodes);
static const std::string FIND_NODE;

View File

@ -44,6 +44,7 @@
#include "DHTPeerAnnounceStorage.h"
#include "Peer.h"
#include "DHTTokenTracker.h"
#include "DHTGetPeersReplyMessage.h"
#include "util.h"
namespace aria2 {
@ -52,37 +53,36 @@ const std::string DHTGetPeersMessage::GET_PEERS("get_peers");
const std::string DHTGetPeersMessage::INFO_HASH("info_hash");
DHTGetPeersMessage::DHTGetPeersMessage(const std::shared_ptr<DHTNode>& localNode,
const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash,
const std::string& transactionID):
DHTQueryMessage(localNode, remoteNode, transactionID),
peerAnnounceStorage_(0),
tokenTracker_(0)
DHTGetPeersMessage::DHTGetPeersMessage
(const std::shared_ptr<DHTNode>& localNode,
const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash,
const std::string& transactionID)
: DHTQueryMessage{localNode, remoteNode, transactionID},
peerAnnounceStorage_{nullptr},
tokenTracker_{nullptr}
{
memcpy(infoHash_, infoHash, DHT_ID_LENGTH);
}
DHTGetPeersMessage::~DHTGetPeersMessage() {}
void DHTGetPeersMessage::doReceivedAction()
{
std::string token = tokenTracker_->generateToken
(infoHash_, getRemoteNode()->getIPAddress(), getRemoteNode()->getPort());
// Check to see localhost has the contents which has same infohash
std::vector<std::shared_ptr<Peer> > peers;
std::vector<std::shared_ptr<Peer>> peers;
peerAnnounceStorage_->getPeers(peers, infoHash_);
std::vector<std::shared_ptr<DHTNode> > nodes;
std::vector<std::shared_ptr<DHTNode>> nodes;
getRoutingTable()->getClosestKNodes(nodes, infoHash_);
std::shared_ptr<DHTMessage> reply =
getMessageFactory()->createGetPeersReplyMessage
(getRemoteNode(), nodes, peers, token, getTransactionID());
getMessageDispatcher()->addMessageToQueue(reply);
getMessageDispatcher()->addMessageToQueue
(getMessageFactory()->createGetPeersReplyMessage
(getRemoteNode(), std::move(nodes), std::move(peers), token,
getTransactionID()));
}
std::shared_ptr<Dict> DHTGetPeersMessage::getArgument()
{
std::shared_ptr<Dict> aDict = Dict::g();
auto aDict = Dict::g();
aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH));
aDict->put(INFO_HASH, String::g(infoHash_, DHT_ID_LENGTH));
return aDict;

View File

@ -59,8 +59,6 @@ public:
const unsigned char* infoHash,
const std::string& transactionID = A2STR::NIL);
virtual ~DHTGetPeersMessage();
virtual void doReceivedAction();
virtual std::shared_ptr<Dict> getArgument();

View File

@ -64,12 +64,11 @@ DHTGetPeersReplyMessage::DHTGetPeersReplyMessage
const std::shared_ptr<DHTNode>& localNode,
const std::shared_ptr<DHTNode>& remoteNode,
const std::string& token,
const std::string& transactionID):
DHTResponseMessage(localNode, remoteNode, transactionID),
family_(family),
token_(token) {}
DHTGetPeersReplyMessage::~DHTGetPeersReplyMessage() {}
const std::string& transactionID)
: DHTResponseMessage{localNode, remoteNode, transactionID},
family_{family},
token_{token}
{}
void DHTGetPeersReplyMessage::doReceivedAction()
{
@ -78,7 +77,7 @@ void DHTGetPeersReplyMessage::doReceivedAction()
std::shared_ptr<Dict> DHTGetPeersReplyMessage::getResponse()
{
std::shared_ptr<Dict> rDict = Dict::g();
auto rDict = Dict::g();
rDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH));
rDict->put(TOKEN, token_);
// TODO want parameter
@ -88,14 +87,12 @@ std::shared_ptr<Dict> DHTGetPeersReplyMessage::getResponse()
const int unit = clen+20;
size_t offset = 0;
size_t k = 0;
for(std::vector<std::shared_ptr<DHTNode> >::const_iterator i =
closestKNodes_.begin(), eoi = closestKNodes_.end();
for(auto i = std::begin(closestKNodes_), eoi = std::end(closestKNodes_);
i != eoi && k < DHTBucket::K; ++i) {
std::shared_ptr<DHTNode> node = *i;
memcpy(buffer+offset, node->getID(), DHT_ID_LENGTH);
memcpy(buffer+offset, (*i)->getID(), DHT_ID_LENGTH);
unsigned char compact[COMPACT_LEN_IPV6];
int compactlen = bittorrent::packcompact
(compact, node->getIPAddress(), node->getPort());
(compact, (*i)->getIPAddress(), (*i)->getPort());
if(compactlen == clen) {
memcpy(buffer+20+offset, compact, compactlen);
offset += unit;
@ -128,15 +125,13 @@ std::shared_ptr<Dict> DHTGetPeersReplyMessage::getResponse()
// template may get bigger than 395 bytes. So we use 25 as maximum
// number of peer info that a message can carry.
static const size_t MAX_VALUES_SIZE = 25;
std::shared_ptr<List> valuesList = List::g();
for(std::vector<std::shared_ptr<Peer> >::const_iterator i = values_.begin(),
eoi = values_.end(); i != eoi && valuesList->size() < MAX_VALUES_SIZE;
++i) {
const std::shared_ptr<Peer>& peer = *i;
auto valuesList = List::g();
for(auto i = std::begin(values_), eoi = std::end(values_);
i != eoi && valuesList->size() < MAX_VALUES_SIZE; ++i) {
unsigned char compact[COMPACT_LEN_IPV6];
const int clen = bittorrent::getCompactLength(family_);
int compactlen = bittorrent::packcompact
(compact, peer->getIPAddress(), peer->getPort());
(compact, (*i)->getIPAddress(), (*i)->getPort());
if(compactlen == clen) {
valuesList->append(String::g(compact, compactlen));
}
@ -164,4 +159,16 @@ std::string DHTGetPeersReplyMessage::toStringOptional() const
static_cast<unsigned long>(closestKNodes_.size()));
}
void DHTGetPeersReplyMessage::setClosestKNodes
(std::vector<std::shared_ptr<DHTNode>> closestKNodes)
{
closestKNodes_ = std::move(closestKNodes);
}
void DHTGetPeersReplyMessage::setValues
(std::vector<std::shared_ptr<Peer>> peers)
{
values_ = std::move(peers);
}
} // namespace aria2

View File

@ -51,9 +51,9 @@ private:
std::string token_;
std::vector<std::shared_ptr<DHTNode> > closestKNodes_;
std::vector<std::shared_ptr<DHTNode>> closestKNodes_;
std::vector<std::shared_ptr<Peer> > values_;
std::vector<std::shared_ptr<Peer>> values_;
protected:
virtual std::string toStringOptional() const;
public:
@ -63,8 +63,6 @@ public:
const std::string& token,
const std::string& transactionID);
virtual ~DHTGetPeersReplyMessage();
virtual void doReceivedAction();
virtual std::shared_ptr<Dict> getResponse();
@ -73,26 +71,19 @@ public:
virtual void accept(DHTMessageCallback* callback);
const std::vector<std::shared_ptr<DHTNode> >& getClosestKNodes() const
const std::vector<std::shared_ptr<DHTNode>>& getClosestKNodes() const
{
return closestKNodes_;
}
const std::vector<std::shared_ptr<Peer> >& getValues() const
const std::vector<std::shared_ptr<Peer>>& getValues() const
{
return values_;
}
void setClosestKNodes
(const std::vector<std::shared_ptr<DHTNode> >& closestKNodes)
{
closestKNodes_ = closestKNodes;
}
void setClosestKNodes(std::vector<std::shared_ptr<DHTNode>> closestKNodes);
void setValues(const std::vector<std::shared_ptr<Peer> >& peers)
{
values_ = peers;
}
void setValues(std::vector<std::shared_ptr<Peer>> peers);
const std::string& getToken() const
{

View File

@ -37,6 +37,7 @@
#include "common.h"
#include "DHTNodeLookupEntry.h"
#include "DHTNode.h"
#include "DHTConstants.h"
#include "XORCloser.h"
@ -46,10 +47,12 @@ class DHTIDCloser {
private:
XORCloser closer_;
public:
DHTIDCloser(const unsigned char* targetID):closer_(targetID, DHT_ID_LENGTH) {}
DHTIDCloser(const unsigned char* targetID)
: closer_{targetID, DHT_ID_LENGTH}
{}
bool operator()(const std::shared_ptr<DHTNodeLookupEntry>& m1,
const std::shared_ptr<DHTNodeLookupEntry>& m2) const
bool operator()(const std::unique_ptr<DHTNodeLookupEntry>& m1,
const std::unique_ptr<DHTNodeLookupEntry>& m2) const
{
return closer_(m1->node->getID(), m2->node->getID());
}

View File

@ -48,8 +48,10 @@ const std::string DHTMessage::ID("id");
DHTMessage::DHTMessage(const std::shared_ptr<DHTNode>& localNode,
const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID):
localNode_(localNode), remoteNode_(remoteNode), transactionID_(transactionID)
const std::string& transactionID)
: localNode_{localNode},
remoteNode_{remoteNode},
transactionID_{transactionID}
{
if(transactionID.empty()) {
generateTransactionID();

View File

@ -53,7 +53,7 @@ class DHTMessageCallback {
public:
virtual ~DHTMessageCallback() {}
void onReceived(const std::shared_ptr<DHTResponseMessage>& message)
void onReceived(DHTResponseMessage* message)
{
message->accept(this);
}

View File

@ -51,15 +51,15 @@ public:
virtual ~DHTMessageDispatcher() {}
virtual void
addMessageToQueue(const std::shared_ptr<DHTMessage>& message,
addMessageToQueue(std::unique_ptr<DHTMessage> message,
time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback =
std::shared_ptr<DHTMessageCallback>()) = 0;
std::unique_ptr<DHTMessageCallback> callback =
std::unique_ptr<DHTMessageCallback>{}) = 0;
virtual void
addMessageToQueue(const std::shared_ptr<DHTMessage>& message,
const std::shared_ptr<DHTMessageCallback>& callback =
std::shared_ptr<DHTMessageCallback>()) = 0;
addMessageToQueue(std::unique_ptr<DHTMessage> message,
std::unique_ptr<DHTMessageCallback> callback =
std::unique_ptr<DHTMessageCallback>{}) = 0;
virtual void sendMessages() = 0;

View File

@ -43,44 +43,41 @@
#include "DHTConstants.h"
#include "fmt.h"
#include "DHTNode.h"
#include "a2functional.h"
namespace aria2 {
DHTMessageDispatcherImpl::DHTMessageDispatcherImpl
(const std::shared_ptr<DHTMessageTracker>& tracker)
: tracker_(tracker),
timeout_(DHT_MESSAGE_TIMEOUT)
: tracker_{tracker},
timeout_{DHT_MESSAGE_TIMEOUT}
{}
DHTMessageDispatcherImpl::~DHTMessageDispatcherImpl() {}
void
DHTMessageDispatcherImpl::addMessageToQueue
(const std::shared_ptr<DHTMessage>& message,
(std::unique_ptr<DHTMessage> message,
time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback)
std::unique_ptr<DHTMessageCallback> callback)
{
std::shared_ptr<DHTMessageEntry> e
(new DHTMessageEntry(message, timeout, callback));
messageQueue_.push_back(e);
messageQueue_.push_back(make_unique<DHTMessageEntry>
(std::move(message), timeout, std::move(callback)));
}
void
DHTMessageDispatcherImpl::addMessageToQueue
(const std::shared_ptr<DHTMessage>& message,
const std::shared_ptr<DHTMessageCallback>& callback)
(std::unique_ptr<DHTMessage> message,
std::unique_ptr<DHTMessageCallback> callback)
{
addMessageToQueue(message, timeout_, callback);
addMessageToQueue(std::move(message), timeout_, std::move(callback));
}
bool
DHTMessageDispatcherImpl::sendMessage
(const std::shared_ptr<DHTMessageEntry>& entry)
bool DHTMessageDispatcherImpl::sendMessage(DHTMessageEntry* entry)
{
try {
if(entry->message->send()) {
if(!entry->message->isReply()) {
tracker_->addMessage(entry->message, entry->timeout, entry->callback);
tracker_->addMessage(entry->message.get(), entry->timeout,
std::move(entry->callback));
}
A2_LOG_INFO(fmt("Message sent: %s", entry->message->toString().c_str()));
} else {
@ -95,7 +92,8 @@ DHTMessageDispatcherImpl::sendMessage
// DHTTask(such as DHTAbstractNodeLookupTask) don't finish
// forever.
if(!entry->message->isReply()) {
tracker_->addMessage(entry->message, 0, entry->callback);
tracker_->addMessage(entry->message.get(), 0,
std::move(entry->callback));
}
}
return true;
@ -103,13 +101,13 @@ DHTMessageDispatcherImpl::sendMessage
void DHTMessageDispatcherImpl::sendMessages()
{
auto itr = messageQueue_.begin();
for(; itr != messageQueue_.end(); ++itr) {
if(!sendMessage(*itr)) {
auto itr = std::begin(messageQueue_);
for(; itr != std::end(messageQueue_); ++itr) {
if(!sendMessage((*itr).get())) {
break;
}
}
messageQueue_.erase(messageQueue_.begin(), itr);
messageQueue_.erase(std::begin(messageQueue_), itr);
A2_LOG_DEBUG(fmt("%lu dht messages remaining in the queue.",
static_cast<unsigned long>(messageQueue_.size())));
}

View File

@ -47,26 +47,24 @@ class DHTMessageDispatcherImpl:public DHTMessageDispatcher {
private:
std::shared_ptr<DHTMessageTracker> tracker_;
std::deque<std::shared_ptr<DHTMessageEntry> > messageQueue_;
std::deque<std::unique_ptr<DHTMessageEntry>> messageQueue_;
time_t timeout_;
bool sendMessage(const std::shared_ptr<DHTMessageEntry>& msg);
bool sendMessage(DHTMessageEntry* msg);
public:
DHTMessageDispatcherImpl(const std::shared_ptr<DHTMessageTracker>& tracker);
virtual ~DHTMessageDispatcherImpl();
virtual void
addMessageToQueue(const std::shared_ptr<DHTMessage>& message,
addMessageToQueue(std::unique_ptr<DHTMessage> message,
time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback =
std::shared_ptr<DHTMessageCallback>());
std::unique_ptr<DHTMessageCallback> callback =
std::unique_ptr<DHTMessageCallback>{});
virtual void
addMessageToQueue(const std::shared_ptr<DHTMessage>& message,
const std::shared_ptr<DHTMessageCallback>& callback =
std::shared_ptr<DHTMessageCallback>());
addMessageToQueue(std::unique_ptr<DHTMessage> message,
std::unique_ptr<DHTMessageCallback> callback =
std::unique_ptr<DHTMessageCallback>{});
virtual void sendMessages();

View File

@ -40,13 +40,12 @@
namespace aria2 {
DHTMessageEntry::DHTMessageEntry
(const std::shared_ptr<DHTMessage>& message,
(std::unique_ptr<DHTMessage> message,
time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback):
message(message),
timeout(timeout),
callback(callback) {}
DHTMessageEntry::~DHTMessageEntry() {}
std::unique_ptr<DHTMessageCallback> callback)
: message{std::move(message)},
timeout{timeout},
callback{std::move(callback)}
{}
} // namespace aria2

View File

@ -47,15 +47,13 @@ class DHTMessage;
class DHTMessageCallback;
struct DHTMessageEntry {
std::shared_ptr<DHTMessage> message;
std::unique_ptr<DHTMessage> message;
time_t timeout;
std::shared_ptr<DHTMessageCallback> callback;
std::unique_ptr<DHTMessageCallback> callback;
DHTMessageEntry(const std::shared_ptr<DHTMessage>& message,
DHTMessageEntry(std::unique_ptr<DHTMessage> message,
time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback);
~DHTMessageEntry();
std::unique_ptr<DHTMessageCallback> callback);
};
} // namespace aria2

View File

@ -49,6 +49,15 @@ namespace aria2 {
class DHTMessage;
class DHTQueryMessage;
class DHTResponseMessage;
class DHTPingMessage;
class DHTPingReplyMessage;
class DHTFindNodeMessage;
class DHTFindNodeReplyMessage;
class DHTGetPeersMessage;
class DHTGetPeersReplyMessage;
class DHTAnnouncePeerMessage;
class DHTAnnouncePeerReplyMessage;
class DHTUnknownMessage;
class DHTNode;
class Peer;
@ -56,60 +65,60 @@ class DHTMessageFactory {
public:
virtual ~DHTMessageFactory() {}
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTQueryMessage>
createQueryMessage(const Dict* dict,
const std::string& ipaddr, uint16_t port) = 0;
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTResponseMessage>
createResponseMessage(const std::string& messageType,
const Dict* dict,
const std::string& ipaddr, uint16_t port) = 0;
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTPingMessage>
createPingMessage(const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID = A2STR::NIL) = 0;
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTPingReplyMessage>
createPingReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* id,
const std::string& transactionID) = 0;
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTFindNodeMessage>
createFindNodeMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* targetNodeID,
const std::string& transactionID = A2STR::NIL) = 0;
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTFindNodeReplyMessage>
createFindNodeReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes,
std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::string& transactionID) = 0;
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTGetPeersMessage>
createGetPeersMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash,
const std::string& transactionID = A2STR::NIL) = 0;
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTGetPeersReplyMessage>
createGetPeersReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes,
const std::vector<std::shared_ptr<Peer> >& peers,
std::vector<std::shared_ptr<DHTNode>> closestKNodes,
std::vector<std::shared_ptr<Peer>> peers,
const std::string& token,
const std::string& transactionID) = 0;
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTAnnouncePeerMessage>
createAnnouncePeerMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash,
uint16_t tcpPort,
const std::string& token,
const std::string& transactionID = A2STR::NIL) = 0;
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTAnnouncePeerReplyMessage>
createAnnouncePeerReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID) = 0;
virtual std::shared_ptr<DHTMessage>
virtual std::unique_ptr<DHTUnknownMessage>
createUnknownMessage(const unsigned char* data, size_t length,
const std::string& ipaddr, uint16_t port) = 0;
};

View File

@ -65,23 +65,21 @@
namespace aria2 {
DHTMessageFactoryImpl::DHTMessageFactoryImpl(int family)
: family_(family),
connection_(0),
dispatcher_(0),
routingTable_(0),
peerAnnounceStorage_(0),
tokenTracker_(0)
: family_{family},
connection_{nullptr},
dispatcher_{nullptr},
routingTable_{nullptr},
peerAnnounceStorage_{nullptr},
tokenTracker_{nullptr}
{}
DHTMessageFactoryImpl::~DHTMessageFactoryImpl() {}
std::shared_ptr<DHTNode>
DHTMessageFactoryImpl::getRemoteNode
(const unsigned char* id, const std::string& ipaddr, uint16_t port) const
{
std::shared_ptr<DHTNode> node = routingTable_->getNode(id, ipaddr, port);
auto node = routingTable_->getNode(id, ipaddr, port);
if(!node) {
node.reset(new DHTNode(id));
node = std::make_shared<DHTNode>(id);
node->setIPAddress(ipaddr);
node->setPort(port);
}
@ -188,7 +186,7 @@ void DHTMessageFactoryImpl::validatePort(const Integer* port) const
}
namespace {
void setVersion(const std::shared_ptr<DHTMessage>& msg, const Dict* dict)
void setVersion(DHTMessage* msg, const Dict* dict)
{
const String* v = downcast<String>(dict->get(DHTMessage::V));
if(v) {
@ -199,7 +197,7 @@ void setVersion(const std::shared_ptr<DHTMessage>& msg, const Dict* dict)
}
} // namespace
std::shared_ptr<DHTQueryMessage> DHTMessageFactoryImpl::createQueryMessage
std::unique_ptr<DHTQueryMessage> DHTMessageFactoryImpl::createQueryMessage
(const Dict* dict, const std::string& ipaddr, uint16_t port)
{
const String* messageType = getString(dict, DHTQueryMessage::Q);
@ -211,8 +209,8 @@ std::shared_ptr<DHTQueryMessage> DHTMessageFactoryImpl::createQueryMessage
}
const String* id = getString(aDict, DHTMessage::ID);
validateID(id);
std::shared_ptr<DHTNode> remoteNode = getRemoteNode(id->uc(), ipaddr, port);
std::shared_ptr<DHTQueryMessage> msg;
auto remoteNode = getRemoteNode(id->uc(), ipaddr, port);
auto msg = std::unique_ptr<DHTQueryMessage>{};
if(messageType->s() == DHTPingMessage::PING) {
msg = createPingMessage(remoteNode, transactionID->s());
} else if(messageType->s() == DHTFindNodeMessage::FIND_NODE) {
@ -238,11 +236,11 @@ std::shared_ptr<DHTQueryMessage> DHTMessageFactoryImpl::createQueryMessage
throw DL_ABORT_EX(fmt("Unsupported message type: %s",
messageType->s().c_str()));
}
setVersion(msg, dict);
setVersion(msg.get(), dict);
return msg;
}
std::shared_ptr<DHTResponseMessage>
std::unique_ptr<DHTResponseMessage>
DHTMessageFactoryImpl::createResponseMessage
(const std::string& messageType,
const Dict* dict,
@ -270,8 +268,8 @@ DHTMessageFactoryImpl::createResponseMessage
const Dict* rDict = getDictionary(dict, DHTResponseMessage::R);
const String* id = getString(rDict, DHTMessage::ID);
validateID(id);
std::shared_ptr<DHTNode> remoteNode = getRemoteNode(id->uc(), ipaddr, port);
std::shared_ptr<DHTResponseMessage> msg;
auto remoteNode = getRemoteNode(id->uc(), ipaddr, port);
auto msg = std::unique_ptr<DHTResponseMessage>{};
if(messageType == DHTPingReplyMessage::PING) {
msg = createPingReplyMessage(remoteNode, id->uc(), transactionID->s());
} else if(messageType == DHTFindNodeReplyMessage::FIND_NODE) {
@ -284,7 +282,7 @@ DHTMessageFactoryImpl::createResponseMessage
throw DL_ABORT_EX
(fmt("Unsupported message type: %s", messageType.c_str()));
}
setVersion(msg, dict);
setVersion(msg.get(), dict);
return msg;
}
@ -312,51 +310,53 @@ void DHTMessageFactoryImpl::setCommonProperty(DHTAbstractMessage* m)
m->setVersion(getDefaultVersion());
}
std::shared_ptr<DHTQueryMessage> DHTMessageFactoryImpl::createPingMessage
std::unique_ptr<DHTPingMessage> DHTMessageFactoryImpl::createPingMessage
(const std::shared_ptr<DHTNode>& remoteNode, const std::string& transactionID)
{
DHTPingMessage* m(new DHTPingMessage(localNode_, remoteNode, transactionID));
setCommonProperty(m);
return std::shared_ptr<DHTQueryMessage>(m);
auto m = make_unique<DHTPingMessage>(localNode_, remoteNode, transactionID);
setCommonProperty(m.get());
return m;
}
std::shared_ptr<DHTResponseMessage> DHTMessageFactoryImpl::createPingReplyMessage
std::unique_ptr<DHTPingReplyMessage>
DHTMessageFactoryImpl::createPingReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* id,
const std::string& transactionID)
{
DHTPingReplyMessage* m
(new DHTPingReplyMessage(localNode_, remoteNode, id, transactionID));
setCommonProperty(m);
return std::shared_ptr<DHTResponseMessage>(m);
auto m = make_unique<DHTPingReplyMessage>(localNode_, remoteNode, id,
transactionID);
setCommonProperty(m.get());
return m;
}
std::shared_ptr<DHTQueryMessage> DHTMessageFactoryImpl::createFindNodeMessage
std::unique_ptr<DHTFindNodeMessage>
DHTMessageFactoryImpl::createFindNodeMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* targetNodeID,
const std::string& transactionID)
{
DHTFindNodeMessage* m(new DHTFindNodeMessage
(localNode_, remoteNode, targetNodeID, transactionID));
setCommonProperty(m);
return std::shared_ptr<DHTQueryMessage>(m);
auto m = make_unique<DHTFindNodeMessage>(localNode_, remoteNode,
targetNodeID, transactionID);
setCommonProperty(m.get());
return m;
}
std::shared_ptr<DHTResponseMessage>
std::unique_ptr<DHTFindNodeReplyMessage>
DHTMessageFactoryImpl::createFindNodeReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes,
std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::string& transactionID)
{
DHTFindNodeReplyMessage* m(new DHTFindNodeReplyMessage
(family_, localNode_, remoteNode, transactionID));
m->setClosestKNodes(closestKNodes);
setCommonProperty(m);
return std::shared_ptr<DHTResponseMessage>(m);
auto m = make_unique<DHTFindNodeReplyMessage>(family_, localNode_,
remoteNode, transactionID);
m->setClosestKNodes(std::move(closestKNodes));
setCommonProperty(m.get());
return m;
}
void DHTMessageFactoryImpl::extractNodes
(std::vector<std::shared_ptr<DHTNode> >& nodes,
(std::vector<std::shared_ptr<DHTNode>>& nodes,
const unsigned char* src, size_t length)
{
int unit = bittorrent::getCompactLength(family_)+20;
@ -365,19 +365,18 @@ void DHTMessageFactoryImpl::extractNodes
(fmt("Nodes length is not multiple of %d", unit));
}
for(size_t offset = 0; offset < length; offset += unit) {
std::shared_ptr<DHTNode> node(new DHTNode(src+offset));
std::pair<std::string, uint16_t> addr =
bittorrent::unpackcompact(src+offset+DHT_ID_LENGTH, family_);
auto node = std::make_shared<DHTNode>(src+offset);
auto addr = bittorrent::unpackcompact(src+offset+DHT_ID_LENGTH, family_);
if(addr.first.empty()) {
continue;
}
node->setIPAddress(addr.first);
node->setPort(addr.second);
nodes.push_back(node);
nodes.push_back(std::move(node));
}
}
std::shared_ptr<DHTResponseMessage>
std::unique_ptr<DHTFindNodeReplyMessage>
DHTMessageFactoryImpl::createFindNodeReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const Dict* dict,
@ -387,28 +386,29 @@ DHTMessageFactoryImpl::createFindNodeReplyMessage
downcast<String>(getDictionary(dict, DHTResponseMessage::R)->
get(family_ == AF_INET?DHTFindNodeReplyMessage::NODES:
DHTFindNodeReplyMessage::NODES6));
std::vector<std::shared_ptr<DHTNode> > nodes;
std::vector<std::shared_ptr<DHTNode>> nodes;
if(nodesData) {
extractNodes(nodes, nodesData->uc(), nodesData->s().size());
}
return createFindNodeReplyMessage(remoteNode, nodes, transactionID);
return createFindNodeReplyMessage(remoteNode, std::move(nodes),
transactionID);
}
std::shared_ptr<DHTQueryMessage>
std::unique_ptr<DHTGetPeersMessage>
DHTMessageFactoryImpl::createGetPeersMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash,
const std::string& transactionID)
{
DHTGetPeersMessage* m
(new DHTGetPeersMessage(localNode_, remoteNode, infoHash, transactionID));
auto m = make_unique<DHTGetPeersMessage>(localNode_, remoteNode, infoHash,
transactionID);
m->setPeerAnnounceStorage(peerAnnounceStorage_);
m->setTokenTracker(tokenTracker_);
setCommonProperty(m);
return std::shared_ptr<DHTQueryMessage>(m);
setCommonProperty(m.get());
return m;
}
std::shared_ptr<DHTResponseMessage>
std::unique_ptr<DHTGetPeersReplyMessage>
DHTMessageFactoryImpl::createGetPeersReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const Dict* dict,
@ -416,54 +416,53 @@ DHTMessageFactoryImpl::createGetPeersReplyMessage
{
const Dict* rDict = getDictionary(dict, DHTResponseMessage::R);
const String* nodesData =
downcast<String>(rDict->get(family_ == AF_INET?DHTGetPeersReplyMessage::NODES:
DHTGetPeersReplyMessage::NODES6));
std::vector<std::shared_ptr<DHTNode> > nodes;
downcast<String>(rDict->get(family_ == AF_INET ?
DHTGetPeersReplyMessage::NODES :
DHTGetPeersReplyMessage::NODES6));
std::vector<std::shared_ptr<DHTNode>> nodes;
if(nodesData) {
extractNodes(nodes, nodesData->uc(), nodesData->s().size());
}
const List* valuesList =
downcast<List>(rDict->get(DHTGetPeersReplyMessage::VALUES));
std::vector<std::shared_ptr<Peer> > peers;
std::vector<std::shared_ptr<Peer>> peers;
size_t clen = bittorrent::getCompactLength(family_);
if(valuesList) {
for(List::ValueType::const_iterator i = valuesList->begin(),
eoi = valuesList->end(); i != eoi; ++i) {
for(auto i = valuesList->begin(), eoi = valuesList->end(); i != eoi; ++i) {
const String* data = downcast<String>(*i);
if(data && data->s().size() == clen) {
std::pair<std::string, uint16_t> addr =
bittorrent::unpackcompact(data->uc(), family_);
auto addr = bittorrent::unpackcompact(data->uc(), family_);
if(addr.first.empty()) {
continue;
}
std::shared_ptr<Peer> peer(new Peer(addr.first, addr.second));
peers.push_back(peer);
peers.push_back(std::make_shared<Peer>(addr.first, addr.second));
}
}
}
const String* token = getString(rDict, DHTGetPeersReplyMessage::TOKEN);
return createGetPeersReplyMessage
(remoteNode, nodes, peers, token->s(), transactionID);
(remoteNode, std::move(nodes), std::move(peers), token->s(),
transactionID);
}
std::shared_ptr<DHTResponseMessage>
std::unique_ptr<DHTGetPeersReplyMessage>
DHTMessageFactoryImpl::createGetPeersReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes,
const std::vector<std::shared_ptr<Peer> >& values,
std::vector<std::shared_ptr<DHTNode>> closestKNodes,
std::vector<std::shared_ptr<Peer>> values,
const std::string& token,
const std::string& transactionID)
{
DHTGetPeersReplyMessage* m(new DHTGetPeersReplyMessage
(family_, localNode_, remoteNode, token,
transactionID));
m->setClosestKNodes(closestKNodes);
m->setValues(values);
setCommonProperty(m);
return std::shared_ptr<DHTResponseMessage>(m);
auto m = make_unique<DHTGetPeersReplyMessage>(family_, localNode_,
remoteNode, token,
transactionID);
m->setClosestKNodes(std::move(closestKNodes));
m->setValues(std::move(values));
setCommonProperty(m.get());
return m;
}
std::shared_ptr<DHTQueryMessage>
std::unique_ptr<DHTAnnouncePeerMessage>
DHTMessageFactoryImpl::createAnnouncePeerMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash,
@ -471,34 +470,33 @@ DHTMessageFactoryImpl::createAnnouncePeerMessage
const std::string& token,
const std::string& transactionID)
{
DHTAnnouncePeerMessage* m(new DHTAnnouncePeerMessage
(localNode_, remoteNode, infoHash, tcpPort, token,
transactionID));
auto m = make_unique<DHTAnnouncePeerMessage>(localNode_, remoteNode,
infoHash, tcpPort, token,
transactionID);
m->setPeerAnnounceStorage(peerAnnounceStorage_);
m->setTokenTracker(tokenTracker_);
setCommonProperty(m);
return std::shared_ptr<DHTQueryMessage>(m);
setCommonProperty(m.get());
return m;
}
std::shared_ptr<DHTResponseMessage>
std::unique_ptr<DHTAnnouncePeerReplyMessage>
DHTMessageFactoryImpl::createAnnouncePeerReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode, const std::string& transactionID)
{
DHTAnnouncePeerReplyMessage* m
(new DHTAnnouncePeerReplyMessage(localNode_, remoteNode, transactionID));
setCommonProperty(m);
return std::shared_ptr<DHTResponseMessage>(m);
auto m = make_unique<DHTAnnouncePeerReplyMessage>(localNode_, remoteNode,
transactionID);
setCommonProperty(m.get());
return m;
}
std::shared_ptr<DHTMessage>
std::unique_ptr<DHTUnknownMessage>
DHTMessageFactoryImpl::createUnknownMessage
(const unsigned char* data, size_t length,
const std::string& ipaddr, uint16_t port)
{
DHTUnknownMessage* m
(new DHTUnknownMessage(localNode_, data, length, ipaddr, port));
return std::shared_ptr<DHTMessage>(m);
return make_unique<DHTUnknownMessage>(localNode_, data, length,
ipaddr, port);
}
void DHTMessageFactoryImpl::setRoutingTable(DHTRoutingTable* routingTable)

View File

@ -81,74 +81,72 @@ private:
public:
DHTMessageFactoryImpl(int family);
virtual ~DHTMessageFactoryImpl();
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTQueryMessage>
createQueryMessage(const Dict* dict,
const std::string& ipaddr, uint16_t port);
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTResponseMessage>
createResponseMessage(const std::string& messageType,
const Dict* dict,
const std::string& ipaddr, uint16_t port);
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTPingMessage>
createPingMessage(const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID = A2STR::NIL);
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTPingReplyMessage>
createPingReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* id,
const std::string& transactionID);
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTFindNodeMessage>
createFindNodeMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* targetNodeID,
const std::string& transactionID = A2STR::NIL);
std::shared_ptr<DHTResponseMessage>
std::unique_ptr<DHTFindNodeReplyMessage>
createFindNodeReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const Dict* dict,
const std::string& transactionID);
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTFindNodeReplyMessage>
createFindNodeReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes,
std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::string& transactionID);
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTGetPeersMessage>
createGetPeersMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash,
const std::string& transactionID = A2STR::NIL);
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTGetPeersReplyMessage>
createGetPeersReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes,
const std::vector<std::shared_ptr<Peer> >& peers,
std::vector<std::shared_ptr<DHTNode>> closestKNodes,
std::vector<std::shared_ptr<Peer>> peers,
const std::string& token,
const std::string& transactionID);
std::shared_ptr<DHTResponseMessage>
std::unique_ptr<DHTGetPeersReplyMessage>
createGetPeersReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const Dict* dict,
const std::string& transactionID);
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTAnnouncePeerMessage>
createAnnouncePeerMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash,
uint16_t tcpPort,
const std::string& token,
const std::string& transactionID = A2STR::NIL);
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTAnnouncePeerReplyMessage>
createAnnouncePeerReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID);
virtual std::shared_ptr<DHTMessage>
virtual std::unique_ptr<DHTUnknownMessage>
createUnknownMessage(const unsigned char* data, size_t length,
const std::string& ipaddr, uint16_t port);

View File

@ -58,18 +58,16 @@ namespace aria2 {
DHTMessageReceiver::DHTMessageReceiver
(const std::shared_ptr<DHTMessageTracker>& tracker)
: tracker_(tracker)
: tracker_{tracker}
{}
DHTMessageReceiver::~DHTMessageReceiver() {}
std::shared_ptr<DHTMessage> DHTMessageReceiver::receiveMessage
std::unique_ptr<DHTMessage> DHTMessageReceiver::receiveMessage
(const std::string& remoteAddr, uint16_t remotePort, unsigned char *data,
size_t length)
{
try {
bool isReply = false;
std::shared_ptr<ValueBase> decoded = bencode2::decode(data, length);
auto decoded = bencode2::decode(data, length);
const Dict* dict = downcast<Dict>(decoded);
if(dict) {
const String* y = downcast<String>(dict->get(DHTMessage::Y));
@ -89,28 +87,26 @@ std::shared_ptr<DHTMessage> DHTMessageReceiver::receiveMessage
return handleUnknownMessage(data, length, remoteAddr, remotePort);
}
if(isReply) {
std::pair<std::shared_ptr<DHTResponseMessage>,
std::shared_ptr<DHTMessageCallback> > p =
tracker_->messageArrived(dict, remoteAddr, remotePort);
auto p = tracker_->messageArrived(dict, remoteAddr, remotePort);
if(!p.first) {
// timeout or malicious? message
return handleUnknownMessage(data, length, remoteAddr, remotePort);
}
onMessageReceived(p.first);
onMessageReceived(p.first.get());
if(p.second) {
p.second->onReceived(p.first);
p.second->onReceived(p.first.get());
}
return p.first;
return std::move(p.first);
} else {
std::shared_ptr<DHTQueryMessage> message =
auto message =
factory_->createQueryMessage(dict, remoteAddr, remotePort);
if(*message->getLocalNode() == *message->getRemoteNode()) {
// drop message from localnode
A2_LOG_INFO("Received DHT message from localnode.");
return handleUnknownMessage(data, length, remoteAddr, remotePort);
}
onMessageReceived(message);
return message;
onMessageReceived(message.get());
return std::move(message);
}
} catch(RecoverableException& e) {
A2_LOG_INFO_EX("Exception thrown while receiving DHT message.", e);
@ -118,8 +114,7 @@ std::shared_ptr<DHTMessage> DHTMessageReceiver::receiveMessage
}
}
void DHTMessageReceiver::onMessageReceived
(const std::shared_ptr<DHTMessage>& message)
void DHTMessageReceiver::onMessageReceived(DHTMessage* message)
{
A2_LOG_INFO(fmt("Message received: %s", message->toString().c_str()));
message->validate();
@ -134,13 +129,13 @@ void DHTMessageReceiver::handleTimeout()
tracker_->handleTimeout();
}
std::shared_ptr<DHTMessage>
std::unique_ptr<DHTUnknownMessage>
DHTMessageReceiver::handleUnknownMessage(const unsigned char* data,
size_t length,
const std::string& remoteAddr,
uint16_t remotePort)
{
std::shared_ptr<DHTMessage> m =
auto m =
factory_->createUnknownMessage(data, length, remoteAddr, remotePort);
A2_LOG_INFO(fmt("Message received: %s", m->toString().c_str()));
return m;

View File

@ -47,6 +47,7 @@ class DHTMessage;
class DHTConnection;
class DHTMessageFactory;
class DHTRoutingTable;
class DHTUnknownMessage;
class DHTMessageReceiver {
private:
@ -58,17 +59,15 @@ private:
std::shared_ptr<DHTRoutingTable> routingTable_;
std::shared_ptr<DHTMessage>
std::unique_ptr<DHTUnknownMessage>
handleUnknownMessage(const unsigned char* data, size_t length,
const std::string& remoteAddr, uint16_t remotePort);
void onMessageReceived(const std::shared_ptr<DHTMessage>& message);
void onMessageReceived(DHTMessage* message);
public:
DHTMessageReceiver(const std::shared_ptr<DHTMessageTracker>& tracker);
~DHTMessageReceiver();
std::shared_ptr<DHTMessage> receiveMessage
std::unique_ptr<DHTMessage> receiveMessage
(const std::string& remoteAddr, uint16_t remotePort, unsigned char *data,
size_t length);

View File

@ -51,17 +51,24 @@
namespace aria2 {
DHTMessageTracker::DHTMessageTracker() {}
DHTMessageTracker::DHTMessageTracker()
: factory_{nullptr}
{}
DHTMessageTracker::~DHTMessageTracker() {}
void DHTMessageTracker::addMessage(const std::shared_ptr<DHTMessage>& message, time_t timeout, const std::shared_ptr<DHTMessageCallback>& callback)
void DHTMessageTracker::addMessage
(DHTMessage* message,
time_t timeout,
std::unique_ptr<DHTMessageCallback> callback)
{
std::shared_ptr<DHTMessageTrackerEntry> e(new DHTMessageTrackerEntry(message, timeout, callback));
entries_.push_back(e);
entries_.push_back(make_unique<DHTMessageTrackerEntry>
(message->getRemoteNode(),
message->getTransactionID(),
message->getMessageType(),
timeout, std::move(callback)));
}
std::pair<std::shared_ptr<DHTResponseMessage>, std::shared_ptr<DHTMessageCallback> >
std::pair<std::unique_ptr<DHTResponseMessage>,
std::unique_ptr<DHTMessageCallback>>
DHTMessageTracker::messageArrived
(const Dict* dict, const std::string& ipaddr, uint16_t port)
{
@ -73,15 +80,14 @@ DHTMessageTracker::messageArrived
A2_LOG_DEBUG(fmt("Searching tracker entry for TransactionID=%s, Remote=%s:%u",
util::toHex(tid->s()).c_str(),
ipaddr.c_str(), port));
for(std::deque<std::shared_ptr<DHTMessageTrackerEntry> >::iterator i =
entries_.begin(), eoi = entries_.end(); i != eoi; ++i) {
for(auto i = std::begin(entries_), eoi = std::end(entries_); i != eoi; ++i) {
if((*i)->match(tid->s(), ipaddr, port)) {
std::shared_ptr<DHTMessageTrackerEntry> entry = *i;
auto entry = std::move(*i);
entries_.erase(i);
A2_LOG_DEBUG("Tracker entry found.");
std::shared_ptr<DHTNode> targetNode = entry->getTargetNode();
auto& targetNode = entry->getTargetNode();
try {
std::shared_ptr<DHTResponseMessage> message =
auto message =
factory_->createResponseMessage(entry->getMessageType(), dict,
targetNode->getIPAddress(),
targetNode->getPort());
@ -89,8 +95,7 @@ DHTMessageTracker::messageArrived
int64_t rtt = entry->getElapsedMillis();
A2_LOG_DEBUG(fmt("RTT is %" PRId64 "", rtt));
message->getRemoteNode()->updateRTT(rtt);
std::shared_ptr<DHTMessageCallback> callback = entry->getCallback();
if(!(*targetNode == *message->getRemoteNode())) {
if(*targetNode != *message->getRemoteNode()) {
// Node ID has changed. Drop previous node ID from
// DHTRoutingTable
A2_LOG_DEBUG
@ -100,23 +105,22 @@ DHTMessageTracker::messageArrived
DHT_ID_LENGTH).c_str()));
routingTable_->dropNode(targetNode);
}
return std::make_pair(message, callback);
return std::make_pair(std::move(message), entry->popCallback());
} catch(RecoverableException& e) {
handleTimeoutEntry(entry);
handleTimeoutEntry(entry.get());
throw;
}
}
}
A2_LOG_DEBUG("Tracker entry not found.");
return std::pair<std::shared_ptr<DHTResponseMessage>,
std::shared_ptr<DHTMessageCallback> >();
return std::pair<std::unique_ptr<DHTResponseMessage>,
std::unique_ptr<DHTMessageCallback>>{};
}
void DHTMessageTracker::handleTimeoutEntry
(const std::shared_ptr<DHTMessageTrackerEntry>& entry)
void DHTMessageTracker::handleTimeoutEntry(DHTMessageTrackerEntry* entry)
{
try {
std::shared_ptr<DHTNode> node = entry->getTargetNode();
auto& node = entry->getTargetNode();
A2_LOG_DEBUG(fmt("Message timeout: To:%s:%u",
node->getIPAddress().c_str(), node->getPort()));
node->updateRTT(entry->getElapsedMillis());
@ -126,7 +130,7 @@ void DHTMessageTracker::handleTimeoutEntry
node->getIPAddress().c_str(), node->getPort()));
routingTable_->dropNode(node);
}
std::shared_ptr<DHTMessageCallback> callback = entry->getCallback();
auto& callback = entry->getCallback();
if(callback) {
callback->onTimeout(node);
}
@ -135,43 +139,33 @@ void DHTMessageTracker::handleTimeoutEntry
}
}
namespace {
struct HandleTimeout {
HandleTimeout(DHTMessageTracker* tracker)
: tracker(tracker)
{}
bool operator()(const std::shared_ptr<DHTMessageTrackerEntry>& ent) const
{
if(ent->isTimeout()) {
tracker->handleTimeoutEntry(ent);
return true;
} else {
return false;
}
}
DHTMessageTracker* tracker;
};
} // namespace
void DHTMessageTracker::handleTimeout()
{
entries_.erase(std::remove_if(entries_.begin(), entries_.end(),
HandleTimeout(this)),
entries_.end());
entries_.erase
(std::remove_if(std::begin(entries_), std::end(entries_),
[&](const std::unique_ptr<DHTMessageTrackerEntry>& ent)
{
if(ent->isTimeout()) {
handleTimeoutEntry(ent.get());
return true;
} else {
return false;
}
}),
std::end(entries_));
}
std::shared_ptr<DHTMessageTrackerEntry>
DHTMessageTracker::getEntryFor(const std::shared_ptr<DHTMessage>& message) const
const DHTMessageTrackerEntry*
DHTMessageTracker::getEntryFor(const DHTMessage* message) const
{
for(std::deque<std::shared_ptr<DHTMessageTrackerEntry> >::const_iterator i =
entries_.begin(), eoi = entries_.end(); i != eoi; ++i) {
if((*i)->match(message->getTransactionID(),
message->getRemoteNode()->getIPAddress(),
message->getRemoteNode()->getPort())) {
return *i;
for(auto& ent : entries_) {
if(ent->match(message->getTransactionID(),
message->getRemoteNode()->getIPAddress(),
message->getRemoteNode()->getPort())) {
return ent.get();
}
}
return std::shared_ptr<DHTMessageTrackerEntry>();
return nullptr;
}
size_t DHTMessageTracker::countEntry() const
@ -185,8 +179,7 @@ void DHTMessageTracker::setRoutingTable
routingTable_ = routingTable;
}
void DHTMessageTracker::setMessageFactory
(const std::shared_ptr<DHTMessageFactory>& factory)
void DHTMessageTracker::setMessageFactory(DHTMessageFactory* factory)
{
factory_ = factory;
}

View File

@ -55,38 +55,37 @@ class DHTMessageTrackerEntry;
class DHTMessageTracker {
private:
std::deque<std::shared_ptr<DHTMessageTrackerEntry> > entries_;
std::deque<std::unique_ptr<DHTMessageTrackerEntry>> entries_;
std::shared_ptr<DHTRoutingTable> routingTable_;
std::shared_ptr<DHTMessageFactory> factory_;
DHTMessageFactory* factory_;
public:
DHTMessageTracker();
~DHTMessageTracker();
void addMessage(const std::shared_ptr<DHTMessage>& message,
void addMessage(DHTMessage* message,
time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback =
std::shared_ptr<DHTMessageCallback>());
std::unique_ptr<DHTMessageCallback> callback =
std::unique_ptr<DHTMessageCallback>{});
std::pair<std::shared_ptr<DHTResponseMessage>, std::shared_ptr<DHTMessageCallback> >
std::pair<std::unique_ptr<DHTResponseMessage>,
std::unique_ptr<DHTMessageCallback>>
messageArrived(const Dict* dict,
const std::string& ipaddr, uint16_t port);
void handleTimeout();
// Made public so that unnamed functor can access this
void handleTimeoutEntry(const std::shared_ptr<DHTMessageTrackerEntry>& entry);
void handleTimeoutEntry(DHTMessageTrackerEntry* entry);
std::shared_ptr<DHTMessageTrackerEntry> getEntryFor
(const std::shared_ptr<DHTMessage>& message) const;
// // For unittest only
const DHTMessageTrackerEntry* getEntryFor(const DHTMessage* message) const;
size_t countEntry() const;
void setRoutingTable(const std::shared_ptr<DHTRoutingTable>& routingTable);
void setMessageFactory(const std::shared_ptr<DHTMessageFactory>& factory);
void setMessageFactory(DHTMessageFactory* factory);
};
} // namespace aria2

View File

@ -42,19 +42,20 @@
namespace aria2 {
DHTMessageTrackerEntry::DHTMessageTrackerEntry(const std::shared_ptr<DHTMessage>& sentMessage,
time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback):
targetNode_(sentMessage->getRemoteNode()),
transactionID_(sentMessage->getTransactionID()),
messageType_(sentMessage->getMessageType()),
callback_(callback),
dispatchedTime_(global::wallclock()),
timeout_(timeout)
DHTMessageTrackerEntry::DHTMessageTrackerEntry
(std::shared_ptr<DHTNode> targetNode,
std::string transactionID,
std::string messageType,
time_t timeout,
std::unique_ptr<DHTMessageCallback> callback)
: targetNode_{std::move(targetNode)},
transactionID_{std::move(transactionID)},
messageType_{std::move(messageType)},
callback_{std::move(callback)},
dispatchedTime_{global::wallclock()},
timeout_{timeout}
{}
DHTMessageTrackerEntry::~DHTMessageTrackerEntry() {}
bool DHTMessageTrackerEntry::isTimeout() const
{
return dispatchedTime_.difference(global::wallclock()) >= timeout_;
@ -84,4 +85,25 @@ int64_t DHTMessageTrackerEntry::getElapsedMillis() const
return dispatchedTime_.differenceInMillis(global::wallclock());
}
const std::shared_ptr<DHTNode>& DHTMessageTrackerEntry::getTargetNode() const
{
return targetNode_;
}
const std::string& DHTMessageTrackerEntry::getMessageType() const
{
return messageType_;
}
const std::unique_ptr<DHTMessageCallback>&
DHTMessageTrackerEntry::getCallback() const
{
return callback_;
}
std::unique_ptr<DHTMessageCallback> DHTMessageTrackerEntry::popCallback()
{
return std::move(callback_);
}
} // namespace aria2

View File

@ -57,18 +57,18 @@ private:
std::string messageType_;
std::shared_ptr<DHTMessageCallback> callback_;
std::unique_ptr<DHTMessageCallback> callback_;
Timer dispatchedTime_;
time_t timeout_;
public:
DHTMessageTrackerEntry(const std::shared_ptr<DHTMessage>& sentMessage,
DHTMessageTrackerEntry(std::shared_ptr<DHTNode> targetNode,
std::string transactionID,
std::string messageType,
time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback =
std::shared_ptr<DHTMessageCallback>());
~DHTMessageTrackerEntry();
std::unique_ptr<DHTMessageCallback> callback =
std::unique_ptr<DHTMessageCallback>{});
bool isTimeout() const;
@ -76,21 +76,10 @@ public:
bool match(const std::string& transactionID, const std::string& ipaddr, uint16_t port) const;
const std::shared_ptr<DHTNode>& getTargetNode() const
{
return targetNode_;
}
const std::string& getMessageType() const
{
return messageType_;
}
const std::shared_ptr<DHTMessageCallback>& getCallback() const
{
return callback_;
}
const std::shared_ptr<DHTNode>& getTargetNode() const;
const std::string& getMessageType() const;
const std::unique_ptr<DHTMessageCallback>& getCallback() const;
std::unique_ptr<DHTMessageCallback> popCallback();
int64_t getElapsedMillis() const;
};

View File

@ -65,6 +65,11 @@ bool DHTNode::operator==(const DHTNode& node) const
return memcmp(id_, node.id_, DHT_ID_LENGTH) == 0;
}
bool DHTNode::operator!=(const DHTNode& node) const
{
return !(*this == node);
}
bool DHTNode::operator<(const DHTNode& node) const
{
for(size_t i = 0; i < DHT_ID_LENGTH; ++i) {

View File

@ -115,6 +115,8 @@ public:
bool operator==(const DHTNode& node) const;
bool operator!=(const DHTNode& node) const;
bool operator<(const DHTNode& node) const;
std::string toString() const;

View File

@ -41,6 +41,7 @@
#include "util.h"
#include "DHTNodeLookupTaskCallback.h"
#include "DHTQueryMessage.h"
#include "DHTFindNodeMessage.h"
namespace aria2 {
@ -53,21 +54,19 @@ DHTNodeLookupTask::getNodesFromMessage
(std::vector<std::shared_ptr<DHTNode> >& nodes,
const DHTFindNodeReplyMessage* message)
{
const std::vector<std::shared_ptr<DHTNode> >& knodes =
message->getClosestKNodes();
nodes.insert(nodes.end(), knodes.begin(), knodes.end());
auto& knodes = message->getClosestKNodes();
nodes.insert(std::end(nodes), std::begin(knodes), std::end(knodes));
}
std::shared_ptr<DHTMessage>
std::unique_ptr<DHTMessage>
DHTNodeLookupTask::createMessage(const std::shared_ptr<DHTNode>& remoteNode)
{
return getMessageFactory()->createFindNodeMessage(remoteNode, getTargetID());
}
std::shared_ptr<DHTMessageCallback> DHTNodeLookupTask::createCallback()
std::unique_ptr<DHTMessageCallback> DHTNodeLookupTask::createCallback()
{
return std::shared_ptr<DHTNodeLookupTaskCallback>
(new DHTNodeLookupTaskCallback(this));
return make_unique<DHTNodeLookupTaskCallback>(this);
}
} // namespace aria2

View File

@ -50,10 +50,10 @@ public:
(std::vector<std::shared_ptr<DHTNode> >& nodes,
const DHTFindNodeReplyMessage* message);
virtual std::shared_ptr<DHTMessage> createMessage
virtual std::unique_ptr<DHTMessage> createMessage
(const std::shared_ptr<DHTNode>& remoteNode);
virtual std::shared_ptr<DHTMessageCallback> createCallback();
virtual std::unique_ptr<DHTMessageCallback> createCallback();
};
} // namespace aria2

View File

@ -48,6 +48,9 @@
#include "bittorrent_helper.h"
#include "DHTPeerLookupTaskCallback.h"
#include "DHTQueryMessage.h"
#include "DHTGetPeersMessage.h"
#include "DHTAnnouncePeerMessage.h"
#include "fmt.h"
namespace aria2 {
@ -62,12 +65,11 @@ DHTPeerLookupTask::DHTPeerLookupTask
void
DHTPeerLookupTask::getNodesFromMessage
(std::vector<std::shared_ptr<DHTNode> >& nodes,
(std::vector<std::shared_ptr<DHTNode>>& nodes,
const DHTGetPeersReplyMessage* message)
{
const std::vector<std::shared_ptr<DHTNode> >& knodes =
message->getClosestKNodes();
nodes.insert(nodes.end(), knodes.begin(), knodes.end());
auto& knodes = message->getClosestKNodes();
nodes.insert(std::end(nodes), std::begin(knodes), std::end(knodes));
}
void DHTPeerLookupTask::onReceivedInternal
@ -81,16 +83,15 @@ void DHTPeerLookupTask::onReceivedInternal
static_cast<unsigned long>(message->getValues().size())));
}
std::shared_ptr<DHTMessage> DHTPeerLookupTask::createMessage
std::unique_ptr<DHTMessage> DHTPeerLookupTask::createMessage
(const std::shared_ptr<DHTNode>& remoteNode)
{
return getMessageFactory()->createGetPeersMessage(remoteNode, getTargetID());
}
std::shared_ptr<DHTMessageCallback> DHTPeerLookupTask::createCallback()
std::unique_ptr<DHTMessageCallback> DHTPeerLookupTask::createCallback()
{
return std::shared_ptr<DHTPeerLookupTaskCallback>
(new DHTPeerLookupTaskCallback(this));
return make_unique<DHTPeerLookupTaskCallback>(this);
}
void DHTPeerLookupTask::onFinish()
@ -99,26 +100,24 @@ void DHTPeerLookupTask::onFinish()
util::toHex(getTargetID(), DHT_ID_LENGTH).c_str()));
// send announce_peer message to K closest nodes
size_t num = DHTBucket::K;
for(std::deque<std::shared_ptr<DHTNodeLookupEntry> >::const_iterator i =
getEntries().begin(), eoi = getEntries().end();
for(auto i = std::begin(getEntries()), eoi = std::end(getEntries());
i != eoi && num > 0; ++i) {
if(!(*i)->used) {
continue;
}
const std::shared_ptr<DHTNode>& node = (*i)->node;
auto& node = (*i)->node;
std::string idHex = util::toHex(node->getID(), DHT_ID_LENGTH);
std::string token = tokenStorage_[idHex];
if(token.empty()) {
A2_LOG_DEBUG(fmt("Token is empty for ID:%s", idHex.c_str()));
continue;
}
std::shared_ptr<DHTMessage> m =
getMessageFactory()->createAnnouncePeerMessage
(node,
getTargetID(), // this is infoHash
tcpPort_,
token);
getMessageDispatcher()->addMessageToQueue(m);
getMessageDispatcher()->addMessageToQueue
(getMessageFactory()->createAnnouncePeerMessage
(node,
getTargetID(), // this is infoHash
tcpPort_,
token));
--num;
}
}

View File

@ -58,15 +58,15 @@ public:
uint16_t tcpPort);
virtual void getNodesFromMessage
(std::vector<std::shared_ptr<DHTNode> >& nodes,
(std::vector<std::shared_ptr<DHTNode>>& nodes,
const DHTGetPeersReplyMessage* message);
virtual void onReceivedInternal(const DHTGetPeersReplyMessage* message);
virtual std::shared_ptr<DHTMessage> createMessage
virtual std::unique_ptr<DHTMessage> createMessage
(const std::shared_ptr<DHTNode>& remoteNode);
virtual std::shared_ptr<DHTMessageCallback> createCallback();
virtual std::unique_ptr<DHTMessageCallback> createCallback();
virtual void onFinish();

View File

@ -38,6 +38,7 @@
#include "DHTMessageDispatcher.h"
#include "DHTMessageFactory.h"
#include "DHTMessageCallback.h"
#include "DHTPingReplyMessage.h"
namespace aria2 {
@ -48,20 +49,17 @@ DHTPingMessage::DHTPingMessage(const std::shared_ptr<DHTNode>& localNode,
const std::string& transactionID):
DHTQueryMessage(localNode, remoteNode, transactionID) {}
DHTPingMessage::~DHTPingMessage() {}
void DHTPingMessage::doReceivedAction()
{
// send back ping reply
std::shared_ptr<DHTMessage> reply =
getMessageFactory()->createPingReplyMessage
(getRemoteNode(), getLocalNode()->getID(), getTransactionID());
getMessageDispatcher()->addMessageToQueue(reply);
getMessageDispatcher()->addMessageToQueue
(getMessageFactory()->createPingReplyMessage
(getRemoteNode(), getLocalNode()->getID(), getTransactionID()));
}
std::shared_ptr<Dict> DHTPingMessage::getArgument()
{
std::shared_ptr<Dict> aDict = Dict::g();
auto aDict = Dict::g();
aDict->put(DHTMessage::ID, String::g(getLocalNode()->getID(), DHT_ID_LENGTH));
return aDict;
}

View File

@ -46,8 +46,6 @@ public:
const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID = A2STR::NIL);
virtual ~DHTPingMessage();
virtual void doReceivedAction();
virtual std::shared_ptr<Dict> getArgument();

View File

@ -53,13 +53,11 @@ DHTPingReplyMessage::DHTPingReplyMessage
memcpy(id_, id, DHT_ID_LENGTH);
}
DHTPingReplyMessage::~DHTPingReplyMessage() {}
void DHTPingReplyMessage::doReceivedAction() {}
std::shared_ptr<Dict> DHTPingReplyMessage::getResponse()
{
std::shared_ptr<Dict> rDict = Dict::g();
auto rDict = Dict::g();
rDict->put(DHTMessage::ID, String::g(id_, DHT_ID_LENGTH));
return rDict;
}

View File

@ -49,8 +49,6 @@ public:
const unsigned char* id,
const std::string& transactionID);
virtual ~DHTPingReplyMessage();
virtual void doReceivedAction();
virtual std::shared_ptr<Dict> getResponse();

View File

@ -40,6 +40,7 @@
#include "DHTConstants.h"
#include "DHTPingReplyMessageCallback.h"
#include "DHTQueryMessage.h"
#include "DHTPingMessage.h"
namespace aria2 {
@ -56,11 +57,10 @@ DHTPingTask::~DHTPingTask() {}
void DHTPingTask::addMessage()
{
std::shared_ptr<DHTMessage> m =
getMessageFactory()->createPingMessage(remoteNode_);
std::shared_ptr<DHTMessageCallback> callback
(new DHTPingReplyMessageCallback<DHTPingTask>(this));
getMessageDispatcher()->addMessageToQueue(m, timeout_, callback);
getMessageDispatcher()->addMessageToQueue
(getMessageFactory()->createPingMessage(remoteNode_),
timeout_,
make_unique<DHTPingReplyMessageCallback<DHTPingTask>>(this));
}
void DHTPingTask::startup()

View File

@ -42,6 +42,7 @@
#include "LogFactory.h"
#include "DHTPingReplyMessageCallback.h"
#include "DHTQueryMessage.h"
#include "DHTPingMessage.h"
#include "fmt.h"
namespace aria2 {
@ -67,11 +68,10 @@ void DHTReplaceNodeTask::sendMessage()
if(!questionableNode) {
setFinished(true);
} else {
std::shared_ptr<DHTMessage> m =
getMessageFactory()->createPingMessage(questionableNode);
std::shared_ptr<DHTMessageCallback> callback
(new DHTPingReplyMessageCallback<DHTReplaceNodeTask>(this));
getMessageDispatcher()->addMessageToQueue(m, timeout_, callback);
getMessageDispatcher()->addMessageToQueue
(getMessageFactory()->createPingMessage(questionableNode),
timeout_,
make_unique<DHTPingReplyMessageCallback<DHTReplaceNodeTask>>(this));
}
}

View File

@ -61,6 +61,8 @@
#include "DHTRegistry.h"
#include "DHTBucketRefreshTask.h"
#include "DHTMessageCallback.h"
#include "DHTMessageTrackerEntry.h"
#include "DHTMessageEntry.h"
#include "UDPTrackerClient.h"
#include "BtRegistry.h"
#include "prefs.h"
@ -137,27 +139,19 @@ std::vector<std::unique_ptr<Command>> DHTSetup::setup
util::toHex(localNode->getID(), DHT_ID_LENGTH).c_str()));
std::shared_ptr<DHTRoutingTable> routingTable(new DHTRoutingTable(localNode));
std::shared_ptr<DHTMessageFactoryImpl> factory
(new DHTMessageFactoryImpl(family));
std::shared_ptr<DHTMessageTracker> tracker(new DHTMessageTracker());
std::shared_ptr<DHTMessageDispatcherImpl> dispatcher(new DHTMessageDispatcherImpl(tracker));
std::shared_ptr<DHTMessageReceiver> receiver(new DHTMessageReceiver(tracker));
std::shared_ptr<DHTTaskQueue> taskQueue(new DHTTaskQueueImpl());
std::shared_ptr<DHTTaskFactoryImpl> taskFactory(new DHTTaskFactoryImpl());
std::shared_ptr<DHTPeerAnnounceStorage> peerAnnounceStorage(new DHTPeerAnnounceStorage());
std::shared_ptr<DHTTokenTracker> tokenTracker(new DHTTokenTracker());
const time_t messageTimeout = e->getOption()->getAsInt(PREF_DHT_MESSAGE_TIMEOUT);
auto factory = std::make_shared<DHTMessageFactoryImpl>(family);
auto tracker = std::make_shared<DHTMessageTracker>();
auto dispatcher = std::make_shared<DHTMessageDispatcherImpl>(tracker);
auto receiver = std::make_shared<DHTMessageReceiver>(tracker);
auto taskQueue = std::make_shared<DHTTaskQueueImpl>();
auto taskFactory = std::make_shared<DHTTaskFactoryImpl>();
auto peerAnnounceStorage = std::make_shared<DHTPeerAnnounceStorage>();
auto tokenTracker = std::make_shared<DHTTokenTracker>();
const time_t messageTimeout =
e->getOption()->getAsInt(PREF_DHT_MESSAGE_TIMEOUT);
// wiring up
tracker->setRoutingTable(routingTable);
tracker->setMessageFactory(factory);
tracker->setMessageFactory(factory.get());
dispatcher->setTimeout(messageTimeout);
@ -186,7 +180,7 @@ std::vector<std::unique_ptr<Command>> DHTSetup::setup
factory->setLocalNode(localNode);
// For now, UDPTrackerClient was enabled along with DHT
std::shared_ptr<UDPTrackerClient> udpTrackerClient(new UDPTrackerClient());
auto udpTrackerClient = std::make_shared<UDPTrackerClient>();
// assign them into DHTRegistry
if(family == AF_INET) {
DHTRegistry::getMutableData().localNode = localNode;
@ -211,11 +205,9 @@ std::vector<std::unique_ptr<Command>> DHTSetup::setup
DHTRegistry::getMutableData6().messageFactory = factory;
}
// add deserialized nodes to routing table
const std::vector<std::shared_ptr<DHTNode> >& desnodes =
deserializer.getNodes();
for(std::vector<std::shared_ptr<DHTNode> >::const_iterator i =
desnodes.begin(), eoi = desnodes.end(); i != eoi; ++i) {
routingTable->addNode(*i);
auto& desnodes = deserializer.getNodes();
for(auto& node : desnodes) {
routingTable->addNode(node);
}
if(!desnodes.empty()) {
auto task = std::static_pointer_cast<DHTBucketRefreshTask>
@ -234,7 +226,7 @@ std::vector<std::unique_ptr<Command>> DHTSetup::setup
std::pair<std::string, uint16_t> addr
(e->getOption()->get(prefEntryPointHost),
e->getOption()->getAsInt(prefEntryPointPort));
std::vector<std::pair<std::string, uint16_t> > entryPoints;
std::vector<std::pair<std::string, uint16_t>> entryPoints;
entryPoints.push_back(addr);
auto command = make_unique<DHTEntryPointNameResolveCommand>
(e->newCUID(), e, entryPoints);
@ -302,7 +294,7 @@ std::vector<std::unique_ptr<Command>> DHTSetup::setup
if(family == AF_INET) {
DHTRegistry::clearData();
e->getBtRegistry()->setUDPTrackerClient
(std::shared_ptr<UDPTrackerClient>());
(std::shared_ptr<UDPTrackerClient>{});
} else {
DHTRegistry::clearData6();
}

View File

@ -20,7 +20,14 @@ class DHTAnnouncePeerMessageTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testDoReceivedAction);
CPPUNIT_TEST_SUITE_END();
public:
void setUp() {}
std::shared_ptr<DHTNode> localNode_;
std::shared_ptr<DHTNode> remoteNode_;
void setUp()
{
localNode_ = std::make_shared<DHTNode>();
remoteNode_ = std::make_shared<DHTNode>();
}
void tearDown() {}
@ -28,13 +35,12 @@ public:
void testDoReceivedAction();
class MockDHTMessageFactory2:public MockDHTMessageFactory {
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTAnnouncePeerReplyMessage>
createAnnouncePeerReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID)
const std::string& transactionID) override
{
return std::shared_ptr<DHTResponseMessage>
(new MockDHTResponseMessage
(localNode_, remoteNode, "announce_peer", transactionID));
return make_unique<DHTAnnouncePeerReplyMessage>(localNode_, remoteNode,
transactionID);
}
};
};
@ -44,9 +50,6 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DHTAnnouncePeerMessageTest);
void DHTAnnouncePeerMessageTest::testGetBencodedMessage()
{
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]);
@ -57,7 +60,8 @@ void DHTAnnouncePeerMessageTest::testGetBencodedMessage()
std::string token = "token";
uint16_t port = 6881;
DHTAnnouncePeerMessage msg(localNode, remoteNode, infoHash, port, token, transactionID);
DHTAnnouncePeerMessage msg(localNode_, remoteNode_, infoHash, port, token,
transactionID);
msg.setVersion("A200");
std::string msgbody = msg.getBencodedMessage();
@ -66,8 +70,8 @@ void DHTAnnouncePeerMessageTest::testGetBencodedMessage()
dict.put("v", "A200");
dict.put("y", "q");
dict.put("q", "announce_peer");
std::shared_ptr<Dict> aDict = Dict::g();
aDict->put("id", String::g(localNode->getID(), DHT_ID_LENGTH));
auto aDict = Dict::g();
aDict->put("id", String::g(localNode_->getID(), DHT_ID_LENGTH));
aDict->put("info_hash", String::g(infoHash, DHT_ID_LENGTH));
aDict->put("port", Integer::g(port));
aDict->put("token", token);
@ -79,10 +83,8 @@ void DHTAnnouncePeerMessageTest::testGetBencodedMessage()
void DHTAnnouncePeerMessageTest::testDoReceivedAction()
{
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
remoteNode_->setIPAddress("192.168.0.1");
remoteNode_->setPort(6881);
unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
@ -96,10 +98,11 @@ void DHTAnnouncePeerMessageTest::testDoReceivedAction()
DHTPeerAnnounceStorage peerAnnounceStorage;
MockDHTMessageFactory2 factory;
factory.setLocalNode(localNode);
factory.setLocalNode(localNode_);
MockDHTMessageDispatcher dispatcher;
DHTAnnouncePeerMessage msg(localNode, remoteNode, infoHash, port, token, transactionID);
DHTAnnouncePeerMessage msg(localNode_, remoteNode_, infoHash, port, token,
transactionID);
msg.setPeerAnnounceStorage(&peerAnnounceStorage);
msg.setMessageFactory(&factory);
msg.setMessageDispatcher(&dispatcher);
@ -107,10 +110,10 @@ void DHTAnnouncePeerMessageTest::testDoReceivedAction()
msg.doReceivedAction();
CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size());
auto m = std::dynamic_pointer_cast<MockDHTResponseMessage>
(dispatcher.messageQueue_[0].message_);
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
auto m = dynamic_cast<DHTAnnouncePeerReplyMessage*>
(dispatcher.messageQueue_[0].message_.get());
CPPUNIT_ASSERT(*localNode_ == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("announce_peer"), m->getMessageType());
CPPUNIT_ASSERT_EQUAL(transactionID, m->getTransactionID());
std::vector<std::shared_ptr<Peer> > peers;

View File

@ -20,7 +20,14 @@ class DHTFindNodeMessageTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testDoReceivedAction);
CPPUNIT_TEST_SUITE_END();
public:
void setUp() {}
std::shared_ptr<DHTNode> localNode_;
std::shared_ptr<DHTNode> remoteNode_;
void setUp()
{
localNode_ = std::make_shared<DHTNode>();
remoteNode_ = std::make_shared<DHTNode>();
}
void tearDown() {}
@ -29,16 +36,15 @@ public:
class MockDHTMessageFactory2:public MockDHTMessageFactory {
public:
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTFindNodeReplyMessage>
createFindNodeReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes,
const std::string& transactionID)
std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::string& transactionID) override
{
std::shared_ptr<MockDHTResponseMessage> m
(new MockDHTResponseMessage
(localNode_, remoteNode, "find_node", transactionID));
m->nodes_ = closestKNodes;
auto m = make_unique<DHTFindNodeReplyMessage>
(AF_INET, localNode_, remoteNode, transactionID);
m->setClosestKNodes(std::move(closestKNodes));
return m;
}
};
@ -49,16 +55,14 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DHTFindNodeMessageTest);
void DHTFindNodeMessageTest::testGetBencodedMessage()
{
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]);
std::shared_ptr<DHTNode> targetNode(new DHTNode());
auto targetNode = std::make_shared<DHTNode>();
DHTFindNodeMessage msg(localNode, remoteNode, targetNode->getID(), transactionID);
DHTFindNodeMessage msg(localNode_, remoteNode_, targetNode->getID(),
transactionID);
msg.setVersion("A200");
std::string msgbody = msg.getBencodedMessage();
@ -67,8 +71,8 @@ void DHTFindNodeMessageTest::testGetBencodedMessage()
dict.put("v", "A200");
dict.put("y", "q");
dict.put("q", "find_node");
std::shared_ptr<Dict> aDict = Dict::g();
aDict->put("id", String::g(localNode->getID(), DHT_ID_LENGTH));
auto aDict = Dict::g();
aDict->put("id", String::g(localNode_->getID(), DHT_ID_LENGTH));
aDict->put("target", String::g(targetNode->getID(), DHT_ID_LENGTH));
dict.put("a", aDict);
@ -77,22 +81,20 @@ void DHTFindNodeMessageTest::testGetBencodedMessage()
void DHTFindNodeMessageTest::testDoReceivedAction()
{
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]);
std::shared_ptr<DHTNode> targetNode(new DHTNode());
auto targetNode = std::make_shared<DHTNode>();
MockDHTMessageDispatcher dispatcher;
MockDHTMessageFactory2 factory;
factory.setLocalNode(localNode);
DHTRoutingTable routingTable(localNode);
factory.setLocalNode(localNode_);
DHTRoutingTable routingTable(localNode_);
routingTable.addNode(targetNode);
DHTFindNodeMessage msg(localNode, remoteNode, targetNode->getID(), transactionID);
DHTFindNodeMessage msg(localNode_, remoteNode_, targetNode->getID(),
transactionID);
msg.setMessageDispatcher(&dispatcher);
msg.setMessageFactory(&factory);
msg.setRoutingTable(&routingTable);
@ -100,13 +102,13 @@ void DHTFindNodeMessageTest::testDoReceivedAction()
msg.doReceivedAction();
CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size());
auto m = std::dynamic_pointer_cast<MockDHTResponseMessage>
(dispatcher.messageQueue_[0].message_);
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
auto m = dynamic_cast<DHTFindNodeReplyMessage*>
(dispatcher.messageQueue_[0].message_.get());
CPPUNIT_ASSERT(*localNode_ == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("find_node"), m->getMessageType());
CPPUNIT_ASSERT_EQUAL(msg.getTransactionID(), m->getTransactionID());
CPPUNIT_ASSERT_EQUAL((size_t)1, m->nodes_.size());
CPPUNIT_ASSERT_EQUAL((size_t)1, m->getClosestKNodes().size());
}
} // namespace aria2

View File

@ -22,7 +22,14 @@ class DHTGetPeersMessageTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testDoReceivedAction);
CPPUNIT_TEST_SUITE_END();
public:
void setUp() {}
std::shared_ptr<DHTNode> localNode_;
std::shared_ptr<DHTNode> remoteNode_;
void setUp()
{
localNode_ = std::make_shared<DHTNode>();
remoteNode_ = std::make_shared<DHTNode>();
}
void tearDown() {}
@ -31,20 +38,19 @@ public:
class MockDHTMessageFactory2:public MockDHTMessageFactory {
public:
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTGetPeersReplyMessage>
createGetPeersReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes,
const std::vector<std::shared_ptr<Peer> >& peers,
std::vector<std::shared_ptr<DHTNode>> closestKNodes,
std::vector<std::shared_ptr<Peer>> peers,
const std::string& token,
const std::string& transactionID)
const std::string& transactionID) override
{
std::shared_ptr<MockDHTResponseMessage> m
(new MockDHTResponseMessage
(localNode_, remoteNode, "get_peers", transactionID));
m->nodes_ = closestKNodes;
m->peers_ = peers;
m->token_ = token;
auto m = make_unique<DHTGetPeersReplyMessage>(AF_INET, localNode_,
remoteNode, token,
transactionID);
m->setClosestKNodes(closestKNodes);
m->setValues(peers);
return m;
}
};
@ -55,9 +61,6 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DHTGetPeersMessageTest);
void DHTGetPeersMessageTest::testGetBencodedMessage()
{
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]);
@ -65,7 +68,7 @@ void DHTGetPeersMessageTest::testGetBencodedMessage()
unsigned char infoHash[DHT_ID_LENGTH];
util::generateRandomData(infoHash, DHT_ID_LENGTH);
DHTGetPeersMessage msg(localNode, remoteNode, infoHash, transactionID);
DHTGetPeersMessage msg(localNode_, remoteNode_, infoHash, transactionID);
msg.setVersion("A200");
std::string msgbody = msg.getBencodedMessage();
@ -75,8 +78,8 @@ void DHTGetPeersMessageTest::testGetBencodedMessage()
dict.put("v", "A200");
dict.put("y", "q");
dict.put("q", "get_peers");
std::shared_ptr<Dict> aDict = Dict::g();
aDict->put("id", String::g(localNode->getID(), DHT_ID_LENGTH));
auto aDict = Dict::g();
aDict->put("id", String::g(localNode_->getID(), DHT_ID_LENGTH));
aDict->put("info_hash", String::g(infoHash, DHT_ID_LENGTH));
dict.put("a", aDict);
@ -86,10 +89,8 @@ void DHTGetPeersMessageTest::testGetBencodedMessage()
void DHTGetPeersMessageTest::testDoReceivedAction()
{
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
remoteNode_->setIPAddress("192.168.0.1");
remoteNode_->setPort(6881);
unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
@ -101,10 +102,10 @@ void DHTGetPeersMessageTest::testDoReceivedAction()
DHTTokenTracker tokenTracker;
MockDHTMessageDispatcher dispatcher;
MockDHTMessageFactory2 factory;
factory.setLocalNode(localNode);
DHTRoutingTable routingTable(localNode);
factory.setLocalNode(localNode_);
DHTRoutingTable routingTable(localNode_);
DHTGetPeersMessage msg(localNode, remoteNode, infoHash, transactionID);
DHTGetPeersMessage msg(localNode_, remoteNode_, infoHash, transactionID);
msg.setRoutingTable(&routingTable);
msg.setTokenTracker(&tokenTracker);
msg.setMessageDispatcher(&dispatcher);
@ -120,22 +121,25 @@ void DHTGetPeersMessageTest::testDoReceivedAction()
msg.doReceivedAction();
CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size());
auto m = std::dynamic_pointer_cast<MockDHTResponseMessage>
(dispatcher.messageQueue_[0].message_);
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
auto m = dynamic_cast<DHTGetPeersReplyMessage*>
(dispatcher.messageQueue_[0].message_.get());
CPPUNIT_ASSERT(*localNode_ == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("get_peers"), m->getMessageType());
CPPUNIT_ASSERT_EQUAL(msg.getTransactionID(), m->getTransactionID());
CPPUNIT_ASSERT_EQUAL(tokenTracker.generateToken(infoHash, remoteNode->getIPAddress(), remoteNode->getPort()), m->token_);
CPPUNIT_ASSERT_EQUAL((size_t)0, m->nodes_.size());
CPPUNIT_ASSERT_EQUAL((size_t)2, m->peers_.size());
CPPUNIT_ASSERT_EQUAL(tokenTracker.generateToken
(infoHash, remoteNode_->getIPAddress(),
remoteNode_->getPort()),
m->getToken());
CPPUNIT_ASSERT_EQUAL((size_t)0, m->getClosestKNodes().size());
CPPUNIT_ASSERT_EQUAL((size_t)2, m->getValues().size());
{
std::shared_ptr<Peer> peer = m->peers_[0];
auto peer = m->getValues()[0];
CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.100"), peer->getIPAddress());
CPPUNIT_ASSERT_EQUAL((uint16_t)6888, peer->getPort());
}
{
std::shared_ptr<Peer> peer = m->peers_[1];
auto peer = m->getValues()[1];
CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.101"), peer->getIPAddress());
CPPUNIT_ASSERT_EQUAL((uint16_t)6889, peer->getPort());
}
@ -144,7 +148,7 @@ void DHTGetPeersMessageTest::testDoReceivedAction()
{
// localhost doesn't have peer contact information for that infohash.
DHTPeerAnnounceStorage peerAnnounceStorage;
DHTRoutingTable routingTable(localNode);
DHTRoutingTable routingTable(localNode_);
std::shared_ptr<DHTNode> returnNode1(new DHTNode());
routingTable.addNode(returnNode1);
@ -154,16 +158,19 @@ void DHTGetPeersMessageTest::testDoReceivedAction()
msg.doReceivedAction();
CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size());
auto m = std::dynamic_pointer_cast<MockDHTResponseMessage>
(dispatcher.messageQueue_[0].message_);
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
auto m = dynamic_cast<DHTGetPeersReplyMessage*>
(dispatcher.messageQueue_[0].message_.get());
CPPUNIT_ASSERT(*localNode_ == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("get_peers"), m->getMessageType());
CPPUNIT_ASSERT_EQUAL(msg.getTransactionID(), m->getTransactionID());
CPPUNIT_ASSERT_EQUAL(tokenTracker.generateToken(infoHash, remoteNode->getIPAddress(), remoteNode->getPort()), m->token_);
CPPUNIT_ASSERT_EQUAL((size_t)1, m->nodes_.size());
CPPUNIT_ASSERT(*returnNode1 == *m->nodes_[0]);
CPPUNIT_ASSERT_EQUAL((size_t)0, m->peers_.size());
CPPUNIT_ASSERT_EQUAL(tokenTracker.generateToken
(infoHash, remoteNode_->getIPAddress(),
remoteNode_->getPort()),
m->getToken());
CPPUNIT_ASSERT_EQUAL((size_t)1, m->getClosestKNodes().size());
CPPUNIT_ASSERT(*returnNode1 == *m->getClosestKNodes()[0]);
CPPUNIT_ASSERT_EQUAL((size_t)0, m->getValues().size());
}
}

View File

@ -1,12 +1,15 @@
#include "DHTNode.h"
#include "DHTNodeLookupEntry.h"
#include "DHTIDCloser.h"
#include "Exception.h"
#include "util.h"
#include <cstring>
#include <algorithm>
#include <cppunit/extensions/HelperMacros.h>
#include "DHTNode.h"
#include "DHTNodeLookupEntry.h"
#include "Exception.h"
#include "util.h"
namespace aria2 {
class DHTIDCloserTest:public CppUnit::TestFixture {
@ -30,39 +33,40 @@ void DHTIDCloserTest::testOperator()
unsigned char id[DHT_ID_LENGTH];
memset(id, 0xf0, DHT_ID_LENGTH);
std::shared_ptr<DHTNodeLookupEntry> e1
(new DHTNodeLookupEntry(std::shared_ptr<DHTNode>(new DHTNode(id))));
auto e1 = make_unique<DHTNodeLookupEntry>(std::make_shared<DHTNode>(id));
auto ep1 = e1.get();
id[0] = 0xb0;
std::shared_ptr<DHTNodeLookupEntry> e2
(new DHTNodeLookupEntry(std::shared_ptr<DHTNode>(new DHTNode(id))));
auto e2 = make_unique<DHTNodeLookupEntry>(std::make_shared<DHTNode>(id));
auto ep2 = e2.get();
id[0] = 0xa0;
std::shared_ptr<DHTNodeLookupEntry> e3
(new DHTNodeLookupEntry(std::shared_ptr<DHTNode>(new DHTNode(id))));
auto e3 = make_unique<DHTNodeLookupEntry>(std::make_shared<DHTNode>(id));
auto ep3 = e3.get();
id[0] = 0x80;
std::shared_ptr<DHTNodeLookupEntry> e4
(new DHTNodeLookupEntry(std::shared_ptr<DHTNode>(new DHTNode(id))));
auto e4 = make_unique<DHTNodeLookupEntry>(std::make_shared<DHTNode>(id));
auto ep4 = e4.get();
id[0] = 0x00;
std::shared_ptr<DHTNodeLookupEntry> e5
(new DHTNodeLookupEntry(std::shared_ptr<DHTNode>(new DHTNode(id))));
auto e5 = make_unique<DHTNodeLookupEntry>(std::make_shared<DHTNode>(id));
auto ep5 = e5.get();
std::deque<std::shared_ptr<DHTNodeLookupEntry> > entries;
entries.push_back(e1);
entries.push_back(e2);
entries.push_back(e3);
entries.push_back(e4);
entries.push_back(e5);
auto entries = std::vector<std::unique_ptr<DHTNodeLookupEntry>>{};
entries.push_back(std::move(e1));
entries.push_back(std::move(e2));
entries.push_back(std::move(e3));
entries.push_back(std::move(e4));
entries.push_back(std::move(e5));
std::sort(entries.begin(), entries.end(), DHTIDCloser(e3->node->getID()));
std::sort(std::begin(entries), std::end(entries),
DHTIDCloser(ep3->node->getID()));
CPPUNIT_ASSERT(*e3 == *entries[0]);
CPPUNIT_ASSERT(*e2 == *entries[1]);
CPPUNIT_ASSERT(*e4 == *entries[2]);
CPPUNIT_ASSERT(*e1 == *entries[3]);
CPPUNIT_ASSERT(*e5 == *entries[4]);
CPPUNIT_ASSERT(*ep3 == *entries[0]);
CPPUNIT_ASSERT(*ep2 == *entries[1]);
CPPUNIT_ASSERT(*ep4 == *entries[2]);
CPPUNIT_ASSERT(*ep1 == *entries[3]);
CPPUNIT_ASSERT(*ep5 == *entries[4]);
}
} // namespace aria2

View File

@ -40,25 +40,36 @@ class DHTMessageFactoryImplTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testReceivedErrorMessage);
CPPUNIT_TEST_SUITE_END();
public:
std::shared_ptr<DHTMessageFactoryImpl> factory;
std::unique_ptr<DHTMessageFactoryImpl> factory;
std::shared_ptr<DHTRoutingTable> routingTable;
std::unique_ptr<DHTRoutingTable> routingTable;
std::shared_ptr<DHTNode> localNode;
std::unique_ptr<DHTNode> remoteNode_;
std::unique_ptr<DHTNode> remoteNode6_;
unsigned char transactionID[DHT_TRANSACTION_ID_LENGTH];
unsigned char remoteNodeID[DHT_ID_LENGTH];
void setUp()
{
localNode.reset(new DHTNode());
factory.reset(new DHTMessageFactoryImpl(AF_INET));
localNode = std::make_shared<DHTNode>();
factory = make_unique<DHTMessageFactoryImpl>(AF_INET);
factory->setLocalNode(localNode);
memset(transactionID, 0xff, DHT_TRANSACTION_ID_LENGTH);
memset(remoteNodeID, 0x0f, DHT_ID_LENGTH);
routingTable.reset(new DHTRoutingTable(localNode));
routingTable = make_unique<DHTRoutingTable>(localNode);
factory->setRoutingTable(routingTable.get());
remoteNode_ = make_unique<DHTNode>(remoteNodeID);
remoteNode_->setIPAddress("192.168.0.1");
remoteNode_->setPort(6881);
remoteNode6_ = make_unique<DHTNode>(remoteNodeID);
remoteNode6_->setIPAddress("2001::2001");
remoteNode6_->setPort(6881);
}
void tearDown() {}
@ -85,18 +96,15 @@ void DHTMessageFactoryImplTest::testCreatePingMessage()
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "q");
dict.put("q", "ping");
std::shared_ptr<Dict> aDict = Dict::g();
auto aDict = Dict::g();
aDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
dict.put("a", aDict);
auto m = std::dynamic_pointer_cast<DHTPingMessage>
(factory->createQueryMessage(&dict, "192.168.0.1", 6881));
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
auto r = factory->createQueryMessage(&dict, "192.168.0.1", 6881);
auto m = dynamic_cast<DHTPingMessage*>(r.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH),
util::toHex(m->getTransactionID()));
}
@ -106,21 +114,17 @@ void DHTMessageFactoryImplTest::testCreatePingReplyMessage()
Dict dict;
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "r");
std::shared_ptr<Dict> rDict = Dict::g();
auto rDict = Dict::g();
rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
dict.put("r", rDict);
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
auto m = std::dynamic_pointer_cast<DHTPingReplyMessage>
(factory->createResponseMessage("ping", &dict,
remoteNode->getIPAddress(),
remoteNode->getPort()));
auto r = factory->createResponseMessage("ping", &dict,
remoteNode_->getIPAddress(),
remoteNode_->getPort());
auto m = dynamic_cast<DHTPingReplyMessage*>(r.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH),
util::toHex(m->getTransactionID()));
}
@ -131,21 +135,18 @@ void DHTMessageFactoryImplTest::testCreateFindNodeMessage()
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "q");
dict.put("q", "find_node");
std::shared_ptr<Dict> aDict = Dict::g();
auto aDict = Dict::g();
aDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
unsigned char targetNodeID[DHT_ID_LENGTH];
memset(targetNodeID, 0x11, DHT_ID_LENGTH);
aDict->put("target", String::g(targetNodeID, DHT_ID_LENGTH));
dict.put("a", aDict);
auto m = std::dynamic_pointer_cast<DHTFindNodeMessage>
(factory->createQueryMessage(&dict, "192.168.0.1", 6881));
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
auto r = factory->createQueryMessage(&dict, "192.168.0.1", 6881);
auto m = dynamic_cast<DHTFindNodeMessage*>(r.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH),
util::toHex(m->getTransactionID()));
CPPUNIT_ASSERT_EQUAL(util::toHex(targetNodeID, DHT_ID_LENGTH),
@ -158,12 +159,12 @@ void DHTMessageFactoryImplTest::testCreateFindNodeReplyMessage()
Dict dict;
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "r");
std::shared_ptr<Dict> rDict = Dict::g();
auto rDict = Dict::g();
rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
std::string compactNodeInfo;
std::shared_ptr<DHTNode> nodes[8];
for(size_t i = 0; i < DHTBucket::K; ++i) {
nodes[i].reset(new DHTNode());
nodes[i] = std::make_shared<DHTNode>();
nodes[i]->setIPAddress("192.168.0."+util::uitos(i+1));
nodes[i]->setPort(6881+i);
@ -179,17 +180,13 @@ void DHTMessageFactoryImplTest::testCreateFindNodeReplyMessage()
rDict->put("nodes", compactNodeInfo);
dict.put("r", rDict);
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
auto m = std::dynamic_pointer_cast<DHTFindNodeReplyMessage>
(factory->createResponseMessage("find_node", &dict,
remoteNode->getIPAddress(),
remoteNode->getPort()));
auto r = factory->createResponseMessage("find_node", &dict,
remoteNode_->getIPAddress(),
remoteNode_->getPort());
auto m = dynamic_cast<DHTFindNodeReplyMessage*>(r.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL((size_t)DHTBucket::K, m->getClosestKNodes().size());
CPPUNIT_ASSERT(*nodes[0] == *m->getClosestKNodes()[0]);
CPPUNIT_ASSERT(*nodes[7] == *m->getClosestKNodes()[7]);
@ -202,19 +199,19 @@ void DHTMessageFactoryImplTest::testCreateFindNodeReplyMessage()
void DHTMessageFactoryImplTest::testCreateFindNodeReplyMessage6()
{
factory.reset(new DHTMessageFactoryImpl(AF_INET6));
factory = make_unique<DHTMessageFactoryImpl>(AF_INET6);
factory->setLocalNode(localNode);
factory->setRoutingTable(routingTable.get());
try {
Dict dict;
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "r");
std::shared_ptr<Dict> rDict = Dict::g();
auto rDict = Dict::g();
rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
std::string compactNodeInfo;
std::shared_ptr<DHTNode> nodes[8];
for(size_t i = 0; i < DHTBucket::K; ++i) {
nodes[i].reset(new DHTNode());
nodes[i] = std::make_shared<DHTNode>();
nodes[i]->setIPAddress("2001::000"+util::uitos(i+1));
nodes[i]->setPort(6881+i);
@ -230,17 +227,13 @@ void DHTMessageFactoryImplTest::testCreateFindNodeReplyMessage6()
rDict->put("nodes6", compactNodeInfo);
dict.put("r", rDict);
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("2001::2001");
remoteNode->setPort(6881);
auto m = std::dynamic_pointer_cast<DHTFindNodeReplyMessage>
(factory->createResponseMessage("find_node", &dict,
remoteNode->getIPAddress(),
remoteNode->getPort()));
auto r = factory->createResponseMessage("find_node", &dict,
remoteNode_->getIPAddress(),
remoteNode_->getPort());
auto m = dynamic_cast<DHTFindNodeReplyMessage*>(r.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL((size_t)DHTBucket::K, m->getClosestKNodes().size());
CPPUNIT_ASSERT(*nodes[0] == *m->getClosestKNodes()[0]);
CPPUNIT_ASSERT(*nodes[7] == *m->getClosestKNodes()[7]);
@ -257,21 +250,18 @@ void DHTMessageFactoryImplTest::testCreateGetPeersMessage()
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "q");
dict.put("q", "get_peers");
std::shared_ptr<Dict> aDict = Dict::g();
auto aDict = Dict::g();
aDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
unsigned char infoHash[DHT_ID_LENGTH];
memset(infoHash, 0x11, DHT_ID_LENGTH);
aDict->put("info_hash", String::g(infoHash, DHT_ID_LENGTH));
dict.put("a", aDict);
auto m = std::dynamic_pointer_cast<DHTGetPeersMessage>
(factory->createQueryMessage(&dict, "192.168.0.1", 6881));
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
auto r = factory->createQueryMessage(&dict, "192.168.0.1", 6881);
auto m = dynamic_cast<DHTGetPeersMessage*>(r.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH),
util::toHex(m->getTransactionID()));
CPPUNIT_ASSERT_EQUAL(util::toHex(infoHash, DHT_ID_LENGTH),
@ -284,12 +274,12 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage()
Dict dict;
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "r");
std::shared_ptr<Dict> rDict = Dict::g();
auto rDict = Dict::g();
rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
std::string compactNodeInfo;
std::shared_ptr<DHTNode> nodes[8];
for(size_t i = 0; i < DHTBucket::K; ++i) {
nodes[i].reset(new DHTNode());
nodes[i] = std::make_shared<DHTNode>();
nodes[i]->setIPAddress("192.168.0."+util::uitos(i+1));
nodes[i]->setPort(6881+i);
@ -307,7 +297,8 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage()
std::deque<std::shared_ptr<Peer> > peers;
std::shared_ptr<List> valuesList = List::g();
for(size_t i = 0; i < 4; ++i) {
std::shared_ptr<Peer> peer(new Peer("192.168.0."+util::uitos(i+1), 6881+i));
auto peer = std::make_shared<Peer>("192.168.0."+util::uitos(i+1),
6881+i);
unsigned char buffer[COMPACT_LEN_IPV6];
CPPUNIT_ASSERT_EQUAL
(COMPACT_LEN_IPV4,
@ -321,17 +312,13 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage()
rDict->put("token", "token");
dict.put("r", rDict);
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
auto m = std::dynamic_pointer_cast<DHTGetPeersReplyMessage>
(factory->createResponseMessage("get_peers", &dict,
remoteNode->getIPAddress(),
remoteNode->getPort()));
auto r = factory->createResponseMessage("get_peers", &dict,
remoteNode_->getIPAddress(),
remoteNode_->getPort());
auto m = dynamic_cast<DHTGetPeersReplyMessage*>(r.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("token"), m->getToken());
CPPUNIT_ASSERT_EQUAL((size_t)DHTBucket::K, m->getClosestKNodes().size());
CPPUNIT_ASSERT(*nodes[0] == *m->getClosestKNodes()[0]);
@ -351,19 +338,19 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage()
void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage6()
{
factory.reset(new DHTMessageFactoryImpl(AF_INET6));
factory = make_unique<DHTMessageFactoryImpl>(AF_INET6);
factory->setLocalNode(localNode);
factory->setRoutingTable(routingTable.get());
try {
Dict dict;
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "r");
std::shared_ptr<Dict> rDict = Dict::g();
auto rDict = Dict::g();
rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
std::string compactNodeInfo;
std::shared_ptr<DHTNode> nodes[8];
for(size_t i = 0; i < DHTBucket::K; ++i) {
nodes[i].reset(new DHTNode());
nodes[i] = std::make_shared<DHTNode>();
nodes[i]->setIPAddress("2001::000"+util::uitos(i+1));
nodes[i]->setPort(6881+i);
@ -378,10 +365,10 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage6()
}
rDict->put("nodes6", compactNodeInfo);
std::deque<std::shared_ptr<Peer> > peers;
std::shared_ptr<List> valuesList = List::g();
std::deque<std::shared_ptr<Peer>> peers;
auto valuesList = List::g();
for(size_t i = 0; i < 4; ++i) {
std::shared_ptr<Peer> peer(new Peer("2001::100"+util::uitos(i+1), 6881+i));
auto peer = std::make_shared<Peer>("2001::100"+util::uitos(i+1), 6881+i);
unsigned char buffer[COMPACT_LEN_IPV6];
CPPUNIT_ASSERT_EQUAL
(COMPACT_LEN_IPV6,
@ -395,17 +382,13 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage6()
rDict->put("token", "token");
dict.put("r", rDict);
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("2001::2001");
remoteNode->setPort(6881);
auto m = std::dynamic_pointer_cast<DHTGetPeersReplyMessage>
(factory->createResponseMessage("get_peers", &dict,
remoteNode->getIPAddress(),
remoteNode->getPort()));
auto r = factory->createResponseMessage("get_peers", &dict,
remoteNode_->getIPAddress(),
remoteNode_->getPort());
auto m = dynamic_cast<DHTGetPeersReplyMessage*>(r.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("token"), m->getToken());
CPPUNIT_ASSERT_EQUAL((size_t)DHTBucket::K, m->getClosestKNodes().size());
CPPUNIT_ASSERT(*nodes[0] == *m->getClosestKNodes()[0]);
@ -430,7 +413,7 @@ void DHTMessageFactoryImplTest::testCreateAnnouncePeerMessage()
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "q");
dict.put("q", "announce_peer");
std::shared_ptr<Dict> aDict = Dict::g();
auto aDict = Dict::g();
aDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
unsigned char infoHash[DHT_ID_LENGTH];
memset(infoHash, 0x11, DHT_ID_LENGTH);
@ -441,14 +424,13 @@ void DHTMessageFactoryImplTest::testCreateAnnouncePeerMessage()
aDict->put("token", token);
dict.put("a", aDict);
auto m = std::dynamic_pointer_cast<DHTAnnouncePeerMessage>
(factory->createQueryMessage(&dict, "192.168.0.1", 6882));
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6882);
remoteNode_->setPort(6882);
auto r = factory->createQueryMessage(&dict, "192.168.0.1", 6882);
auto m = dynamic_cast<DHTAnnouncePeerMessage*>(r.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(token, m->getToken());
CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH),
util::toHex(m->getTransactionID()));
@ -465,21 +447,17 @@ void DHTMessageFactoryImplTest::testCreateAnnouncePeerReplyMessage()
Dict dict;
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "r");
std::shared_ptr<Dict> rDict = Dict::g();
auto rDict = Dict::g();
rDict->put("id", String::g(remoteNodeID, DHT_ID_LENGTH));
dict.put("r", rDict);
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
auto m = std::dynamic_pointer_cast<DHTAnnouncePeerReplyMessage>
(factory->createResponseMessage("announce_peer", &dict,
remoteNode->getIPAddress(),
remoteNode->getPort()));
auto r = factory->createResponseMessage("announce_peer", &dict,
remoteNode_->getIPAddress(),
remoteNode_->getPort());
auto m = dynamic_cast<DHTAnnouncePeerReplyMessage*>(r.get());
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(util::toHex(transactionID, DHT_TRANSACTION_ID_LENGTH),
util::toHex(m->getTransactionID()));
}
@ -489,19 +467,15 @@ void DHTMessageFactoryImplTest::testReceivedErrorMessage()
Dict dict;
dict.put("t", String::g(transactionID, DHT_TRANSACTION_ID_LENGTH));
dict.put("y", "e");
std::shared_ptr<List> list = List::g();
auto list = List::g();
list->append(Integer::g(404));
list->append("Not found");
dict.put("e", list);
std::shared_ptr<DHTNode> remoteNode(new DHTNode(remoteNodeID));
remoteNode->setIPAddress("192.168.0.1");
remoteNode->setPort(6881);
try {
factory->createResponseMessage("announce_peer", &dict,
remoteNode->getIPAddress(),
remoteNode->getPort());
remoteNode_->getIPAddress(),
remoteNode_->getPort());
CPPUNIT_FAIL("exception must be thrown.");
} catch(RecoverableException& e) {
std::cerr << e.stackTrace() << std::endl;

View File

@ -33,14 +33,17 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DHTMessageTrackerEntryTest);
void DHTMessageTrackerEntryTest::testMatch()
{
std::shared_ptr<DHTNode> localNode(new DHTNode());
auto localNode = std::make_shared<DHTNode>();
try {
std::shared_ptr<DHTNode> node1(new DHTNode());
std::shared_ptr<MockDHTMessage> msg1(new MockDHTMessage(localNode, node1));
std::shared_ptr<DHTNode> node2(new DHTNode());
std::shared_ptr<MockDHTMessage> msg2(new MockDHTMessage(localNode, node2));
auto node1 = std::make_shared<DHTNode>();
auto msg1 = make_unique<MockDHTMessage>(localNode, node1);
auto node2 = std::make_shared<DHTNode>();
auto msg2 = make_unique<MockDHTMessage>(localNode, node2);
DHTMessageTrackerEntry entry(msg1, 30);
DHTMessageTrackerEntry entry(msg1->getRemoteNode(),
msg1->getTransactionID(),
msg1->getMessageType(),
30);
CPPUNIT_ASSERT(entry.match(msg1->getTransactionID(),
msg1->getRemoteNode()->getIPAddress(),

View File

@ -34,65 +34,62 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DHTMessageTrackerTest);
void DHTMessageTrackerTest::testMessageArrived()
{
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTRoutingTable> routingTable(new DHTRoutingTable(localNode));
std::shared_ptr<MockDHTMessageFactory> factory(new MockDHTMessageFactory());
auto localNode = std::make_shared<DHTNode>();
auto routingTable = std::make_shared<DHTRoutingTable>(localNode);
auto factory = std::make_shared<MockDHTMessageFactory>();
factory->setLocalNode(localNode);
std::shared_ptr<MockDHTMessage> m1(new MockDHTMessage(localNode,
std::shared_ptr<DHTNode>(new DHTNode())));
std::shared_ptr<MockDHTMessage> m2(new MockDHTMessage(localNode,
std::shared_ptr<DHTNode>(new DHTNode())));
std::shared_ptr<MockDHTMessage> m3(new MockDHTMessage(localNode,
std::shared_ptr<DHTNode>(new DHTNode())));
auto r1 = std::make_shared<DHTNode>();
r1->setIPAddress("192.168.0.1");
r1->setPort(6881);
auto r2 = std::make_shared<DHTNode>();
r2->setIPAddress("192.168.0.2");
r2->setPort(6882);
auto r3 = std::make_shared<DHTNode>();
r3->setIPAddress("192.168.0.3");
r3->setPort(6883);
m1->getRemoteNode()->setIPAddress("192.168.0.1");
m1->getRemoteNode()->setPort(6881);
m2->getRemoteNode()->setIPAddress("192.168.0.2");
m2->getRemoteNode()->setPort(6882);
m3->getRemoteNode()->setIPAddress("192.168.0.3");
m3->getRemoteNode()->setPort(6883);
auto m1 = make_unique<MockDHTMessage>(localNode, r1);
auto m2 = make_unique<MockDHTMessage>(localNode, r2);
auto m3 = make_unique<MockDHTMessage>(localNode, r3);
DHTMessageTracker tracker;
tracker.setRoutingTable(routingTable);
tracker.setMessageFactory(factory);
tracker.addMessage(m1, DHT_MESSAGE_TIMEOUT);
tracker.addMessage(m2, DHT_MESSAGE_TIMEOUT);
tracker.addMessage(m3, DHT_MESSAGE_TIMEOUT);
tracker.setMessageFactory(factory.get());
tracker.addMessage(m1.get(), DHT_MESSAGE_TIMEOUT);
tracker.addMessage(m2.get(), DHT_MESSAGE_TIMEOUT);
tracker.addMessage(m3.get(), DHT_MESSAGE_TIMEOUT);
{
Dict resDict;
resDict.put("t", m2->getTransactionID());
std::pair<std::shared_ptr<DHTMessage>, std::shared_ptr<DHTMessageCallback> > p =
tracker.messageArrived(&resDict, m2->getRemoteNode()->getIPAddress(),
m2->getRemoteNode()->getPort());
std::shared_ptr<DHTMessage> reply = p.first;
auto p =
tracker.messageArrived(&resDict, r2->getIPAddress(), r2->getPort());
auto& reply = p.first;
CPPUNIT_ASSERT(reply);
CPPUNIT_ASSERT(!tracker.getEntryFor(m2));
CPPUNIT_ASSERT(!tracker.getEntryFor(m2.get()));
CPPUNIT_ASSERT_EQUAL((size_t)2, tracker.countEntry());
}
{
Dict resDict;
resDict.put("t", m3->getTransactionID());
std::pair<std::shared_ptr<DHTMessage>, std::shared_ptr<DHTMessageCallback> > p =
tracker.messageArrived(&resDict, m3->getRemoteNode()->getIPAddress(),
m3->getRemoteNode()->getPort());
std::shared_ptr<DHTMessage> reply = p.first;
auto p =
tracker.messageArrived(&resDict, r3->getIPAddress(), r3->getPort());
auto& reply = p.first;
CPPUNIT_ASSERT(reply);
CPPUNIT_ASSERT(!tracker.getEntryFor(m3));
CPPUNIT_ASSERT(!tracker.getEntryFor(m3.get()));
CPPUNIT_ASSERT_EQUAL((size_t)1, tracker.countEntry());
}
{
Dict resDict;
resDict.put("t", m1->getTransactionID());
std::pair<std::shared_ptr<DHTMessage>, std::shared_ptr<DHTMessageCallback> > p =
tracker.messageArrived(&resDict, "192.168.1.100", 6889);
std::shared_ptr<DHTMessage> reply = p.first;
auto p = tracker.messageArrived(&resDict, "192.168.1.100", 6889);
auto& reply = p.first;
CPPUNIT_ASSERT(!reply);
}

View File

@ -19,7 +19,14 @@ class DHTPingMessageTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testDoReceivedAction);
CPPUNIT_TEST_SUITE_END();
public:
void setUp() {}
std::shared_ptr<DHTNode> localNode_;
std::shared_ptr<DHTNode> remoteNode_;
void setUp()
{
localNode_ = std::make_shared<DHTNode>();
remoteNode_ = std::make_shared<DHTNode>();
}
void tearDown() {}
@ -28,14 +35,15 @@ public:
class MockDHTMessageFactory2:public MockDHTMessageFactory {
public:
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTPingReplyMessage>
createPingReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* remoteNodeID,
const std::string& transactionID)
const std::string& transactionID) override
{
return std::shared_ptr<MockDHTResponseMessage>
(new MockDHTResponseMessage(localNode_, remoteNode, "ping_reply",
transactionID));
unsigned char id[DHT_ID_LENGTH];
std::fill(std::begin(id), std::end(id), '0');
return make_unique<DHTPingReplyMessage>
(localNode_, remoteNode, id, transactionID);
}
};
};
@ -45,14 +53,11 @@ CPPUNIT_TEST_SUITE_REGISTRATION(DHTPingMessageTest);
void DHTPingMessageTest::testGetBencodedMessage()
{
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]);
DHTPingMessage msg(localNode, remoteNode, transactionID);
DHTPingMessage msg(localNode_, remoteNode_, transactionID);
msg.setVersion("A200");
std::string msgbody = msg.getBencodedMessage();
@ -62,8 +67,8 @@ void DHTPingMessageTest::testGetBencodedMessage()
dict.put("v", "A200");
dict.put("y", "q");
dict.put("q", "ping");
std::shared_ptr<Dict> aDict = Dict::g();
aDict->put("id", String::g(localNode->getID(), DHT_ID_LENGTH));
auto aDict = Dict::g();
aDict->put("id", String::g(localNode_->getID(), DHT_ID_LENGTH));
dict.put("a", aDict);
CPPUNIT_ASSERT_EQUAL(bencode2::encode(&dict), msgbody);
@ -71,29 +76,26 @@ void DHTPingMessageTest::testGetBencodedMessage()
void DHTPingMessageTest::testDoReceivedAction()
{
std::shared_ptr<DHTNode> localNode(new DHTNode());
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
unsigned char tid[DHT_TRANSACTION_ID_LENGTH];
util::generateRandomData(tid, DHT_TRANSACTION_ID_LENGTH);
std::string transactionID(&tid[0], &tid[DHT_TRANSACTION_ID_LENGTH]);
MockDHTMessageDispatcher dispatcher;
MockDHTMessageFactory2 factory;
factory.setLocalNode(localNode);
factory.setLocalNode(localNode_);
DHTPingMessage msg(localNode, remoteNode, transactionID);
DHTPingMessage msg(localNode_, remoteNode_, transactionID);
msg.setMessageDispatcher(&dispatcher);
msg.setMessageFactory(&factory);
msg.doReceivedAction();
CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher.messageQueue_.size());
auto m = std::dynamic_pointer_cast<MockDHTResponseMessage>
(dispatcher.messageQueue_[0].message_);
CPPUNIT_ASSERT(*localNode == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("ping_reply"), m->getMessageType());
auto m = dynamic_cast<DHTPingReplyMessage*>
(dispatcher.messageQueue_[0].message_.get());
CPPUNIT_ASSERT(*localNode_ == *m->getLocalNode());
CPPUNIT_ASSERT(*remoteNode_ == *m->getRemoteNode());
CPPUNIT_ASSERT_EQUAL(std::string("ping"), m->getMessageType());
CPPUNIT_ASSERT_EQUAL(msg.getTransactionID(), m->getTransactionID());
}

View File

@ -10,17 +10,17 @@ namespace aria2 {
class MockDHTMessageDispatcher:public DHTMessageDispatcher {
public:
class Entry {
public:
std::shared_ptr<DHTMessage> message_;
struct Entry {
std::unique_ptr<DHTMessage> message_;
time_t timeout_;
std::shared_ptr<DHTMessageCallback> callback_;
std::unique_ptr<DHTMessageCallback> callback_;
Entry(const std::shared_ptr<DHTMessage>& message, time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback):
message_(message),
timeout_(timeout),
callback_(callback) {}
Entry(std::unique_ptr<DHTMessage> message, time_t timeout,
std::unique_ptr<DHTMessageCallback> callback)
: message_{std::move(message)},
timeout_{timeout},
callback_{std::move(callback)}
{}
};
std::deque<Entry> messageQueue_;
@ -28,23 +28,23 @@ public:
public:
MockDHTMessageDispatcher() {}
virtual ~MockDHTMessageDispatcher() {}
virtual void
addMessageToQueue(const std::shared_ptr<DHTMessage>& message,
addMessageToQueue(std::unique_ptr<DHTMessage> message,
time_t timeout,
const std::shared_ptr<DHTMessageCallback>& callback =
std::shared_ptr<DHTMessageCallback>())
std::unique_ptr<DHTMessageCallback> callback =
std::unique_ptr<DHTMessageCallback>{})
{
messageQueue_.push_back(Entry(message, timeout, callback));
messageQueue_.push_back(Entry(std::move(message), timeout,
std::move(callback)));
}
virtual void
addMessageToQueue(const std::shared_ptr<DHTMessage>& message,
const std::shared_ptr<DHTMessageCallback>& callback =
std::shared_ptr<DHTMessageCallback>())
addMessageToQueue(std::unique_ptr<DHTMessage> message,
std::unique_ptr<DHTMessageCallback> callback =
std::unique_ptr<DHTMessageCallback>{})
{
messageQueue_.push_back(Entry(message, DHT_MESSAGE_TIMEOUT, callback));
messageQueue_.push_back(Entry(std::move(message), DHT_MESSAGE_TIMEOUT,
std::move(callback)));
}
virtual void sendMessages() {}

View File

@ -4,6 +4,15 @@
#include "DHTMessageFactory.h"
#include "DHTNode.h"
#include "MockDHTMessage.h"
#include "DHTPingMessage.h"
#include "DHTPingReplyMessage.h"
#include "DHTFindNodeMessage.h"
#include "DHTFindNodeReplyMessage.h"
#include "DHTGetPeersMessage.h"
#include "DHTGetPeersReplyMessage.h"
#include "DHTAnnouncePeerMessage.h"
#include "DHTAnnouncePeerReplyMessage.h"
#include "DHTUnknownMessage.h"
namespace aria2 {
@ -13,103 +22,99 @@ protected:
public:
MockDHTMessageFactory() {}
virtual ~MockDHTMessageFactory() {}
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTQueryMessage>
createQueryMessage(const Dict* dict,
const std::string& ipaddr, uint16_t port)
{
return std::shared_ptr<DHTQueryMessage>();
return std::unique_ptr<DHTQueryMessage>{};
}
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTResponseMessage>
createResponseMessage(const std::string& messageType,
const Dict* dict,
const std::string& ipaddr, uint16_t port)
{
std::shared_ptr<DHTNode> remoteNode(new DHTNode());
auto remoteNode = std::make_shared<DHTNode>();
// TODO At this point, removeNode's ID is random.
remoteNode->setIPAddress(ipaddr);
remoteNode->setPort(port);
std::shared_ptr<MockDHTResponseMessage> m
(new MockDHTResponseMessage(localNode_, remoteNode,
downcast<String>(dict->get("t"))->s()));
return m;
return make_unique<MockDHTResponseMessage>
(localNode_, remoteNode, downcast<String>(dict->get("t"))->s());
}
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTPingMessage>
createPingMessage(const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID = "")
{
return std::shared_ptr<DHTQueryMessage>();
return std::unique_ptr<DHTPingMessage>{};
}
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTPingReplyMessage>
createPingReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* remoteNodeID,
const std::string& transactionID)
{
return std::shared_ptr<DHTResponseMessage>();
return std::unique_ptr<DHTPingReplyMessage>{};
}
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTFindNodeMessage>
createFindNodeMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* targetNodeID,
const std::string& transactionID = "")
{
return std::shared_ptr<DHTQueryMessage>();
return std::unique_ptr<DHTFindNodeMessage>{};
}
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTFindNodeReplyMessage>
createFindNodeReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes,
std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::string& transactionID)
{
return std::shared_ptr<DHTResponseMessage>();
return std::unique_ptr<DHTFindNodeReplyMessage>{};
}
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTGetPeersMessage>
createGetPeersMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash,
const std::string& transactionID)
{
return std::shared_ptr<DHTQueryMessage>();
return std::unique_ptr<DHTGetPeersMessage>{};
}
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTGetPeersReplyMessage>
createGetPeersReplyMessage
(const std::shared_ptr<DHTNode>& remoteNode,
const std::vector<std::shared_ptr<DHTNode> >& closestKNodes,
const std::vector<std::shared_ptr<Peer> >& peers,
std::vector<std::shared_ptr<DHTNode>> closestKNodes,
std::vector<std::shared_ptr<Peer>> peers,
const std::string& token,
const std::string& transactionID)
{
return std::shared_ptr<DHTResponseMessage>();
return std::unique_ptr<DHTGetPeersReplyMessage>{};
}
virtual std::shared_ptr<DHTQueryMessage>
virtual std::unique_ptr<DHTAnnouncePeerMessage>
createAnnouncePeerMessage(const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* infoHash,
uint16_t tcpPort,
const std::string& token,
const std::string& transactionID = "")
{
return std::shared_ptr<DHTQueryMessage>();
return std::unique_ptr<DHTAnnouncePeerMessage>{};
}
virtual std::shared_ptr<DHTResponseMessage>
virtual std::unique_ptr<DHTAnnouncePeerReplyMessage>
createAnnouncePeerReplyMessage(const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID)
{
return std::shared_ptr<DHTResponseMessage>();
return std::unique_ptr<DHTAnnouncePeerReplyMessage>{};
}
virtual std::shared_ptr<DHTMessage>
virtual std::unique_ptr<DHTUnknownMessage>
createUnknownMessage(const unsigned char* data, size_t length,
const std::string& ipaddr, uint16_t port)
{
return std::shared_ptr<DHTMessage>();
return std::unique_ptr<DHTUnknownMessage>{};
}
void setLocalNode(const std::shared_ptr<DHTNode>& node)