/* */ #include "DHTMessageReceiver.h" #include #include #include "DHTMessageTracker.h" #include "DHTConnection.h" #include "DHTMessage.h" #include "DHTQueryMessage.h" #include "DHTResponseMessage.h" #include "DHTUnknownMessage.h" #include "DHTMessageFactory.h" #include "DHTRoutingTable.h" #include "DHTNode.h" #include "DHTMessageCallback.h" #include "DlAbortEx.h" #include "LogFactory.h" #include "Logger.h" #include "util.h" #include "bencode2.h" #include "fmt.h" namespace aria2 { DHTMessageReceiver::DHTMessageReceiver (const SharedHandle& tracker) : tracker_(tracker) {} DHTMessageReceiver::~DHTMessageReceiver() {} SharedHandle DHTMessageReceiver::receiveMessage() { std::string remoteAddr; uint16_t remotePort; unsigned char data[64*1024]; try { ssize_t length = connection_->receiveMessage(data, sizeof(data), remoteAddr, remotePort); if(length <= 0) { return SharedHandle(); } bool isReply = false; SharedHandle decoded = bencode2::decode(data, data+length); const Dict* dict = downcast(decoded); if(dict) { const String* y = downcast(dict->get(DHTMessage::Y)); if(y) { if(y->s() == DHTResponseMessage::R || y->s() == DHTUnknownMessage::E) { isReply = true; } } else { A2_LOG_INFO(fmt("Malformed DHT message. Missing 'y' key. From:%s:%u", remoteAddr.c_str(), remotePort)); return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); } } else { A2_LOG_INFO(fmt("Malformed DHT message. This is not a bencoded directory." " From:%s:%u", remoteAddr.c_str(), remotePort)); return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); } if(isReply) { std::pair, SharedHandle > p = tracker_->messageArrived(dict, remoteAddr, remotePort); if(!p.first) { // timeout or malicious? message return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); } onMessageReceived(p.first); if(p.second) { p.second->onReceived(p.first); } return p.first; } else { SharedHandle message = factory_->createQueryMessage(dict, remoteAddr, remotePort); if(*message->getLocalNode() == *message->getRemoteNode()) { // drop message from localnode A2_LOG_INFO("Received DHT message from localnode."); return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); } onMessageReceived(message); return message; } } catch(RecoverableException& e) { A2_LOG_INFO_EX("Exception thrown while receiving DHT message.", e); return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); } } void DHTMessageReceiver::onMessageReceived (const SharedHandle& message) { A2_LOG_INFO(fmt("Message received: %s", message->toString().c_str())); message->validate(); message->doReceivedAction(); message->getRemoteNode()->markGood(); message->getRemoteNode()->updateLastContact(); routingTable_->addGoodNode(message->getRemoteNode()); } void DHTMessageReceiver::handleTimeout() { tracker_->handleTimeout(); } SharedHandle DHTMessageReceiver::handleUnknownMessage(const unsigned char* data, size_t length, const std::string& remoteAddr, uint16_t remotePort) { SharedHandle m = factory_->createUnknownMessage(data, length, remoteAddr, remotePort); A2_LOG_INFO(fmt("Message received: %s", m->toString().c_str())); return m; } void DHTMessageReceiver::setConnection(const SharedHandle& connection) { connection_ = connection; } void DHTMessageReceiver::setMessageFactory(const SharedHandle& factory) { factory_ = factory; } void DHTMessageReceiver::setRoutingTable(const SharedHandle& routingTable) { routingTable_ = routingTable; } } // namespace aria2