/* */ #include "DefaultBtMessageReceiver.h" #include #include "BtHandshakeMessage.h" #include "message.h" #include "DownloadContext.h" #include "Peer.h" #include "PeerConnection.h" #include "BtMessageDispatcher.h" #include "BtMessageFactory.h" #include "Logger.h" #include "LogFactory.h" #include "bittorrent_helper.h" #include "BtPieceMessage.h" #include "util.h" #include "fmt.h" #include "DlAbortEx.h" namespace aria2 { DefaultBtMessageReceiver::DefaultBtMessageReceiver(): handshakeSent_(false), peerConnection_(0), dispatcher_(0), messageFactory_(0) {} SharedHandle DefaultBtMessageReceiver::receiveHandshake(bool quickReply) { A2_LOG_DEBUG (fmt("Receiving handshake bufferLength=%lu", static_cast(peerConnection_->getBufferLength()))); unsigned char data[BtHandshakeMessage::MESSAGE_LENGTH]; size_t dataLength = BtHandshakeMessage::MESSAGE_LENGTH; SharedHandle msg; if(handshakeSent_ || !quickReply || peerConnection_->getBufferLength() < 48) { if(peerConnection_->receiveHandshake(data, dataLength)) { msg = messageFactory_->createHandshakeMessage(data, dataLength); msg->validate(); } } // Handle tracker's NAT-checking feature if(!handshakeSent_ && quickReply && peerConnection_->getBufferLength() >= 48){ handshakeSent_ = true; // check info_hash if(memcmp(bittorrent::getInfoHash(downloadContext_), peerConnection_->getBuffer()+28, INFO_HASH_LENGTH) == 0) { sendHandshake(); } else { throw DL_ABORT_EX (fmt("Bad Info Hash %s", util::toHex(peerConnection_->getBuffer()+28, INFO_HASH_LENGTH).c_str())); } if(!msg && peerConnection_->getBufferLength() == BtHandshakeMessage::MESSAGE_LENGTH && peerConnection_->receiveHandshake(data, dataLength)) { msg = messageFactory_->createHandshakeMessage(data, dataLength); msg->validate(); } } return msg; } SharedHandle DefaultBtMessageReceiver::receiveAndSendHandshake() { return receiveHandshake(true); } void DefaultBtMessageReceiver::sendHandshake() { SharedHandle msg = messageFactory_->createHandshakeMessage (bittorrent::getInfoHash(downloadContext_), bittorrent::getStaticPeerId()); dispatcher_->addMessageToQueue(msg); dispatcher_->sendMessages(); } SharedHandle DefaultBtMessageReceiver::receiveMessage() { size_t dataLength = 0; // Give 0 to PeerConnection::receiveMessage() to prevent memcpy. if(!peerConnection_->receiveMessage(0, dataLength)) { return SharedHandle(); } SharedHandle msg = messageFactory_->createBtMessage(peerConnection_->getMsgPayloadBuffer(), dataLength); msg->validate(); if(msg->getId() == BtPieceMessage::ID) { SharedHandle piecemsg = static_pointer_cast(msg); piecemsg->setMsgPayload(peerConnection_->getMsgPayloadBuffer()); } return msg; } void DefaultBtMessageReceiver::setDownloadContext (const SharedHandle& downloadContext) { downloadContext_ = downloadContext; } void DefaultBtMessageReceiver::setPeerConnection(PeerConnection* peerConnection) { peerConnection_ = peerConnection; } void DefaultBtMessageReceiver::setDispatcher(BtMessageDispatcher* dispatcher) { dispatcher_ = dispatcher; } void DefaultBtMessageReceiver::setBtMessageFactory(BtMessageFactory* factory) { messageFactory_ = factory; } } // namespace aria2