/* */ #include "DefaultExtensionMessageFactory.h" #include "Peer.h" #include "DlAbortEx.h" #include "HandshakeExtensionMessage.h" #include "UTPexExtensionMessage.h" #include "fmt.h" #include "PeerStorage.h" #include "ExtensionMessageRegistry.h" #include "DownloadContext.h" #include "BtMessageDispatcher.h" #include "BtMessageFactory.h" #include "UTMetadataRequestExtensionMessage.h" #include "UTMetadataDataExtensionMessage.h" #include "UTMetadataRejectExtensionMessage.h" #include "message.h" #include "PieceStorage.h" #include "UTMetadataRequestTracker.h" #include "RequestGroup.h" #include "bencode2.h" namespace aria2 { DefaultExtensionMessageFactory::DefaultExtensionMessageFactory() : messageFactory_(0), dispatcher_(0), tracker_(0) {} DefaultExtensionMessageFactory::DefaultExtensionMessageFactory (const SharedHandle& peer, const SharedHandle& registry) : peer_(peer), registry_(registry) {} DefaultExtensionMessageFactory::~DefaultExtensionMessageFactory() {} ExtensionMessageHandle DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t length) { uint8_t extensionMessageID = *data; if(extensionMessageID == 0) { // handshake HandshakeExtensionMessageHandle m = HandshakeExtensionMessage::create(data, length); m->setPeer(peer_); m->setDownloadContext(dctx_); return m; } else { std::string extensionName = registry_->getExtensionName(extensionMessageID); if(extensionName.empty()) { throw DL_ABORT_EX (fmt("No extension registered for extended message ID %u", extensionMessageID)); } if(extensionName == "ut_pex") { // uTorrent compatible Peer-Exchange UTPexExtensionMessageHandle m = UTPexExtensionMessage::create(data, length); m->setPeerStorage(peerStorage_); return m; } else if(extensionName == "ut_metadata") { if(length == 0) { throw DL_ABORT_EX (fmt(MSG_TOO_SMALL_PAYLOAD_SIZE, "ut_metadata", static_cast(length))); } size_t end; SharedHandle decoded = bencode2::decode(data+1, data+length, end); const Dict* dict = downcast(decoded); if(!dict) { throw DL_ABORT_EX("Bad ut_metadata: dictionary not found"); } const Integer* msgType = downcast(dict->get("msg_type")); if(!msgType) { throw DL_ABORT_EX("Bad ut_metadata: msg_type not found"); } const Integer* index = downcast(dict->get("piece")); if(!index) { throw DL_ABORT_EX("Bad ut_metadata: piece not found"); } switch(msgType->i()) { case 0: { SharedHandle m (new UTMetadataRequestExtensionMessage(extensionMessageID)); m->setIndex(index->i()); m->setDownloadContext(dctx_); m->setPeer(peer_); m->setBtMessageFactory(messageFactory_); m->setBtMessageDispatcher(dispatcher_); return m; } case 1: { if(end == length) { throw DL_ABORT_EX("Bad ut_metadata data: data not found"); } const Integer* totalSize = downcast(dict->get("total_size")); if(!totalSize) { throw DL_ABORT_EX("Bad ut_metadata data: total_size not found"); } SharedHandle m (new UTMetadataDataExtensionMessage(extensionMessageID)); m->setIndex(index->i()); m->setTotalSize(totalSize->i()); m->setData(&data[1+end], &data[length]); m->setUTMetadataRequestTracker(tracker_); m->setPieceStorage(dctx_->getOwnerRequestGroup()->getPieceStorage()); m->setDownloadContext(dctx_); return m; } case 2: { SharedHandle m (new UTMetadataRejectExtensionMessage(extensionMessageID)); m->setIndex(index->i()); // No need to inject tracker because peer will be disconnected. return m; } default: throw DL_ABORT_EX (fmt("Bad ut_metadata: unknown msg_type=%u", static_cast(msgType->i()))); } } else { throw DL_ABORT_EX (fmt("Unsupported extension message received." " extensionMessageID=%u, extensionName=%s", extensionMessageID, extensionName.c_str())); } } } void DefaultExtensionMessageFactory::setPeerStorage (const SharedHandle& peerStorage) { peerStorage_ = peerStorage; } void DefaultExtensionMessageFactory::setPeer(const SharedHandle& peer) { peer_ = peer; } } // namespace aria2