diff --git a/ChangeLog b/ChangeLog index a33541c5..b1a53f4c 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,54 @@ +2010-06-13 Tatsuhiro Tsujikawa + + Rewritten DHTMessageCallback using Visitor pattern. Eliminated + dynamic_pointer_cast. + * src/DHTMessageCallbackImpl.cc: Removed + * src/DHTMessageReceiver.cc + * src/DHTAbstractNodeLookupTask.h + * src/DHTAnnouncePeerReplyMessage.h + * src/DHTReplaceNodeTask.h + * src/DHTFindNodeReplyMessage.cc + * src/DHTGetPeersReplyMessage.h + * src/DHTPeerLookupTask.h + * src/DHTMessageCallbackImpl.h: Removed + * src/DHTMessageFactory.h + * src/DHTNodeLookupTaskCallback.h + * src/DHTMessageTracker.h + * src/DHTMessageCallbackListener.h: Removed + * src/DHTGetPeersReplyMessage.cc + * src/DHTMessageCallback.h + * src/DHTAnnouncePeerReplyMessage.cc + * src/DHTNodeLookupTask.h + * src/DHTReplaceNodeTask.cc + * src/DHTPeerLookupTaskCallback.cc + * src/DHTMessageTracker.cc + * src/DHTPingReplyMessage.cc + * src/DHTPingTask.cc + * src/DHTMessageFactoryImpl.h + * src/Makefile.am + * src/DHTNodeLookupTask.cc + * src/DHTPeerLookupTaskCallback.h + * src/DHTPeerLookupTask.cc + * src/DHTMessageReceiver.h + * src/DHTMessageFactoryImpl.cc + * src/DHTResponseMessage.h + * src/DHTFindNodeReplyMessage.h + * src/DHTPingReplyMessageCallback.h + * src/Makefile.in + * src/DHTBucketRefreshTask.cc + * src/DHTNodeLookupTaskCallback.cc + * src/DHTPingTask.h + * src/DHTPingReplyMessage.h + * src/DHTAbstractNodeLookupTask.cc: Removed + * test/DHTMessageTrackerTest.cc + * test/DHTPingMessageTest.cc + * test/DHTGetPeersMessageTest.cc + * test/MockDHTMessage.h + * test/MockDHTMessageFactory.h + * test/DHTFindNodeMessageTest.cc + * test/MockDHTMessageCallback.h + * test/DHTAnnouncePeerMessageTest.cc + 2010-06-13 Tatsuhiro Tsujikawa Replaced dynamic_pointer_cast with static_pointer_cast diff --git a/src/DHTAbstractNodeLookupTask.cc b/src/DHTAbstractNodeLookupTask.cc deleted file mode 100644 index dd3bf302..00000000 --- a/src/DHTAbstractNodeLookupTask.cc +++ /dev/null @@ -1,182 +0,0 @@ -/* */ -#include "DHTAbstractNodeLookupTask.h" - -#include -#include - -#include "DHTRoutingTable.h" -#include "DHTMessageDispatcher.h" -#include "DHTMessageFactory.h" -#include "DHTMessage.h" -#include "DHTNode.h" -#include "DHTMessageCallbackImpl.h" -#include "DHTBucket.h" -#include "LogFactory.h" -#include "Logger.h" -#include "util.h" -#include "DHTIDCloser.h" - -namespace aria2 { - -DHTAbstractNodeLookupTask::DHTAbstractNodeLookupTask -(const unsigned char* targetID): - _inFlightMessage(0) -{ - memcpy(_targetID, targetID, DHT_ID_LENGTH); -} - -void DHTAbstractNodeLookupTask::onReceived -(const SharedHandle& message) -{ - --_inFlightMessage; - onReceivedInternal(message); - std::vector > nodes; - getNodesFromMessage(nodes, message); - std::vector > newEntries; - toEntries(newEntries, nodes); - - size_t count = 0; - for(std::vector >::const_iterator i = - newEntries.begin(), eoi = newEntries.end(); i != eoi; ++i) { - if(memcmp(getLocalNode()->getID(), (*i)->node->getID(), - DHT_ID_LENGTH) != 0) { - _entries.push_front(*i); - ++count; - if(getLogger()->debug()) { - getLogger()->debug("Received nodes: id=%s, ip=%s", - util::toHex((*i)->node->getID(), - DHT_ID_LENGTH).c_str(), - (*i)->node->getIPAddress().c_str()); - } - } - } - if(getLogger()->debug()) { - getLogger()->debug("%u node lookup entries added.", count); - } - std::stable_sort(_entries.begin(), _entries.end(), DHTIDCloser(_targetID)); - _entries.erase(std::unique(_entries.begin(), _entries.end()), _entries.end()); - if(getLogger()->debug()) { - getLogger()->debug("%u node lookup entries are unique.", _entries.size()); - } - if(_entries.size() > DHTBucket::K) { - _entries.erase(_entries.begin()+DHTBucket::K, _entries.end()); - } - sendMessageAndCheckFinish(); -} - -void DHTAbstractNodeLookupTask::onTimeout(const SharedHandle& node) -{ - if(getLogger()->debug()) { - getLogger()->debug("node lookup message timeout for node ID=%s", - util::toHex(node->getID(), DHT_ID_LENGTH).c_str()); - } - --_inFlightMessage; - for(std::deque >::iterator i = - _entries.begin(), eoi = _entries.end(); i != eoi; ++i) { - if((*i)->node == node) { - _entries.erase(i); - break; - } - } - sendMessageAndCheckFinish(); -} - -void DHTAbstractNodeLookupTask::sendMessageAndCheckFinish() -{ - if(needsAdditionalOutgoingMessage()) { - sendMessage(); - } - if(_inFlightMessage == 0) { - if(getLogger()->debug()) { - getLogger()->debug("Finished node_lookup for node ID %s", - util::toHex(_targetID, DHT_ID_LENGTH).c_str()); - } - onFinish(); - updateBucket(); - setFinished(true); - } else { - if(getLogger()->debug()) { - getLogger()->debug("%d in flight message for node ID %s", - _inFlightMessage, - util::toHex(_targetID, DHT_ID_LENGTH).c_str()); - } - } -} - -void DHTAbstractNodeLookupTask::sendMessage() -{ - for(std::deque >::iterator i = - _entries.begin(), eoi = _entries.end(); - i != eoi && _inFlightMessage < ALPHA; ++i) { - if((*i)->used == false) { - ++_inFlightMessage; - (*i)->used = true; - SharedHandle m = createMessage((*i)->node); - WeakHandle listener(this); - SharedHandle callback - (new DHTMessageCallbackImpl(listener)); - getMessageDispatcher()->addMessageToQueue(m, callback); - } - } -} - -void DHTAbstractNodeLookupTask::updateBucket() -{ - // TODO we have to something here? -} - -void DHTAbstractNodeLookupTask::startup() -{ - std::vector > nodes; - getRoutingTable()->getClosestKNodes(nodes, _targetID); - _entries.clear(); - toEntries(_entries, nodes); - if(_entries.empty()) { - setFinished(true); - } else { - // TODO use RTT here - _inFlightMessage = 0; - sendMessage(); - if(_inFlightMessage == 0) { - if(getLogger()->debug()) { - getLogger()->debug("No message was sent in this lookup stage. Finished."); - } - setFinished(true); - } - } -} - -} // namespace aria2 diff --git a/src/DHTAbstractNodeLookupTask.h b/src/DHTAbstractNodeLookupTask.h index 0846b557..70089cd2 100644 --- a/src/DHTAbstractNodeLookupTask.h +++ b/src/DHTAbstractNodeLookupTask.h @@ -37,20 +37,31 @@ #include "DHTAbstractTask.h" +#include +#include #include #include -#include "DHTMessageCallbackListener.h" #include "DHTConstants.h" #include "DHTNodeLookupEntry.h" +#include "DHTRoutingTable.h" +#include "DHTMessageDispatcher.h" +#include "DHTMessageFactory.h" +#include "DHTMessage.h" +#include "DHTNode.h" +#include "DHTBucket.h" +#include "LogFactory.h" +#include "Logger.h" +#include "util.h" +#include "DHTIDCloser.h" namespace aria2 { class DHTNode; class DHTMessage; -class DHTAbstractNodeLookupTask:public DHTAbstractTask, - public DHTMessageCallbackListener { +template +class DHTAbstractNodeLookupTask:public DHTAbstractTask { private: unsigned char _targetID[DHT_ID_LENGTH]; @@ -69,11 +80,44 @@ private: } } - void sendMessage(); + void sendMessage() + { + for(std::deque >::iterator i = + _entries.begin(), eoi = _entries.end(); + i != eoi && _inFlightMessage < ALPHA; ++i) { + if((*i)->used == false) { + ++_inFlightMessage; + (*i)->used = true; + SharedHandle m = createMessage((*i)->node); + SharedHandle callback(createCallback()); + getMessageDispatcher()->addMessageToQueue(m, callback); + } + } + } - void updateBucket(); + void sendMessageAndCheckFinish() + { + if(needsAdditionalOutgoingMessage()) { + sendMessage(); + } + if(_inFlightMessage == 0) { + if(getLogger()->debug()) { + getLogger()->debug("Finished node_lookup for node ID %s", + util::toHex(_targetID, DHT_ID_LENGTH).c_str()); + } + onFinish(); + updateBucket(); + setFinished(true); + } else { + if(getLogger()->debug()) { + getLogger()->debug("%d in flight message for node ID %s", + _inFlightMessage, + util::toHex(_targetID, DHT_ID_LENGTH).c_str()); + } + } + } - void sendMessageAndCheckFinish(); + void updateBucket() {} protected: const unsigned char* getTargetID() const { @@ -84,21 +128,13 @@ protected: { return _entries; } -public: - DHTAbstractNodeLookupTask(const unsigned char* targetID); - static const size_t ALPHA = 3; - - virtual void startup(); - - virtual void onReceived(const SharedHandle& message); - - virtual void onTimeout(const SharedHandle& node); - - virtual void getNodesFromMessage(std::vector >& nodes, - const SharedHandle& message) = 0; + virtual void getNodesFromMessage + (std::vector >& nodes, + const ResponseMessage* message) = 0; - virtual void onReceivedInternal(const SharedHandle& message) {} + virtual void onReceivedInternal + (const ResponseMessage* message) {} virtual bool needsAdditionalOutgoingMessage() { return true; } @@ -106,6 +142,92 @@ public: virtual SharedHandle createMessage (const SharedHandle& remoteNode) = 0; + + virtual SharedHandle createCallback() = 0; +public: + DHTAbstractNodeLookupTask(const unsigned char* targetID): + _inFlightMessage(0) + { + memcpy(_targetID, targetID, DHT_ID_LENGTH); + } + + static const size_t ALPHA = 3; + + virtual void startup() + { + std::vector > nodes; + getRoutingTable()->getClosestKNodes(nodes, _targetID); + _entries.clear(); + toEntries(_entries, nodes); + if(_entries.empty()) { + setFinished(true); + } else { + // TODO use RTT here + _inFlightMessage = 0; + sendMessage(); + if(_inFlightMessage == 0) { + if(getLogger()->debug()) { + getLogger()->debug("No message was sent in this lookup stage. Finished."); + } + setFinished(true); + } + } + } + + void onReceived(const ResponseMessage* message) + { + --_inFlightMessage; + onReceivedInternal(message); + std::vector > nodes; + getNodesFromMessage(nodes, message); + std::vector > newEntries; + toEntries(newEntries, nodes); + + size_t count = 0; + for(std::vector >::const_iterator i = + newEntries.begin(), eoi = newEntries.end(); i != eoi; ++i) { + if(memcmp(getLocalNode()->getID(), (*i)->node->getID(), + DHT_ID_LENGTH) != 0) { + _entries.push_front(*i); + ++count; + if(getLogger()->debug()) { + getLogger()->debug("Received nodes: id=%s, ip=%s", + util::toHex((*i)->node->getID(), + DHT_ID_LENGTH).c_str(), + (*i)->node->getIPAddress().c_str()); + } + } + } + if(getLogger()->debug()) { + getLogger()->debug("%u node lookup entries added.", count); + } + std::stable_sort(_entries.begin(), _entries.end(), DHTIDCloser(_targetID)); + _entries.erase(std::unique(_entries.begin(), _entries.end()), _entries.end()); + if(getLogger()->debug()) { + getLogger()->debug("%u node lookup entries are unique.", _entries.size()); + } + if(_entries.size() > DHTBucket::K) { + _entries.erase(_entries.begin()+DHTBucket::K, _entries.end()); + } + sendMessageAndCheckFinish(); + } + + void onTimeout(const SharedHandle& node) + { + if(getLogger()->debug()) { + getLogger()->debug("node lookup message timeout for node ID=%s", + util::toHex(node->getID(), DHT_ID_LENGTH).c_str()); + } + --_inFlightMessage; + for(std::deque >::iterator i = + _entries.begin(), eoi = _entries.end(); i != eoi; ++i) { + if((*i)->node == node) { + _entries.erase(i); + break; + } + } + sendMessageAndCheckFinish(); + } }; } // namespace aria2 diff --git a/src/DHTAnnouncePeerReplyMessage.cc b/src/DHTAnnouncePeerReplyMessage.cc index 925c74f9..d1088103 100644 --- a/src/DHTAnnouncePeerReplyMessage.cc +++ b/src/DHTAnnouncePeerReplyMessage.cc @@ -35,6 +35,7 @@ #include "DHTAnnouncePeerReplyMessage.h" #include "DHTNode.h" #include "bencode.h" +#include "DHTMessageCallback.h" namespace aria2 { @@ -62,4 +63,9 @@ const std::string& DHTAnnouncePeerReplyMessage::getMessageType() const return ANNOUNCE_PEER; } +void DHTAnnouncePeerReplyMessage::accept(DHTMessageCallback* callback) +{ + callback->visit(this); +} + } // namespace aria2 diff --git a/src/DHTAnnouncePeerReplyMessage.h b/src/DHTAnnouncePeerReplyMessage.h index 685f601c..03d6284f 100644 --- a/src/DHTAnnouncePeerReplyMessage.h +++ b/src/DHTAnnouncePeerReplyMessage.h @@ -53,6 +53,8 @@ public: virtual const std::string& getMessageType() const; + virtual void accept(DHTMessageCallback* callback); + static const std::string ANNOUNCE_PEER; }; diff --git a/src/DHTBucketRefreshTask.cc b/src/DHTBucketRefreshTask.cc index 3640297d..84ddc49e 100644 --- a/src/DHTBucketRefreshTask.cc +++ b/src/DHTBucketRefreshTask.cc @@ -41,6 +41,7 @@ #include "DHTNodeLookupEntry.h" #include "util.h" #include "Logger.h" +#include "DHTMessageCallback.h" namespace aria2 { diff --git a/src/DHTFindNodeReplyMessage.cc b/src/DHTFindNodeReplyMessage.cc index e25e9a74..99542623 100644 --- a/src/DHTFindNodeReplyMessage.cc +++ b/src/DHTFindNodeReplyMessage.cc @@ -96,6 +96,11 @@ const std::string& DHTFindNodeReplyMessage::getMessageType() const return FIND_NODE; } +void DHTFindNodeReplyMessage::accept(DHTMessageCallback* callback) +{ + callback->visit(this); +} + void DHTFindNodeReplyMessage::setClosestKNodes (const std::vector >& closestKNodes) { diff --git a/src/DHTFindNodeReplyMessage.h b/src/DHTFindNodeReplyMessage.h index ad07bd7d..7ff88e20 100644 --- a/src/DHTFindNodeReplyMessage.h +++ b/src/DHTFindNodeReplyMessage.h @@ -58,6 +58,8 @@ public: virtual const std::string& getMessageType() const; + virtual void accept(DHTMessageCallback* callback); + const std::vector >& getClosestKNodes() const { return _closestKNodes; diff --git a/src/DHTGetPeersReplyMessage.cc b/src/DHTGetPeersReplyMessage.cc index 1aebbcb3..84720a02 100644 --- a/src/DHTGetPeersReplyMessage.cc +++ b/src/DHTGetPeersReplyMessage.cc @@ -134,6 +134,11 @@ const std::string& DHTGetPeersReplyMessage::getMessageType() const return GET_PEERS; } +void DHTGetPeersReplyMessage::accept(DHTMessageCallback* callback) +{ + callback->visit(this); +} + std::string DHTGetPeersReplyMessage::toStringOptional() const { return strconcat("token=", util::toHex(_token), diff --git a/src/DHTGetPeersReplyMessage.h b/src/DHTGetPeersReplyMessage.h index cdb14358..61e3fca9 100644 --- a/src/DHTGetPeersReplyMessage.h +++ b/src/DHTGetPeersReplyMessage.h @@ -68,6 +68,8 @@ public: virtual const std::string& getMessageType() const; + virtual void accept(DHTMessageCallback* callback); + const std::vector >& getClosestKNodes() const { return _closestKNodes; diff --git a/src/DHTMessageCallback.h b/src/DHTMessageCallback.h index d596fed9..6a2f8cb4 100644 --- a/src/DHTMessageCallback.h +++ b/src/DHTMessageCallback.h @@ -37,17 +37,32 @@ #include "common.h" #include "SharedHandle.h" +#include "DHTResponseMessage.h" namespace aria2 { -class DHTMessage; class DHTNode; +class DHTAnnouncePeerReplyMessage; +class DHTFindNodeReplyMessage; +class DHTGetPeersReplyMessage; +class DHTPingReplyMessage; class DHTMessageCallback { public: virtual ~DHTMessageCallback() {} - virtual void onReceived(const SharedHandle& message) = 0; + void onReceived(const SharedHandle& message) + { + message->accept(this); + } + + virtual void visit(const DHTAnnouncePeerReplyMessage* message) = 0; + + virtual void visit(const DHTFindNodeReplyMessage* message) = 0; + + virtual void visit(const DHTGetPeersReplyMessage* message) = 0; + + virtual void visit(const DHTPingReplyMessage* message) = 0; virtual void onTimeout(const SharedHandle& remoteNode) = 0; }; diff --git a/src/DHTMessageFactory.h b/src/DHTMessageFactory.h index be00f2cd..fb25dc44 100644 --- a/src/DHTMessageFactory.h +++ b/src/DHTMessageFactory.h @@ -46,6 +46,8 @@ namespace aria2 { class DHTMessage; +class DHTQueryMessage; +class DHTResponseMessage; class DHTNode; class Peer; class BDE; @@ -54,62 +56,62 @@ class DHTMessageFactory { public: virtual ~DHTMessageFactory() {} - virtual SharedHandle + virtual SharedHandle createQueryMessage(const BDE& dict, const std::string& ipaddr, uint16_t port) = 0; - virtual SharedHandle + virtual SharedHandle createResponseMessage(const std::string& messageType, const BDE& dict, const std::string& ipaddr, uint16_t port) = 0; - virtual SharedHandle + virtual SharedHandle createPingMessage(const SharedHandle& remoteNode, const std::string& transactionID = A2STR::NIL) = 0; - virtual SharedHandle + virtual SharedHandle createPingReplyMessage(const SharedHandle& remoteNode, const unsigned char* id, const std::string& transactionID) = 0; - virtual SharedHandle + virtual SharedHandle createFindNodeMessage(const SharedHandle& remoteNode, const unsigned char* targetNodeID, const std::string& transactionID = A2STR::NIL) = 0; - virtual SharedHandle + virtual SharedHandle createFindNodeReplyMessage (const SharedHandle& remoteNode, const std::vector >& closestKNodes, const std::string& transactionID) = 0; - virtual SharedHandle + virtual SharedHandle createGetPeersMessage(const SharedHandle& remoteNode, const unsigned char* infoHash, const std::string& transactionID = A2STR::NIL) = 0; - virtual SharedHandle + virtual SharedHandle createGetPeersReplyMessage (const SharedHandle& remoteNode, const std::vector >& closestKNodes, const std::string& token, const std::string& transactionID) = 0; - virtual SharedHandle + virtual SharedHandle createGetPeersReplyMessage (const SharedHandle& remoteNode, const std::vector >& peers, const std::string& token, const std::string& transactionID) = 0; - virtual SharedHandle + virtual SharedHandle createAnnouncePeerMessage(const SharedHandle& remoteNode, const unsigned char* infoHash, uint16_t tcpPort, const std::string& token, const std::string& transactionID = A2STR::NIL) = 0; - virtual SharedHandle + virtual SharedHandle createAnnouncePeerReplyMessage(const SharedHandle& remoteNode, const std::string& transactionID) = 0; diff --git a/src/DHTMessageFactoryImpl.cc b/src/DHTMessageFactoryImpl.cc index b2a9a6ef..3db7e41d 100644 --- a/src/DHTMessageFactoryImpl.cc +++ b/src/DHTMessageFactoryImpl.cc @@ -71,7 +71,8 @@ DHTMessageFactoryImpl::DHTMessageFactoryImpl(): DHTMessageFactoryImpl::~DHTMessageFactoryImpl() {} SharedHandle -DHTMessageFactoryImpl::getRemoteNode(const unsigned char* id, const std::string& ipaddr, uint16_t port) const +DHTMessageFactoryImpl::getRemoteNode +(const unsigned char* id, const std::string& ipaddr, uint16_t port) const { SharedHandle node = _routingTable->getNode(id, ipaddr, port); if(node.isNull()) { @@ -184,10 +185,8 @@ static void setVersion(const SharedHandle& msg, const BDE& dict) } } -SharedHandle DHTMessageFactoryImpl::createQueryMessage -(const BDE& dict, - const std::string& ipaddr, - uint16_t port) +SharedHandle DHTMessageFactoryImpl::createQueryMessage +(const BDE& dict, const std::string& ipaddr, uint16_t port) { const BDE& messageType = getString(dict, DHTQueryMessage::Q); const BDE& transactionID = getString(dict, DHTMessage::T); @@ -199,7 +198,7 @@ SharedHandle DHTMessageFactoryImpl::createQueryMessage const BDE& id = getString(aDict, DHTMessage::ID); validateID(id); SharedHandle remoteNode = getRemoteNode(id.uc(), ipaddr, port); - SharedHandle msg; + SharedHandle msg; if(messageType.s() == DHTPingMessage::PING) { msg = createPingMessage(remoteNode, transactionID.s()); } else if(messageType.s() == DHTFindNodeMessage::FIND_NODE) { @@ -233,11 +232,12 @@ SharedHandle DHTMessageFactoryImpl::createQueryMessage return msg; } -SharedHandle -DHTMessageFactoryImpl::createResponseMessage(const std::string& messageType, - const BDE& dict, - const std::string& ipaddr, - uint16_t port) +SharedHandle +DHTMessageFactoryImpl::createResponseMessage +(const std::string& messageType, + const BDE& dict, + const std::string& ipaddr, + uint16_t port) { const BDE& transactionID = getString(dict, DHTMessage::T); const BDE& y = getString(dict, DHTMessage::Y); @@ -265,7 +265,7 @@ DHTMessageFactoryImpl::createResponseMessage(const std::string& messageType, const BDE& id = getString(rDict, DHTMessage::ID); validateID(id); SharedHandle remoteNode = getRemoteNode(id.uc(), ipaddr, port); - SharedHandle msg; + SharedHandle msg; if(messageType == DHTPingReplyMessage::PING) { msg = createPingReplyMessage(remoteNode, id.uc(), transactionID.s()); } else if(messageType == DHTFindNodeReplyMessage::FIND_NODE) { @@ -317,34 +317,38 @@ void DHTMessageFactoryImpl::setCommonProperty(const SharedHandlesetVersion(getDefaultVersion()); } -SharedHandle DHTMessageFactoryImpl::createPingMessage(const SharedHandle& remoteNode, const std::string& transactionID) +SharedHandle DHTMessageFactoryImpl::createPingMessage +(const SharedHandle& remoteNode, const std::string& transactionID) { SharedHandle m(new DHTPingMessage(_localNode, remoteNode, transactionID)); setCommonProperty(m); return m; } -SharedHandle -DHTMessageFactoryImpl::createPingReplyMessage(const SharedHandle& remoteNode, - const unsigned char* id, - const std::string& transactionID) +SharedHandle DHTMessageFactoryImpl::createPingReplyMessage +(const SharedHandle& remoteNode, + const unsigned char* id, + const std::string& transactionID) { - SharedHandle m(new DHTPingReplyMessage(_localNode, remoteNode, id, transactionID)); + SharedHandle m + (new DHTPingReplyMessage(_localNode, remoteNode, id, transactionID)); setCommonProperty(m); return m; } -SharedHandle -DHTMessageFactoryImpl::createFindNodeMessage(const SharedHandle& remoteNode, - const unsigned char* targetNodeID, - const std::string& transactionID) +SharedHandle DHTMessageFactoryImpl::createFindNodeMessage +(const SharedHandle& remoteNode, + const unsigned char* targetNodeID, + const std::string& transactionID) { - SharedHandle m(new DHTFindNodeMessage(_localNode, remoteNode, targetNodeID, transactionID)); + SharedHandle m + (new DHTFindNodeMessage + (_localNode, remoteNode, targetNodeID, transactionID)); setCommonProperty(m); return m; } -SharedHandle +SharedHandle DHTMessageFactoryImpl::createFindNodeReplyMessage (const SharedHandle& remoteNode, const std::vector >& closestKNodes, @@ -378,7 +382,7 @@ DHTMessageFactoryImpl::extractNodes(const unsigned char* src, size_t length) return nodes; } -SharedHandle +SharedHandle DHTMessageFactoryImpl::createFindNodeReplyMessage (const SharedHandle& remoteNode, const BDE& dict, @@ -392,22 +396,21 @@ DHTMessageFactoryImpl::createFindNodeReplyMessage return createFindNodeReplyMessage(remoteNode, nodes, transactionID); } -SharedHandle -DHTMessageFactoryImpl::createGetPeersMessage(const SharedHandle& remoteNode, - const unsigned char* infoHash, - const std::string& transactionID) +SharedHandle +DHTMessageFactoryImpl::createGetPeersMessage +(const SharedHandle& remoteNode, + const unsigned char* infoHash, + const std::string& transactionID) { - SharedHandle m(new DHTGetPeersMessage(_localNode, - remoteNode, - infoHash, - transactionID)); + SharedHandle m + (new DHTGetPeersMessage(_localNode, remoteNode, infoHash, transactionID)); m->setPeerAnnounceStorage(_peerAnnounceStorage); m->setTokenTracker(_tokenTracker); setCommonProperty(m); return m; } -SharedHandle +SharedHandle DHTMessageFactoryImpl::createGetPeersReplyMessageWithNodes (const SharedHandle& remoteNode, const BDE& dict, @@ -423,7 +426,7 @@ DHTMessageFactoryImpl::createGetPeersReplyMessageWithNodes transactionID); } -SharedHandle +SharedHandle DHTMessageFactoryImpl::createGetPeersReplyMessage (const SharedHandle& remoteNode, const std::vector >& closestKNodes, @@ -437,7 +440,7 @@ DHTMessageFactoryImpl::createGetPeersReplyMessage return m; } -SharedHandle +SharedHandle DHTMessageFactoryImpl::createGetPeersReplyMessageWithValues (const SharedHandle& remoteNode, const BDE& dict, @@ -462,37 +465,40 @@ DHTMessageFactoryImpl::createGetPeersReplyMessageWithValues transactionID); } -SharedHandle +SharedHandle DHTMessageFactoryImpl::createGetPeersReplyMessage (const SharedHandle& remoteNode, const std::vector >& values, const std::string& token, const std::string& transactionID) { - SharedHandle m(new DHTGetPeersReplyMessage(_localNode, remoteNode, token, transactionID)); + SharedHandle m + (new DHTGetPeersReplyMessage(_localNode, remoteNode, token, transactionID)); m->setValues(values); setCommonProperty(m); return m; } -SharedHandle -DHTMessageFactoryImpl::createAnnouncePeerMessage(const SharedHandle& remoteNode, - const unsigned char* infoHash, - uint16_t tcpPort, - const std::string& token, - const std::string& transactionID) +SharedHandle +DHTMessageFactoryImpl::createAnnouncePeerMessage +(const SharedHandle& remoteNode, + const unsigned char* infoHash, + uint16_t tcpPort, + const std::string& token, + const std::string& transactionID) { SharedHandle m - (new DHTAnnouncePeerMessage(_localNode, remoteNode, infoHash, tcpPort, token, transactionID)); + (new DHTAnnouncePeerMessage + (_localNode, remoteNode, infoHash, tcpPort, token, transactionID)); m->setPeerAnnounceStorage(_peerAnnounceStorage); m->setTokenTracker(_tokenTracker); setCommonProperty(m); return m; } -SharedHandle -DHTMessageFactoryImpl::createAnnouncePeerReplyMessage(const SharedHandle& remoteNode, - const std::string& transactionID) +SharedHandle +DHTMessageFactoryImpl::createAnnouncePeerReplyMessage +(const SharedHandle& remoteNode, const std::string& transactionID) { SharedHandle m (new DHTAnnouncePeerReplyMessage(_localNode, remoteNode, transactionID)); @@ -501,8 +507,9 @@ DHTMessageFactoryImpl::createAnnouncePeerReplyMessage(const SharedHandle -DHTMessageFactoryImpl::createUnknownMessage(const unsigned char* data, size_t length, - const std::string& ipaddr, uint16_t port) +DHTMessageFactoryImpl::createUnknownMessage +(const unsigned char* data, size_t length, + const std::string& ipaddr, uint16_t port) { SharedHandle m @@ -510,32 +517,38 @@ DHTMessageFactoryImpl::createUnknownMessage(const unsigned char* data, size_t le return m; } -void DHTMessageFactoryImpl::setRoutingTable(const WeakHandle& routingTable) +void DHTMessageFactoryImpl::setRoutingTable +(const WeakHandle& routingTable) { _routingTable = routingTable; } -void DHTMessageFactoryImpl::setConnection(const WeakHandle& connection) +void DHTMessageFactoryImpl::setConnection +(const WeakHandle& connection) { _connection = connection; } -void DHTMessageFactoryImpl::setMessageDispatcher(const WeakHandle& dispatcher) +void DHTMessageFactoryImpl::setMessageDispatcher +(const WeakHandle& dispatcher) { _dispatcher = dispatcher; } -void DHTMessageFactoryImpl::setPeerAnnounceStorage(const WeakHandle& storage) +void DHTMessageFactoryImpl::setPeerAnnounceStorage +(const WeakHandle& storage) { _peerAnnounceStorage = storage; } -void DHTMessageFactoryImpl::setTokenTracker(const WeakHandle& tokenTracker) +void DHTMessageFactoryImpl::setTokenTracker +(const WeakHandle& tokenTracker) { _tokenTracker = tokenTracker; } -void DHTMessageFactoryImpl::setLocalNode(const SharedHandle& localNode) +void DHTMessageFactoryImpl::setLocalNode +(const SharedHandle& localNode) { _localNode = localNode; } diff --git a/src/DHTMessageFactoryImpl.h b/src/DHTMessageFactoryImpl.h index ce29af08..2aed434d 100644 --- a/src/DHTMessageFactoryImpl.h +++ b/src/DHTMessageFactoryImpl.h @@ -66,7 +66,8 @@ private: Logger* _logger; // search node in routingTable. If it is not found, create new one. - SharedHandle getRemoteNode(const unsigned char* id, const std::string& ipaddr, uint16_t port) const; + SharedHandle getRemoteNode + (const unsigned char* id, const std::string& ipaddr, uint16_t port) const; void validateID(const BDE& id) const; @@ -82,78 +83,78 @@ public: virtual ~DHTMessageFactoryImpl(); - virtual SharedHandle + virtual SharedHandle createQueryMessage(const BDE& dict, const std::string& ipaddr, uint16_t port); - virtual SharedHandle + virtual SharedHandle createResponseMessage(const std::string& messageType, const BDE& dict, const std::string& ipaddr, uint16_t port); - virtual SharedHandle + virtual SharedHandle createPingMessage(const SharedHandle& remoteNode, const std::string& transactionID = A2STR::NIL); - virtual SharedHandle + virtual SharedHandle createPingReplyMessage(const SharedHandle& remoteNode, const unsigned char* id, const std::string& transactionID); - virtual SharedHandle + virtual SharedHandle createFindNodeMessage(const SharedHandle& remoteNode, const unsigned char* targetNodeID, const std::string& transactionID = A2STR::NIL); - SharedHandle + SharedHandle createFindNodeReplyMessage(const SharedHandle& remoteNode, const BDE& dict, const std::string& transactionID); - virtual SharedHandle + virtual SharedHandle createFindNodeReplyMessage (const SharedHandle& remoteNode, const std::vector >& closestKNodes, const std::string& transactionID); - virtual SharedHandle + virtual SharedHandle createGetPeersMessage(const SharedHandle& remoteNode, const unsigned char* infoHash, const std::string& transactionID = A2STR::NIL); - virtual SharedHandle + virtual SharedHandle createGetPeersReplyMessage (const SharedHandle& remoteNode, const std::vector >& closestKNodes, const std::string& token, const std::string& transactionID); - SharedHandle + SharedHandle createGetPeersReplyMessageWithNodes(const SharedHandle& remoteNode, const BDE& dict, const std::string& transactionID); - virtual SharedHandle + virtual SharedHandle createGetPeersReplyMessage (const SharedHandle& remoteNode, const std::vector >& peers, const std::string& token, const std::string& transactionID); - SharedHandle + SharedHandle createGetPeersReplyMessageWithValues(const SharedHandle& remoteNode, const BDE& dict, const std::string& transactionID); - virtual SharedHandle + virtual SharedHandle createAnnouncePeerMessage(const SharedHandle& remoteNode, const unsigned char* infoHash, uint16_t tcpPort, const std::string& token, const std::string& transactionID = A2STR::NIL); - virtual SharedHandle + virtual SharedHandle createAnnouncePeerReplyMessage(const SharedHandle& remoteNode, const std::string& transactionID); @@ -167,7 +168,8 @@ public: void setMessageDispatcher(const WeakHandle& dispatcher); - void setPeerAnnounceStorage(const WeakHandle& storage); + void setPeerAnnounceStorage + (const WeakHandle& storage); void setTokenTracker(const WeakHandle& tokenTracker); diff --git a/src/DHTMessageReceiver.cc b/src/DHTMessageReceiver.cc index 62b40c37..3269d0a8 100644 --- a/src/DHTMessageReceiver.cc +++ b/src/DHTMessageReceiver.cc @@ -40,6 +40,7 @@ #include "DHTMessageTracker.h" #include "DHTConnection.h" #include "DHTMessage.h" +#include "DHTQueryMessage.h" #include "DHTResponseMessage.h" #include "DHTUnknownMessage.h" #include "DHTMessageFactory.h" @@ -91,43 +92,49 @@ SharedHandle DHTMessageReceiver::receiveMessage() " From:%s:%u", remoteAddr.c_str(), remotePort); return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); } - SharedHandle message; - SharedHandle callback; if(isReply) { - std::pair, SharedHandle > p = + std::pair, + SharedHandle > p = _tracker->messageArrived(dict, remoteAddr, remotePort); - message = p.first; - callback = p.second; - if(message.isNull()) { + if(p.first.isNull()) { // timeout or malicious? message return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); } + onMessageReceived(p.first); + if(!p.second.isNull()) { + p.second->onReceived(p.first); + } + return p.first; } else { - message = _factory->createQueryMessage(dict, remoteAddr, remotePort); + SharedHandle message = + _factory->createQueryMessage(dict, remoteAddr, remotePort); if(message->getLocalNode() == message->getRemoteNode()) { // drop message from localnode _logger->info("Received DHT message from localnode."); return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); } + onMessageReceived(message); + return message; } - if(_logger->info()) { - _logger->info("Message received: %s", message->toString().c_str()); - } - message->validate(); - message->doReceivedAction(); - message->getRemoteNode()->markGood(); - message->getRemoteNode()->updateLastContact(); - _routingTable->addGoodNode(message->getRemoteNode()); - if(!callback.isNull()) { - callback->onReceived(message); - } - return message; } catch(RecoverableException& e) { _logger->info("Exception thrown while receiving DHT message.", e); return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); } } +void DHTMessageReceiver::onMessageReceived +(const SharedHandle& message) +{ + if(_logger->info()) { + _logger->info("Message received: %s", message->toString().c_str()); + } + message->validate(); + message->doReceivedAction(); + message->getRemoteNode()->markGood(); + message->getRemoteNode()->updateLastContact(); + _routingTable->addGoodNode(message->getRemoteNode()); +} + void DHTMessageReceiver::handleTimeout() { _tracker->handleTimeout(); diff --git a/src/DHTMessageReceiver.h b/src/DHTMessageReceiver.h index cafc7ada..46945209 100644 --- a/src/DHTMessageReceiver.h +++ b/src/DHTMessageReceiver.h @@ -63,6 +63,8 @@ private: SharedHandle handleUnknownMessage(const unsigned char* data, size_t length, const std::string& remoteAddr, uint16_t remotePort); + + void onMessageReceived(const SharedHandle& message); public: DHTMessageReceiver(const SharedHandle& tracker); diff --git a/src/DHTMessageTracker.cc b/src/DHTMessageTracker.cc index 60f9feff..02c582bc 100644 --- a/src/DHTMessageTracker.cc +++ b/src/DHTMessageTracker.cc @@ -63,9 +63,9 @@ void DHTMessageTracker::addMessage(const SharedHandle& message, time _entries.push_back(e); } -std::pair, SharedHandle > -DHTMessageTracker::messageArrived(const BDE& dict, - const std::string& ipaddr, uint16_t port) +std::pair, SharedHandle > +DHTMessageTracker::messageArrived +(const BDE& dict, const std::string& ipaddr, uint16_t port) { const BDE& tid = dict[DHTMessage::T]; if(!tid.isString()) { @@ -86,7 +86,7 @@ DHTMessageTracker::messageArrived(const BDE& dict, } SharedHandle targetNode = entry->getTargetNode(); - SharedHandle message = + SharedHandle message = _factory->createResponseMessage(entry->getMessageType(), dict, targetNode->getIPAddress(), targetNode->getPort()); @@ -103,7 +103,8 @@ DHTMessageTracker::messageArrived(const BDE& dict, if(_logger->debug()) { _logger->debug("Tracker entry not found."); } - return std::pair, SharedHandle >(); + return std::pair, + SharedHandle >(); } void DHTMessageTracker::handleTimeout() diff --git a/src/DHTMessageTracker.h b/src/DHTMessageTracker.h index 86d70fe7..c3d36062 100644 --- a/src/DHTMessageTracker.h +++ b/src/DHTMessageTracker.h @@ -46,6 +46,7 @@ namespace aria2 { class DHTMessage; +class DHTResponseMessage; class DHTMessageCallback; class DHTRoutingTable; class DHTMessageFactory; @@ -72,13 +73,14 @@ public: const SharedHandle& callback = SharedHandle()); - std::pair, SharedHandle > + std::pair, SharedHandle > messageArrived(const BDE& dict, const std::string& ipaddr, uint16_t port); void handleTimeout(); - SharedHandle getEntryFor(const SharedHandle& message) const; + SharedHandle getEntryFor + (const SharedHandle& message) const; size_t countEntry() const; diff --git a/src/DHTNodeLookupTask.cc b/src/DHTNodeLookupTask.cc index 82cbb481..46c56c35 100644 --- a/src/DHTNodeLookupTask.cc +++ b/src/DHTNodeLookupTask.cc @@ -39,24 +39,23 @@ #include "DHTNodeLookupEntry.h" #include "LogFactory.h" #include "util.h" +#include "DHTNodeLookupTaskCallback.h" +#include "DHTQueryMessage.h" namespace aria2 { DHTNodeLookupTask::DHTNodeLookupTask(const unsigned char* targetNodeID): - DHTAbstractNodeLookupTask(targetNodeID) + DHTAbstractNodeLookupTask(targetNodeID) {} void DHTNodeLookupTask::getNodesFromMessage (std::vector >& nodes, - const SharedHandle& message) + const DHTFindNodeReplyMessage* message) { - SharedHandle m - (dynamic_pointer_cast(message)); - if(!m.isNull()) { - const std::vector >& knodes = m->getClosestKNodes(); - nodes.insert(nodes.end(), knodes.begin(), knodes.end()); - } + const std::vector >& knodes = + message->getClosestKNodes(); + nodes.insert(nodes.end(), knodes.begin(), knodes.end()); } SharedHandle @@ -65,4 +64,10 @@ DHTNodeLookupTask::createMessage(const SharedHandle& remoteNode) return getMessageFactory()->createFindNodeMessage(remoteNode, getTargetID()); } +SharedHandle DHTNodeLookupTask::createCallback() +{ + return SharedHandle + (new DHTNodeLookupTaskCallback(this)); +} + } // namespace aria2 diff --git a/src/DHTNodeLookupTask.h b/src/DHTNodeLookupTask.h index bf20a60a..270ec6b2 100644 --- a/src/DHTNodeLookupTask.h +++ b/src/DHTNodeLookupTask.h @@ -39,15 +39,21 @@ namespace aria2 { -class DHTNodeLookupTask:public DHTAbstractNodeLookupTask { +class DHTFindNodeReplyMessage; + +class DHTNodeLookupTask: + public DHTAbstractNodeLookupTask { public: DHTNodeLookupTask(const unsigned char* targetNodeID); - virtual void getNodesFromMessage(std::vector >& nodes, - const SharedHandle& message); + virtual void getNodesFromMessage + (std::vector >& nodes, + const DHTFindNodeReplyMessage* message); virtual SharedHandle createMessage (const SharedHandle& remoteNode); + + virtual SharedHandle createCallback(); }; } // namespace aria2 diff --git a/src/DHTMessageCallbackListener.h b/src/DHTNodeLookupTaskCallback.cc similarity index 76% rename from src/DHTMessageCallbackListener.h rename to src/DHTNodeLookupTaskCallback.cc index 9e06fcdb..0b0fd197 100644 --- a/src/DHTMessageCallbackListener.h +++ b/src/DHTNodeLookupTaskCallback.cc @@ -2,7 +2,7 @@ /* * aria2 - The high speed download utility * - * Copyright (C) 2006 Tatsuhiro Tsujikawa + * Copyright (C) 2010 Tatsuhiro Tsujikawa * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -32,26 +32,23 @@ * files in the program, then also delete it here. */ /* copyright --> */ -#ifndef _D_DHT_MESSAGE_CALLBACK_LISTENER_H_ -#define _D_DHT_MESSAGE_CALLBACK_LISTENER_H_ - -#include "common.h" -#include "SharedHandle.h" +#include "DHTNodeLookupTaskCallback.h" +#include "DHTNodeLookupTask.h" namespace aria2 { -class DHTMessage; -class DHTNode; +DHTNodeLookupTaskCallback::DHTNodeLookupTaskCallback(DHTNodeLookupTask* task): + _task(task) {} -class DHTMessageCallbackListener { -public: - virtual ~DHTMessageCallbackListener() {} +void DHTNodeLookupTaskCallback::visit(const DHTFindNodeReplyMessage* message) +{ + _task->onReceived(message); +} - virtual void onReceived(const SharedHandle& message) = 0; - - virtual void onTimeout(const SharedHandle& remoteNode) = 0; -}; +void DHTNodeLookupTaskCallback::onTimeout +(const SharedHandle& remoteNode) +{ + _task->onTimeout(remoteNode); +} } // namespace aria2 - -#endif // _D_DHT_MESSAGE_CALLBACK_LISTENER_H_ diff --git a/src/DHTMessageCallbackImpl.h b/src/DHTNodeLookupTaskCallback.h similarity index 74% rename from src/DHTMessageCallbackImpl.h rename to src/DHTNodeLookupTaskCallback.h index f8334893..02bdd757 100644 --- a/src/DHTMessageCallbackImpl.h +++ b/src/DHTNodeLookupTaskCallback.h @@ -2,7 +2,7 @@ /* * aria2 - The high speed download utility * - * Copyright (C) 2006 Tatsuhiro Tsujikawa + * Copyright (C) 2010 Tatsuhiro Tsujikawa * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -32,29 +32,32 @@ * files in the program, then also delete it here. */ /* copyright --> */ -#ifndef _D_DHT_MESSAGE_CALLBACK_IMPL_H_ -#define _D_DHT_MESSAGE_CALLBACK_IMPL_H_ +#ifndef _D_DHT_NODE_LOOKUP_TASK_CALLBACK_H_ +#define _D_DHT_NODE_LOOKUP_TASK_CALLBACK_H_ #include "DHTMessageCallback.h" namespace aria2 { -class DHTMessageCallbackListener; +class DHTNodeLookupTask; -class DHTMessageCallbackImpl:public DHTMessageCallback { +class DHTNodeLookupTaskCallback:public DHTMessageCallback { private: - WeakHandle _listener; - + DHTNodeLookupTask* _task; public: - DHTMessageCallbackImpl(const WeakHandle& listener); + DHTNodeLookupTaskCallback(DHTNodeLookupTask* task); - virtual ~DHTMessageCallbackImpl(); + virtual void visit(const DHTAnnouncePeerReplyMessage* message) {} - virtual void onReceived(const SharedHandle& message); + virtual void visit(const DHTFindNodeReplyMessage* message); + + virtual void visit(const DHTGetPeersReplyMessage* message) {} + + virtual void visit(const DHTPingReplyMessage* message) {} virtual void onTimeout(const SharedHandle& remoteNode); }; } // namespace aria2 -#endif // _D_DHT_MESSAGE_CALLBACK_IMPL_H_ +#endif // _D_DHT_NODE_LOOKUP_TASK_CALLBACK_H_ diff --git a/src/DHTPeerLookupTask.cc b/src/DHTPeerLookupTask.cc index d1ab7b36..9cb5eeff 100644 --- a/src/DHTPeerLookupTask.cc +++ b/src/DHTPeerLookupTask.cc @@ -46,47 +46,50 @@ #include "util.h" #include "DHTBucket.h" #include "bittorrent_helper.h" +#include "DHTPeerLookupTaskCallback.h" +#include "DHTQueryMessage.h" namespace aria2 { DHTPeerLookupTask::DHTPeerLookupTask (const SharedHandle& downloadContext): - DHTAbstractNodeLookupTask(bittorrent::getInfoHash(downloadContext)) {} + DHTAbstractNodeLookupTask + (bittorrent::getInfoHash(downloadContext)) {} void DHTPeerLookupTask::getNodesFromMessage (std::vector >& nodes, - const SharedHandle& message) + const DHTGetPeersReplyMessage* message) { - SharedHandle m - (dynamic_pointer_cast(message)); - if(!m.isNull()) { - const std::vector >& knodes = m->getClosestKNodes(); - nodes.insert(nodes.end(), knodes.begin(), knodes.end()); - } + const std::vector >& knodes = + message->getClosestKNodes(); + nodes.insert(nodes.end(), knodes.begin(), knodes.end()); } void DHTPeerLookupTask::onReceivedInternal -(const SharedHandle& message) +(const DHTGetPeersReplyMessage* message) { - SharedHandle m - (dynamic_pointer_cast(message)); - if(m.isNull()) { - return; - } - SharedHandle remoteNode = m->getRemoteNode(); + SharedHandle remoteNode = message->getRemoteNode(); _tokenStorage[util::toHex(remoteNode->getID(), DHT_ID_LENGTH)] = - m->getToken(); - _peerStorage->addPeer(m->getValues()); - _peers.insert(_peers.end(), m->getValues().begin(), m->getValues().end()); - getLogger()->info("Received %u peers.", m->getValues().size()); + message->getToken(); + _peerStorage->addPeer(message->getValues()); + _peers.insert(_peers.end(), + message->getValues().begin(), message->getValues().end()); + getLogger()->info("Received %u peers.", message->getValues().size()); } -SharedHandle DHTPeerLookupTask::createMessage(const SharedHandle& remoteNode) +SharedHandle DHTPeerLookupTask::createMessage +(const SharedHandle& remoteNode) { return getMessageFactory()->createGetPeersMessage(remoteNode, getTargetID()); } +SharedHandle DHTPeerLookupTask::createCallback() +{ + return SharedHandle + (new DHTPeerLookupTaskCallback(this)); +} + void DHTPeerLookupTask::onFinish() { // send announce_peer message to K closest nodes diff --git a/src/DHTPeerLookupTask.h b/src/DHTPeerLookupTask.h index fe07b4bf..5ead514a 100644 --- a/src/DHTPeerLookupTask.h +++ b/src/DHTPeerLookupTask.h @@ -44,8 +44,10 @@ class DownloadContext; class Peer; class PeerStorage; class BtRuntime; +class DHTGetPeersReplyMessage; -class DHTPeerLookupTask:public DHTAbstractNodeLookupTask { +class DHTPeerLookupTask: + public DHTAbstractNodeLookupTask { private: std::map _tokenStorage; @@ -57,14 +59,17 @@ private: public: DHTPeerLookupTask(const SharedHandle& downloadContext); - virtual void getNodesFromMessage(std::vector >& nodes, - const SharedHandle& message); + virtual void getNodesFromMessage + (std::vector >& nodes, + const DHTGetPeersReplyMessage* message); - virtual void onReceivedInternal(const SharedHandle& message); + virtual void onReceivedInternal(const DHTGetPeersReplyMessage* message); virtual SharedHandle createMessage (const SharedHandle& remoteNode); + virtual SharedHandle createCallback(); + virtual void onFinish(); const std::vector >& getPeers() const diff --git a/src/DHTMessageCallbackImpl.cc b/src/DHTPeerLookupTaskCallback.cc similarity index 71% rename from src/DHTMessageCallbackImpl.cc rename to src/DHTPeerLookupTaskCallback.cc index 01a66ff5..5343dc70 100644 --- a/src/DHTMessageCallbackImpl.cc +++ b/src/DHTPeerLookupTaskCallback.cc @@ -2,7 +2,7 @@ /* * aria2 - The high speed download utility * - * Copyright (C) 2006 Tatsuhiro Tsujikawa + * Copyright (C) 2010 Tatsuhiro Tsujikawa * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -32,29 +32,23 @@ * files in the program, then also delete it here. */ /* copyright --> */ -#include "DHTMessageCallbackImpl.h" -#include "DHTMessage.h" -#include "DHTNode.h" -#include "DHTMessageCallbackListener.h" +#include "DHTPeerLookupTaskCallback.h" +#include "DHTPeerLookupTask.h" namespace aria2 { - -DHTMessageCallbackImpl::DHTMessageCallbackImpl(const WeakHandle& listener):_listener(listener) {} -DHTMessageCallbackImpl::~DHTMessageCallbackImpl() {} +DHTPeerLookupTaskCallback::DHTPeerLookupTaskCallback(DHTPeerLookupTask* task): + _task(task) {} -void DHTMessageCallbackImpl::onReceived(const SharedHandle& message) +void DHTPeerLookupTaskCallback::visit(const DHTGetPeersReplyMessage* message) { - if(!_listener.isNull()) { - _listener->onReceived(message); - } + _task->onReceived(message); } -void DHTMessageCallbackImpl::onTimeout(const SharedHandle& remoteNode) +void DHTPeerLookupTaskCallback::onTimeout +(const SharedHandle& remoteNode) { - if(!_listener.isNull()) { - _listener->onTimeout(remoteNode); - } + _task->onTimeout(remoteNode); } } // namespace aria2 diff --git a/src/DHTPeerLookupTaskCallback.h b/src/DHTPeerLookupTaskCallback.h new file mode 100644 index 00000000..ef865ee7 --- /dev/null +++ b/src/DHTPeerLookupTaskCallback.h @@ -0,0 +1,63 @@ +/* */ +#ifndef _D_DHT_PEER_LOOKUP_TASK_CALLBACK_H_ +#define _D_DHT_PEER_LOOKUP_TASK_CALLBACK_H_ + +#include "DHTMessageCallback.h" + +namespace aria2 { + +class DHTPeerLookupTask; + +class DHTPeerLookupTaskCallback:public DHTMessageCallback { +private: + DHTPeerLookupTask* _task; +public: + DHTPeerLookupTaskCallback(DHTPeerLookupTask* task); + + virtual void visit(const DHTAnnouncePeerReplyMessage* message) {} + + virtual void visit(const DHTFindNodeReplyMessage* message) {} + + virtual void visit(const DHTGetPeersReplyMessage* message); + + virtual void visit(const DHTPingReplyMessage* message) {} + + virtual void onTimeout(const SharedHandle& remoteNode); +}; + +} // namespace aria2 + +#endif // _D_DHT_PEER_LOOKUP_TASK_CALLBACK_H_ diff --git a/src/DHTPingReplyMessage.cc b/src/DHTPingReplyMessage.cc index 37b81f2a..f004a037 100644 --- a/src/DHTPingReplyMessage.cc +++ b/src/DHTPingReplyMessage.cc @@ -38,6 +38,7 @@ #include "DHTNode.h" #include "bencode.h" +#include "DHTMessageCallback.h" namespace aria2 { @@ -68,4 +69,9 @@ const std::string& DHTPingReplyMessage::getMessageType() const return PING; } +void DHTPingReplyMessage::accept(DHTMessageCallback* callback) +{ + callback->visit(this); +} + } // namespace aria2 diff --git a/src/DHTPingReplyMessage.h b/src/DHTPingReplyMessage.h index fbe00e10..00a8e699 100644 --- a/src/DHTPingReplyMessage.h +++ b/src/DHTPingReplyMessage.h @@ -57,6 +57,8 @@ public: virtual const std::string& getMessageType() const; + virtual void accept(DHTMessageCallback* callback); + const unsigned char* getRemoteID() { return _id; diff --git a/src/DHTPingReplyMessageCallback.h b/src/DHTPingReplyMessageCallback.h new file mode 100644 index 00000000..e8bad843 --- /dev/null +++ b/src/DHTPingReplyMessageCallback.h @@ -0,0 +1,68 @@ +/* */ +#ifndef _D_DHT_PING_REPLY_MESSAGE_CALLBACK_H_ +#define _D_DHT_PING_REPLY_MESSAGE_CALLBACK_H_ + +#include "DHTMessageCallback.h" + +namespace aria2 { + +template +class DHTPingReplyMessageCallback:public DHTMessageCallback { +private: + Task* _task; +public: + DHTPingReplyMessageCallback(Task* task):_task(task) {} + + virtual void visit(const DHTAnnouncePeerReplyMessage* message) {} + + virtual void visit(const DHTFindNodeReplyMessage* message) {} + + virtual void visit(const DHTGetPeersReplyMessage* message) {} + + virtual void visit(const DHTPingReplyMessage* message) + { + _task->onReceived(message); + } + + virtual void onTimeout(const SharedHandle& remoteNode) + { + _task->onTimeout(remoteNode); + } +}; + +} // namespace aria2 + +#endif // _D_DHT_PING_REPLY_MESSAGE_CALLBACK_H_ diff --git a/src/DHTPingTask.cc b/src/DHTPingTask.cc index cdae2948..f58b9067 100644 --- a/src/DHTPingTask.cc +++ b/src/DHTPingTask.cc @@ -33,12 +33,13 @@ */ /* copyright --> */ #include "DHTPingTask.h" -#include "DHTMessageCallbackImpl.h" #include "DHTMessage.h" #include "DHTMessageFactory.h" #include "DHTMessageDispatcher.h" #include "DHTNode.h" #include "DHTConstants.h" +#include "DHTPingReplyMessageCallback.h" +#include "DHTQueryMessage.h" namespace aria2 { @@ -53,17 +54,21 @@ DHTPingTask::DHTPingTask DHTPingTask::~DHTPingTask() {} -void DHTPingTask::startup() +void DHTPingTask::addMessage() { SharedHandle m = getMessageFactory()->createPingMessage(_remoteNode); - WeakHandle listener(this); SharedHandle callback - (new DHTMessageCallbackImpl(listener)); + (new DHTPingReplyMessageCallback(this)); getMessageDispatcher()->addMessageToQueue(m, _timeout, callback); } -void DHTPingTask::onReceived(const SharedHandle& message) +void DHTPingTask::startup() +{ + addMessage(); +} + +void DHTPingTask::onReceived(const DHTPingReplyMessage* message) { _pingSuccessful = true; setFinished(true); @@ -76,12 +81,7 @@ void DHTPingTask::onTimeout(const SharedHandle& node) _pingSuccessful = false; setFinished(true); } else { - SharedHandle m = - getMessageFactory()->createPingMessage(_remoteNode); - WeakHandle listener(this); - SharedHandle callback - (new DHTMessageCallbackImpl(listener)); - getMessageDispatcher()->addMessageToQueue(m, _timeout, callback); + addMessage(); } } diff --git a/src/DHTPingTask.h b/src/DHTPingTask.h index 47ad9c2c..65719f21 100644 --- a/src/DHTPingTask.h +++ b/src/DHTPingTask.h @@ -36,12 +36,13 @@ #define _D_DHT_PING_TASK_H_ #include "DHTAbstractTask.h" -#include "DHTMessageCallbackListener.h" #include "a2time.h" namespace aria2 { -class DHTPingTask:public DHTAbstractTask, public DHTMessageCallbackListener { +class DHTPingReplyMessage; + +class DHTPingTask:public DHTAbstractTask { private: SharedHandle _remoteNode; @@ -52,6 +53,8 @@ private: bool _pingSuccessful; time_t _timeout; + + void addMessage(); public: DHTPingTask(const SharedHandle& remoteNode, size_t numMaxRetry = 0); @@ -59,9 +62,9 @@ public: virtual void startup(); - virtual void onReceived(const SharedHandle& message); + void onReceived(const DHTPingReplyMessage* message); - virtual void onTimeout(const SharedHandle& node); + void onTimeout(const SharedHandle& node); void setTimeout(time_t timeout) { diff --git a/src/DHTReplaceNodeTask.cc b/src/DHTReplaceNodeTask.cc index d9c68bc6..84c7d37f 100644 --- a/src/DHTReplaceNodeTask.cc +++ b/src/DHTReplaceNodeTask.cc @@ -35,11 +35,12 @@ #include "DHTReplaceNodeTask.h" #include "DHTBucket.h" #include "DHTNode.h" -#include "DHTMessage.h" +#include "DHTPingReplyMessage.h" #include "DHTMessageFactory.h" #include "DHTMessageDispatcher.h" -#include "DHTMessageCallbackImpl.h" #include "Logger.h" +#include "DHTPingReplyMessageCallback.h" +#include "DHTQueryMessage.h" namespace aria2 { @@ -66,14 +67,13 @@ void DHTReplaceNodeTask::sendMessage() } else { SharedHandle m = getMessageFactory()->createPingMessage(questionableNode); - WeakHandle listener(this); SharedHandle callback - (new DHTMessageCallbackImpl(listener)); + (new DHTPingReplyMessageCallback(this)); getMessageDispatcher()->addMessageToQueue(m, _timeout, callback); } } -void DHTReplaceNodeTask::onReceived(const SharedHandle& message) +void DHTReplaceNodeTask::onReceived(const DHTPingReplyMessage* message) { getLogger()->info("ReplaceNode: Ping reply received from %s.", message->getRemoteNode()->toString().c_str()); diff --git a/src/DHTReplaceNodeTask.h b/src/DHTReplaceNodeTask.h index efa40f69..4cd3dfea 100644 --- a/src/DHTReplaceNodeTask.h +++ b/src/DHTReplaceNodeTask.h @@ -36,15 +36,14 @@ #define _D_DHT_REPLACE_NODE_TASK_H_ #include "DHTAbstractTask.h" -#include "DHTMessageCallbackListener.h" #include "a2time.h" namespace aria2 { class DHTBucket; +class DHTPingReplyMessage; -class DHTReplaceNodeTask:public DHTAbstractTask, - public DHTMessageCallbackListener { +class DHTReplaceNodeTask:public DHTAbstractTask { private: SharedHandle _bucket; @@ -65,9 +64,9 @@ public: virtual void startup(); - virtual void onReceived(const SharedHandle& message); + void onReceived(const DHTPingReplyMessage* message); - virtual void onTimeout(const SharedHandle& node); + void onTimeout(const SharedHandle& node); void setTimeout(time_t timeout) { diff --git a/src/DHTResponseMessage.h b/src/DHTResponseMessage.h index d8cf958b..6f466368 100644 --- a/src/DHTResponseMessage.h +++ b/src/DHTResponseMessage.h @@ -40,6 +40,8 @@ namespace aria2 { +class DHTMessageCallback; + class DHTResponseMessage:public DHTAbstractMessage { protected: virtual std::string toStringOptional() const { return A2STR::NIL; } @@ -60,6 +62,8 @@ public: virtual std::string toString() const; + virtual void accept(DHTMessageCallback* callback) = 0; + static const std::string R; }; diff --git a/src/Makefile.am b/src/Makefile.am index becda418..193838e5 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -388,15 +388,16 @@ SRCS += PeerAbstractCommand.cc PeerAbstractCommand.h\ DHTNodeLookupEntry.cc DHTNodeLookupEntry.h\ BNode.cc BNode.h\ DHTMessageCallback.h\ - DHTMessageCallbackImpl.cc DHTMessageCallbackImpl.h\ - DHTMessageCallbackListener.h\ + DHTNodeLookupTaskCallback.cc DHTNodeLookupTaskCallback.h\ + DHTPingReplyMessageCallback.h\ + DHTPeerLookupTaskCallback.cc DHTPeerLookupTaskCallback.h\ DHTAbstractTask.cc DHTAbstractTask.h\ DHTTask.h\ DHTPingTask.cc DHTPingTask.h\ DHTTaskQueue.h\ DHTTaskQueueImpl.cc DHTTaskQueueImpl.h\ DHTBucketRefreshTask.cc DHTBucketRefreshTask.h\ - DHTAbstractNodeLookupTask.cc DHTAbstractNodeLookupTask.h\ + DHTAbstractNodeLookupTask.h\ DHTPeerLookupTask.cc DHTPeerLookupTask.h\ DHTSetup.cc DHTSetup.h\ DHTTaskFactory.h\ @@ -544,7 +545,7 @@ aria2c_LDADD = libaria2c.a @LIBINTL@ @ALLOCA@ @LIBGNUTLS_LIBS@\ @LIBGCRYPT_LIBS@ @OPENSSL_LIBS@ @XML_LIBS@\ @LIBCARES_LIBS@ @LIBEXPAT_LIBS@ @LIBZ_LIBS@\ @SQLITE3_LIBS@ #-lprofiler -#aria2c_LDFLAGS = -pg +#aria2c_LDFLAGS = -pg AM_CPPFLAGS = -Wall\ -I../lib -I../intl -I$(top_srcdir)/intl\ @LIBGNUTLS_CFLAGS@ @LIBGCRYPT_CFLAGS@ @OPENSSL_CFLAGS@ @XML_CPPFLAGS@\ diff --git a/src/Makefile.in b/src/Makefile.in index 02e8f664..8764c069 100644 --- a/src/Makefile.in +++ b/src/Makefile.in @@ -186,15 +186,16 @@ bin_PROGRAMS = aria2c$(EXEEXT) @ENABLE_BITTORRENT_TRUE@ DHTNodeLookupEntry.cc DHTNodeLookupEntry.h\ @ENABLE_BITTORRENT_TRUE@ BNode.cc BNode.h\ @ENABLE_BITTORRENT_TRUE@ DHTMessageCallback.h\ -@ENABLE_BITTORRENT_TRUE@ DHTMessageCallbackImpl.cc DHTMessageCallbackImpl.h\ -@ENABLE_BITTORRENT_TRUE@ DHTMessageCallbackListener.h\ +@ENABLE_BITTORRENT_TRUE@ DHTNodeLookupTaskCallback.cc DHTNodeLookupTaskCallback.h\ +@ENABLE_BITTORRENT_TRUE@ DHTPingReplyMessageCallback.h\ +@ENABLE_BITTORRENT_TRUE@ DHTPeerLookupTaskCallback.cc DHTPeerLookupTaskCallback.h\ @ENABLE_BITTORRENT_TRUE@ DHTAbstractTask.cc DHTAbstractTask.h\ @ENABLE_BITTORRENT_TRUE@ DHTTask.h\ @ENABLE_BITTORRENT_TRUE@ DHTPingTask.cc DHTPingTask.h\ @ENABLE_BITTORRENT_TRUE@ DHTTaskQueue.h\ @ENABLE_BITTORRENT_TRUE@ DHTTaskQueueImpl.cc DHTTaskQueueImpl.h\ @ENABLE_BITTORRENT_TRUE@ DHTBucketRefreshTask.cc DHTBucketRefreshTask.h\ -@ENABLE_BITTORRENT_TRUE@ DHTAbstractNodeLookupTask.cc DHTAbstractNodeLookupTask.h\ +@ENABLE_BITTORRENT_TRUE@ DHTAbstractNodeLookupTask.h\ @ENABLE_BITTORRENT_TRUE@ DHTPeerLookupTask.cc DHTPeerLookupTask.h\ @ENABLE_BITTORRENT_TRUE@ DHTSetup.cc DHTSetup.h\ @ENABLE_BITTORRENT_TRUE@ DHTTaskFactory.h\ @@ -347,8 +348,8 @@ am__libaria2c_a_SOURCES_DIST = Socket.h SocketCore.cc SocketCore.h \ DefaultDiskWriterFactory.cc DefaultDiskWriterFactory.h File.cc \ File.h Option.cc Option.h Base64.cc Base64.h base32.cc \ base32.h LogFactory.cc LogFactory.h TimerA2.cc TimerA2.h \ - TimeA2.cc TimeA2.h SharedHandle.h HandleRegistry.h \ - FeatureConfig.cc FeatureConfig.h DownloadEngineFactory.cc \ + TimeA2.cc TimeA2.h SharedHandle.h FeatureConfig.cc \ + FeatureConfig.h DownloadEngineFactory.cc \ DownloadEngineFactory.h SpeedCalc.cc SpeedCalc.h PeerStat.h \ BitfieldMan.cc BitfieldMan.h Randomizer.h SimpleRandomizer.cc \ SimpleRandomizer.h HttpResponse.cc HttpResponse.h \ @@ -547,24 +548,24 @@ am__libaria2c_a_SOURCES_DIST = Socket.h SocketCore.cc SocketCore.h \ DHTMessageFactoryImpl.cc DHTMessageFactoryImpl.h \ DHTNodeLookupTask.cc DHTNodeLookupTask.h DHTNodeLookupEntry.cc \ DHTNodeLookupEntry.h BNode.cc BNode.h DHTMessageCallback.h \ - DHTMessageCallbackImpl.cc DHTMessageCallbackImpl.h \ - DHTMessageCallbackListener.h DHTAbstractTask.cc \ + DHTNodeLookupTaskCallback.cc DHTNodeLookupTaskCallback.h \ + DHTPingReplyMessageCallback.h DHTPeerLookupTaskCallback.cc \ + DHTPeerLookupTaskCallback.h DHTAbstractTask.cc \ DHTAbstractTask.h DHTTask.h DHTPingTask.cc DHTPingTask.h \ DHTTaskQueue.h DHTTaskQueueImpl.cc DHTTaskQueueImpl.h \ DHTBucketRefreshTask.cc DHTBucketRefreshTask.h \ - DHTAbstractNodeLookupTask.cc DHTAbstractNodeLookupTask.h \ - DHTPeerLookupTask.cc DHTPeerLookupTask.h DHTSetup.cc \ - DHTSetup.h DHTTaskFactory.h DHTTaskFactoryImpl.cc \ - DHTTaskFactoryImpl.h DHTInteractionCommand.cc \ - DHTInteractionCommand.h DHTPeerAnnounceEntry.cc \ - DHTPeerAnnounceEntry.h DHTPeerAnnounceStorage.cc \ - DHTPeerAnnounceStorage.h DHTTokenTracker.cc DHTTokenTracker.h \ - DHTGetPeersCommand.cc DHTGetPeersCommand.h \ - DHTTokenUpdateCommand.cc DHTTokenUpdateCommand.h \ - DHTBucketRefreshCommand.cc DHTBucketRefreshCommand.h \ - DHTPeerAnnounceCommand.cc DHTPeerAnnounceCommand.h \ - DHTReplaceNodeTask.cc DHTReplaceNodeTask.h \ - DHTEntryPointNameResolveCommand.cc \ + DHTAbstractNodeLookupTask.h DHTPeerLookupTask.cc \ + DHTPeerLookupTask.h DHTSetup.cc DHTSetup.h DHTTaskFactory.h \ + DHTTaskFactoryImpl.cc DHTTaskFactoryImpl.h \ + DHTInteractionCommand.cc DHTInteractionCommand.h \ + DHTPeerAnnounceEntry.cc DHTPeerAnnounceEntry.h \ + DHTPeerAnnounceStorage.cc DHTPeerAnnounceStorage.h \ + DHTTokenTracker.cc DHTTokenTracker.h DHTGetPeersCommand.cc \ + DHTGetPeersCommand.h DHTTokenUpdateCommand.cc \ + DHTTokenUpdateCommand.h DHTBucketRefreshCommand.cc \ + DHTBucketRefreshCommand.h DHTPeerAnnounceCommand.cc \ + DHTPeerAnnounceCommand.h DHTReplaceNodeTask.cc \ + DHTReplaceNodeTask.h DHTEntryPointNameResolveCommand.cc \ DHTEntryPointNameResolveCommand.h DHTRoutingTableSerializer.cc \ DHTRoutingTableSerializer.h DHTRoutingTableDeserializer.cc \ DHTRoutingTableDeserializer.h DHTAutoSaveCommand.cc \ @@ -725,12 +726,12 @@ am__objects_6 = @ENABLE_BITTORRENT_TRUE@ DHTNodeLookupTask.$(OBJEXT) \ @ENABLE_BITTORRENT_TRUE@ DHTNodeLookupEntry.$(OBJEXT) \ @ENABLE_BITTORRENT_TRUE@ BNode.$(OBJEXT) \ -@ENABLE_BITTORRENT_TRUE@ DHTMessageCallbackImpl.$(OBJEXT) \ +@ENABLE_BITTORRENT_TRUE@ DHTNodeLookupTaskCallback.$(OBJEXT) \ +@ENABLE_BITTORRENT_TRUE@ DHTPeerLookupTaskCallback.$(OBJEXT) \ @ENABLE_BITTORRENT_TRUE@ DHTAbstractTask.$(OBJEXT) \ @ENABLE_BITTORRENT_TRUE@ DHTPingTask.$(OBJEXT) \ @ENABLE_BITTORRENT_TRUE@ DHTTaskQueueImpl.$(OBJEXT) \ @ENABLE_BITTORRENT_TRUE@ DHTBucketRefreshTask.$(OBJEXT) \ -@ENABLE_BITTORRENT_TRUE@ DHTAbstractNodeLookupTask.$(OBJEXT) \ @ENABLE_BITTORRENT_TRUE@ DHTPeerLookupTask.$(OBJEXT) \ @ENABLE_BITTORRENT_TRUE@ DHTSetup.$(OBJEXT) \ @ENABLE_BITTORRENT_TRUE@ DHTTaskFactoryImpl.$(OBJEXT) \ @@ -1123,8 +1124,8 @@ SRCS = Socket.h SocketCore.cc SocketCore.h BinaryStream.h Command.cc \ DefaultDiskWriterFactory.cc DefaultDiskWriterFactory.h File.cc \ File.h Option.cc Option.h Base64.cc Base64.h base32.cc \ base32.h LogFactory.cc LogFactory.h TimerA2.cc TimerA2.h \ - TimeA2.cc TimeA2.h SharedHandle.h HandleRegistry.h \ - FeatureConfig.cc FeatureConfig.h DownloadEngineFactory.cc \ + TimeA2.cc TimeA2.h SharedHandle.h FeatureConfig.cc \ + FeatureConfig.h DownloadEngineFactory.cc \ DownloadEngineFactory.h SpeedCalc.cc SpeedCalc.h PeerStat.h \ BitfieldMan.cc BitfieldMan.h Randomizer.h SimpleRandomizer.cc \ SimpleRandomizer.h HttpResponse.cc HttpResponse.h \ @@ -1230,7 +1231,7 @@ aria2c_LDADD = libaria2c.a @LIBINTL@ @ALLOCA@ @LIBGNUTLS_LIBS@\ @LIBCARES_LIBS@ @LIBEXPAT_LIBS@ @LIBZ_LIBS@\ @SQLITE3_LIBS@ #-lprofiler -#aria2c_LDFLAGS = -pg +#aria2c_LDFLAGS = -pg AM_CPPFLAGS = -Wall\ -I../lib -I../intl -I$(top_srcdir)/intl\ @LIBGNUTLS_CFLAGS@ @LIBGCRYPT_CFLAGS@ @OPENSSL_CFLAGS@ @XML_CPPFLAGS@\ @@ -1389,7 +1390,6 @@ distclean-compile: @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/CookieStorage.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/CreateRequestCommand.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTAbstractMessage.Po@am__quote@ -@AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTAbstractNodeLookupTask.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTAbstractTask.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTAnnouncePeerMessage.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTAnnouncePeerReplyMessage.Po@am__quote@ @@ -1406,7 +1406,6 @@ distclean-compile: @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTGetPeersReplyMessage.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTInteractionCommand.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTMessage.Po@am__quote@ -@AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTMessageCallbackImpl.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTMessageDispatcherImpl.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTMessageEntry.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTMessageFactoryImpl.Po@am__quote@ @@ -1416,10 +1415,12 @@ distclean-compile: @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTNode.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTNodeLookupEntry.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTNodeLookupTask.Po@am__quote@ +@AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTNodeLookupTaskCallback.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTPeerAnnounceCommand.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTPeerAnnounceEntry.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTPeerAnnounceStorage.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTPeerLookupTask.Po@am__quote@ +@AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTPeerLookupTaskCallback.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTPingMessage.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTPingReplyMessage.Po@am__quote@ @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTPingTask.Po@am__quote@ diff --git a/test/DHTAnnouncePeerMessageTest.cc b/test/DHTAnnouncePeerMessageTest.cc index 55272d84..81aa7336 100644 --- a/test/DHTAnnouncePeerMessageTest.cc +++ b/test/DHTAnnouncePeerMessageTest.cc @@ -28,11 +28,13 @@ public: void testDoReceivedAction(); class MockDHTMessageFactory2:public MockDHTMessageFactory { - virtual SharedHandle + virtual SharedHandle createAnnouncePeerReplyMessage(const SharedHandle& remoteNode, const std::string& transactionID) { - return SharedHandle(new MockDHTMessage(_localNode, remoteNode, "announce_peer", transactionID)); + return SharedHandle + (new MockDHTResponseMessage + (_localNode, remoteNode, "announce_peer", transactionID)); } }; }; @@ -106,8 +108,9 @@ void DHTAnnouncePeerMessageTest::testDoReceivedAction() msg.doReceivedAction(); CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher._messageQueue.size()); - SharedHandle m - (dynamic_pointer_cast(dispatcher._messageQueue[0]._message)); + SharedHandle m + (dynamic_pointer_cast + (dispatcher._messageQueue[0]._message)); CPPUNIT_ASSERT(localNode == m->getLocalNode()); CPPUNIT_ASSERT(remoteNode == m->getRemoteNode()); CPPUNIT_ASSERT_EQUAL(std::string("announce_peer"), m->getMessageType()); diff --git a/test/DHTFindNodeMessageTest.cc b/test/DHTFindNodeMessageTest.cc index 9686797b..0fd9214d 100644 --- a/test/DHTFindNodeMessageTest.cc +++ b/test/DHTFindNodeMessageTest.cc @@ -29,14 +29,15 @@ public: class MockDHTMessageFactory2:public MockDHTMessageFactory { public: - virtual SharedHandle + virtual SharedHandle createFindNodeReplyMessage (const SharedHandle& remoteNode, const std::vector >& closestKNodes, const std::string& transactionID) { - SharedHandle m - (new MockDHTMessage(_localNode, remoteNode, "find_node", transactionID)); + SharedHandle m + (new MockDHTResponseMessage + (_localNode, remoteNode, "find_node", transactionID)); m->_nodes = closestKNodes; return m; } @@ -99,8 +100,9 @@ void DHTFindNodeMessageTest::testDoReceivedAction() msg.doReceivedAction(); CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher._messageQueue.size()); - SharedHandle m - (dynamic_pointer_cast(dispatcher._messageQueue[0]._message)); + SharedHandle m + (dynamic_pointer_cast + (dispatcher._messageQueue[0]._message)); CPPUNIT_ASSERT(localNode == m->getLocalNode()); CPPUNIT_ASSERT(remoteNode == m->getRemoteNode()); CPPUNIT_ASSERT_EQUAL(std::string("find_node"), m->getMessageType()); diff --git a/test/DHTGetPeersMessageTest.cc b/test/DHTGetPeersMessageTest.cc index 69e8033b..f867cc90 100644 --- a/test/DHTGetPeersMessageTest.cc +++ b/test/DHTGetPeersMessageTest.cc @@ -31,28 +31,30 @@ public: class MockDHTMessageFactory2:public MockDHTMessageFactory { public: - virtual SharedHandle + virtual SharedHandle createGetPeersReplyMessage(const SharedHandle& remoteNode, const std::vector >& peers, const std::string& token, const std::string& transactionID) { - SharedHandle m - (new MockDHTMessage(_localNode, remoteNode, "get_peers", transactionID)); + SharedHandle m + (new MockDHTResponseMessage + (_localNode, remoteNode, "get_peers", transactionID)); m->_peers = peers; m->_token = token; return m; } - virtual SharedHandle + virtual SharedHandle createGetPeersReplyMessage (const SharedHandle& remoteNode, const std::vector >& closestKNodes, const std::string& token, const std::string& transactionID) { - SharedHandle m - (new MockDHTMessage(_localNode, remoteNode, "get_peers", transactionID)); + SharedHandle m + (new MockDHTResponseMessage + (_localNode, remoteNode, "get_peers", transactionID)); m->_nodes = closestKNodes; m->_token = token; return m; @@ -133,8 +135,9 @@ void DHTGetPeersMessageTest::testDoReceivedAction() msg.doReceivedAction(); CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher._messageQueue.size()); - SharedHandle m - (dynamic_pointer_cast(dispatcher._messageQueue[0]._message)); + SharedHandle m + (dynamic_pointer_cast + (dispatcher._messageQueue[0]._message)); CPPUNIT_ASSERT(localNode == m->getLocalNode()); CPPUNIT_ASSERT(remoteNode == m->getRemoteNode()); CPPUNIT_ASSERT_EQUAL(std::string("get_peers"), m->getMessageType()); @@ -169,8 +172,9 @@ void DHTGetPeersMessageTest::testDoReceivedAction() msg.doReceivedAction(); CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher._messageQueue.size()); - SharedHandle m - (dynamic_pointer_cast(dispatcher._messageQueue[0]._message)); + SharedHandle m + (dynamic_pointer_cast + (dispatcher._messageQueue[0]._message)); CPPUNIT_ASSERT(localNode == m->getLocalNode()); CPPUNIT_ASSERT(remoteNode == m->getRemoteNode()); CPPUNIT_ASSERT_EQUAL(std::string("get_peers"), m->getMessageType()); diff --git a/test/DHTMessageTrackerTest.cc b/test/DHTMessageTrackerTest.cc index 60ec2f9b..a4fed872 100644 --- a/test/DHTMessageTrackerTest.cc +++ b/test/DHTMessageTrackerTest.cc @@ -28,18 +28,6 @@ public: void testMessageArrived(); void testHandleTimeout(); - - class MockDHTMessageCallback2:public MockDHTMessageCallback { - public: - uint32_t _countOnRecivedCalled; - - MockDHTMessageCallback2():_countOnRecivedCalled(0) {} - - virtual void onReceived(const SharedHandle& message) - { - ++_countOnRecivedCalled; - } - }; }; @@ -66,13 +54,11 @@ void DHTMessageTrackerTest::testMessageArrived() m3->getRemoteNode()->setIPAddress("192.168.0.3"); m3->getRemoteNode()->setPort(6883); - SharedHandle c2(new MockDHTMessageCallback2()); - DHTMessageTracker tracker; tracker.setRoutingTable(routingTable); tracker.setMessageFactory(factory); tracker.addMessage(m1, DHT_MESSAGE_TIMEOUT); - tracker.addMessage(m2, DHT_MESSAGE_TIMEOUT, c2); + tracker.addMessage(m2, DHT_MESSAGE_TIMEOUT); tracker.addMessage(m3, DHT_MESSAGE_TIMEOUT); { @@ -85,7 +71,6 @@ void DHTMessageTrackerTest::testMessageArrived() SharedHandle reply = p.first; CPPUNIT_ASSERT(!reply.isNull()); - CPPUNIT_ASSERT_EQUAL((uint32_t)0, c2->_countOnRecivedCalled); CPPUNIT_ASSERT(tracker.getEntryFor(m2).isNull()); CPPUNIT_ASSERT_EQUAL((size_t)2, tracker.countEntry()); } diff --git a/test/DHTPingMessageTest.cc b/test/DHTPingMessageTest.cc index 96f29cc9..95fa8924 100644 --- a/test/DHTPingMessageTest.cc +++ b/test/DHTPingMessageTest.cc @@ -28,14 +28,14 @@ public: class MockDHTMessageFactory2:public MockDHTMessageFactory { public: - virtual SharedHandle + virtual SharedHandle createPingReplyMessage(const SharedHandle& remoteNode, const unsigned char* remoteNodeID, const std::string& transactionID) { - return SharedHandle - (new MockDHTMessage(_localNode, remoteNode, "ping_reply", - transactionID)); + return SharedHandle + (new MockDHTResponseMessage(_localNode, remoteNode, "ping_reply", + transactionID)); } }; }; @@ -89,8 +89,9 @@ void DHTPingMessageTest::testDoReceivedAction() msg.doReceivedAction(); CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher._messageQueue.size()); - SharedHandle m - (dynamic_pointer_cast(dispatcher._messageQueue[0]._message)); + SharedHandle m + (dynamic_pointer_cast + (dispatcher._messageQueue[0]._message)); CPPUNIT_ASSERT(localNode == m->getLocalNode()); CPPUNIT_ASSERT(remoteNode == m->getRemoteNode()); CPPUNIT_ASSERT_EQUAL(std::string("ping_reply"), m->getMessageType()); diff --git a/test/MockDHTMessage.h b/test/MockDHTMessage.h index 1e5b2e80..bbaf7332 100644 --- a/test/MockDHTMessage.h +++ b/test/MockDHTMessage.h @@ -2,14 +2,19 @@ #define _D_MOCK_DHT_MESSAGE_H_ #include "DHTMessage.h" +#include "DHTQueryMessage.h" +#include "DHTResponseMessage.h" #include #include "DHTNode.h" #include "Peer.h" +#include "BDE.h" namespace aria2 { +class DHTMessageCallback; + class MockDHTMessage:public DHTMessage { public: bool _isReply; @@ -43,6 +48,72 @@ public: virtual std::string toString() const { return "MockDHTMessage"; } }; +class MockDHTQueryMessage:public DHTQueryMessage { +public: + std::string _messageType; + + std::vector > _nodes; + + std::vector > _peers; + + std::string _token; +public: + MockDHTQueryMessage(const SharedHandle& localNode, + const SharedHandle& remoteNode, + const std::string& messageType = "mock", + const std::string& transactionID = ""): + DHTQueryMessage(localNode, remoteNode, transactionID), + _messageType(messageType) {} + + virtual ~MockDHTQueryMessage() {} + + virtual void doReceivedAction() {} + + virtual bool send() { return true; } + + virtual bool isReply() const { return false; } + + virtual const std::string& getMessageType() const { return _messageType; } + + virtual std::string toString() const { return "MockDHTMessage"; } + + virtual BDE getArgument() { return BDE::dict(); } +}; + +class MockDHTResponseMessage:public DHTResponseMessage { +public: + std::string _messageType; + + std::vector > _nodes; + + std::vector > _peers; + + std::string _token; +public: + MockDHTResponseMessage(const SharedHandle& localNode, + const SharedHandle& remoteNode, + const std::string& messageType = "mock", + const std::string& transactionID = ""): + DHTResponseMessage(localNode, remoteNode, transactionID), + _messageType(messageType) {} + + virtual ~MockDHTResponseMessage() {} + + virtual void doReceivedAction() {} + + virtual bool send() { return true; } + + virtual bool isReply() const { return true; } + + virtual const std::string& getMessageType() const { return _messageType; } + + virtual std::string toString() const { return "MockDHTMessage"; } + + virtual BDE getResponse() { return BDE::dict(); } + + virtual void accept(DHTMessageCallback* callback) {} +}; + } // namespace aria2 #endif // _D_MOCK_DHT_MESSAGE_H_ diff --git a/test/MockDHTMessageCallback.h b/test/MockDHTMessageCallback.h index 52bade88..6929074d 100644 --- a/test/MockDHTMessageCallback.h +++ b/test/MockDHTMessageCallback.h @@ -11,7 +11,13 @@ public: virtual ~MockDHTMessageCallback() {} - virtual void onReceived(const SharedHandle& message) {} + virtual void visit(const DHTAnnouncePeerReplyMessage* message) {} + + virtual void visit(const DHTFindNodeReplyMessage* message) {} + + virtual void visit(const DHTGetPeersReplyMessage* message) {} + + virtual void visit(const DHTPingReplyMessage* message) {} virtual void onTimeout(const SharedHandle& remoteNode) {} }; diff --git a/test/MockDHTMessageFactory.h b/test/MockDHTMessageFactory.h index 0ae5ccd1..32eec7f6 100644 --- a/test/MockDHTMessageFactory.h +++ b/test/MockDHTMessageFactory.h @@ -16,14 +16,14 @@ public: virtual ~MockDHTMessageFactory() {} - virtual SharedHandle + virtual SharedHandle createQueryMessage(const BDE& dict, const std::string& ipaddr, uint16_t port) { - return SharedHandle(); + return SharedHandle(); } - virtual SharedHandle + virtual SharedHandle createResponseMessage(const std::string& messageType, const BDE& dict, const std::string& ipaddr, uint16_t port) @@ -32,87 +32,86 @@ public: // TODO At this point, removeNode's ID is random. remoteNode->setIPAddress(ipaddr); remoteNode->setPort(port); - SharedHandle m - (new MockDHTMessage(_localNode, remoteNode, dict["t"].s())); - m->setReply(true); + SharedHandle m + (new MockDHTResponseMessage(_localNode, remoteNode, dict["t"].s())); return m; } - virtual SharedHandle + virtual SharedHandle createPingMessage(const SharedHandle& remoteNode, const std::string& transactionID = "") { - return SharedHandle(); + return SharedHandle(); } - virtual SharedHandle + virtual SharedHandle createPingReplyMessage(const SharedHandle& remoteNode, const unsigned char* remoteNodeID, const std::string& transactionID) { - return SharedHandle(); + return SharedHandle(); } - virtual SharedHandle + virtual SharedHandle createFindNodeMessage(const SharedHandle& remoteNode, const unsigned char* targetNodeID, const std::string& transactionID = "") { - return SharedHandle(); + return SharedHandle(); } - virtual SharedHandle + virtual SharedHandle createFindNodeReplyMessage (const SharedHandle& remoteNode, const std::vector >& closestKNodes, const std::string& transactionID) { - return SharedHandle(); + return SharedHandle(); } - virtual SharedHandle + virtual SharedHandle createGetPeersMessage(const SharedHandle& remoteNode, const unsigned char* infoHash, const std::string& transactionID) { - return SharedHandle(); + return SharedHandle(); } - virtual SharedHandle + virtual SharedHandle createGetPeersReplyMessage (const SharedHandle& remoteNode, const std::vector >& closestKNodes, const std::string& token, const std::string& transactionID) { - return SharedHandle(); + return SharedHandle(); } - virtual SharedHandle + virtual SharedHandle createGetPeersReplyMessage(const SharedHandle& remoteNode, const std::vector >& peers, const std::string& token, const std::string& transactionID) { - return SharedHandle(); + return SharedHandle(); } - virtual SharedHandle + virtual SharedHandle createAnnouncePeerMessage(const SharedHandle& remoteNode, const unsigned char* infoHash, uint16_t tcpPort, const std::string& token, const std::string& transactionID = "") { - return SharedHandle(); + return SharedHandle(); } - virtual SharedHandle + virtual SharedHandle createAnnouncePeerReplyMessage(const SharedHandle& remoteNode, const std::string& transactionID) { - return SharedHandle(); + return SharedHandle(); } virtual SharedHandle