/* */ #include "DHTMessageTracker.h" #include #include "DHTMessage.h" #include "DHTMessageCallback.h" #include "DHTMessageTrackerEntry.h" #include "DHTNode.h" #include "DHTRoutingTable.h" #include "DHTMessageFactory.h" #include "util.h" #include "LogFactory.h" #include "Logger.h" #include "DlAbortEx.h" #include "DHTConstants.h" #include "fmt.h" namespace aria2 { DHTMessageTracker::DHTMessageTracker() {} DHTMessageTracker::~DHTMessageTracker() {} void DHTMessageTracker::addMessage(const std::shared_ptr& message, time_t timeout, const std::shared_ptr& callback) { std::shared_ptr e(new DHTMessageTrackerEntry(message, timeout, callback)); entries_.push_back(e); } std::pair, std::shared_ptr > DHTMessageTracker::messageArrived (const Dict* dict, const std::string& ipaddr, uint16_t port) { const String* tid = downcast(dict->get(DHTMessage::T)); if(!tid) { throw DL_ABORT_EX(fmt("Malformed DHT message. From:%s:%u", ipaddr.c_str(), port)); } A2_LOG_DEBUG(fmt("Searching tracker entry for TransactionID=%s, Remote=%s:%u", util::toHex(tid->s()).c_str(), ipaddr.c_str(), port)); for(std::deque >::iterator i = entries_.begin(), eoi = entries_.end(); i != eoi; ++i) { if((*i)->match(tid->s(), ipaddr, port)) { std::shared_ptr entry = *i; entries_.erase(i); A2_LOG_DEBUG("Tracker entry found."); std::shared_ptr targetNode = entry->getTargetNode(); try { std::shared_ptr message = factory_->createResponseMessage(entry->getMessageType(), dict, targetNode->getIPAddress(), targetNode->getPort()); int64_t rtt = entry->getElapsedMillis(); A2_LOG_DEBUG(fmt("RTT is %" PRId64 "", rtt)); message->getRemoteNode()->updateRTT(rtt); std::shared_ptr callback = entry->getCallback(); if(!(*targetNode == *message->getRemoteNode())) { // Node ID has changed. Drop previous node ID from // DHTRoutingTable A2_LOG_DEBUG (fmt("Node ID has changed: old:%s, new:%s", util::toHex(targetNode->getID(), DHT_ID_LENGTH).c_str(), util::toHex(message->getRemoteNode()->getID(), DHT_ID_LENGTH).c_str())); routingTable_->dropNode(targetNode); } return std::make_pair(message, callback); } catch(RecoverableException& e) { handleTimeoutEntry(entry); throw; } } } A2_LOG_DEBUG("Tracker entry not found."); return std::pair, std::shared_ptr >(); } void DHTMessageTracker::handleTimeoutEntry (const std::shared_ptr& entry) { try { std::shared_ptr node = entry->getTargetNode(); A2_LOG_DEBUG(fmt("Message timeout: To:%s:%u", node->getIPAddress().c_str(), node->getPort())); node->updateRTT(entry->getElapsedMillis()); node->timeout(); if(node->isBad()) { A2_LOG_DEBUG(fmt("Marked bad: %s:%u", node->getIPAddress().c_str(), node->getPort())); routingTable_->dropNode(node); } std::shared_ptr callback = entry->getCallback(); if(callback) { callback->onTimeout(node); } } catch(RecoverableException& e) { A2_LOG_INFO_EX("Exception thrown while handling timeouts.", e); } } namespace { struct HandleTimeout { HandleTimeout(DHTMessageTracker* tracker) : tracker(tracker) {} bool operator()(const std::shared_ptr& ent) const { if(ent->isTimeout()) { tracker->handleTimeoutEntry(ent); return true; } else { return false; } } DHTMessageTracker* tracker; }; } // namespace void DHTMessageTracker::handleTimeout() { entries_.erase(std::remove_if(entries_.begin(), entries_.end(), HandleTimeout(this)), entries_.end()); } std::shared_ptr DHTMessageTracker::getEntryFor(const std::shared_ptr& message) const { for(std::deque >::const_iterator i = entries_.begin(), eoi = entries_.end(); i != eoi; ++i) { if((*i)->match(message->getTransactionID(), message->getRemoteNode()->getIPAddress(), message->getRemoteNode()->getPort())) { return *i; } } return std::shared_ptr(); } size_t DHTMessageTracker::countEntry() const { return entries_.size(); } void DHTMessageTracker::setRoutingTable (const std::shared_ptr& routingTable) { routingTable_ = routingTable; } void DHTMessageTracker::setMessageFactory (const std::shared_ptr& factory) { factory_ = factory; } } // namespace aria2