From fa30fe4b15dc10d5ac278534da54e63951a099e7 Mon Sep 17 00:00:00 2001 From: Nils Maier Date: Fri, 20 Sep 2013 07:54:30 +0200 Subject: [PATCH] One MessageDigestImpl.h to rule them all. --- src/AppleMessageDigestImpl.cc | 60 ++++--------- src/AppleMessageDigestImpl.h | 71 ---------------- src/LibgcryptMessageDigestImpl.cc | 133 ++++++++++++++--------------- src/LibgcryptMessageDigestImpl.h | 75 ---------------- src/LibnettleMessageDigestImpl.cc | 123 ++++++++++++--------------- src/LibnettleMessageDigestImpl.h | 75 ---------------- src/LibsslMessageDigestImpl.cc | 137 +++++++++++++----------------- src/LibsslMessageDigestImpl.h | 75 ---------------- src/Makefile.am | 10 +-- src/MessageDigestImpl.h | 83 +++++++++++++++--- src/WinMessageDigestImpl.cc | 104 +++++++++++++---------- src/WinMessageDigestImpl.h | 70 --------------- 12 files changed, 329 insertions(+), 687 deletions(-) delete mode 100644 src/AppleMessageDigestImpl.h delete mode 100644 src/LibgcryptMessageDigestImpl.h delete mode 100644 src/LibnettleMessageDigestImpl.h delete mode 100644 src/LibsslMessageDigestImpl.h delete mode 100644 src/WinMessageDigestImpl.h diff --git a/src/AppleMessageDigestImpl.cc b/src/AppleMessageDigestImpl.cc index 80ac9e70..a10bf0b6 100644 --- a/src/AppleMessageDigestImpl.cc +++ b/src/AppleMessageDigestImpl.cc @@ -32,15 +32,13 @@ * files in the program, then also delete it here. */ /* copyright --> */ -#include "AppleMessageDigestImpl.h" + +#include "MessageDigestImpl.h" #include -#include "array_fun.h" -#include "a2functional.h" -#include "HashFuncEntry.h" - namespace aria2 { +namespace { template MessageDigestSHA512; +} // namespace + std::unique_ptr MessageDigestImpl::sha1() { return std::unique_ptr(new MessageDigestSHA1()); } -std::unique_ptr MessageDigestImpl::create -(const std::string& hashType) -{ - if (hashType == "sha-1") { - return make_unique(); - } - if (hashType == "sha-224") { - return make_unique(); - } - if (hashType == "sha-256") { - return make_unique(); - } - if (hashType == "sha-384") { - return make_unique(); - } - if (hashType == "sha-512") { - return make_unique(); - } - if (hashType == "md5") { - return make_unique(); - } - return nullptr; -} - -bool MessageDigestImpl::supports(const std::string& hashType) -{ - return hashType == "sha-1" || hashType == "sha-224" || - hashType == "sha-256" || hashType == "sha-384" || - hashType == "sha-512" || hashType == "md5"; -} - -size_t MessageDigestImpl::getDigestLength(const std::string& hashType) -{ - std::unique_ptr impl = create(hashType); - if (!impl) { - return 0; - } - return impl->getDigestLength(); -} +MessageDigestImpl::hashes_t MessageDigestImpl::hashes = { + { "sha-1", make_hi() }, + { "sha-224", make_hi() }, + { "sha-256", make_hi() }, + { "sha-384", make_hi() }, + { "sha-512", make_hi() }, + { "md5", make_hi() }, +}; } // namespace aria2 diff --git a/src/AppleMessageDigestImpl.h b/src/AppleMessageDigestImpl.h deleted file mode 100644 index a552966c..00000000 --- a/src/AppleMessageDigestImpl.h +++ /dev/null @@ -1,71 +0,0 @@ -/* */ -#ifndef D_APPLE_MESSAGE_DIGEST_IMPL_H -#define D_APPLE_MESSAGE_DIGEST_IMPL_H - -#include "common.h" - -#include -#include - -namespace aria2 { - -class MessageDigestImpl { -public: - virtual ~MessageDigestImpl() {} - static std::unique_ptr sha1(); - static std::unique_ptr create(const std::string& hashType); - - static bool supports(const std::string& hashType); - static size_t getDigestLength(const std::string& hashType); - -public: - virtual size_t getDigestLength() const = 0; - virtual void reset() = 0; - virtual void update(const void* data, size_t length) = 0; - virtual void digest(unsigned char* md) = 0; - -protected: - MessageDigestImpl() {} - -private: - MessageDigestImpl(const MessageDigestImpl&); - MessageDigestImpl& operator=(const MessageDigestImpl&); - -}; - -} // namespace aria2 - -#endif // D_APPLE_MESSAGE_DIGEST_IMPL_H diff --git a/src/LibgcryptMessageDigestImpl.cc b/src/LibgcryptMessageDigestImpl.cc index d6a52e57..f15ba8c9 100644 --- a/src/LibgcryptMessageDigestImpl.cc +++ b/src/LibgcryptMessageDigestImpl.cc @@ -2,7 +2,7 @@ /* * aria2 - The high speed download utility * - * Copyright (C) 2010 Tatsuhiro Tsujikawa + * Copyright (C) 2013 Nils Maier * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -32,84 +32,81 @@ * files in the program, then also delete it here. */ /* copyright --> */ -#include "LibgcryptMessageDigestImpl.h" -#include +#include "MessageDigestImpl.h" -#include "array_fun.h" -#include "HashFuncEntry.h" -#include "a2functional.h" +#include namespace aria2 { -MessageDigestImpl::MessageDigestImpl(int hashFunc) : hashFunc_{hashFunc} -{ - gcry_md_open(&ctx_, hashFunc_, 0); -} +namespace { +template +class MessageDigestBase : public MessageDigestImpl { +private: + struct Deleter { + void operator()(gcry_md_hd_t ctx) { + if (ctx) { + gcry_md_close(ctx); + } + } + }; -MessageDigestImpl::~MessageDigestImpl() -{ - gcry_md_close(ctx_); -} +public: + MessageDigestBase() { + gcry_md_hd_t ctx = nullptr; + gcry_md_open(&ctx, hash, 0); + ctx_.reset(ctx); + reset(); + } + virtual ~MessageDigestBase() {} + + static size_t length() { + return ::gcry_md_get_algo_dlen(hash); + } + virtual size_t getDigestLength() const CXX11_OVERRIDE { + return ::gcry_md_get_algo_dlen(hash); + } + virtual void reset() CXX11_OVERRIDE { + ::gcry_md_reset(ctx_.get()); + } + virtual void update(const void* data, size_t length) CXX11_OVERRIDE { + auto bytes = reinterpret_cast(data); + while (length) { + size_t l = std::min(length, (size_t)std::numeric_limits::max()); + gcry_md_write(ctx_.get(), bytes, length); + length -= l; + bytes += l; + } + } + virtual void digest(unsigned char* md) CXX11_OVERRIDE { + ::memcpy(md, gcry_md_read(ctx_.get(), 0), getDigestLength()); + } + +private: + std::unique_ptr::type, Deleter> ctx_; + size_t len_; +}; + +typedef MessageDigestBase MessageDigestMD5; +typedef MessageDigestBase MessageDigestSHA1; +typedef MessageDigestBase MessageDigestSHA224; +typedef MessageDigestBase MessageDigestSHA256; +typedef MessageDigestBase MessageDigestSHA384; +typedef MessageDigestBase MessageDigestSHA512; +} // namespace std::unique_ptr MessageDigestImpl::sha1() { - return make_unique(GCRY_MD_SHA1); + return std::unique_ptr(new MessageDigestSHA1()); } -typedef HashFuncEntry CHashFuncEntry; -typedef FindHashFunc CFindHashFunc; - -namespace { -CHashFuncEntry hashFuncs[] = { - CHashFuncEntry("sha-1", GCRY_MD_SHA1), - CHashFuncEntry("sha-224", GCRY_MD_SHA224), - CHashFuncEntry("sha-256", GCRY_MD_SHA256), - CHashFuncEntry("sha-384", GCRY_MD_SHA384), - CHashFuncEntry("sha-512", GCRY_MD_SHA512), - CHashFuncEntry("md5", GCRY_MD_MD5) +MessageDigestImpl::hashes_t MessageDigestImpl::hashes = { + { "sha-1", make_hi() }, + { "sha-224", make_hi() }, + { "sha-256", make_hi() }, + { "sha-384", make_hi() }, + { "sha-512", make_hi() }, + { "md5", make_hi() }, }; -} // namespace - -std::unique_ptr MessageDigestImpl::create -(const std::string& hashType) -{ - int hashFunc = getHashFunc(std::begin(hashFuncs), std::end(hashFuncs), - hashType); - return make_unique(hashFunc); -} - -bool MessageDigestImpl::supports(const std::string& hashType) -{ - return std::end(hashFuncs) != std::find_if(std::begin(hashFuncs), - std::end(hashFuncs), - CFindHashFunc(hashType)); -} - -size_t MessageDigestImpl::getDigestLength(const std::string& hashType) -{ - int hashFunc = getHashFunc(std::begin(hashFuncs), std::end(hashFuncs), - hashType); - return gcry_md_get_algo_dlen(hashFunc); -} - -size_t MessageDigestImpl::getDigestLength() const -{ - return gcry_md_get_algo_dlen(hashFunc_); -} - -void MessageDigestImpl::reset() -{ - gcry_md_reset(ctx_); -} -void MessageDigestImpl::update(const void* data, size_t length) -{ - gcry_md_write(ctx_, data, length); -} - -void MessageDigestImpl::digest(unsigned char* md) -{ - memcpy(md, gcry_md_read(ctx_, 0), gcry_md_get_algo_dlen(hashFunc_)); -} } // namespace aria2 diff --git a/src/LibgcryptMessageDigestImpl.h b/src/LibgcryptMessageDigestImpl.h deleted file mode 100644 index 56cadf14..00000000 --- a/src/LibgcryptMessageDigestImpl.h +++ /dev/null @@ -1,75 +0,0 @@ -/* */ -#ifndef D_LIBGCRYPT_MESSAGE_DIGEST_IMPL_H -#define D_LIBGCRYPT_MESSAGE_DIGEST_IMPL_H - -#include "common.h" - -#include -#include - -#include - -namespace aria2 { - -class MessageDigestImpl { -public: - MessageDigestImpl(int hashFunc); - // We don't implement copy ctor. - MessageDigestImpl(const MessageDigestImpl&) = delete; - // We don't implement assignment operator. - MessageDigestImpl& operator==(const MessageDigestImpl&) = delete; - - ~MessageDigestImpl(); - - static std::unique_ptr sha1(); - static std::unique_ptr create - (const std::string& hashType); - - static bool supports(const std::string& hashType); - static size_t getDigestLength(const std::string& hashType); - - size_t getDigestLength() const; - void reset(); - void update(const void* data, size_t length); - void digest(unsigned char* md); -private: - int hashFunc_; - gcry_md_hd_t ctx_; -}; - -} // namespace aria2 - -#endif // D_LIBGCRYPT_MESSAGE_DIGEST_IMPL_H diff --git a/src/LibnettleMessageDigestImpl.cc b/src/LibnettleMessageDigestImpl.cc index 5bb1326a..1f03c2e6 100644 --- a/src/LibnettleMessageDigestImpl.cc +++ b/src/LibnettleMessageDigestImpl.cc @@ -2,7 +2,7 @@ /* * aria2 - The high speed download utility * - * Copyright (C) 2011 Tatsuhiro Tsujikawa + * Copyright (C) 2013 Nils Maier * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -32,86 +32,69 @@ * files in the program, then also delete it here. */ /* copyright --> */ -#include "LibnettleMessageDigestImpl.h" -#include +#include "MessageDigestImpl.h" -#include "array_fun.h" -#include "HashFuncEntry.h" -#include "a2functional.h" +#include namespace aria2 { -MessageDigestImpl::MessageDigestImpl(const nettle_hash* hashInfo) - : hashInfo_{hashInfo}, - ctx_{new char[hashInfo->context_size]} -{ - reset(); -} +namespace { +template +class MessageDigestBase : public MessageDigestImpl { +public: + MessageDigestBase() : ctx_(new char[hash->context_size]) { + reset(); + } + virtual ~MessageDigestBase() {} -MessageDigestImpl::~MessageDigestImpl() -{ - delete [] ctx_; -} + static size_t length() { + return hash->digest_size; + } + virtual size_t getDigestLength() const CXX11_OVERRIDE { + return hash->digest_size; + } + virtual void reset() CXX11_OVERRIDE { + hash->init(ctx_.get()); + } + virtual void update(const void* data, size_t length) CXX11_OVERRIDE { + auto bytes = reinterpret_cast(data); + while (length) { + size_t l = std::min(length, (size_t)std::numeric_limits::max()); + hash->update(ctx_.get(), l, bytes); + length -= l; + bytes += l; + } + } + virtual void digest(unsigned char* md) CXX11_OVERRIDE { + hash->digest(ctx_.get(), getDigestLength(), md); + } + +private: + std::unique_ptr ctx_; + size_t len_; +}; + +typedef MessageDigestBase<&nettle_md5> MessageDigestMD5; +typedef MessageDigestBase<&nettle_sha1> MessageDigestSHA1; +typedef MessageDigestBase<&nettle_sha224> MessageDigestSHA224; +typedef MessageDigestBase<&nettle_sha256> MessageDigestSHA256; +typedef MessageDigestBase<&nettle_sha384> MessageDigestSHA384; +typedef MessageDigestBase<&nettle_sha512> MessageDigestSHA512; +} // namespace std::unique_ptr MessageDigestImpl::sha1() { - return make_unique(&nettle_sha1); + return std::unique_ptr(new MessageDigestSHA1()); } -typedef HashFuncEntry CHashFuncEntry; -typedef FindHashFunc CFindHashFunc; - -namespace { -CHashFuncEntry hashFuncs[] = { - CHashFuncEntry("sha-1", &nettle_sha1), - CHashFuncEntry("sha-224", &nettle_sha224), - CHashFuncEntry("sha-256", &nettle_sha256), - CHashFuncEntry("sha-384", &nettle_sha384), - CHashFuncEntry("sha-512", &nettle_sha512), - CHashFuncEntry("md5", &nettle_md5) +MessageDigestImpl::hashes_t MessageDigestImpl::hashes = { + { "sha-1", make_hi() }, + { "sha-224", make_hi() }, + { "sha-256", make_hi() }, + { "sha-384", make_hi() }, + { "sha-512", make_hi() }, + { "md5", make_hi() }, }; -} // namespace - -std::unique_ptr MessageDigestImpl::create -(const std::string& hashType) -{ - auto hashInfo = - getHashFunc(std::begin(hashFuncs), std::end(hashFuncs), hashType); - return make_unique(hashInfo); -} - -bool MessageDigestImpl::supports(const std::string& hashType) -{ - return std::end(hashFuncs) != std::find_if(std::begin(hashFuncs), - std::end(hashFuncs), - CFindHashFunc(hashType)); -} - -size_t MessageDigestImpl::getDigestLength(const std::string& hashType) -{ - auto hashInfo = - getHashFunc(std::begin(hashFuncs), std::end(hashFuncs), hashType); - return hashInfo->digest_size; -} - -size_t MessageDigestImpl::getDigestLength() const -{ - return hashInfo_->digest_size; -} - -void MessageDigestImpl::reset() -{ - hashInfo_->init(ctx_); -} -void MessageDigestImpl::update(const void* data, size_t length) -{ - hashInfo_->update(ctx_, length, static_cast(data)); -} - -void MessageDigestImpl::digest(unsigned char* md) -{ - hashInfo_->digest(ctx_, getDigestLength(), md); -} } // namespace aria2 diff --git a/src/LibnettleMessageDigestImpl.h b/src/LibnettleMessageDigestImpl.h deleted file mode 100644 index 4d0b3d4c..00000000 --- a/src/LibnettleMessageDigestImpl.h +++ /dev/null @@ -1,75 +0,0 @@ -/* */ -#ifndef D_LIBNETTLE_MESSAGE_DIGEST_IMPL_H -#define D_LIBNETTLE_MESSAGE_DIGEST_IMPL_H - -#include "common.h" - -#include -#include - -#include - -namespace aria2 { - -class MessageDigestImpl { -public: - MessageDigestImpl(const nettle_hash* hashInfo); - // We don't implement copy ctor. - MessageDigestImpl(const MessageDigestImpl&) = delete; - // We don't implement assignment operator. - MessageDigestImpl& operator==(const MessageDigestImpl&) = delete; - - ~MessageDigestImpl(); - - static std::unique_ptr sha1(); - static std::unique_ptr create - (const std::string& hashType); - - static bool supports(const std::string& hashType); - static size_t getDigestLength(const std::string& hashType); - - size_t getDigestLength() const; - void reset(); - void update(const void* data, size_t length); - void digest(unsigned char* md); -private: - const nettle_hash* hashInfo_; - char* ctx_; -}; - -} // namespace aria2 - -#endif // D_LIBNETTLE_MESSAGE_DIGEST_IMPL_H diff --git a/src/LibsslMessageDigestImpl.cc b/src/LibsslMessageDigestImpl.cc index eea12505..e4536a4c 100644 --- a/src/LibsslMessageDigestImpl.cc +++ b/src/LibsslMessageDigestImpl.cc @@ -2,7 +2,7 @@ /* * aria2 - The high speed download utility * - * Copyright (C) 2010 Tatsuhiro Tsujikawa + * Copyright (C) 2013 Nils Maier * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -32,95 +32,76 @@ * files in the program, then also delete it here. */ /* copyright --> */ -#include "LibsslMessageDigestImpl.h" -#include +#include "MessageDigestImpl.h" -#include "array_fun.h" -#include "HashFuncEntry.h" -#include "a2functional.h" +#include namespace aria2 { -MessageDigestImpl::MessageDigestImpl(const EVP_MD* hashFunc) - : hashFunc_{hashFunc} -{ - EVP_MD_CTX_init(&ctx_); - reset(); -} +template +class MessageDigestBase : public MessageDigestImpl { +public: + MessageDigestBase() : md_(init_fn()), len_(EVP_MD_size(md_)) { + EVP_MD_CTX_init(&ctx_); + reset(); + } + virtual ~MessageDigestBase() { + EVP_MD_CTX_cleanup(&ctx_); + } -MessageDigestImpl::~MessageDigestImpl() -{ - EVP_MD_CTX_cleanup(&ctx_); -} + static size_t length() { + return EVP_MD_size(init_fn()); + } + virtual size_t getDigestLength() const CXX11_OVERRIDE { + return len_; + } + virtual void reset() CXX11_OVERRIDE { + EVP_DigestInit_ex(&ctx_, md_, 0); + } + virtual void update(const void* data, size_t length) CXX11_OVERRIDE { + auto bytes = reinterpret_cast(data); + while (length) { + size_t l = std::min(length, (size_t)std::numeric_limits::max()); + EVP_DigestUpdate(&ctx_, bytes, l); + length -= l; + bytes += l; + } + } + virtual void digest(unsigned char* md) CXX11_OVERRIDE { + unsigned int len; + EVP_DigestFinal_ex(&ctx_, md, &len); + } + +private: + EVP_MD_CTX ctx_; + const EVP_MD* md_; + const size_t len_; +}; + +typedef MessageDigestBase MessageDigestMD5; +typedef MessageDigestBase MessageDigestSHA1; std::unique_ptr MessageDigestImpl::sha1() { - return make_unique(EVP_sha1()); + return std::unique_ptr(new MessageDigestSHA1()); } -typedef HashFuncEntry CHashFuncEntry; -typedef FindHashFunc CFindHashFunc; - -namespace { -CHashFuncEntry hashFuncs[] = { - CHashFuncEntry("sha-1", EVP_sha1()), +MessageDigestImpl::hashes_t MessageDigestImpl::hashes = { + { "sha-1", make_hi() }, #ifdef HAVE_EVP_SHA224 - CHashFuncEntry("sha-224", EVP_sha224()), -#endif // HAVE_EVP_SHA224 -#ifdef HAVE_EVP_SHA256 - CHashFuncEntry("sha-256", EVP_sha256()), -#endif // HAVE_EVP_SHA256 -#ifdef HAVE_EVP_SHA384 - CHashFuncEntry("sha-384", EVP_sha384()), -#endif // HAVE_EVP_SHA384 -#ifdef HAVE_EVP_SHA512 - CHashFuncEntry("sha-512", EVP_sha512()), -#endif // HAVE_EVP_SHA512 - CHashFuncEntry("md5", EVP_md5()) + { "sha-224", make_hi >() }, +#endif +#ifdef HAVE_EVP_SHA224 + { "sha-256", make_hi >() }, +#endif +#ifdef HAVE_EVP_SHA224 + { "sha-384", make_hi >() }, +#endif +#ifdef HAVE_EVP_SHA224 + { "sha-512", make_hi >() }, +#endif + { "md5", make_hi() }, }; -} // namespace - -std::unique_ptr MessageDigestImpl::create -(const std::string& hashType) -{ - auto hashFunc = getHashFunc(std::begin(hashFuncs), std::end(hashFuncs), - hashType); - return make_unique(hashFunc); -} - -bool MessageDigestImpl::supports(const std::string& hashType) -{ - return std::end(hashFuncs) != std::find_if(std::begin(hashFuncs), - std::end(hashFuncs), - CFindHashFunc(hashType)); -} - -size_t MessageDigestImpl::getDigestLength(const std::string& hashType) -{ - auto hashFunc = getHashFunc(std::begin(hashFuncs), std::end(hashFuncs), - hashType); - return EVP_MD_size(hashFunc); -} - -size_t MessageDigestImpl::getDigestLength() const -{ - return EVP_MD_size(hashFunc_); -} - -void MessageDigestImpl::reset() -{ - EVP_DigestInit_ex(&ctx_, hashFunc_, 0); -} -void MessageDigestImpl::update(const void* data, size_t length) -{ - EVP_DigestUpdate(&ctx_, data, length); -} - -void MessageDigestImpl::digest(unsigned char* md) -{ - unsigned int len; - EVP_DigestFinal_ex(&ctx_, md, &len); -} } // namespace aria2 diff --git a/src/LibsslMessageDigestImpl.h b/src/LibsslMessageDigestImpl.h deleted file mode 100644 index d739ec37..00000000 --- a/src/LibsslMessageDigestImpl.h +++ /dev/null @@ -1,75 +0,0 @@ -/* */ -#ifndef D_LIBSSL_MESSAGE_DIGEST_IMPL_H -#define D_LIBSSL_MESSAGE_DIGEST_IMPL_H - -#include "common.h" - -#include -#include - -#include - -namespace aria2 { - -class MessageDigestImpl { -public: - MessageDigestImpl(const EVP_MD* hashFunc); - // We don't implement copy ctor. - MessageDigestImpl(const MessageDigestImpl&) = delete; - // We don't implement assignment operator. - MessageDigestImpl& operator==(const MessageDigestImpl&) = delete; - - ~MessageDigestImpl(); - - static std::unique_ptr sha1(); - static std::unique_ptr create - (const std::string& hashType); - - static bool supports(const std::string& hashType); - static size_t getDigestLength(const std::string& hashType); - - size_t getDigestLength() const; - void reset(); - void update(const void* data, size_t length); - void digest(unsigned char* md); -private: - const EVP_MD* hashFunc_; - EVP_MD_CTX ctx_; -}; - -} // namespace aria2 - -#endif // D_LIBSSL_MESSAGE_DIGEST_IMPL_H diff --git a/src/Makefile.am b/src/Makefile.am index c3652181..15ac829a 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -321,7 +321,7 @@ SRCS += TLSContext.h TLSSession.h endif # ENABLE_SSL if USE_APPLE_MD -SRCS += AppleMessageDigestImpl.cc AppleMessageDigestImpl.h +SRCS += AppleMessageDigestImpl.cc endif # USE_APPLE_MD if HAVE_APPLETLS @@ -330,7 +330,7 @@ SRCS += AppleTLSContext.cc AppleTLSContext.h \ endif # HAVE_APPLETLS if USE_WINDOWS_MD -SRCS += WinMessageDigestImpl.cc WinMessageDigestImpl.h +SRCS += WinMessageDigestImpl.cc endif # USE_WINDOWS_MD if HAVE_LIBGNUTLS @@ -342,14 +342,14 @@ if HAVE_LIBGCRYPT SRCS += LibgcryptARC4Encryptor.cc LibgcryptARC4Encryptor.h \ LibgcryptDHKeyExchange.cc LibgcryptDHKeyExchange.h if USE_LIBGCRYPT_MD -SRCS += LibgcryptMessageDigestImpl.cc LibgcryptMessageDigestImpl.h +SRCS += LibgcryptMessageDigestImpl.cc endif # USE_LIBGCRYPT_MD endif # HAVE_LIBGCRYPT if HAVE_LIBNETTLE SRCS += LibnettleARC4Encryptor.cc LibnettleARC4Encryptor.h if USE_LIBNETTLE_MD -SRCS += LibnettleMessageDigestImpl.cc LibnettleMessageDigestImpl.h +SRCS += LibnettleMessageDigestImpl.cc endif # USE_LIBNETTLE_MD endif # HAVE_LIBNETTLE @@ -366,7 +366,7 @@ SRCS += LibsslTLSContext.cc LibsslTLSContext.h \ LibsslTLSSession.cc LibsslTLSSession.h endif # !HAVE_APPLETLS if USE_OPENSSL_MD -SRCS += LibsslMessageDigestImpl.cc LibsslMessageDigestImpl.h +SRCS += LibsslMessageDigestImpl.cc endif endif # HAVE_OPENSSL diff --git a/src/MessageDigestImpl.h b/src/MessageDigestImpl.h index aee3057f..0cee1cb5 100644 --- a/src/MessageDigestImpl.h +++ b/src/MessageDigestImpl.h @@ -2,7 +2,7 @@ /* * aria2 - The high speed download utility * - * Copyright (C) 2010 Tatsuhiro Tsujikawa + * Copyright (C) 2013 Nils Maier * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -32,20 +32,79 @@ * files in the program, then also delete it here. */ /* copyright --> */ + #ifndef D_MESSAGE_DIGEST_IMPL_H #define D_MESSAGE_DIGEST_IMPL_H +#include "common.h" -#ifdef USE_APPLE_MD -# include "AppleMessageDigestImpl.h" -#elif defined(USE_WINDOWS_MD) -# include "WinMessageDigestImpl.h" -#elif defined(USE_LIBNETTLE_MD) -# include "LibnettleMessageDigestImpl.h" -#elif defined(USE_LIBGCRYPT_MD) -# include "LibgcryptMessageDigestImpl.h" -#elif defined(USE_OPENSSL_MD) -# include "LibsslMessageDigestImpl.h" -#endif +#include +#include +#include +#include +#include + +#include "a2functional.h" + +namespace aria2 { + +class MessageDigestImpl { +public: + typedef std::function()> factory_t; + typedef std::tuple hash_info_t; + typedef std::map hashes_t; + + template + inline static hash_info_t make_hi() { + return std::make_tuple([]() { return make_unique(); }, T::length()); + } + +private: + static hashes_t hashes; + + MessageDigestImpl(const MessageDigestImpl&) = delete; + MessageDigestImpl& operator=(const MessageDigestImpl&) = delete; + +public: + virtual ~MessageDigestImpl() {} + static std::unique_ptr sha1(); + + inline static std::unique_ptr create( + const std::string& hashType) { + auto i = hashes.find(hashType); + if (i == hashes.end()) { + return nullptr; + } + factory_t factory; + std::tie(factory, std::ignore) = i->second; + return factory(); + } + + inline static bool supports(const std::string& hashType) { + auto i = hashes.find(hashType); + return i != hashes.end(); + } + + inline static size_t getDigestLength(const std::string& hashType) { + auto i = hashes.find(hashType); + if (i == hashes.end()) { + return 0; + } + size_t len; + std::tie(std::ignore, len) = i->second; + return len; + } + +public: + virtual size_t getDigestLength() const = 0; + virtual void reset() = 0; + virtual void update(const void* data, size_t length) = 0; + virtual void digest(unsigned char* md) = 0; + +protected: + MessageDigestImpl() {} +}; + +} // namespace aria2 #endif // D_MESSAGE_DIGEST_IMPL_H diff --git a/src/WinMessageDigestImpl.cc b/src/WinMessageDigestImpl.cc index bbc6d897..5628f484 100644 --- a/src/WinMessageDigestImpl.cc +++ b/src/WinMessageDigestImpl.cc @@ -33,14 +33,13 @@ */ /* copyright --> */ -#include "WinMessageDigestImpl.h" +#include "MessageDigestImpl.h" #include -#include "array_fun.h" -#include "a2functional.h" -#include "HashFuncEntry.h" +#include "fmt.h" #include "DlAbortEx.h" +#include "LogFactory.h" namespace { using namespace aria2; @@ -50,9 +49,12 @@ private: HCRYPTPROV provider_; public: Context() { - if (!::CryptAcquireContext(&provider_, nullptr, nullptr, PROV_RSA_FULL, - CRYPT_VERIFYCONTEXT)) { - throw DL_ABORT_EX("Failed to get cryptographic provider"); + if (!::CryptAcquireContext(&provider_, nullptr, nullptr, + PROV_RSA_AES, CRYPT_VERIFYCONTEXT)) { + if (!::CryptAcquireContext(&provider_, nullptr, nullptr, PROV_RSA_AES, + CRYPT_VERIFYCONTEXT)) { + throw DL_ABORT_EX("Failed to get cryptographic provider"); + } } } ~Context() { @@ -67,10 +69,30 @@ public: // XXX static OK? static Context context_; +inline size_t getAlgLength(ALG_ID id) +{ + Context context; + HCRYPTHASH hash; + if (!::CryptCreateHash(context.get(), id, 0, 0, &hash)) { + throw DL_ABORT_EX(fmt("Failed to initialize hash %d", id)); + } + + DWORD rv = 0; + DWORD len = sizeof(rv); + if (!::CryptGetHashParam(hash, HP_HASHSIZE, reinterpret_cast(&rv), + &len, 0)) { + throw DL_ABORT_EX("Failed to initialize hash(2)"); + } + ::CryptDestroyHash(hash); + + return rv; +} + } // namespace namespace aria2 { + template class MessageDigestBase : public MessageDigestImpl { private: @@ -88,6 +110,10 @@ public: MessageDigestBase() : hash_(0), len_(0) { reset(); } virtual ~MessageDigestBase() { destroy(); } + static size_t length() { + MessageDigestBase rv; + return rv.getDigestLength(); + } virtual size_t getDigestLength() const CXX11_OVERRIDE { return len_; } @@ -96,11 +122,10 @@ public: if (!::CryptCreateHash(context_.get(), id, 0, 0, &hash_)) { throw DL_ABORT_EX("Failed to create hash"); } - DWORD len = sizeof(len_); if (!::CryptGetHashParam(hash_, HP_HASHSIZE, reinterpret_cast(&len_), &len, 0)) { - throw DL_ABORT_EX("Failed to create hash"); + throw DL_ABORT_EX("Failed to initialize hash"); } } virtual void update(const void* data, size_t length) CXX11_OVERRIDE { @@ -133,45 +158,36 @@ std::unique_ptr MessageDigestImpl::sha1() return std::unique_ptr(new MessageDigestSHA1()); } -std::unique_ptr MessageDigestImpl::create( - const std::string& hashType) -{ - if (hashType == "sha-1") { - return make_unique(); - } - if (hashType == "sha-256") { - return make_unique(); - } - if (hashType == "sha-384") { - return make_unique(); - } - if (hashType == "sha-512") { - return make_unique(); - } - if (hashType == "md5") { - return make_unique(); - } - return nullptr; -} +namespace { +MessageDigestImpl::hashes_t initialize() { + MessageDigestImpl::hashes_t rv = { + { "sha-1", MessageDigestImpl::make_hi() }, + { "md5", MessageDigestImpl::make_hi() }, + }; -bool MessageDigestImpl::supports(const std::string& hashType) -{ try { - return !!create(hashType); + rv.emplace("sha-256", MessageDigestImpl::make_hi()); } - catch (RecoverableException& ex) { - // no op + catch (RecoverableException &ex) { + printf("SHA-256 is not supported on this machine"); + } + try { + rv.emplace("sha-384", MessageDigestImpl::make_hi()); + } + catch (RecoverableException &ex) { + printf("SHA-384 is not supported on this machine"); + } + try { + rv.emplace("sha-512", MessageDigestImpl::make_hi()); + } + catch (RecoverableException &ex) { + printf("SHA-512 is not supported on this machine"); } - return false; -} -size_t MessageDigestImpl::getDigestLength(const std::string& hashType) -{ - std::unique_ptr impl = create(hashType); - if (!impl) { - return 0; - } - return impl->getDigestLength(); -} + return rv; +}; +} // namespace + +MessageDigestImpl::hashes_t MessageDigestImpl::hashes = initialize(); } // namespace aria2 diff --git a/src/WinMessageDigestImpl.h b/src/WinMessageDigestImpl.h deleted file mode 100644 index 758937a7..00000000 --- a/src/WinMessageDigestImpl.h +++ /dev/null @@ -1,70 +0,0 @@ -/* */ -#ifndef D_WIN_MESSAGE_DIGEST_IMPL_H -#define D_WIN_MESSAGE_DIGEST_IMPL_H - -#include "common.h" - -#include -#include - -namespace aria2 { - -class MessageDigestImpl { -public: - virtual ~MessageDigestImpl() {} - static std::unique_ptr sha1(); - static std::unique_ptr create(const std::string& hashType); - - static bool supports(const std::string& hashType); - static size_t getDigestLength(const std::string& hashType); - -public: - virtual size_t getDigestLength() const = 0; - virtual void reset() = 0; - virtual void update(const void* data, size_t length) = 0; - virtual void digest(unsigned char* md) = 0; - -protected: - MessageDigestImpl() {} - -private: - MessageDigestImpl(const MessageDigestImpl&); - MessageDigestImpl& operator=(const MessageDigestImpl&); -}; - -} // namespace aria2 - -#endif // D_WIN_MESSAGE_DIGEST_IMPL_H