/* */ #include "DefaultBtMessageFactory.h" #include "DlAbortEx.h" #include "bittorrent_helper.h" #include "BtKeepAliveMessage.h" #include "BtChokeMessage.h" #include "BtUnchokeMessage.h" #include "BtInterestedMessage.h" #include "BtNotInterestedMessage.h" #include "BtHaveMessage.h" #include "BtBitfieldMessage.h" #include "BtBitfieldMessageValidator.h" #include "RangeBtMessageValidator.h" #include "IndexBtMessageValidator.h" #include "BtRequestMessage.h" #include "BtCancelMessage.h" #include "BtPieceMessage.h" #include "BtPieceMessageValidator.h" #include "BtPortMessage.h" #include "BtHaveAllMessage.h" #include "BtHaveNoneMessage.h" #include "BtRejectMessage.h" #include "BtSuggestPieceMessage.h" #include "BtAllowedFastMessage.h" #include "BtHandshakeMessage.h" #include "BtHandshakeMessageValidator.h" #include "BtExtendedMessage.h" #include "ExtensionMessage.h" #include "Peer.h" #include "Piece.h" #include "DownloadContext.h" #include "PieceStorage.h" #include "PeerStorage.h" #include "fmt.h" #include "ExtensionMessageFactory.h" #include "bittorrent_helper.h" namespace aria2 { DefaultBtMessageFactory::DefaultBtMessageFactory() : cuid_{0}, downloadContext_{nullptr}, pieceStorage_{nullptr}, peerStorage_{nullptr}, dhtEnabled_(false), dispatcher_{nullptr}, requestFactory_{nullptr}, peerConnection_{nullptr}, extensionMessageFactory_{nullptr}, localNode_{nullptr}, routingTable_{nullptr}, taskQueue_{nullptr}, taskFactory_{nullptr}, metadataGetMode_(false) { } std::unique_ptr DefaultBtMessageFactory::createBtMessage(const unsigned char* data, size_t dataLength) { auto msg = std::unique_ptr{}; if (dataLength == 0) { // keep-alive msg = make_unique(); } else { uint8_t id = bittorrent::getId(data); switch (id) { case BtChokeMessage::ID: msg = BtChokeMessage::create(data, dataLength); break; case BtUnchokeMessage::ID: msg = BtUnchokeMessage::create(data, dataLength); break; case BtInterestedMessage::ID: { auto m = BtInterestedMessage::create(data, dataLength); m->setPeerStorage(peerStorage_); msg = std::move(m); break; } case BtNotInterestedMessage::ID: { auto m = BtNotInterestedMessage::create(data, dataLength); m->setPeerStorage(peerStorage_); msg = std::move(m); break; } case BtHaveMessage::ID: msg = BtHaveMessage::create(data, dataLength); if (!metadataGetMode_) { msg->setBtMessageValidator(make_unique( static_cast(msg.get()), downloadContext_->getNumPieces())); } break; case BtBitfieldMessage::ID: msg = BtBitfieldMessage::create(data, dataLength); if (!metadataGetMode_) { msg->setBtMessageValidator(make_unique( static_cast(msg.get()), downloadContext_->getNumPieces())); } break; case BtRequestMessage::ID: { auto m = BtRequestMessage::create(data, dataLength); if (!metadataGetMode_) { m->setBtMessageValidator(make_unique( static_cast(m.get()), downloadContext_->getNumPieces(), pieceStorage_->getPieceLength(m->getIndex()))); } msg = std::move(m); break; } case BtPieceMessage::ID: { auto m = BtPieceMessage::create(data, dataLength); if (!metadataGetMode_) { m->setBtMessageValidator(make_unique( static_cast(m.get()), downloadContext_->getNumPieces(), pieceStorage_->getPieceLength(m->getIndex()))); } m->setDownloadContext(downloadContext_); m->setPeerStorage(peerStorage_); msg = std::move(m); break; } case BtCancelMessage::ID: { auto m = BtCancelMessage::create(data, dataLength); if (!metadataGetMode_) { m->setBtMessageValidator(make_unique( static_cast(m.get()), downloadContext_->getNumPieces(), pieceStorage_->getPieceLength(m->getIndex()))); } msg = std::move(m); break; } case BtPortMessage::ID: { auto m = BtPortMessage::create(data, dataLength); m->setLocalNode(localNode_); m->setRoutingTable(routingTable_); m->setTaskQueue(taskQueue_); m->setTaskFactory(taskFactory_); msg = std::move(m); break; } case BtSuggestPieceMessage::ID: { auto m = BtSuggestPieceMessage::create(data, dataLength); if (!metadataGetMode_) { m->setBtMessageValidator(make_unique( static_cast(m.get()), downloadContext_->getNumPieces())); } msg = std::move(m); break; } case BtHaveAllMessage::ID: msg = BtHaveAllMessage::create(data, dataLength); break; case BtHaveNoneMessage::ID: msg = BtHaveNoneMessage::create(data, dataLength); break; case BtRejectMessage::ID: { auto m = BtRejectMessage::create(data, dataLength); if (!metadataGetMode_) { m->setBtMessageValidator(make_unique( static_cast(m.get()), downloadContext_->getNumPieces(), pieceStorage_->getPieceLength(m->getIndex()))); } msg = std::move(m); break; } case BtAllowedFastMessage::ID: { auto m = BtAllowedFastMessage::create(data, dataLength); if (!metadataGetMode_) { m->setBtMessageValidator(make_unique( static_cast(m.get()), downloadContext_->getNumPieces())); } msg = std::move(m); break; } case BtExtendedMessage::ID: { if (peer_->isExtendedMessagingEnabled()) { msg = BtExtendedMessage::create(extensionMessageFactory_, peer_, data, dataLength); } else { throw DL_ABORT_EX("Received extended message from peer during" " a session with extended messaging disabled."); } break; } default: throw DL_ABORT_EX(fmt("Invalid message ID. id=%u", id)); } } setCommonProperty(msg.get()); return std::move(msg); } void DefaultBtMessageFactory::setCommonProperty(AbstractBtMessage* msg) { msg->setCuid(cuid_); msg->setPeer(peer_); msg->setPieceStorage(pieceStorage_); msg->setBtMessageDispatcher(dispatcher_); msg->setBtRequestFactory(requestFactory_); msg->setBtMessageFactory(this); msg->setPeerConnection(peerConnection_); if (metadataGetMode_) { msg->enableMetadataGetMode(); } } std::unique_ptr DefaultBtMessageFactory::createHandshakeMessage(const unsigned char* data, size_t dataLength) { auto msg = BtHandshakeMessage::create(data, dataLength); msg->setBtMessageValidator(make_unique( msg.get(), bittorrent::getInfoHash(downloadContext_))); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createHandshakeMessage(const unsigned char* infoHash, const unsigned char* peerId) { auto msg = make_unique(infoHash, peerId); msg->setDHTEnabled(dhtEnabled_); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createRequestMessage( const std::shared_ptr& piece, size_t blockIndex) { auto msg = make_unique( piece->getIndex(), blockIndex * piece->getBlockLength(), piece->getBlockLength(blockIndex), blockIndex); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createCancelMessage(size_t index, int32_t begin, int32_t length) { auto msg = make_unique(index, begin, length); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createPieceMessage(size_t index, int32_t begin, int32_t length) { auto msg = make_unique(index, begin, length); msg->setDownloadContext(downloadContext_); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createHaveMessage(size_t index) { auto msg = make_unique(index); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createChokeMessage() { auto msg = make_unique(); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createUnchokeMessage() { auto msg = make_unique(); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createInterestedMessage() { auto msg = make_unique(); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createNotInterestedMessage() { auto msg = make_unique(); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createBitfieldMessage() { auto msg = make_unique(pieceStorage_->getBitfield(), pieceStorage_->getBitfieldLength()); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createKeepAliveMessage() { auto msg = make_unique(); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createHaveAllMessage() { auto msg = make_unique(); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createHaveNoneMessage() { auto msg = make_unique(); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createRejectMessage(size_t index, int32_t begin, int32_t length) { auto msg = make_unique(index, begin, length); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createAllowedFastMessage(size_t index) { auto msg = make_unique(index); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createPortMessage(uint16_t port) { auto msg = make_unique(port); setCommonProperty(msg.get()); return msg; } std::unique_ptr DefaultBtMessageFactory::createBtExtendedMessage( std::unique_ptr exmsg) { auto msg = make_unique(std::move(exmsg)); setCommonProperty(msg.get()); return msg; } void DefaultBtMessageFactory::setTaskQueue(DHTTaskQueue* taskQueue) { taskQueue_ = taskQueue; } void DefaultBtMessageFactory::setTaskFactory(DHTTaskFactory* taskFactory) { taskFactory_ = taskFactory; } void DefaultBtMessageFactory::setPeer(const std::shared_ptr& peer) { peer_ = peer; } void DefaultBtMessageFactory::setDownloadContext( DownloadContext* downloadContext) { downloadContext_ = downloadContext; } void DefaultBtMessageFactory::setPieceStorage(PieceStorage* pieceStorage) { pieceStorage_ = pieceStorage; } void DefaultBtMessageFactory::setPeerStorage(PeerStorage* peerStorage) { peerStorage_ = peerStorage; } void DefaultBtMessageFactory::setBtMessageDispatcher( BtMessageDispatcher* dispatcher) { dispatcher_ = dispatcher; } void DefaultBtMessageFactory::setExtensionMessageFactory( ExtensionMessageFactory* factory) { extensionMessageFactory_ = factory; } void DefaultBtMessageFactory::setLocalNode(DHTNode* localNode) { localNode_ = localNode; } void DefaultBtMessageFactory::setRoutingTable(DHTRoutingTable* routingTable) { routingTable_ = routingTable; } void DefaultBtMessageFactory::setBtRequestFactory(BtRequestFactory* factory) { requestFactory_ = factory; } void DefaultBtMessageFactory::setPeerConnection(PeerConnection* connection) { peerConnection_ = connection; } } // namespace aria2