/* */ #include "DefaultExtensionMessageFactory.h" #include #include "Peer.h" #include "DlAbortEx.h" #include "HandshakeExtensionMessage.h" #include "UTPexExtensionMessage.h" #include "fmt.h" #include "PeerStorage.h" #include "ExtensionMessageRegistry.h" #include "DownloadContext.h" #include "BtMessageDispatcher.h" #include "BtMessageFactory.h" #include "UTMetadataRequestExtensionMessage.h" #include "UTMetadataDataExtensionMessage.h" #include "UTMetadataRejectExtensionMessage.h" #include "message.h" #include "PieceStorage.h" #include "UTMetadataRequestTracker.h" #include "RequestGroup.h" #include "bencode2.h" namespace aria2 { // i686-w64-mingw32-g++ 4.6 does not support constructor delegate DefaultExtensionMessageFactory::DefaultExtensionMessageFactory() : peerStorage_{nullptr}, registry_{nullptr}, dctx_{nullptr}, messageFactory_{nullptr}, dispatcher_{nullptr}, tracker_{nullptr} {} DefaultExtensionMessageFactory::DefaultExtensionMessageFactory (const std::shared_ptr& peer, ExtensionMessageRegistry* registry) : peerStorage_{nullptr}, peer_{peer}, registry_{registry}, dctx_{nullptr}, messageFactory_{nullptr}, dispatcher_{nullptr}, tracker_{nullptr} {} std::unique_ptr DefaultExtensionMessageFactory::createMessage (const unsigned char* data, size_t length) { uint8_t extensionMessageID = *data; if(extensionMessageID == 0) { // handshake auto m = HandshakeExtensionMessage::create(data, length); m->setPeer(peer_); m->setDownloadContext(dctx_); return std::move(m); } else { const char* extensionName = registry_->getExtensionName(extensionMessageID); if(!extensionName) { throw DL_ABORT_EX (fmt("No extension registered for extended message ID %u", extensionMessageID)); } if(strcmp(extensionName, "ut_pex") == 0) { // uTorrent compatible Peer-Exchange auto m = UTPexExtensionMessage::create(data, length); m->setPeerStorage(peerStorage_); return std::move(m); } else if(strcmp(extensionName, "ut_metadata") == 0) { if(length == 0) { throw DL_ABORT_EX (fmt(MSG_TOO_SMALL_PAYLOAD_SIZE, "ut_metadata", static_cast(length))); } size_t end; auto decoded = bencode2::decode(data+1, length - 1, end); const Dict* dict = downcast(decoded); if(!dict) { throw DL_ABORT_EX("Bad ut_metadata: dictionary not found"); } const Integer* msgType = downcast(dict->get("msg_type")); if(!msgType) { throw DL_ABORT_EX("Bad ut_metadata: msg_type not found"); } const Integer* index = downcast(dict->get("piece")); if(!index) { throw DL_ABORT_EX("Bad ut_metadata: piece not found"); } switch(msgType->i()) { case 0: { auto m = make_unique (extensionMessageID); m->setIndex(index->i()); m->setDownloadContext(dctx_); m->setPeer(peer_); m->setBtMessageFactory(messageFactory_); m->setBtMessageDispatcher(dispatcher_); return std::move(m); } case 1: { if(end == length) { throw DL_ABORT_EX("Bad ut_metadata data: data not found"); } const Integer* totalSize = downcast(dict->get("total_size")); if(!totalSize) { throw DL_ABORT_EX("Bad ut_metadata data: total_size not found"); } auto m = make_unique (extensionMessageID); m->setIndex(index->i()); m->setTotalSize(totalSize->i()); m->setData(&data[1+end], &data[length]); m->setUTMetadataRequestTracker(tracker_); m->setPieceStorage (dctx_->getOwnerRequestGroup()->getPieceStorage().get()); m->setDownloadContext(dctx_); return std::move(m); } case 2: { auto m = make_unique (extensionMessageID); m->setIndex(index->i()); // No need to inject tracker because peer will be disconnected. return std::move(m); } default: throw DL_ABORT_EX (fmt("Bad ut_metadata: unknown msg_type=%u", static_cast(msgType->i()))); } } else { throw DL_ABORT_EX (fmt("Unsupported extension message received." " extensionMessageID=%u, extensionName=%s", extensionMessageID, extensionName)); } } } void DefaultExtensionMessageFactory::setPeerStorage(PeerStorage* peerStorage) { peerStorage_ = peerStorage; } void DefaultExtensionMessageFactory::setPeer(const std::shared_ptr& peer) { peer_ = peer; } void DefaultExtensionMessageFactory::setExtensionMessageRegistry (ExtensionMessageRegistry* registry) { registry_ = registry; } void DefaultExtensionMessageFactory::setDownloadContext(DownloadContext* dctx) { dctx_ = dctx; } void DefaultExtensionMessageFactory::setBtMessageFactory (BtMessageFactory* factory) { messageFactory_ = factory; } void DefaultExtensionMessageFactory::setBtMessageDispatcher (BtMessageDispatcher* disp) { dispatcher_ = disp; } void DefaultExtensionMessageFactory::setUTMetadataRequestTracker (UTMetadataRequestTracker* tracker) { tracker_ = tracker; } } // namespace aria2