/* */ #include "DHTMessageReceiver.h" #include #include #include "DHTMessageTracker.h" #include "DHTConnection.h" #include "DHTMessage.h" #include "DHTResponseMessage.h" #include "DHTUnknownMessage.h" #include "DHTMessageFactory.h" #include "DHTRoutingTable.h" #include "DHTNode.h" #include "DHTMessageCallback.h" #include "DlAbortEx.h" #include "LogFactory.h" #include "Logger.h" #include "Util.h" #include "bencode.h" namespace aria2 { DHTMessageReceiver::DHTMessageReceiver(const SharedHandle& tracker): _tracker(tracker), _logger(LogFactory::getInstance()) {} DHTMessageReceiver::~DHTMessageReceiver() {} SharedHandle DHTMessageReceiver::receiveMessage() { std::string remoteAddr; uint16_t remotePort; unsigned char data[64*1024]; ssize_t length = _connection->receiveMessage(data, sizeof(data), remoteAddr, remotePort); if(length <= 0) { return SharedHandle(); } try { bool isReply = false; const bencode::BDE dict = bencode::decode(data, length); if(dict.isDict()) { const bencode::BDE& y = dict[DHTMessage::Y]; if(y.isString()) { if(y.s() == DHTResponseMessage::R || y.s() == DHTUnknownMessage::E) { isReply = true; } } else { _logger->info("Malformed DHT message. Missing 'y' key. From:%s:%u", remoteAddr.c_str(), remotePort); return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); } } else { _logger->info("Malformed DHT message. This is not a bencoded directory." " From:%s:%u", remoteAddr.c_str(), remotePort); return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); } SharedHandle message; SharedHandle callback; if(isReply) { std::pair, SharedHandle > p = _tracker->messageArrived(dict, remoteAddr, remotePort); message = p.first; callback = p.second; if(message.isNull()) { // timeout or malicious? message return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); } } else { message = _factory->createQueryMessage(dict, remoteAddr, remotePort); if(message->getLocalNode() == message->getRemoteNode()) { // drop message from localnode _logger->info("Recieved DHT message from localnode."); return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); } } _logger->info("Message received: %s", message->toString().c_str()); message->validate(); message->doReceivedAction(); message->getRemoteNode()->markGood(); message->getRemoteNode()->updateLastContact(); _routingTable->addGoodNode(message->getRemoteNode()); if(!callback.isNull()) { callback->onReceived(message); } return message; } catch(RecoverableException& e) { _logger->info("Exception thrown while receiving DHT message.", e); return handleUnknownMessage(data, sizeof(data), remoteAddr, remotePort); } } void DHTMessageReceiver::handleTimeout() { _tracker->handleTimeout(); } SharedHandle DHTMessageReceiver::handleUnknownMessage(const unsigned char* data, size_t length, const std::string& remoteAddr, uint16_t remotePort) { SharedHandle m = _factory->createUnknownMessage(data, length, remoteAddr, remotePort); _logger->info("Message received: %s", m->toString().c_str()); return m; } SharedHandle DHTMessageReceiver::getConnection() const { return _connection; } SharedHandle DHTMessageReceiver::getMessageTracker() const { return _tracker; } void DHTMessageReceiver::setConnection(const SharedHandle& connection) { _connection = connection; } void DHTMessageReceiver::setMessageFactory(const SharedHandle& factory) { _factory = factory; } void DHTMessageReceiver::setRoutingTable(const SharedHandle& routingTable) { _routingTable = routingTable; } } // namespace aria2