diff --git a/src/HttpResponseCommand.cc b/src/HttpResponseCommand.cc index 88c19a21..c02e324b 100644 --- a/src/HttpResponseCommand.cc +++ b/src/HttpResponseCommand.cc @@ -72,6 +72,10 @@ #include "ChunkedDecodingStreamFilter.h" #include "uri.h" #include "SocketRecvBuffer.h" +#include "MetalinkHttpEntry.h" +#ifdef ENABLE_MESSAGE_DIGEST +#include "Checksum.h" +#endif // ENABLE_MESSAGE_DIGEST #ifdef HAVE_ZLIB # include "GZipDecodingStreamFilter.h" #endif // HAVE_ZLIB @@ -189,6 +193,42 @@ bool HttpResponseCommand::executeInternal() getFileEntry()->poolRequest(getRequest()); return true; } + if(!getPieceStorage()) { + // Metalink/HTTP + if(!getDownloadContext()->getMetalinkServerContacted()) { + if(httpHeader->defined(HttpHeader::LINK)) { + getDownloadContext()->setMetalinkServerContacted(true); + std::vector entries; + httpResponse->getMetalinKHttpEntries(entries, getOption()); + for(std::vector::iterator i = entries.begin(), + eoi = entries.end(); i != eoi; ++i) { + getFileEntry()->addUri((*i).uri); + A2_LOG_DEBUG(fmt("Adding URI=%s", (*i).uri.c_str())); + } + } + } +#ifdef ENABLE_MESSAGE_DIGEST + if(httpHeader->defined(HttpHeader::DIGEST)) { + std::vector checksums; + httpResponse->getDigest(checksums); + for(std::vector::iterator i = checksums.begin(), + eoi = checksums.end(); i != eoi; ++i) { + if(getDownloadContext()->getChecksumHashAlgo().empty()) { + A2_LOG_DEBUG(fmt("Setting digest: type=%s, digest=%s", + (*i).getAlgo().c_str(), + (*i).getMessageDigest().c_str())); + getDownloadContext()->setChecksumHashAlgo((*i).getAlgo()); + getDownloadContext()->setChecksum((*i).getMessageDigest()); + break; + } else { + if(checkChecksum(getDownloadContext(), *i)) { + break; + } + } + } + } +#endif // ENABLE_MESSAGE_DIGEST + } if(statusCode >= 300) { if(statusCode == 404) { getRequestGroup()->increaseAndValidateFileNotFoundCount(); @@ -241,6 +281,19 @@ bool HttpResponseCommand::executeInternal() return handleDefaultEncoding(httpResponse); } } else { +#ifdef ENABLE_MESSAGE_DIGEST + if(!getDownloadContext()->getChecksumHashAlgo().empty() && + httpHeader->defined(HttpHeader::DIGEST)) { + std::vector checksums; + httpResponse->getDigest(checksums); + for(std::vector::iterator i = checksums.begin(), + eoi = checksums.end(); i != eoi; ++i) { + if(checkChecksum(getDownloadContext(), *i)) { + break; + } + } + } +#endif // ENABLE_MESSAGE_DIGEST // validate totalsize getRequestGroup()->validateTotalLength(getFileEntry()->getLength(), httpResponse->getEntityLength()); @@ -501,4 +554,21 @@ void HttpResponseCommand::onDryRunFileFound() poolConnection(); } +#ifdef ENABLE_MESSAGE_DIGEST +bool HttpResponseCommand::checkChecksum +(const SharedHandle& dctx, + const Checksum& checksum) +{ + if(dctx->getChecksumHashAlgo() == checksum.getAlgo()) { + if(dctx->getChecksum() == checksum.getMessageDigest()) { + A2_LOG_INFO("Valid hash found in Digest header field."); + return true; + } else { + throw DL_ABORT_EX("Invalid hash found in Digest header field."); + } + } + return false; +} +#endif // ENABLE_MESSAGE_DIGEST + } // namespace aria2 diff --git a/src/HttpResponseCommand.h b/src/HttpResponseCommand.h index 6d341154..6c691e0e 100644 --- a/src/HttpResponseCommand.h +++ b/src/HttpResponseCommand.h @@ -45,6 +45,9 @@ class HttpDownloadCommand; class HttpResponse; class SocketCore; class StreamFilter; +#ifdef ENABLE_MESSAGE_DIGEST +class Checksum; +#endif // ENABLE_MESSAGE_DIGEST // HttpResponseCommand receives HTTP response header from remote // server. Because network I/O is non-blocking, execute() returns @@ -74,6 +77,14 @@ private: void poolConnection(); void onDryRunFileFound(); +#ifdef ENABLE_MESSAGE_DIGEST + // Returns true if dctx and checksum has same hash type and hash + // value. If they have same hash type but different hash value, + // throws exception. Otherwise returns false. + bool checkChecksum + (const SharedHandle& dctx, + const Checksum& checksum); +#endif // ENABLE_MESSAGE_DIGEST protected: bool executeInternal();