diff --git a/ChangeLog b/ChangeLog index e0e71ca3..3e0223c9 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,24 @@ +2008-06-05 Tatsuhiro Tsujikawa + + Calculate piece hash when data is arrived if the data is arrived in + order. This removes additional read operation for hash calculation. + If the data is arrived in out of order, the hash is calucated when the + piece is completed. This is the same behavior as the old implementation. + * src/BtPieceMessage.cc + * src/DefaultBtProgressInfoFile.cc + * src/DefaultPieceStorage.cc + * src/DownloadCommand.cc + * src/DownloadCommand.h + * src/DownloadEngine.cc + * src/GrowSegment.cc + * src/GrowSegment.h + * src/Piece.cc + * src/Piece.h + * src/PiecedSegment.cc + * src/PiecedSegment.h + * src/Segment.h + * test/PieceTest.cc + 2008-06-05 Tatsuhiro Tsujikawa Try to keep the ordering of outgoing piece message. diff --git a/src/BtPieceMessage.cc b/src/BtPieceMessage.cc index 1b43def5..68562684 100644 --- a/src/BtPieceMessage.cc +++ b/src/BtPieceMessage.cc @@ -97,6 +97,7 @@ void BtPieceMessage::doReceivedAction() { logger->debug(MSG_PIECE_BITFIELD, cuid, Util::toHex(piece->getBitfield(), piece->getBitfieldLength()).c_str()); + piece->updateHash(begin, block, blockLength); dispatcher->removeOutstandingRequest(slot); if(piece->pieceComplete()) { if(checkPieceHash(piece)) { @@ -202,10 +203,15 @@ std::string BtPieceMessage::toString() const { } bool BtPieceMessage::checkPieceHash(const PieceHandle& piece) { - off_t offset = (off_t)piece->getIndex()*btContext->getPieceLength(); - - return MessageDigestHelper::staticSHA1Digest(pieceStorage->getDiskAdaptor(), offset, piece->getLength()) - == btContext->getPieceHash(piece->getIndex()); + if(piece->isHashCalculated()) { + logger->debug("Hash is available!! index=%zu", piece->getIndex()); + return piece->getHashString() == btContext->getPieceHash(piece->getIndex()); + } else { + off_t offset = (off_t)piece->getIndex()*btContext->getPieceLength(); + + return MessageDigestHelper::staticSHA1Digest(pieceStorage->getDiskAdaptor(), offset, piece->getLength()) + == btContext->getPieceHash(piece->getIndex()); + } } void BtPieceMessage::onNewPiece(const PieceHandle& piece) { @@ -218,6 +224,7 @@ void BtPieceMessage::onWrongPiece(const PieceHandle& piece) { logger->info(MSG_GOT_WRONG_PIECE, cuid, piece->getIndex()); erasePieceOnDisk(piece); piece->clearAllBlock(); + piece->destroyHashContext(); requestFactory->removeTargetPiece(piece); } diff --git a/src/DefaultBtProgressInfoFile.cc b/src/DefaultBtProgressInfoFile.cc index 966dcf65..91d580a8 100644 --- a/src/DefaultBtProgressInfoFile.cc +++ b/src/DefaultBtProgressInfoFile.cc @@ -280,6 +280,13 @@ void DefaultBtProgressInfoFile::load() savedBitfield = new unsigned char[bitfieldLength]; in.read(reinterpret_cast(savedBitfield), bitfieldLength); piece->setBitfield(savedBitfield, bitfieldLength); + +#ifdef ENABLE_MESSAGE_DIGEST + + piece->setHashAlgo(_dctx->getPieceHashAlgo()); + +#endif // ENABLE_MESSAGE_DIGEST + delete [] savedBitfield; savedBitfield = 0; diff --git a/src/DefaultPieceStorage.cc b/src/DefaultPieceStorage.cc index a45bdde5..0df5a237 100644 --- a/src/DefaultPieceStorage.cc +++ b/src/DefaultPieceStorage.cc @@ -115,6 +115,13 @@ PieceHandle DefaultPieceStorage::checkOutPiece(size_t index) PieceHandle piece = findUsedPiece(index); if(piece.isNull()) { piece.reset(new Piece(index, bitfieldMan->getBlockLength(index))); + +#ifdef ENABLE_MESSAGE_DIGEST + + piece->setHashAlgo(downloadContext->getPieceHashAlgo()); + +#endif // ENABLE_MESSAGE_DIGEST + addUsedPiece(piece); return piece; } else { @@ -548,9 +555,17 @@ void DefaultPieceStorage::markPiecesDone(uint64_t length) size_t r = (length%bitfieldMan->getBlockLength())/Piece::BLOCK_LENGTH; if(r > 0) { PieceHandle p(new Piece(numPiece, bitfieldMan->getBlockLength(numPiece))); + for(size_t i = 0; i < r; ++i) { p->completeBlock(i); } + +#ifdef ENABLE_MESSAGE_DIGEST + + p->setHashAlgo(downloadContext->getPieceHashAlgo()); + +#endif // ENABLE_MESSAGE_DIGEST + addUsedPiece(p); } } diff --git a/src/DownloadCommand.cc b/src/DownloadCommand.cc index ff8862ed..e15023e7 100644 --- a/src/DownloadCommand.cc +++ b/src/DownloadCommand.cc @@ -68,14 +68,21 @@ DownloadCommand::DownloadCommand(int cuid, DownloadEngine* e, const SocketHandle& s): AbstractCommand(cuid, req, requestGroup, e, s) +#ifdef ENABLE_MESSAGE_DIGEST + , _pieceHashValidationEnabled(false) +#endif // ENABLE_MESSAGE_DIGEST { #ifdef ENABLE_MESSAGE_DIGEST { - std::string algo = _requestGroup->getDownloadContext()->getPieceHashAlgo(); - if(MessageDigestContext::supports(algo)) { - _messageDigestContext.reset(new MessageDigestContext()); - _messageDigestContext->trySetAlgo(algo); - _messageDigestContext->digestInit(); + if(e->option->getAsBool(PREF_REALTIME_CHUNK_CHECKSUM)) { + std::string algo = _requestGroup->getDownloadContext()->getPieceHashAlgo(); + if(MessageDigestContext::supports(algo)) { + _messageDigestContext.reset(new MessageDigestContext()); + _messageDigestContext->trySetAlgo(algo); + _messageDigestContext->digestInit(); + + _pieceHashValidationEnabled = true; + } } } #endif // ENABLE_MESSAGE_DIGEST @@ -116,7 +123,17 @@ bool DownloadCommand::executeInternal() { _requestGroup->getPieceStorage()->getDiskAdaptor()->writeData(buf, bufSize, segment->getPositionToWrite()); //logger->debug("bufSize = %d, posToWrite = %lld", bufSize, segment->getPositionToWrite()); +#ifdef ENABLE_MESSAGE_DIGEST + + if(_pieceHashValidationEnabled) { + segment->updateHash(segment->getWrittenLength(), buf, bufSize); + } + +#endif // ENABLE_MESSAGE_DIGEST + segment->updateWrittenLength(bufSize); + + //logger->debug("overflow length = %d, next posToWrite = %lld", segment->getOverflowLength(), segment->getPositionToWrite()); //logger->debug("%s", Util::toHex(segment->getPiece()->getBitfield(), //segment->getPiece()->getBitfieldLength()).c_str()); @@ -128,7 +145,17 @@ bool DownloadCommand::executeInternal() { transferDecoder->inflate(infbuf, infbufSize, buf, bufSize); _requestGroup->getPieceStorage()->getDiskAdaptor()->writeData(infbuf, infbufSize, segment->getPositionToWrite()); + +#ifdef ENABLE_MESSAGE_DIGEST + + if(_pieceHashValidationEnabled) { + segment->updateHash(segment->getWrittenLength(), infbuf, infbufSize); + } + +#endif // ENABLE_MESSAGE_DIGEST + segment->updateWrittenLength(infbufSize); + //segment->writtenLength += infbufSize; peerStat->updateDownloadLength(infbufSize); } @@ -140,7 +167,36 @@ bool DownloadCommand::executeInternal() { || bufSize == 0) { if(!transferDecoder.isNull()) transferDecoder->end(); logger->info(MSG_SEGMENT_DOWNLOAD_COMPLETED, cuid); - validatePieceHash(segment); + +#ifdef ENABLE_MESSAGE_DIGEST + + { + std::string expectedPieceHash = + _requestGroup->getDownloadContext()->getPieceHash(segment->getIndex()); + if(_pieceHashValidationEnabled && !expectedPieceHash.empty()) { + if(segment->isHashCalculated()) { + logger->debug("Hash is available! index=%zu", segment->getIndex()); + validatePieceHash(segment, expectedPieceHash, segment->getHashString()); + } else { + _messageDigestContext->digestReset(); + validatePieceHash(segment, expectedPieceHash, + MessageDigestHelper::digest + (_messageDigestContext.get(), + _requestGroup->getPieceStorage()->getDiskAdaptor(), + segment->getPosition(), + segment->getLength())); + } + } else { + _requestGroup->getSegmentMan()->completeSegment(cuid, segment); + } + } + +#else // !ENABLE_MESSAGE_DIGEST + + _requestGroup->getSegmentMan()->completeSegment(cuid, segment); + +#endif // !ENABLE_MESSAGE_DIGEST + checkLowestDownloadSpeed(); // this unit is going to download another segment. return prepareForNextSegment(); @@ -195,41 +251,30 @@ bool DownloadCommand::prepareForNextSegment() { } } -void DownloadCommand::validatePieceHash(const SegmentHandle& segment) -{ #ifdef ENABLE_MESSAGE_DIGEST - std::string expectedPieceHash = - _requestGroup->getDownloadContext()->getPieceHash(segment->getIndex()); - if(!_messageDigestContext.isNull() && - e->option->getAsBool(PREF_REALTIME_CHUNK_CHECKSUM) && - !expectedPieceHash.empty()) { - _messageDigestContext->digestReset(); - std::string actualPieceHash = - MessageDigestHelper::digest(_messageDigestContext.get(), - _requestGroup->getPieceStorage()->getDiskAdaptor(), - segment->getPosition(), - segment->getLength()); - if(actualPieceHash == expectedPieceHash) { - logger->info(MSG_GOOD_CHUNK_CHECKSUM, actualPieceHash.c_str()); - _requestGroup->getSegmentMan()->completeSegment(cuid, segment); - } else { - logger->info(EX_INVALID_CHUNK_CHECKSUM, - segment->getIndex(), - Util::itos(segment->getPosition(), true).c_str(), - expectedPieceHash.c_str(), - actualPieceHash.c_str()); - segment->clear(); - _requestGroup->getSegmentMan()->cancelSegment(cuid); - throw DlRetryEx - (StringFormat("Invalid checksum index=%d", segment->getIndex()).str()); - } - } else -#endif // ENABLE_MESSAGE_DIGEST - { - _requestGroup->getSegmentMan()->completeSegment(cuid, segment); - } + +void DownloadCommand::validatePieceHash(const SharedHandle& segment, + const std::string& expectedPieceHash, + const std::string& actualPieceHash) +{ + if(actualPieceHash == expectedPieceHash) { + logger->info(MSG_GOOD_CHUNK_CHECKSUM, actualPieceHash.c_str()); + _requestGroup->getSegmentMan()->completeSegment(cuid, segment); + } else { + logger->info(EX_INVALID_CHUNK_CHECKSUM, + segment->getIndex(), + Util::itos(segment->getPosition(), true).c_str(), + expectedPieceHash.c_str(), + actualPieceHash.c_str()); + segment->clear(); + _requestGroup->getSegmentMan()->cancelSegment(cuid); + throw DlRetryEx + (StringFormat("Invalid checksum index=%d", segment->getIndex()).str()); + } } +#endif // ENABLE_MESSAGE_DIGEST + void DownloadCommand::setTransferDecoder(const TransferEncodingHandle& transferDecoder) { this->transferDecoder = transferDecoder; diff --git a/src/DownloadCommand.h b/src/DownloadCommand.h index 2a5bfd3d..a87b62ee 100644 --- a/src/DownloadCommand.h +++ b/src/DownloadCommand.h @@ -51,11 +51,18 @@ private: time_t startupIdleTime; unsigned int lowestDownloadSpeedLimit; SharedHandle peerStat; + #ifdef ENABLE_MESSAGE_DIGEST + + bool _pieceHashValidationEnabled; + SharedHandle _messageDigestContext; + #endif // ENABLE_MESSAGE_DIGEST - void validatePieceHash(const SharedHandle& segment); + void validatePieceHash(const SharedHandle& segment, + const std::string& expectedPieceHash, + const std::string& actualPieceHash); void checkLowestDownloadSpeed() const; protected: diff --git a/src/DownloadEngine.cc b/src/DownloadEngine.cc index 31319b2b..1b1fdf29 100644 --- a/src/DownloadEngine.cc +++ b/src/DownloadEngine.cc @@ -133,12 +133,12 @@ void ADNSEvent::processEvents(int events) { ares_socket_t readfd; ares_socket_t writefd; - if(events&EPOLLIN) { + if(events&(SocketEntry::EVENT_READ|SocketEntry::EVENT_ERROR|SocketEntry::EVENT_HUP)) { readfd = _socket; } else { readfd = ARES_SOCKET_BAD; } - if(events&EPOLLOUT) { + if(events&(SocketEntry::EVENT_WRITE|SocketEntry::EVENT_ERROR|SocketEntry::EVENT_HUP)) { writefd = _socket; } else { writefd = ARES_SOCKET_BAD; diff --git a/src/GrowSegment.cc b/src/GrowSegment.cc index 0fbd9e88..3d78c303 100644 --- a/src/GrowSegment.cc +++ b/src/GrowSegment.cc @@ -34,6 +34,7 @@ /* copyright --> */ #include "GrowSegment.h" #include "Piece.h" +#include "A2STR.h" namespace aria2 { @@ -49,6 +50,15 @@ void GrowSegment::updateWrittenLength(size_t bytes) _piece->setAllBlock(); } +#ifdef ENABLE_MESSAGE_DIGEST + +std::string GrowSegment::getHashString() +{ + return A2STR::NIL; +} + +#endif // ENABLE_MESSAGE_DIGEST + void GrowSegment::clear() { _writtenLength = 0; diff --git a/src/GrowSegment.h b/src/GrowSegment.h index 697d88d9..b2a66097 100644 --- a/src/GrowSegment.h +++ b/src/GrowSegment.h @@ -90,6 +90,23 @@ public: virtual void updateWrittenLength(size_t bytes); +#ifdef ENABLE_MESSAGE_DIGEST + + virtual bool updateHash(size_t begin, + const unsigned char* data, size_t dataLength) + { + return false; + } + + virtual bool isHashCalculated() const + { + return false; + } + + virtual std::string getHashString(); + +#endif // ENABLE_MESSAGE_DIGEST + virtual void clear(); virtual SharedHandle getPiece() const; diff --git a/src/Piece.cc b/src/Piece.cc index 45537009..fd743b34 100644 --- a/src/Piece.cc +++ b/src/Piece.cc @@ -36,12 +36,25 @@ #include "Util.h" #include "BitfieldManFactory.h" #include "BitfieldMan.h" +#include "A2STR.h" +#include "Util.h" +#ifdef ENABLE_MESSAGE_DIGEST +# include "messageDigest.h" +#endif // ENABLE_MESSAGE_DIGEST namespace aria2 { -Piece::Piece():index(0), length(0), _blockLength(BLOCK_LENGTH), bitfield(0) {} +Piece::Piece():index(0), length(0), _blockLength(BLOCK_LENGTH), bitfield(0) +#ifdef ENABLE_MESSAGE_DIGEST + , _nextBegin(0) +#endif // ENABLE_MESSAGE_DIGEST +{} -Piece::Piece(size_t index, size_t length, size_t blockLength):index(index), length(length), _blockLength(blockLength) { +Piece::Piece(size_t index, size_t length, size_t blockLength):index(index), length(length), _blockLength(blockLength) +#ifdef ENABLE_MESSAGE_DIGEST + , _nextBegin(0) +#endif // ENABLE_MESSAGE_DIGEST +{ bitfield = BitfieldManFactory::getFactoryInstance()->createBitfieldMan(_blockLength, length); } @@ -55,6 +68,11 @@ Piece::Piece(const Piece& piece) { } else { bitfield = new BitfieldMan(*piece.bitfield); } +#ifdef ENABLE_MESSAGE_DIGEST + _nextBegin = piece._nextBegin; + // TODO Is this OK? + _mdctx = piece._mdctx; +#endif // ENABLE_MESSAGE_DIGEST } Piece::~Piece() @@ -200,4 +218,55 @@ size_t Piece::getCompletedLength() return bitfield->getCompletedLength(); } +#ifdef ENABLE_MESSAGE_DIGEST + +void Piece::setHashAlgo(const std::string& algo) +{ + _hashAlgo = algo; +} + +bool Piece::updateHash(size_t begin, const unsigned char* data, size_t dataLength) +{ + if(_hashAlgo.empty()) { + return false; + } + if(begin == _nextBegin && _nextBegin+dataLength <= length) { + + if(_mdctx.isNull()) { + _mdctx.reset(new MessageDigestContext()); + _mdctx->trySetAlgo(_hashAlgo); + _mdctx->digestInit(); + } + + _mdctx->digestUpdate(data, dataLength); + _nextBegin += dataLength; + return true; + } else { + return false; + } +} + +bool Piece::isHashCalculated() const +{ + return !_mdctx.isNull() && _nextBegin == length; +} + +// TODO should be getHashString() +std::string Piece::getHashString() +{ + if(_mdctx.isNull()) { + return A2STR::NIL; + } else { + return Util::toHex(_mdctx->digestFinal()); + } +} + +void Piece::destroyHashContext() +{ + _mdctx.reset(); + _nextBegin = 0; +} + +#endif // ENABLE_MESSAGE_DIGEST + } // namespace aria2 diff --git a/src/Piece.h b/src/Piece.h index cd52e3e6..6b1f7c62 100644 --- a/src/Piece.h +++ b/src/Piece.h @@ -45,12 +45,29 @@ namespace aria2 { class BitfieldMan; +#ifdef ENABLE_MESSAGE_DIGEST + +class MessageDigestContext; + +#endif // ENABLE_MESSAGE_DIGEST + class Piece { private: size_t index; size_t length; size_t _blockLength; BitfieldMan* bitfield; + +#ifdef ENABLE_MESSAGE_DIGEST + + size_t _nextBegin; + + std::string _hashAlgo; + + SharedHandle _mdctx; + +#endif // ENABLE_MESSAGE_DIGEST + public: static const size_t BLOCK_LENGTH = 16*1024; @@ -116,6 +133,25 @@ public: // Calculates completed length size_t getCompletedLength(); +#ifdef ENABLE_MESSAGE_DIGEST + + void setHashAlgo(const std::string& algo); + + // Updates hash value. This function compares begin and private variable + // _nextBegin and only when they are equal, hash is updated eating data and + // returns true. Otherwise returns false. + bool updateHash(size_t begin, const unsigned char* data, size_t dataLength); + + bool isHashCalculated() const; + + // Returns hash value in ASCII hexadecimal form. + // WARN: This function must be called only once. + std::string getHashString(); + + void destroyHashContext(); + +#endif // ENABLE_MESSAGE_DIGEST + /** * Loses current bitfield state. */ diff --git a/src/PiecedSegment.cc b/src/PiecedSegment.cc index 1570c1bd..e0c15d42 100644 --- a/src/PiecedSegment.cc +++ b/src/PiecedSegment.cc @@ -90,11 +90,37 @@ void PiecedSegment::updateWrittenLength(size_t bytes) _writtenLength = newWrittenLength; } +#ifdef ENABLE_MESSAGE_DIGEST + +bool PiecedSegment::updateHash(size_t begin, + const unsigned char* data, size_t dataLength) +{ + return _piece->updateHash(begin, data, dataLength); +} + +bool PiecedSegment::isHashCalculated() const +{ + return _piece->isHashCalculated(); +} + +std::string PiecedSegment::getHashString() +{ + return _piece->getHashString(); +} + +#endif // ENABLE_MESSAGE_DIGEST + void PiecedSegment::clear() { _writtenLength = 0; _overflowLength = 0; _piece->clearAllBlock(); + +#ifdef ENABLE_MESSAGE_DIGEST + + _piece->destroyHashContext(); + +#endif // ENABLE_MESSAGE_DIGEST } PieceHandle PiecedSegment::getPiece() const diff --git a/src/PiecedSegment.h b/src/PiecedSegment.h index 87b415c4..9f6f34f9 100644 --- a/src/PiecedSegment.h +++ b/src/PiecedSegment.h @@ -82,6 +82,18 @@ public: virtual void updateWrittenLength(size_t bytes); +#ifdef ENABLE_MESSAGE_DIGEST + + // `begin' is a offset inside this segment. + virtual bool updateHash(size_t begin, + const unsigned char* data, size_t dataLength); + + virtual bool isHashCalculated() const; + + virtual std::string getHashString(); + +#endif // ENABLE_MESSAGE_DIGEST + virtual void clear(); virtual SharedHandle getPiece() const; diff --git a/src/Segment.h b/src/Segment.h index 2e732236..edc5a36a 100644 --- a/src/Segment.h +++ b/src/Segment.h @@ -40,6 +40,7 @@ #include #include #include +#include namespace aria2 { @@ -67,6 +68,18 @@ public: virtual void updateWrittenLength(size_t bytes) = 0; +#ifdef ENABLE_MESSAGE_DIGEST + + // `begin' is a offset inside this segment. + virtual bool updateHash(size_t begin, + const unsigned char* data, size_t dataLength) = 0; + + virtual bool isHashCalculated() const = 0; + + virtual std::string getHashString() = 0; + +#endif // ENABLE_MESSAGE_DIGEST + virtual void clear() = 0; virtual SharedHandle getPiece() const = 0; diff --git a/test/PieceTest.cc b/test/PieceTest.cc index 34a0d7ab..0efaa2bc 100644 --- a/test/PieceTest.cc +++ b/test/PieceTest.cc @@ -1,4 +1,7 @@ #include "Piece.h" +#ifdef ENABLE_MESSAGE_DIGEST +# include "messageDigest.h" +#endif // ENABLE_MESSAGE_DIGEST #include #include @@ -9,6 +12,13 @@ class PieceTest:public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(PieceTest); CPPUNIT_TEST(testCompleteBlock); CPPUNIT_TEST(testGetCompletedLength); + +#ifdef ENABLE_MESSAGE_DIGEST + + CPPUNIT_TEST(testUpdateHash); + +#endif // ENABLE_MESSAGE_DIGEST + CPPUNIT_TEST_SUITE_END(); private: @@ -17,6 +27,12 @@ public: void testCompleteBlock(); void testGetCompletedLength(); + +#ifdef ENABLE_MESSAGE_DIGEST + + void testUpdateHash(); + +#endif // ENABLE_MESSAGE_DIGEST }; @@ -45,4 +61,30 @@ void PieceTest::testGetCompletedLength() CPPUNIT_ASSERT_EQUAL(blockLength*3+100, p.getCompletedLength()); } +#ifdef ENABLE_MESSAGE_DIGEST + +void PieceTest::testUpdateHash() +{ + Piece p(0, 16, 2*1024*1024); + p.setHashAlgo(MessageDigestContext::SHA1); + + std::string spam("SPAM!"); + CPPUNIT_ASSERT(p.updateHash + (0, reinterpret_cast(spam.c_str()), + spam.size())); + CPPUNIT_ASSERT(!p.isHashCalculated()); + + std::string spamspam("SPAM!SPAM!!"); + CPPUNIT_ASSERT(p.updateHash + (spam.size(), + reinterpret_cast(spamspam.c_str()), + spamspam.size())); + CPPUNIT_ASSERT(p.isHashCalculated()); + + CPPUNIT_ASSERT_EQUAL(std::string("d9189aff79e075a2e60271b9556a710dc1bc7de7"), + p.getHashString()); +} + +#endif // ENABLE_MESSAGE_DIGEST + } // namespace aria2