diff --git a/ChangeLog b/ChangeLog index 356e03e4..a0653dba 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,29 @@ +2009-11-23 Tatsuhiro Tsujikawa + + Fixed ut_metadata data handling. Implemented + UTMetadataDataExtensionMessage::doReceivedAction(). Initialize + PeerStorage in HandshakeExtensionMessage::doReceivedAction() when + metadata_size is received. + * src/DefaultExtensionMessageFactory.cc + * src/DefaultExtensionMessageFactory.h + * src/HandshakeExtensionMessage.cc + * src/HandshakeExtensionMessage.h + * src/UTMetadataDataExtensionMessage.cc + * src/UTMetadataDataExtensionMessage.h + * src/UTMetadataRequestExtensionMessage.cc + * src/UTMetadataRequestFactory.cc + * src/UTMetadataRequestFactory.h + * src/UTMetadataRequestTracker.cc + * src/UTMetadataRequestTracker.h + * test/DefaultExtensionMessageFactoryTest.cc + * test/HandshakeExtensionMessageTest.cc + * test/MockBtMessage.h + * test/UTMetadataDataExtensionMessageTest.cc + * test/UTMetadataRequestExtensionMessageTest.cc + * test/UTMetadataRequestFactoryTest.cc + * test/UTMetadataRequestTrackerTest.cc + * test/extension_message_test_helper.h + 2009-11-23 Tatsuhiro Tsujikawa Drop connection if ut_metadata reject message is received. diff --git a/src/DefaultExtensionMessageFactory.cc b/src/DefaultExtensionMessageFactory.cc index 4dad47e4..87261d5d 100644 --- a/src/DefaultExtensionMessageFactory.cc +++ b/src/DefaultExtensionMessageFactory.cc @@ -50,6 +50,10 @@ #include "UTMetadataRejectExtensionMessage.h" #include "message.h" #include "bencode.h" +#include "PieceStorage.h" +#include "UTMetadataRequestTracker.h" +#include "BtRuntime.h" +#include "RequestGroup.h" namespace aria2 { @@ -93,16 +97,8 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t throw DL_ABORT_EX(StringFormat(MSG_TOO_SMALL_PAYLOAD_SIZE, "ut_metadata", length).str()); } - std::string listdata; - listdata += 'l'; - listdata += std::string(&data[1], &data[length]); - listdata += 'e'; - - const BDE& list = bencode::decode(listdata); - if(!list.isList() || list.empty()) { - throw DL_ABORT_EX("Bad ut_metadata"); - } - const BDE& dict = list[0]; + size_t end; + BDE dict = bencode::decode(data+1, length-1, end); if(!dict.isDict()) { throw DL_ABORT_EX("Bad ut_metadata: dictionary not found"); } @@ -126,13 +122,9 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t return m; } case 1: { - if(list.size() != 2) { + if(end == length) { throw DL_ABORT_EX("Bad ut_metadata data: data not found"); } - const BDE& pieceData = list[1]; - if(!pieceData.isString()) { - throw DL_ABORT_EX("Bad ut_metadata data: data is not string"); - } const BDE& totalSize = dict["total_size"]; if(!totalSize.isInteger()) { throw DL_ABORT_EX("Bad ut_metadata data: total_size not found"); @@ -141,16 +133,18 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t (new UTMetadataDataExtensionMessage(extensionMessageID)); m->setIndex(index.i()); m->setTotalSize(totalSize.i()); - m->setData(pieceData.s()); - // set tracker - // set piecestorage + m->setData(std::string(&data[1+end], &data[length])); + m->setUTMetadataRequestTracker(_tracker); + m->setPieceStorage(_dctx->getOwnerRequestGroup()->getPieceStorage()); + m->setDownloadContext(_dctx); + m->setBtRuntime(_btRuntime); return m; } case 2: { SharedHandle m (new UTMetadataRejectExtensionMessage(extensionMessageID)); m->setIndex(index.i()); - // set tracker if disconnecing peer on receive. + // No need to inject tracker because peer will be disconnected. return m; } default: diff --git a/src/DefaultExtensionMessageFactory.h b/src/DefaultExtensionMessageFactory.h index 319528f6..778a41ef 100644 --- a/src/DefaultExtensionMessageFactory.h +++ b/src/DefaultExtensionMessageFactory.h @@ -46,6 +46,8 @@ class ExtensionMessageRegistry; class DownloadContext; class BtMessageFactory; class BtMessageDispatcher; +class UTMetadataRequestTracker; +class BtRuntime; class DefaultExtensionMessageFactory:public ExtensionMessageFactory { private: @@ -57,10 +59,14 @@ private: SharedHandle _dctx; + SharedHandle _btRuntime; + WeakHandle _messageFactory; WeakHandle _dispatcher; + WeakHandle _tracker; + Logger* _logger; public: @@ -99,6 +105,17 @@ public: { _dispatcher = disp; } + + void setUTMetadataRequestTracker + (const WeakHandle& tracker) + { + _tracker = tracker; + } + + void setBtRuntime(const SharedHandle& btRuntime) + { + _btRuntime = btRuntime; + } }; typedef SharedHandle DefaultExtensionMessageFactoryHandle; diff --git a/src/HandshakeExtensionMessage.cc b/src/HandshakeExtensionMessage.cc index a0c04c3a..fd722e78 100644 --- a/src/HandshakeExtensionMessage.cc +++ b/src/HandshakeExtensionMessage.cc @@ -43,6 +43,8 @@ #include "bencode.h" #include "DownloadContext.h" #include "bittorrent_helper.h" +#include "RequestGroup.h" +#include "PieceStorage.h" namespace aria2 { @@ -110,8 +112,19 @@ void HandshakeExtensionMessage::doReceivedAction() } if(_metadataSize > 0) { BDE& attrs = _dctx->getAttribute(bittorrent::BITTORRENT); - if(!attrs.containsKey(bittorrent::METADATA_SIZE)) { + if(attrs.containsKey(bittorrent::METADATA_SIZE)) { + if(_metadataSize != (size_t)attrs[bittorrent::METADATA_SIZE].i()) { + throw DL_ABORT_EX("Wrong metadata_size. Which one is correct!?"); + } + } else { attrs[bittorrent::METADATA_SIZE] = _metadataSize; + _dctx->getFirstFileEntry()->setLength(_metadataSize); + _dctx->markTotalLengthIsKnown(); + _dctx->getOwnerRequestGroup()->initPieceStorage(); + + SharedHandle pieceStorage = + _dctx->getOwnerRequestGroup()->getPieceStorage(); + pieceStorage->setEndGamePieceNum(0); } } } diff --git a/src/HandshakeExtensionMessage.h b/src/HandshakeExtensionMessage.h index 73386883..e1b3c892 100644 --- a/src/HandshakeExtensionMessage.h +++ b/src/HandshakeExtensionMessage.h @@ -45,9 +45,7 @@ namespace aria2 { class Peer; class Logger; -class HandshakeExtensionMessage; class DownloadContext; -typedef SharedHandle HandshakeExtensionMessageHandle; class HandshakeExtensionMessage:public ExtensionMessage { private: @@ -137,8 +135,8 @@ public: void setPeer(const SharedHandle& peer); - static HandshakeExtensionMessageHandle create(const unsigned char* data, - size_t dataLength); + static SharedHandle + create(const unsigned char* data, size_t dataLength); }; typedef SharedHandle HandshakeExtensionMessageHandle; diff --git a/src/UTMetadataDataExtensionMessage.cc b/src/UTMetadataDataExtensionMessage.cc index 7f80a608..bb39d4e4 100644 --- a/src/UTMetadataDataExtensionMessage.cc +++ b/src/UTMetadataDataExtensionMessage.cc @@ -37,29 +37,30 @@ #include "bencode.h" #include "util.h" #include "a2functional.h" +#include "DownloadContext.h" +#include "UTMetadataRequestTracker.h" +#include "PieceStorage.h" +#include "BtConstants.h" +#include "MessageDigestHelper.h" +#include "bittorrent_helper.h" +#include "DiskAdaptor.h" +#include "Piece.h" +#include "BtRuntime.h" +#include "LogFactory.h" namespace aria2 { UTMetadataDataExtensionMessage::UTMetadataDataExtensionMessage -(uint8_t extensionMessageID):UTMetadataExtensionMessage(extensionMessageID) {} +(uint8_t extensionMessageID):UTMetadataExtensionMessage(extensionMessageID), + _logger(LogFactory::getInstance()) {} std::string UTMetadataDataExtensionMessage::getBencodedData() { - BDE list = BDE::list(); - BDE dict = BDE::dict(); dict["msg_type"] = 1; dict["piece"] = _index; dict["total_size"] = _totalSize; - - BDE data = _data; - - list << dict; - list << data; - - std::string encodedList = bencode::encode(list); - // Remove first 'l' and last 'e' and return. - return std::string(encodedList.begin()+1, encodedList.end()-1); + return bencode::encode(dict)+_data; } std::string UTMetadataDataExtensionMessage::toString() const @@ -69,9 +70,36 @@ std::string UTMetadataDataExtensionMessage::toString() const void UTMetadataDataExtensionMessage::doReceivedAction() { - // Update tracker - - // Write to pieceStorage + if(_tracker->tracks(_index)) { + _logger->debug("ut_metadata index=%lu found in tracking list", + static_cast(_index)); + _tracker->remove(_index); + _pieceStorage->getDiskAdaptor()->writeData + (reinterpret_cast(_data.c_str()), _data.size(), + _index*METADATA_PIECE_SIZE); + _pieceStorage->completePiece(_pieceStorage->getPiece(_index)); + if(_pieceStorage->downloadFinished()) { + std::string metadata = util::toString(_pieceStorage->getDiskAdaptor()); + unsigned char infoHash[INFO_HASH_LENGTH]; + MessageDigestHelper::digest(infoHash, INFO_HASH_LENGTH, + MessageDigestContext::SHA1, + metadata.data(), metadata.size()); + const BDE& attrs = _dctx->getAttribute(bittorrent::BITTORRENT); + if(std::string(&infoHash[0], &infoHash[INFO_HASH_LENGTH]) == + attrs[bittorrent::INFO_HASH].s()){ + _logger->info("Got ut_metadata"); + _btRuntime->setHalt(true); + } else { + _logger->info("Got wrong ut_metadata"); + for(size_t i = 0; i < _dctx->getNumPieces(); ++i) { + _pieceStorage->markPieceMissing(i); + } + } + } + } else { + _logger->debug("ut_metadata index=%lu is not tracked", + static_cast(_index)); + } } } // namespace aria2 diff --git a/src/UTMetadataDataExtensionMessage.h b/src/UTMetadataDataExtensionMessage.h index d2d6aa39..11d71521 100644 --- a/src/UTMetadataDataExtensionMessage.h +++ b/src/UTMetadataDataExtensionMessage.h @@ -39,11 +39,27 @@ namespace aria2 { +class DownloadContext; +class PieceStorage; +class UTMetadataRequestTracker; +class BtRuntime; +class Logger; + class UTMetadataDataExtensionMessage:public UTMetadataExtensionMessage { private: size_t _totalSize; std::string _data; + + SharedHandle _dctx; + + SharedHandle _pieceStorage; + + SharedHandle _btRuntime; + + WeakHandle _tracker; + + Logger* _logger; public: UTMetadataDataExtensionMessage(uint8_t extensionMessageID); @@ -72,6 +88,27 @@ public: { return _data; } + + void setPieceStorage(const SharedHandle& pieceStorage) + { + _pieceStorage = pieceStorage; + } + + void setUTMetadataRequestTracker + (const WeakHandle& tracker) + { + _tracker = tracker; + } + + void setDownloadContext(const SharedHandle& dctx) + { + _dctx = dctx; + } + + void setBtRuntime(const SharedHandle& btRuntime) + { + _btRuntime = btRuntime; + } }; } // namespace aria2 diff --git a/src/UTMetadataRequestExtensionMessage.cc b/src/UTMetadataRequestExtensionMessage.cc index 92de8fae..a57b6fd2 100644 --- a/src/UTMetadataRequestExtensionMessage.cc +++ b/src/UTMetadataRequestExtensionMessage.cc @@ -48,6 +48,8 @@ #include "BtConstants.h" #include "DownloadContext.h" #include "BtMessage.h" +#include "PieceStorage.h" +#include "BtRuntime.h" namespace aria2 { diff --git a/src/UTMetadataRequestFactory.cc b/src/UTMetadataRequestFactory.cc new file mode 100644 index 00000000..38b5b595 --- /dev/null +++ b/src/UTMetadataRequestFactory.cc @@ -0,0 +1,79 @@ +/* */ +#include "UTMetadataRequestFactory.h" +#include "PieceStorage.h" +#include "DownloadContext.h" +#include "Peer.h" +#include "BtMessageDispatcher.h" +#include "BtMessageFactory.h" +#include "UTMetadataRequestExtensionMessage.h" +#include "UTMetadataRequestTracker.h" +#include "BtMessage.h" +#include "LogFactory.h" + +namespace aria2 { + +UTMetadataRequestFactory::UTMetadataRequestFactory(): + _logger(LogFactory::getInstance()) {} + +void UTMetadataRequestFactory::create +(std::deque >& msgs, size_t num, + const SharedHandle& pieceStorage) +{ + for(size_t index = 0; index < _dctx->getNumPieces() && num; ++index) { + SharedHandle p = pieceStorage->getMissingPiece(index); + if(p.isNull()) { + _logger->debug("ut_metadata piece %lu is used or already acquired."); + continue; + } + --num; + _logger->debug("Creating ut_metadata request index=%lu", + static_cast(index)); + SharedHandle m + (new UTMetadataRequestExtensionMessage + (_peer->getExtensionMessageID("ut_metadata"))); + m->setIndex(index); + m->setDownloadContext(_dctx); + m->setBtMessageDispatcher(_dispatcher); + m->setBtMessageFactory(_messageFactory); + m->setPeer(_peer); + + SharedHandle msg = _messageFactory->createBtExtendedMessage(m); + msgs.push_back(msg); + _tracker->add(index); + } +} + +} // namespace aria2 diff --git a/src/UTMetadataRequestFactory.h b/src/UTMetadataRequestFactory.h new file mode 100644 index 00000000..d16fb6a8 --- /dev/null +++ b/src/UTMetadataRequestFactory.h @@ -0,0 +1,105 @@ +/* */ +#ifndef _UT_METADATA_REQUEST_FACTORY_H_ +#define _UT_METADATA_REQUEST_FACTORY_H_ + +#include "common.h" + +#include + +#include "SharedHandle.h" + +namespace aria2 { + +class PieceStorage; +class DownloadContext; +class Peer; +class BtMessageDispatcher; +class BtMessageFactory; +class UTMetadataRequestTracker; +class BtMessage; +class Logger; + +class UTMetadataRequestFactory { +private: + SharedHandle _dctx; + + SharedHandle _peer; + + WeakHandle _dispatcher; + + WeakHandle _messageFactory; + + WeakHandle _tracker; + + Logger* _logger; +public: + UTMetadataRequestFactory(); + + // Creates at most num of ut_metadata request message and appends + // them to msgs. pieceStorage is used to identify missing piece. + void create(std::deque >& msgs, size_t num, + const SharedHandle& pieceStorage); + + void setDownloadContext(const SharedHandle& dctx) + { + _dctx = dctx; + } + + void setBtMessageDispatcher(const WeakHandle& disp) + { + _dispatcher = disp; + } + + void setBtMessageFactory(const WeakHandle& factory) + { + _messageFactory = factory; + } + + void setPeer(const SharedHandle& peer) + { + _peer = peer; + } + + void setUTMetadataRequestTracker + (const WeakHandle& tracker) + { + _tracker = tracker; + } +}; + +} // namespace aria2 + +#endif // _UT_METADATA_REQUEST_FACTORY_H_ diff --git a/src/UTMetadataRequestTracker.cc b/src/UTMetadataRequestTracker.cc new file mode 100644 index 00000000..cbe11b42 --- /dev/null +++ b/src/UTMetadataRequestTracker.cc @@ -0,0 +1,106 @@ +/* */ +#include "UTMetadataRequestTracker.h" + +#include + +#include "LogFactory.h" + +namespace aria2 { + +UTMetadataRequestTracker::UTMetadataRequestTracker(): + _logger(LogFactory::getInstance()) {} + +void UTMetadataRequestTracker::add(size_t index) +{ + _trackedRequests.push_back(RequestEntry(index)); +} + +bool UTMetadataRequestTracker::tracks(size_t index) +{ + return std::find(_trackedRequests.begin(), _trackedRequests.end(), + RequestEntry(index)) != _trackedRequests.end(); +} + +void UTMetadataRequestTracker::remove(size_t index) +{ + std::vector::iterator i = + std::find(_trackedRequests.begin(), _trackedRequests.end(), + RequestEntry(index)); + if(i != _trackedRequests.end()) { + _trackedRequests.erase(i); + } +} + +std::vector UTMetadataRequestTracker::removeTimeoutEntry() +{ + std::vector indexes; + const time_t TIMEOUT = 20; + for(std::vector::iterator i = _trackedRequests.begin(); + i != _trackedRequests.end();) { + if((*i).elapsed(TIMEOUT)) { + LogFactory::getInstance()->debug + ("ut_metadata request timeout. index=%lu", + static_cast((*i)._index)); + indexes.push_back((*i)._index); + i = _trackedRequests.erase(i); + } else { + ++i; + } + } + return indexes; +} + +size_t UTMetadataRequestTracker::avail() const +{ + const size_t MAX_OUTSTANDING_REQUEST = 1; + if(MAX_OUTSTANDING_REQUEST > count()) { + return MAX_OUTSTANDING_REQUEST-count(); + } else { + return 0; + } +} + +std::vector UTMetadataRequestTracker::getAllTrackedIndex() const +{ + std::vector indexes; + for(std::vector::const_iterator i = _trackedRequests.begin(); + i != _trackedRequests.end(); ++i) { + indexes.push_back((*i)._index); + } + return indexes; +} + +} // namespace aria2 diff --git a/src/UTMetadataRequestTracker.h b/src/UTMetadataRequestTracker.h new file mode 100644 index 00000000..3788496e --- /dev/null +++ b/src/UTMetadataRequestTracker.h @@ -0,0 +1,100 @@ +/* */ +#ifndef _UT_METADATA_REQUEST_TRACKER_H_ +#define _UT_METADATA_REQUEST_TRACKER_H_ + +#include "common.h" + +#include + +#include "TimeA2.h" + +namespace aria2 { + +class Logger; + +class UTMetadataRequestTracker { +private: + struct RequestEntry { + size_t _index; + Time _dispatchedTime; + + RequestEntry(size_t index):_index(index) {} + + bool elapsed(time_t t) const + { + return _dispatchedTime.elapsed(t); + } + + bool operator==(const RequestEntry& e) const + { + return _index == e._index; + } + }; + + std::vector _trackedRequests; + + Logger* _logger; +public: + UTMetadataRequestTracker(); + + // Add request index to tracking list. + void add(size_t index); + + // Returns true if request index is tracked. + bool tracks(size_t index); + + // Remove index from tracking list. + void remove(size_t index); + + // Returns all tracking indexes. + std::vector getAllTrackedIndex() const; + + // Removes request index which is timed out and returns their indexes. + std::vector removeTimeoutEntry(); + + // Returns the number of tracking list. + size_t count() const + { + return _trackedRequests.size(); + } + + // Returns the number of additional index this tracker can track. + size_t avail() const; +}; + +} // namespace aria2 + +#endif // _UT_METADATA_REQUEST_TRACKER_H_ diff --git a/test/DefaultExtensionMessageFactoryTest.cc b/test/DefaultExtensionMessageFactoryTest.cc index 878bdb06..9ace3e29 100644 --- a/test/DefaultExtensionMessageFactoryTest.cc +++ b/test/DefaultExtensionMessageFactoryTest.cc @@ -21,6 +21,10 @@ #include "UTMetadataRequestExtensionMessage.h" #include "UTMetadataDataExtensionMessage.h" #include "UTMetadataRejectExtensionMessage.h" +#include "BtRuntime.h" +#include "PieceStorage.h" +#include "RequestGroup.h" +#include "Option.h" namespace aria2 { @@ -42,6 +46,7 @@ private: SharedHandle _dispatcher; SharedHandle _messageFactory; SharedHandle _dctx; + SharedHandle _requestGroup; public: void setUp() { @@ -59,6 +64,11 @@ public: _dctx.reset(new DownloadContext()); + SharedHandle