/* */ #include "DHTMessageTracker.h" #include #include "DHTMessage.h" #include "DHTMessageCallback.h" #include "DHTMessageTrackerEntry.h" #include "DHTNode.h" #include "DHTRoutingTable.h" #include "DHTMessageFactory.h" #include "util.h" #include "LogFactory.h" #include "Logger.h" #include "DlAbortEx.h" #include "DHTConstants.h" #include "StringFormat.h" #include "bencode.h" namespace aria2 { DHTMessageTracker::DHTMessageTracker(): _logger(LogFactory::getInstance()) {} DHTMessageTracker::~DHTMessageTracker() {} void DHTMessageTracker::addMessage(const SharedHandle& message, time_t timeout, const SharedHandle& callback) { SharedHandle e(new DHTMessageTrackerEntry(message, timeout, callback)); _entries.push_back(e); } void DHTMessageTracker::addMessage(const SharedHandle& message, const SharedHandle& callback) { addMessage(message, DHT_MESSAGE_TIMEOUT, callback); } std::pair, SharedHandle > DHTMessageTracker::messageArrived(const BDE& dict, const std::string& ipaddr, uint16_t port) { const BDE& tid = dict[DHTMessage::T]; if(!tid.isString()) { throw DL_ABORT_EX(StringFormat("Malformed DHT message. From:%s:%u", ipaddr.c_str(), port).str()); } _logger->debug("Searching tracker entry for TransactionID=%s, Remote=%s:%u", util::toHex(tid.s()).c_str(), ipaddr.c_str(), port); for(std::deque >::iterator i = _entries.begin(); i != _entries.end(); ++i) { if((*i)->match(tid.s(), ipaddr, port)) { SharedHandle entry = *i; _entries.erase(i); _logger->debug("Tracker entry found."); SharedHandle targetNode = entry->getTargetNode(); SharedHandle message = _factory->createResponseMessage(entry->getMessageType(), dict, targetNode->getIPAddress(), targetNode->getPort()); int64_t rtt = entry->getElapsedMillis(); _logger->debug("RTT is %s", util::itos(rtt).c_str()); message->getRemoteNode()->updateRTT(rtt); SharedHandle callback = entry->getCallback(); return std::pair, SharedHandle >(message, callback); } } _logger->debug("Tracker entry not found."); return std::pair, SharedHandle >(); } void DHTMessageTracker::handleTimeout() { for(std::deque >::iterator i = _entries.begin(); i != _entries.end();) { if((*i)->isTimeout()) { try { SharedHandle entry = *i; i = _entries.erase(i); SharedHandle node = entry->getTargetNode(); _logger->debug("Message timeout: To:%s:%u", node->getIPAddress().c_str(), node->getPort()); node->updateRTT(entry->getElapsedMillis()); node->timeout(); if(node->isBad()) { _logger->debug("Marked bad: %s:%u", node->getIPAddress().c_str(), node->getPort()); _routingTable->dropNode(node); } SharedHandle callback = entry->getCallback(); if(!callback.isNull()) { callback->onTimeout(node); } } catch(RecoverableException& e) { _logger->info("Exception thrown while handling timeouts.", e); } } else { ++i; } } } SharedHandle DHTMessageTracker::getEntryFor(const SharedHandle& message) const { for(std::deque >::const_iterator i = _entries.begin(); i != _entries.end(); ++i) { if((*i)->match(message->getTransactionID(), message->getRemoteNode()->getIPAddress(), message->getRemoteNode()->getPort())) { return *i; } } return SharedHandle(); } size_t DHTMessageTracker::countEntry() const { return _entries.size(); } void DHTMessageTracker::setRoutingTable(const SharedHandle& routingTable) { _routingTable = routingTable; } void DHTMessageTracker::setMessageFactory(const SharedHandle& factory) { _factory = factory; } } // namespace aria2