/* */ #include "DefaultBtRequestFactory.h" #include #include "LogFactory.h" #include "Logger.h" #include "Piece.h" #include "Peer.h" #include "PieceStorage.h" #include "BtMessageDispatcher.h" #include "BtMessageFactory.h" #include "BtMessage.h" #include "a2functional.h" #include "SimpleRandomizer.h" #include "array_fun.h" #include "fmt.h" #include "BtRequestMessage.h" namespace aria2 { DefaultBtRequestFactory::DefaultBtRequestFactory() : pieceStorage_(nullptr), dispatcher_(nullptr), messageFactory_(nullptr), cuid_(0) {} DefaultBtRequestFactory::~DefaultBtRequestFactory() {} void DefaultBtRequestFactory::addTargetPiece(const std::shared_ptr& piece) { pieces_.push_back(piece); } namespace { class AbortCompletedPieceRequest { private: BtMessageDispatcher* dispatcher_; public: AbortCompletedPieceRequest(BtMessageDispatcher* dispatcher): dispatcher_(dispatcher) {} void operator()(const std::shared_ptr& piece) { if(piece->pieceComplete()) { dispatcher_->doAbortOutstandingRequestAction(piece); } } }; } // namespace void DefaultBtRequestFactory::removeCompletedPiece() { std::for_each(pieces_.begin(), pieces_.end(), AbortCompletedPieceRequest(dispatcher_)); pieces_.erase(std::remove_if(pieces_.begin(), pieces_.end(), std::mem_fn(&Piece::pieceComplete)), pieces_.end()); } void DefaultBtRequestFactory::removeTargetPiece (const std::shared_ptr& piece) { pieces_.erase(std::remove_if(pieces_.begin(), pieces_.end(), derefEqual(piece)), pieces_.end()); dispatcher_->doAbortOutstandingRequestAction(piece); pieceStorage_->cancelPiece(piece, cuid_); } namespace { class ProcessChokedPiece { private: std::shared_ptr peer_; PieceStorage* pieceStorage_; cuid_t cuid_; public: ProcessChokedPiece(const std::shared_ptr& peer, PieceStorage* pieceStorage, cuid_t cuid): peer_(peer), pieceStorage_(pieceStorage), cuid_(cuid) {} void operator()(const std::shared_ptr& piece) { if(!peer_->isInPeerAllowedIndexSet(piece->getIndex())) { pieceStorage_->cancelPiece(piece, cuid_); } } }; } // namespace namespace { class FindChokedPiece { private: std::shared_ptr peer_; public: FindChokedPiece(const std::shared_ptr& peer):peer_(peer) {} bool operator()(const std::shared_ptr& piece) { return !peer_->isInPeerAllowedIndexSet(piece->getIndex()); } }; } // namespace void DefaultBtRequestFactory::doChokedAction() { std::for_each(pieces_.begin(), pieces_.end(), ProcessChokedPiece(peer_, pieceStorage_, cuid_)); pieces_.erase(std::remove_if(pieces_.begin(), pieces_.end(), FindChokedPiece(peer_)), pieces_.end()); } void DefaultBtRequestFactory::removeAllTargetPiece() { for(auto & elem : pieces_) { dispatcher_->doAbortOutstandingRequestAction(elem); pieceStorage_->cancelPiece(elem, cuid_); } pieces_.clear(); } std::vector> DefaultBtRequestFactory::createRequestMessages(size_t max, bool endGame) { if(endGame) { return createRequestMessagesOnEndGame(max); } auto requests = std::vector>{}; size_t getnum = max-requests.size(); auto blockIndexes = std::vector{}; blockIndexes.reserve(getnum); for(auto itr = std::begin(pieces_), eoi = std::end(pieces_); itr != eoi && getnum; ++itr) { auto& piece = *itr; if(piece->getMissingUnusedBlockIndex(blockIndexes, getnum)) { getnum -= blockIndexes.size(); for(auto i = std::begin(blockIndexes), eoi2 = std::end(blockIndexes); i != eoi2; ++i) { A2_LOG_DEBUG (fmt("Creating RequestMessage index=%lu, begin=%u," " blockIndex=%lu", static_cast(piece->getIndex()), static_cast((*i)*piece->getBlockLength()), static_cast(*i))); requests.push_back(messageFactory_->createRequestMessage(piece, *i)); } blockIndexes.clear(); } } return requests; } std::vector> DefaultBtRequestFactory::createRequestMessagesOnEndGame(size_t max) { auto requests = std::vector>{}; for(auto itr = std::begin(pieces_), eoi = std::end(pieces_); itr != eoi && requests.size() < max; ++itr) { auto& piece = *itr; const size_t mislen = piece->getBitfieldLength(); auto misbitfield = make_unique(mislen); piece->getAllMissingBlockIndexes(misbitfield.get(), mislen); auto missingBlockIndexes = std::vector{}; size_t blockIndex = 0; for(size_t i = 0; i < mislen; ++i) { unsigned char bits = misbitfield[i]; unsigned char mask = 128; for(size_t bi = 0; bi < 8; ++bi, mask >>= 1, ++blockIndex) { if(bits & mask) { missingBlockIndexes.push_back(blockIndex); } } } std::random_shuffle(std::begin(missingBlockIndexes), std::end(missingBlockIndexes), *SimpleRandomizer::getInstance()); for(auto bitr = std::begin(missingBlockIndexes), eoi2 = std::end(missingBlockIndexes); bitr != eoi2 && requests.size() < max; ++bitr) { size_t blockIndex = *bitr; if(!dispatcher_->isOutstandingRequest(piece->getIndex(), blockIndex)) { A2_LOG_DEBUG (fmt("Creating RequestMessage index=%lu, begin=%u," " blockIndex=%lu", static_cast(piece->getIndex()), static_cast(blockIndex*piece->getBlockLength()), static_cast(blockIndex))); requests.push_back(messageFactory_->createRequestMessage (piece, blockIndex)); } } } return requests; } namespace { class CountMissingBlock { private: size_t numMissingBlock_; public: CountMissingBlock():numMissingBlock_(0) {} size_t getNumMissingBlock() { return numMissingBlock_; } void operator()(const std::shared_ptr& piece) { numMissingBlock_ += piece->countMissingBlock(); } }; } // namespace size_t DefaultBtRequestFactory::countMissingBlock() { return std::for_each(pieces_.begin(), pieces_.end(), CountMissingBlock()).getNumMissingBlock(); } std::vector DefaultBtRequestFactory::getTargetPieceIndexes() const { auto res = std::vector{}; res.reserve(pieces_.size()); std::transform(std::begin(pieces_), std::end(pieces_), std::back_inserter(res), std::mem_fn(&Piece::getIndex)); return res; } void DefaultBtRequestFactory::setPieceStorage(PieceStorage* pieceStorage) { pieceStorage_ = pieceStorage; } void DefaultBtRequestFactory::setPeer(const std::shared_ptr& peer) { peer_ = peer; } void DefaultBtRequestFactory::setBtMessageDispatcher (BtMessageDispatcher* dispatcher) { dispatcher_ = dispatcher; } void DefaultBtRequestFactory::setBtMessageFactory(BtMessageFactory* factory) { messageFactory_ = factory; } } // namespace aria2