diff --git a/src/DefaultBtProgressInfoFile.cc b/src/DefaultBtProgressInfoFile.cc index 7907284f..2a59d863 100644 --- a/src/DefaultBtProgressInfoFile.cc +++ b/src/DefaultBtProgressInfoFile.cc @@ -55,9 +55,7 @@ #include "fmt.h" #include "array_fun.h" #include "DownloadContext.h" -#include "BufferedFile.h" #include "SHA1IOFile.h" -#include "BtConstants.h" #ifdef ENABLE_BITTORRENT # include "PeerStorage.h" # include "BtRuntime.h" @@ -223,23 +221,17 @@ void DefaultBtProgressInfoFile::save() } } -#define READ_CHECK(fp, ptr, count) \ +#define READ_CHECK_STATIC(fp, ptr, count, filename) \ if (fp.read((ptr), (count)) != (count)) { \ - throw DL_ABORT_EX(fmt(EX_SEGMENT_FILE_READ, filename_.c_str())); \ + throw DL_ABORT_EX(fmt(EX_SEGMENT_FILE_READ, filename.c_str())); \ } -// It is assumed that integers are saved as: -// 1) host byte order if version == 0000 -// 2) network byte order if version == 0001 -void DefaultBtProgressInfoFile::load() +#define READ_CHECK(fp, ptr, count) READ_CHECK_STATIC(fp, ptr, count, filename_) + +uint DefaultBtProgressInfoFile::getControlFileVersion(BufferedFile& fp, const std::string& filename) { - A2_LOG_INFO(fmt(MSG_LOADING_SEGMENT_FILE, filename_.c_str())); - BufferedFile fp(filename_.c_str(), BufferedFile::READ); - if (!fp) { - throw DL_ABORT_EX(fmt(EX_SEGMENT_FILE_READ, filename_.c_str())); - } unsigned char versionBuf[2]; - READ_CHECK(fp, versionBuf, sizeof(versionBuf)); + READ_CHECK_STATIC(fp, versionBuf, sizeof(versionBuf), filename); std::string versionHex = util::toHex(versionBuf, sizeof(versionBuf)); int version; if ("0000" == versionHex) { @@ -252,6 +244,53 @@ void DefaultBtProgressInfoFile::load() throw DL_ABORT_EX( fmt("Unsupported ctrl file version: %s", versionHex.c_str())); } + + return version; +} + +std::array DefaultBtProgressInfoFile::getInfoHash(const std::string& control_file) +{ + A2_LOG_INFO(fmt(MSG_LOADING_SEGMENT_FILE, control_file.c_str())); + BufferedFile fp(control_file.c_str(), BufferedFile::READ); + if (!fp) { + throw DL_ABORT_EX(fmt(EX_SEGMENT_FILE_READ, control_file.c_str())); + } + + auto version = getControlFileVersion(fp, control_file); + + unsigned char extension[4]; + READ_CHECK_STATIC(fp, extension, sizeof(extension), control_file); + + uint32_t infoHashLength; + READ_CHECK_STATIC(fp, &infoHashLength, sizeof(infoHashLength), control_file); + if (version >= 1) { + infoHashLength = ntohl(infoHashLength); + } + if (infoHashLength != INFO_HASH_LENGTH) { + throw DL_ABORT_EX(fmt("Invalid info hash length: %d", infoHashLength)); + } + + std::array savedInfoHash; + if (infoHashLength > 0) { + READ_CHECK_STATIC(fp, savedInfoHash.data(), infoHashLength, control_file); + } + + return savedInfoHash; +} + +// It is assumed that integers are saved as: +// 1) host byte order if version == 0000 +// 2) network byte order if version == 0001 +void DefaultBtProgressInfoFile::load() +{ + A2_LOG_INFO(fmt(MSG_LOADING_SEGMENT_FILE, filename_.c_str())); + BufferedFile fp(filename_.c_str(), BufferedFile::READ); + if (!fp) { + throw DL_ABORT_EX(fmt(EX_SEGMENT_FILE_READ, filename_.c_str())); + } + + auto version = getControlFileVersion(fp, filename_); + unsigned char extension[4]; READ_CHECK(fp, extension, sizeof(extension)); bool infoHashCheckEnabled = false; diff --git a/src/DefaultBtProgressInfoFile.h b/src/DefaultBtProgressInfoFile.h index 33a28044..ef7a4a4a 100644 --- a/src/DefaultBtProgressInfoFile.h +++ b/src/DefaultBtProgressInfoFile.h @@ -36,6 +36,8 @@ #define D_DEFAULT_BT_PROGRESS_INFO_FILE_H #include "BtProgressInfoFile.h" +#include "BufferedFile.h" +#include "BtConstants.h" #include @@ -93,6 +95,11 @@ public: void setBtRuntime(const std::shared_ptr& btRuntime); #endif // ENABLE_BITTORRENT + // Assume getting pointer to the start of the file + static uint getControlFileVersion(BufferedFile& fp, const std::string& filename); + + static std::array getInfoHash(const std::string& control_file); + static const std::string& getSuffix() { static std::string suffix = ".aria2";