diff --git a/ChangeLog b/ChangeLog index ffbe81c0..1190fa30 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,22 @@ +2008-10-26 Tatsuhiro Tsujikawa + + Changed signature of DHTMessageFactory::createResponseMessage(). + Removed unused validateIDMatch. + * src/DHTMessageFactory.h + * src/DHTMessageFactoryImpl.cc + * src/DHTMessageFactoryImpl.h + * src/DHTMessageTracker.cc + * test/DHTMessageFactoryImplTest.cc + * test/MockDHTMessageFactory.h + + Dropped DHT message coming from same ID of localhost. + * src/DHTMessageReceiver.cc + + Rejected adding node whose ID is the same as localhost's. + * src/DHTRoutingTable.cc + * test/BtPortMessageTest.cc + * test/DHTRoutingTableTest.cc + 2008-10-23 Tatsuhiro Tsujikawa Pool connection when redirection occurs with Content-Length = 0. diff --git a/src/DHTMessageFactory.h b/src/DHTMessageFactory.h index 3c04eda0..157e7322 100644 --- a/src/DHTMessageFactory.h +++ b/src/DHTMessageFactory.h @@ -36,11 +36,13 @@ #define _D_DHT_MESSAGE_FACTORY_H_ #include "common.h" -#include "SharedHandle.h" -#include "A2STR.h" + #include #include +#include "SharedHandle.h" +#include "A2STR.h" + namespace aria2 { class DHTMessage; @@ -59,7 +61,7 @@ public: virtual SharedHandle createResponseMessage(const std::string& messageType, const Dictionary* d, - const SharedHandle& remoteNode) = 0; + const std::string& ipaddr, uint16_t port) = 0; virtual SharedHandle createPingMessage(const SharedHandle& remoteNode, diff --git a/src/DHTMessageFactoryImpl.cc b/src/DHTMessageFactoryImpl.cc index e2621708..f1985335 100644 --- a/src/DHTMessageFactoryImpl.cc +++ b/src/DHTMessageFactoryImpl.cc @@ -33,6 +33,10 @@ */ /* copyright --> */ #include "DHTMessageFactoryImpl.h" + +#include +#include + #include "LogFactory.h" #include "DlAbortEx.h" #include "Data.h" @@ -60,8 +64,6 @@ #include "Peer.h" #include "Logger.h" #include "StringFormat.h" -#include -#include namespace aria2 { @@ -135,13 +137,6 @@ void DHTMessageFactoryImpl::validateID(const Data* id) const } } -void DHTMessageFactoryImpl::validateIDMatch(const unsigned char* expected, const unsigned char* actual) const -{ - if(memcmp(expected, actual, DHT_ID_LENGTH) != 0) { - //throw DlAbortEx("Different ID received."); - } -} - void DHTMessageFactoryImpl::validatePort(const Data* i) const { if(!i->isNumber()) { @@ -202,7 +197,8 @@ SharedHandle DHTMessageFactoryImpl::createQueryMessage(const Diction SharedHandle DHTMessageFactoryImpl::createResponseMessage(const std::string& messageType, const Dictionary* d, - const SharedHandle& remoteNode) + const std::string& ipaddr, + uint16_t port) { const Data* t = getData(d, DHTMessage::T); const Data* y = getData(d, DHTMessage::Y); @@ -225,7 +221,8 @@ DHTMessageFactoryImpl::createResponseMessage(const std::string& messageType, const Dictionary* r = getDictionary(d, DHTResponseMessage::R); const Data* id = getData(r, DHTMessage::ID); validateID(id); - validateIDMatch(remoteNode->getID(), id->getData()); + SharedHandle remoteNode = getRemoteNode(id->getData(), ipaddr, port); + std::string transactionID = t->toString(); if(messageType == DHTPingReplyMessage::PING) { return createPingReplyMessage(remoteNode, diff --git a/src/DHTMessageFactoryImpl.h b/src/DHTMessageFactoryImpl.h index 7014a3b1..57986c8b 100644 --- a/src/DHTMessageFactoryImpl.h +++ b/src/DHTMessageFactoryImpl.h @@ -71,8 +71,6 @@ private: void validateID(const Data* id) const; - void validateIDMatch(const unsigned char* expected, const unsigned char* actual) const; - void validatePort(const Data* i) const; std::deque > extractNodes(const unsigned char* src, size_t length); @@ -91,7 +89,7 @@ public: virtual SharedHandle createResponseMessage(const std::string& messageType, const Dictionary* d, - const SharedHandle& remoteNode); + const std::string& ipaddr, uint16_t port); virtual SharedHandle createPingMessage(const SharedHandle& remoteNode, diff --git a/src/DHTMessageReceiver.cc b/src/DHTMessageReceiver.cc index 8fcfaefa..ad624127 100644 --- a/src/DHTMessageReceiver.cc +++ b/src/DHTMessageReceiver.cc @@ -33,6 +33,10 @@ */ /* copyright --> */ #include "DHTMessageReceiver.h" + +#include +#include + #include "DHTMessageTracker.h" #include "DHTConnection.h" #include "DHTMessage.h" @@ -49,7 +53,6 @@ #include "LogFactory.h" #include "Logger.h" #include "Util.h" -#include namespace aria2 { @@ -102,6 +105,11 @@ SharedHandle DHTMessageReceiver::receiveMessage() } } else { message = _factory->createQueryMessage(d, remoteAddr, remotePort); + if(message->getLocalNode() == message->getRemoteNode()) { + // drop message from localnode + _logger->info("Recieved DHT message from localnode."); + return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); + } } _logger->info("Message received: %s", message->toString().c_str()); message->validate(); diff --git a/src/DHTMessageTracker.cc b/src/DHTMessageTracker.cc index 7b1ae9df..3d0d5db5 100644 --- a/src/DHTMessageTracker.cc +++ b/src/DHTMessageTracker.cc @@ -33,6 +33,9 @@ */ /* copyright --> */ #include "DHTMessageTracker.h" + +#include + #include "DHTMessage.h" #include "DHTMessageCallback.h" #include "DHTMessageTrackerEntry.h" @@ -47,7 +50,6 @@ #include "DlAbortEx.h" #include "DHTConstants.h" #include "StringFormat.h" -#include namespace aria2 { @@ -86,11 +88,14 @@ DHTMessageTracker::messageArrived(const Dictionary* d, _logger->debug("Tracker entry found."); SharedHandle targetNode = entry->getTargetNode(); - SharedHandle message = _factory->createResponseMessage(entry->getMessageType(), - d, targetNode); + SharedHandle message = + _factory->createResponseMessage(entry->getMessageType(), d, + targetNode->getIPAddress(), + targetNode->getPort()); + int64_t rtt = entry->getElapsedMillis(); _logger->debug("RTT is %s", Util::itos(rtt).c_str()); - targetNode->updateRTT(rtt); + message->getRemoteNode()->updateRTT(rtt); SharedHandle callback = entry->getCallback(); return std::pair, SharedHandle >(message, callback); } diff --git a/src/DHTRoutingTable.cc b/src/DHTRoutingTable.cc index a74a1c1f..a3bd3140 100644 --- a/src/DHTRoutingTable.cc +++ b/src/DHTRoutingTable.cc @@ -33,6 +33,9 @@ */ /* copyright --> */ #include "DHTRoutingTable.h" + +#include + #include "DHTNode.h" #include "DHTBucket.h" #include "BNode.h" @@ -72,6 +75,11 @@ bool DHTRoutingTable::addGoodNode(const SharedHandle& node) bool DHTRoutingTable::addNode(const SharedHandle& node, bool good) { _logger->debug("Trying to add node:%s", node->toString().c_str()); + if(_localNode == node) { + _logger->debug("Adding node with the same ID with localnode is not" + " allowed."); + return false; + } BNode* bnode = BNode::findBNodeFor(_root, node->getID()); SharedHandle bucket = bnode->getBucket(); while(1) { diff --git a/test/BtPortMessageTest.cc b/test/BtPortMessageTest.cc index 3d2d184c..648cfeb1 100644 --- a/test/BtPortMessageTest.cc +++ b/test/BtPortMessageTest.cc @@ -104,7 +104,7 @@ void BtPortMessageTest::testDoReceivedAction() SharedHandle nodes[9]; for(size_t i = 0; i < arrayLength(nodes); ++i) { memset(nodeID, 0, DHT_ID_LENGTH); - nodeID[DHT_ID_LENGTH-1] = i; + nodeID[DHT_ID_LENGTH-1] = i+1; nodes[i].reset(new DHTNode(nodeID)); } diff --git a/test/DHTMessageFactoryImplTest.cc b/test/DHTMessageFactoryImplTest.cc index c22525df..ca62b239 100644 --- a/test/DHTMessageFactoryImplTest.cc +++ b/test/DHTMessageFactoryImplTest.cc @@ -1,4 +1,10 @@ #include "DHTMessageFactoryImpl.h" + +#include +#include + +#include + #include "RecoverableException.h" #include "Util.h" #include "DHTNode.h" @@ -17,9 +23,6 @@ #include "DHTGetPeersReplyMessage.h" #include "DHTAnnouncePeerMessage.h" #include "DHTAnnouncePeerReplyMessage.h" -#include -#include -#include namespace aria2 { @@ -112,7 +115,10 @@ void DHTMessageFactoryImplTest::testCreatePingReplyMessage() remoteNode->setPort(6881); SharedHandle m - (dynamic_pointer_cast(factory->createResponseMessage("ping", d.get(), remoteNode))); + (dynamic_pointer_cast + (factory->createResponseMessage("ping", d.get(), + remoteNode->getIPAddress(), + remoteNode->getPort()))); CPPUNIT_ASSERT(localNode == m->getLocalNode()); CPPUNIT_ASSERT(remoteNode == m->getRemoteNode()); @@ -176,7 +182,10 @@ void DHTMessageFactoryImplTest::testCreateFindNodeReplyMessage() remoteNode->setPort(6881); SharedHandle m - (dynamic_pointer_cast(factory->createResponseMessage("find_node", d.get(), remoteNode))); + (dynamic_pointer_cast + (factory->createResponseMessage("find_node", d.get(), + remoteNode->getIPAddress(), + remoteNode->getPort()))); CPPUNIT_ASSERT(localNode == m->getLocalNode()); CPPUNIT_ASSERT(remoteNode == m->getRemoteNode()); @@ -247,7 +256,10 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage_nodes() remoteNode->setPort(6881); SharedHandle m - (dynamic_pointer_cast(factory->createResponseMessage("get_peers", d.get(), remoteNode))); + (dynamic_pointer_cast + (factory->createResponseMessage("get_peers", d.get(), + remoteNode->getIPAddress(), + remoteNode->getPort()))); CPPUNIT_ASSERT(localNode == m->getLocalNode()); CPPUNIT_ASSERT(remoteNode == m->getRemoteNode()); @@ -290,7 +302,10 @@ void DHTMessageFactoryImplTest::testCreateGetPeersReplyMessage_values() remoteNode->setPort(6881); SharedHandle m - (dynamic_pointer_cast(factory->createResponseMessage("get_peers", d.get(), remoteNode))); + (dynamic_pointer_cast + (factory->createResponseMessage("get_peers", d.get(), + remoteNode->getIPAddress(), + remoteNode->getPort()))); CPPUNIT_ASSERT(localNode == m->getLocalNode()); CPPUNIT_ASSERT(remoteNode == m->getRemoteNode()); @@ -356,7 +371,10 @@ void DHTMessageFactoryImplTest::testCreateAnnouncePeerReplyMessage() remoteNode->setPort(6881); SharedHandle m - (dynamic_pointer_cast(factory->createResponseMessage("announce_peer", d.get(), remoteNode))); + (dynamic_pointer_cast + (factory->createResponseMessage("announce_peer", d.get(), + remoteNode->getIPAddress(), + remoteNode->getPort()))); CPPUNIT_ASSERT(localNode == m->getLocalNode()); CPPUNIT_ASSERT(remoteNode == m->getRemoteNode()); @@ -379,7 +397,9 @@ void DHTMessageFactoryImplTest::testReceivedErrorMessage() remoteNode->setPort(6881); try { - factory->createResponseMessage("announce_peer", d.get(), remoteNode); + factory->createResponseMessage("announce_peer", d.get(), + remoteNode->getIPAddress(), + remoteNode->getPort()); CPPUNIT_FAIL("exception must be thrown."); } catch(RecoverableException& e) { std::cerr << e.stackTrace() << std::endl; diff --git a/test/DHTRoutingTableTest.cc b/test/DHTRoutingTableTest.cc index 717e849d..486da953 100644 --- a/test/DHTRoutingTableTest.cc +++ b/test/DHTRoutingTableTest.cc @@ -1,4 +1,8 @@ #include "DHTRoutingTable.h" + +#include +#include + #include "Exception.h" #include "Util.h" #include "DHTNode.h" @@ -6,8 +10,6 @@ #include "MockDHTTaskQueue.h" #include "MockDHTTaskFactory.h" #include "DHTTask.h" -#include -#include namespace aria2 { @@ -15,6 +17,7 @@ class DHTRoutingTableTest:public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(DHTRoutingTableTest); CPPUNIT_TEST(testAddNode); + CPPUNIT_TEST(testAddNode_localNode); CPPUNIT_TEST(testGetClosestKNodes); CPPUNIT_TEST_SUITE_END(); public: @@ -23,6 +26,7 @@ public: void tearDown() {} void testAddNode(); + void testAddNode_localNode(); void testGetClosestKNodes(); }; @@ -47,6 +51,19 @@ void DHTRoutingTableTest::testAddNode() table.showBuckets(); } +void DHTRoutingTableTest::testAddNode_localNode() +{ + SharedHandle localNode(new DHTNode()); + DHTRoutingTable table(localNode); + SharedHandle taskFactory(new MockDHTTaskFactory()); + table.setTaskFactory(taskFactory); + SharedHandle taskQueue(new MockDHTTaskQueue()); + table.setTaskQueue(taskQueue); + + SharedHandle newNode(new DHTNode(localNode->getID())); + CPPUNIT_ASSERT(!table.addNode(newNode)); +} + static void createID(unsigned char* id, unsigned char firstChar, unsigned char lastChar) { memset(id, 0, DHT_ID_LENGTH); diff --git a/test/MockDHTMessageFactory.h b/test/MockDHTMessageFactory.h index f9c4265e..8cfcacf1 100644 --- a/test/MockDHTMessageFactory.h +++ b/test/MockDHTMessageFactory.h @@ -27,8 +27,12 @@ public: virtual SharedHandle createResponseMessage(const std::string& messageType, const Dictionary* d, - const SharedHandle& remoteNode) + const std::string& ipaddr, uint16_t port) { + SharedHandle remoteNode(new DHTNode()); + // TODO At this point, removeNode's ID is random. + remoteNode->setIPAddress(ipaddr); + remoteNode->setPort(port); SharedHandle m (new MockDHTMessage(_localNode, remoteNode, reinterpret_cast(d->get("t"))->toString()));