diff --git a/configure.ac b/configure.ac index 5d480507..70ef70c9 100644 --- a/configure.ac +++ b/configure.ac @@ -24,6 +24,7 @@ esac AC_DEFINE_UNQUOTED([TARGET], ["$target"], [Define target-type]) # Checks for arguments. +ARIA2_ARG_WITHOUT([appletls]) ARIA2_ARG_WITHOUT([gnutls]) ARIA2_ARG_WITHOUT([libnettle]) ARIA2_ARG_WITHOUT([libgmp]) @@ -145,7 +146,28 @@ if test "x$with_sqlite3" = "xyes"; then fi fi -if test "x$with_gnutls" = "xyes"; then +case "$host" in + *darwin*) + have_osx="yes" + ;; +esac + +if test "x$with_appletls" = "xyes"; then + AC_MSG_CHECKING([whether to enable Mac OS X native SSL/TLS]) + if test "x$have_osx" = "xyes"; then + AC_DEFINE([HAVE_APPLETLS], [1], [Define to 1 if you have Apple TLS]) + LDFLAGS="$LDFLAGS -framework CoreFoundation -framework Security" + have_appletls="yes" + AC_MSG_RESULT(yes) + else + AC_MSG_RESULT(no) + if test "x$with_appletls_requested" = "xyes"; then + ARIA2_DEP_NOT_MET([appletls]) + fi + fi +fi + +if test "x$with_gnutls" = "xyes" && test "x$have_appletls" != "xyes"; then # gnutls >= 2.8 doesn't have libgnutls-config anymore. We require # 2.2.0 because we use gnutls_priority_set_direct() PKG_CHECK_MODULES([LIBGNUTLS], [gnutls >= 2.2.0], @@ -163,7 +185,7 @@ if test "x$with_gnutls" = "xyes"; then fi fi -if test "x$with_openssl" = "xyes" && test "x$have_libgnutls" != "xyes"; then +if test "x$with_openssl" = "xyes" && test "x$have_appletls" != "xyes" && test "x$have_libgnutls" != "xyes"; then PKG_CHECK_MODULES([OPENSSL], [openssl >= 0.9.8], [have_openssl=yes], [have_openssl=no]) if test "x$have_openssl" = "xyes"; then @@ -235,8 +257,30 @@ if test "x$with_libcares" = "xyes"; then fi fi +use_md="" +if test "x$have_osx" == "xyes"; then + use_md="apple" + AC_DEFINE([USE_APPLE_MD], [1], [What message digest implementation to use]) +else + if test "x$have_libnettle" = "xyes"; then + AC_DEFINE([USE_LIBNETTLE_MD], [1], [What message digest implementation to use]) + use_md="libnettle" + else + if test "x$have_libgcrypt" = "xyes"; then + AC_DEFINE([USE_LIBGCRYPT_MD], [1], [What message digest implementation to use]) + use_md="libgcrypt" + else + if test = "x$have_openssl" = "xyes"; then + AC_DEFINE([USE_OPENSSL_MD], [1], [What message digest implementation to use]) + use_md="openssl" + fi + fi + fi +fi + # Define variables based on the result of the checks for libraries. -if test "x$have_libgnutls" = "xyes" || test "x$have_openssl" = "xyes"; then +if test "x$have_appletls" = "xyes" || test "x$have_libgnutls" = "xyes" || test "x$have_openssl" = "xyes"; then + have_ssl="yes" AC_DEFINE([ENABLE_SSL], [1], [Define to 1 if ssl support is enabled.]) AM_CONDITIONAL([ENABLE_SSL], true) AC_SUBST([ca_bundle]) @@ -244,14 +288,20 @@ else AM_CONDITIONAL([ENABLE_SSL], false) fi + +AM_CONDITIONAL([HAVE_OSX], [ test "x$have_osx" = "xyes" ]) +AM_CONDITIONAL([HAVE_APPLETLS], [ test "x$have_appletls" = "xyes" ]) +AM_CONDITIONAL([USE_APPLE_MD], [ test "x$use_md" = "xapple" ]) AM_CONDITIONAL([HAVE_LIBGNUTLS], [ test "x$have_libgnutls" = "xyes" ]) AM_CONDITIONAL([HAVE_LIBNETTLE], [ test "x$have_libnettle" = "xyes" ]) +AM_CONDITIONAL([USE_LIBNETTLE_MD], [ test "x$use_md" = "xlibnettle"]) AM_CONDITIONAL([HAVE_LIBGMP], [ test "x$have_libgmp" = "xyes" ]) AM_CONDITIONAL([HAVE_LIBGCRYPT], [ test "x$have_libgcrypt" = "xyes" ]) +AM_CONDITIONAL([USE_LIBGCRYPT_MD], [ test "x$use_md" = "xlibgcrypt"]) AM_CONDITIONAL([HAVE_OPENSSL], [ test "x$have_openssl" = "xyes" ]) +AM_CONDITIONAL([USE_OPENSSL_MD], [ test "x$use_md" = "xopenssl"]) -if test "x$have_libnettle" = "xyes" || test "x$have_libgcrypt" = "xyes" || - test "x$have_openssl" = "xyes"; then +if test "x$use_md" != "x"; then AC_DEFINE([ENABLE_MESSAGE_DIGEST], [1], [Define to 1 if message digest support is enabled.]) AM_CONDITIONAL([ENABLE_MESSAGE_DIGEST], true) @@ -325,9 +375,9 @@ AM_CONDITIONAL([HAVE_SQLITE3], [test "x$have_sqlite3" = "xyes"]) AC_SEARCH_LIBS([clock_gettime], [rt]) case "$host" in - *solaris*) - AC_SEARCH_LIBS([getaddrinfo], [nsl socket]) - ;; + *solaris*) + AC_SEARCH_LIBS([getaddrinfo], [nsl socket]) + ;; esac # Checks for header files. @@ -670,6 +720,8 @@ echo "LDFLAGS: $LDFLAGS" echo "LIBS: $LIBS" echo "DEFS: $DEFS" echo "SQLite3: $have_sqlite3" +echo "SSL Support: $have_ssl" +echo "AppleTLS: $have_appletls" echo "GnuTLS: $have_libgnutls" echo "OpenSSL: $have_openssl" echo "CA Bundle: $ca_bundle" diff --git a/src/AppleMessageDigestImpl.cc b/src/AppleMessageDigestImpl.cc new file mode 100644 index 00000000..1498f075 --- /dev/null +++ b/src/AppleMessageDigestImpl.cc @@ -0,0 +1,153 @@ +/* */ +#include "AppleMessageDigestImpl.h" + +#include + +#include "array_fun.h" +#include "HashFuncEntry.h" + +namespace aria2 { + +template +class MessageDigestBase : public MessageDigestImpl { +public: + MessageDigestBase() { reset(); } + + virtual size_t getDigestLength() const { + return dlen; + } + virtual void reset() { + init_fn(&ctx_); + } + virtual void update(const void* data, size_t length) { + while (length) { + CC_LONG l = std::min(length, (size_t)std::numeric_limits::max()); + update_fn(&ctx_, data, l); + length -= l; + } + } + virtual void digest(unsigned char* md) { + final_fn(md, &ctx_); + } +private: + ctx_t ctx_; +}; + +typedef MessageDigestBase +MessageDigestMD5; +typedef MessageDigestBase +MessageDigestSHA1; +typedef MessageDigestBase +MessageDigestSHA224; +typedef MessageDigestBase +MessageDigestSHA256; +typedef MessageDigestBase +MessageDigestSHA384; +typedef MessageDigestBase +MessageDigestSHA512; + +SharedHandle MessageDigestImpl::sha1() +{ + return SharedHandle(new MessageDigestSHA1()); +} + +SharedHandle MessageDigestImpl::create +(const std::string& hashType) +{ + if (hashType == "sha-1") { + return SharedHandle(new MessageDigestSHA1()); + } + if (hashType == "sha-224") { + return SharedHandle(new MessageDigestSHA224()); + } + if (hashType == "sha-256") { + return SharedHandle(new MessageDigestSHA256()); + } + if (hashType == "sha-384") { + return SharedHandle(new MessageDigestSHA384()); + } + if (hashType == "sha-512") { + return SharedHandle(new MessageDigestSHA512()); + } + if (hashType == "md5") { + return SharedHandle(new MessageDigestMD5()); + } + return SharedHandle(); +} + +bool MessageDigestImpl::supports(const std::string& hashType) +{ + return hashType == "sha-1" || hashType == "sha-224" || hashType == "sha-256" || hashType == "sha-384" || hashType == "sha-512" || hashType == "md5"; +} + +size_t MessageDigestImpl::getDigestLength(const std::string& hashType) +{ + SharedHandle impl = create(hashType); + if (!impl) { + return 0; + } + return impl->getDigestLength(); +} + +} // namespace aria2 diff --git a/src/AppleMessageDigestImpl.h b/src/AppleMessageDigestImpl.h new file mode 100644 index 00000000..b3851da5 --- /dev/null +++ b/src/AppleMessageDigestImpl.h @@ -0,0 +1,71 @@ +/* */ +#ifndef D_APPLE_MESSAGE_DIGEST_IMPL_H +#define D_APPLE_MESSAGE_DIGEST_IMPL_H + +#include "common.h" + +#include + +#include "SharedHandle.h" + +namespace aria2 { + +class MessageDigestImpl { +public: + static SharedHandle sha1(); + static SharedHandle create(const std::string& hashType); + + static bool supports(const std::string& hashType); + static size_t getDigestLength(const std::string& hashType); + +public: + virtual size_t getDigestLength() const = 0; + virtual void reset() = 0; + virtual void update(const void* data, size_t length) = 0; + virtual void digest(unsigned char* md) = 0; + +protected: + MessageDigestImpl() {} + +private: + MessageDigestImpl(const MessageDigestImpl&); + MessageDigestImpl& operator=(const MessageDigestImpl&); + +}; + +} // namespace aria2 + +#endif // D_APPLE_MESSAGE_DIGEST_IMPL_H diff --git a/src/TLSSessionConst.h b/src/AppleTLSContext.cc similarity index 69% rename from src/TLSSessionConst.h rename to src/AppleTLSContext.cc index f3b8422d..99d88505 100644 --- a/src/TLSSessionConst.h +++ b/src/AppleTLSContext.cc @@ -2,7 +2,7 @@ /* * aria2 - The high speed download utility * - * Copyright (C) 2013 Tatsuhiro Tsujikawa + * Copyright (C) 2013 Nils Maier * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -32,24 +32,31 @@ * files in the program, then also delete it here. */ /* copyright --> */ -#ifndef TLS_SESSION_CONST_H -#define TLS_SESSION_CONST_H +#include "AppleTLSContext.h" -#include "common.h" +#include "LogFactory.h" +#include "Logger.h" +#include "fmt.h" +#include "message.h" namespace aria2 { -enum TLSDirection { - TLS_WANT_READ = 1, - TLS_WANT_WRITE -}; +TLSContext* TLSContext::make(TLSSessionSide side) { + return new AppleTLSContext(side); +} + +bool AppleTLSContext::addCredentialFile(const std::string& certfile, + const std::string& keyfile) +{ + A2_LOG_WARN("TLS credential files are not supported. Use the KeyChain to manage your certificates."); + return false; +} + +bool AppleTLSContext::addTrustedCACertFile(const std::string& certfile) +{ + A2_LOG_WARN("TLS CA bundle files are not supported. Use the KeyChain to manage your certificates."); + return false; +} -enum TLSErrorCode { - TLS_ERR_OK = 0, - TLS_ERR_ERROR = -1, - TLS_ERR_WOULDBLOCK = -2 -}; } // namespace aria2 - -#endif // TLS_SESSION_CONST_H diff --git a/src/AppleTLSContext.h b/src/AppleTLSContext.h new file mode 100644 index 00000000..549d432e --- /dev/null +++ b/src/AppleTLSContext.h @@ -0,0 +1,90 @@ +/* */ +#ifndef D_APPLE_TLS_CONTEXT_H +#define D_APPLE_TLS_CONTEXT_H + +#include "common.h" + +#include +#include +#include + +#include "TLSContext.h" +#include "DlAbortEx.h" + +namespace aria2 { + +class AppleTLSContext : public TLSContext { +public: + AppleTLSContext(TLSSessionSide side) + : side_(side), + verifyPeer_(true) + {} + + virtual ~AppleTLSContext() {} + + // private key `keyfile' must be decrypted. + virtual bool addCredentialFile(const std::string& certfile, + const std::string& keyfile); + + virtual bool addSystemTrustedCACerts() { + return true; + } + + // certfile can contain multiple certificates. + virtual bool addTrustedCACertFile(const std::string& certfile); + + virtual bool good() const { + return true; + } + virtual TLSSessionSide getSide() const { + return side_; + } + + virtual bool getVerifyPeer() const { + return verifyPeer_; + } + virtual void setVerifyPeer(bool verify) { + verifyPeer_ = verify; + } + +private: + TLSSessionSide side_; + bool verifyPeer_; +}; + +} // namespace aria2 + +#endif // D_LIBSSL_TLS_CONTEXT_H diff --git a/src/AppleTLSSession.cc b/src/AppleTLSSession.cc new file mode 100644 index 00000000..a85c0c89 --- /dev/null +++ b/src/AppleTLSSession.cc @@ -0,0 +1,354 @@ +/* */ + +#include "AppleTLSSession.h" + +#include + +#include "fmt.h" +#include "LogFactory.h" + +#define ioErr -36 +#define paramErr -50 +#define errSSLServerAuthCompleted -9841 + +namespace { + static const SSLProtocol kTLSProtocol11_h = (SSLProtocol)(kSSLProtocolAll + 1); + static const SSLProtocol kTLSProtocol12_h = (SSLProtocol)(kSSLProtocolAll + 2); +} + +namespace aria2 { + +TLSSession* TLSSession::make(TLSContext* ctx) +{ + return new AppleTLSSession(static_cast(ctx)); +} + +AppleTLSSession::AppleTLSSession(AppleTLSContext* ctx) + : ctx_(ctx), + sslCtx_(0), + sockfd_(0), + state_(st_constructed), + lastError_(noErr), + writeBuffered_(0) +{ + lastError_ = SSLNewContext(ctx->getSide() == TLS_SERVER, &sslCtx_) == noErr; + if (lastError_ == noErr) { + state_ = st_error; + return; + } +#if defined(__MAC_10_8) + (void)SSLSetProtocolVersionMin(sslCtx_, kSSLProtocol3); + (void)SSLSetProtocolVersionMax(sslCtx_, kTLSProtocol12); +#else + (void)SSLSetProtocolVersionEnabled(sslCtx_, kSSLProtocolAll, false); + (void)SSLSetProtocolVersionEnabled(sslCtx_, kSSLProtocol3, true); + (void)SSLSetProtocolVersionEnabled(sslCtx_, kTLSProtocol1, true); + (void)SSLSetProtocolVersionEnabled(sslCtx_, kTLSProtocol11_h, true); + (void)SSLSetProtocolVersionEnabled(sslCtx_, kTLSProtocol12_h, true); +#endif + (void)SSLSetEnableCertVerify(sslCtx_, ctx->getVerifyPeer()); +} + +AppleTLSSession::~AppleTLSSession() +{ + closeConnection(); + if (sslCtx_) { + SSLDisposeContext(sslCtx_); + sslCtx_ = 0; + } + state_ = st_error; +} + +int AppleTLSSession::init(sock_t sockfd) +{ + if (state_ != st_constructed) { + lastError_ = noErr; + return TLS_ERR_ERROR; + } + lastError_ = SSLSetIOFuncs(sslCtx_, SocketRead, SocketWrite); + if (lastError_ != noErr) { + state_ = st_error; + return TLS_ERR_ERROR; + } + lastError_ = SSLSetConnection(sslCtx_, this); + if (lastError_ != noErr) { + state_ = st_error; + return TLS_ERR_ERROR; + } + sockfd_ = sockfd; + state_ = st_initialized; + return TLS_ERR_OK; +} + +int AppleTLSSession::setSNIHostname(const std::string& hostname) +{ + if (state_ != st_initialized) { + lastError_ = noErr; + return TLS_ERR_ERROR; + } + lastError_ = SSLSetPeerDomainName(sslCtx_, hostname.c_str(), hostname.length()); + return (lastError_ != noErr) ? TLS_ERR_ERROR : TLS_ERR_OK; +} + +int AppleTLSSession::closeConnection() +{ + if (state_ != st_connected) { + lastError_ = noErr; + return TLS_ERR_ERROR; + } + lastError_ = SSLClose(sslCtx_); + state_ = st_closed; + return lastError_ == noErr ? TLS_ERR_OK : TLS_ERR_ERROR; +} + +int AppleTLSSession::checkDirection() { + if (writeBuffered_) { + return TLS_WANT_WRITE; + } + if (state_ == st_connected) { + size_t buffered; + lastError_ = SSLGetBufferedReadSize(sslCtx_, &buffered); + if (lastError_ == noErr && buffered) { + return TLS_WANT_READ; + } + } + return 0; +} + +ssize_t AppleTLSSession::writeData(const void* data, size_t len) +{ + if (state_ != st_connected) { + lastError_ = noErr; + return TLS_ERR_ERROR; + } + size_t processed = 0; + if (writeBuffered_) { + lastError_ = SSLWrite(sslCtx_, 0, 0, &processed); + switch (lastError_) { + case noErr: + processed = writeBuffered_; + writeBuffered_ = 0; + return processed; + case errSSLWouldBlock: + return TLS_ERR_WOULDBLOCK; + case errSSLClosedGraceful: + case errSSLClosedNoNotify: + closeConnection(); + return TLS_ERR_ERROR; + default: + closeConnection(); + state_ = st_error; + return TLS_ERR_ERROR; + } + } + + lastError_ = SSLWrite(sslCtx_, data, len, &processed); + switch (lastError_) { + case noErr: + return processed; + case errSSLWouldBlock: + writeBuffered_ = len; + return TLS_ERR_WOULDBLOCK; + case errSSLClosedGraceful: + case errSSLClosedNoNotify: + closeConnection(); + return TLS_ERR_ERROR; + default: + closeConnection(); + state_ = st_error; + return TLS_ERR_ERROR; + } +} +OSStatus AppleTLSSession::sockWrite(const void* data, size_t* len) +{ + size_t remain = *len; + const uint8_t *buffer = static_cast(data); + *len = 0; + while (remain) { + ssize_t w = write(sockfd_, buffer, remain); + if (w <= 0) { + switch (errno) { + case EAGAIN: + return errSSLWouldBlock; + default: + return errSSLClosedAbort; + } + } + remain -= w; + buffer += w; + *len += w; + } + return noErr; +} +ssize_t AppleTLSSession::readData(void* data, size_t len) +{ + if (state_ != st_connected) { + lastError_ = noErr; + return TLS_ERR_ERROR; + } + size_t processed = 0; + lastError_ = SSLRead(sslCtx_, data, len, &processed); + switch (lastError_) { + case noErr: + return processed; + case errSSLWouldBlock: + if (processed) { + return processed; + } + return TLS_ERR_WOULDBLOCK; + case errSSLClosedGraceful: + case errSSLClosedNoNotify: + closeConnection(); + return TLS_ERR_ERROR; + default: + closeConnection(); + state_ = st_error; + return TLS_ERR_ERROR; + } +} + +OSStatus AppleTLSSession::sockRead(void* data, size_t* len) +{ + size_t remain = *len; + uint8_t *buffer = static_cast(data); + *len = 0; + while (remain) { + ssize_t r = read(sockfd_, buffer, remain); + if (r == 0) { + return errSSLClosedGraceful; + } + if (r < 0) { + switch (errno) { + case ENOENT: + return errSSLClosedGraceful; + case ECONNRESET: + return errSSLClosedAbort; + case EAGAIN: + return errSSLWouldBlock; + default: + return errSSLClosedAbort; + } + } + remain -= r; + buffer += r; + *len += r; + } + return noErr; +} + +int AppleTLSSession::tlsConnect(const std::string& hostname, std::string& handshakeErr) +{ + if (state_ != st_initialized) { + return TLS_ERR_ERROR; + } + if (!hostname.empty()) { + setSNIHostname(hostname); + } + lastError_ = SSLHandshake(sslCtx_); + switch (lastError_) { + case noErr: + state_ = st_connected; + return TLS_ERR_OK; + case errSSLWouldBlock: + return TLS_ERR_WOULDBLOCK; + case errSSLServerAuthCompleted: + return tlsConnect(hostname, handshakeErr); + default: + handshakeErr = getLastErrorString(); + return TLS_ERR_ERROR; + } +} + +int AppleTLSSession::tlsAccept() +{ + std::string hostname, err; + return tlsConnect(hostname, err); +} + +std::string AppleTLSSession::getLastErrorString() +{ + switch (lastError_) { + case errSSLProtocol: + return "Protocol error"; + case errSSLNegotiation: + return "No common cipher suites"; + case errSSLFatalAlert: + return "Received fatal alert"; + case errSSLSessionNotFound: + return "Unknown session"; + case errSSLClosedGraceful: + return "Closed gracefully"; + case errSSLClosedAbort: + return "Connection aborted"; + case errSSLXCertChainInvalid: + return "Invalid certificate chain"; + case errSSLBadCert: + return "Invalid certificate format"; + case errSSLCrypto: + return "Cryptographic error"; + case paramErr: + case errSSLInternal: + return "Internal SSL error"; + case errSSLUnknownRootCert: + return "Self-signed certificate"; + case errSSLNoRootCert: + return "No root certificate"; + case errSSLCertExpired: + return "Certificate expired"; + case errSSLCertNotYetValid: + return "Certificate not yet valid"; + case errSSLClosedNoNotify: + return "Closed without notification"; + case errSSLBufferOverflow: + return "Buffer not large enough"; + case errSSLBadCipherSuite: + return "Bad cipher suite"; + case errSSLPeerUnexpectedMsg: + return "Unexpected peer message"; + case errSSLPeerBadRecordMac: + return "Bad MAC"; + case errSSLPeerDecryptionFail: + return "Decryption failure"; + case errSSLHostNameMismatch: + return "Invalid hostname"; + case errSSLConnectionRefused: + return "Connection refused"; + default: + return fmt("Unspecified error %d", lastError_); + } +} + +} diff --git a/src/AppleTLSSession.h b/src/AppleTLSSession.h new file mode 100644 index 00000000..59a3e917 --- /dev/null +++ b/src/AppleTLSSession.h @@ -0,0 +1,127 @@ +/* */ +#ifndef APPLE_TLS_SESSION_H +#define APPLE_TLS_SESSION_H + +#include "common.h" +#include "TLSSession.h" +#include "AppleTLSContext.h" + +namespace aria2 { + +class AppleTLSSession : public TLSSession { + enum state_t { + st_constructed, + st_initialized, + st_connected, + st_closed, + st_error + }; +public: + AppleTLSSession(AppleTLSContext* ctx); + + // MUST deallocate all resources + virtual ~AppleTLSSession(); + + // Initializes SSL/TLS session. The |sockfd| is the underlying + // tranport socket. This function returns TLS_ERR_OK if it + // succeeds, or TLS_ERR_ERROR. + virtual int init(sock_t sockfd); + + // Sets |hostname| for TLS SNI extension. This is only meaningful for + // client side session. This function returns TLS_ERR_OK if it + // succeeds, or TLS_ERR_ERROR. + virtual int setSNIHostname(const std::string& hostname); + + // Closes the SSL/TLS session. Don't close underlying transport + // socket. This function returns TLS_ERR_OK if it succeeds, or + // TLS_ERR_ERROR. + virtual int closeConnection(); + + // Returns TLS_WANT_READ if SSL/TLS session needs more data from + // remote endpoint to proceed, or TLS_WANT_WRITE if SSL/TLS session + // needs to write more data to proceed. If SSL/TLS session needs + // neither read nor write data at the moment, return value is + // undefined. + virtual int checkDirection(); + + // Sends |data| with length |len|. This function returns the number + // of bytes sent if it succeeds, or TLS_ERR_WOULDBLOCK if the + // underlying tranport blocks, or TLS_ERR_ERROR. + virtual ssize_t writeData(const void* data, size_t len); + + // Receives data into |data| with length |len|. This function returns + // the number of bytes received if it succeeds, or TLS_ERR_WOULDBLOCK + // if the underlying tranport blocks, or TLS_ERR_ERROR. + virtual ssize_t readData(void* data, size_t len); + + // Performs client side handshake. The |hostname| is the hostname of + // the remote endpoint and is used to verify its certificate. This + // function returns TLS_ERR_OK if it succeeds, or TLS_ERR_WOULDBLOCK + // if the underlying transport blocks, or TLS_ERR_ERROR. + // When returning TLS_ERR_ERROR, provide certificate validation error + // in |handshakeErr|. + virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr); + + // Performs server side handshake. This function returns TLS_ERR_OK + // if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport + // blocks, or TLS_ERR_ERROR. + virtual int tlsAccept(); + + // Returns last error string + virtual std::string getLastErrorString(); + +private: + static OSStatus SocketWrite(SSLConnectionRef conn, const void* data, size_t* len) { + return ((AppleTLSSession*)conn)->sockWrite(data, len); + } + static OSStatus SocketRead(SSLConnectionRef conn, void* data, size_t* len) { + return ((AppleTLSSession*)conn)->sockRead(data, len); + } + + AppleTLSContext *ctx_; + SSLContextRef sslCtx_; + sock_t sockfd_; + state_t state_; + OSStatus lastError_; + size_t writeBuffered_; + + OSStatus sockWrite(const void* data, size_t* len); + OSStatus sockRead(void* data, size_t* len); +}; + +} + +#endif // TLS_SESSION_H diff --git a/src/LibgnutlsTLSContext.cc b/src/LibgnutlsTLSContext.cc index fa57ce73..507c7222 100644 --- a/src/LibgnutlsTLSContext.cc +++ b/src/LibgnutlsTLSContext.cc @@ -45,10 +45,15 @@ namespace aria2 { -TLSContext::TLSContext(TLSSessionSide side) +TLSContext* TLSContext::make(TLSSessionSide side) +{ + return new GnuTLSContext(side); +} + +GnuTLSContext::GnuTLSContext(TLSSessionSide side) : certCred_(0), side_(side), - peerVerificationEnabled_(false) + verifyPeer_(true) { int r = gnutls_certificate_allocate_credentials(&certCred_); if(r == GNUTLS_E_SUCCESS) { @@ -63,24 +68,19 @@ TLSContext::TLSContext(TLSSessionSide side) } } -TLSContext::~TLSContext() +GnuTLSContext::~GnuTLSContext() { if(certCred_) { gnutls_certificate_free_credentials(certCred_); } } -bool TLSContext::good() const +bool GnuTLSContext::good() const { return good_; } -bool TLSContext::bad() const -{ - return !good_; -} - -bool TLSContext::addCredentialFile(const std::string& certfile, +bool GnuTLSContext::addCredentialFile(const std::string& certfile, const std::string& keyfile) { int ret = gnutls_certificate_set_x509_key_file(certCred_, @@ -101,7 +101,7 @@ bool TLSContext::addCredentialFile(const std::string& certfile, } } -bool TLSContext::addSystemTrustedCACerts() +bool GnuTLSContext::addSystemTrustedCACerts() { #ifdef HAVE_GNUTLS_CERTIFICATE_SET_X509_SYSTEM_TRUST int ret = gnutls_certificate_set_x509_system_trust(certCred_); @@ -114,11 +114,12 @@ bool TLSContext::addSystemTrustedCACerts() return true; } #else + A2_LOG_WARN("System certificates not supported"); return false; #endif } -bool TLSContext::addTrustedCACertFile(const std::string& certfile) +bool GnuTLSContext::addTrustedCACertFile(const std::string& certfile) { int ret = gnutls_certificate_set_x509_trust_file(certCred_, certfile.c_str(), @@ -133,24 +134,9 @@ bool TLSContext::addTrustedCACertFile(const std::string& certfile) } } -gnutls_certificate_credentials_t TLSContext::getCertCred() const +gnutls_certificate_credentials_t GnuTLSContext::getCertCred() const { return certCred_; } -void TLSContext::enablePeerVerification() -{ - peerVerificationEnabled_ = true; -} - -void TLSContext::disablePeerVerification() -{ - peerVerificationEnabled_ = false; -} - -bool TLSContext::peerVerificationEnabled() const -{ - return peerVerificationEnabled_; -} - } // namespace aria2 diff --git a/src/LibgnutlsTLSContext.h b/src/LibgnutlsTLSContext.h index 4fc49fbc..8d744bd6 100644 --- a/src/LibgnutlsTLSContext.h +++ b/src/LibgnutlsTLSContext.h @@ -37,8 +37,6 @@ #include "common.h" -#include - #include #include "TLSContext.h" @@ -46,45 +44,41 @@ namespace aria2 { -class TLSContext { -private: - gnutls_certificate_credentials_t certCred_; - - TLSSessionSide side_; - - bool good_; - - bool peerVerificationEnabled_; +class GnuTLSContext : public TLSContext { public: - TLSContext(TLSSessionSide side); + GnuTLSContext(TLSSessionSide side); - ~TLSContext(); + virtual ~GnuTLSContext(); // private key `keyfile' must be decrypted. - bool addCredentialFile(const std::string& certfile, - const std::string& keyfile); + virtual bool addCredentialFile(const std::string& certfile, + const std::string& keyfile); - bool addSystemTrustedCACerts(); + virtual bool addSystemTrustedCACerts(); // certfile can contain multiple certificates. - bool addTrustedCACertFile(const std::string& certfile); + virtual bool addTrustedCACertFile(const std::string& certfile); - bool good() const; + virtual bool good() const; - bool bad() const; - - gnutls_certificate_credentials_t getCertCred() const; - - TLSSessionSide getSide() const - { + virtual TLSSessionSide getSide() const { return side_; } - void enablePeerVerification(); + virtual bool getVerifyPeer() const { + return verifyPeer_; + } + virtual void setVerifyPeer(bool verify) { + verifyPeer_ = verify; + } - void disablePeerVerification(); + gnutls_certificate_credentials_t getCertCred() const; - bool peerVerificationEnabled() const; +private: + gnutls_certificate_credentials_t certCred_; + TLSSessionSide side_; + bool good_; + bool verifyPeer_; }; } // namespace aria2 diff --git a/src/LibgnutlsTLSSession.cc b/src/LibgnutlsTLSSession.cc index c1d449f3..d8721aeb 100644 --- a/src/LibgnutlsTLSSession.cc +++ b/src/LibgnutlsTLSSession.cc @@ -42,20 +42,25 @@ namespace aria2 { -TLSSession::TLSSession(TLSContext* tlsContext) +TLSSession* TLSSession::make(TLSContext* ctx) +{ + return new GnuTLSSession(static_cast(ctx)); +} + +GnuTLSSession::GnuTLSSession(GnuTLSContext* tlsContext) : sslSession_(0), tlsContext_(tlsContext), rv_(0) {} -TLSSession::~TLSSession() +GnuTLSSession::~GnuTLSSession() { if(sslSession_) { gnutls_deinit(sslSession_); } } -int TLSSession::init(sock_t sockfd) +int GnuTLSSession::init(sock_t sockfd) { rv_ = gnutls_init(&sslSession_, tlsContext_->getSide() == TLS_CLIENT ? @@ -89,7 +94,7 @@ int TLSSession::init(sock_t sockfd) return TLS_ERR_OK; } -int TLSSession::setSNIHostname(const std::string& hostname) +int GnuTLSSession::setSNIHostname(const std::string& hostname) { // TLS extensions: SNI rv_ = gnutls_server_name_set(sslSession_, GNUTLS_NAME_DNS, @@ -100,7 +105,7 @@ int TLSSession::setSNIHostname(const std::string& hostname) return TLS_ERR_OK; } -int TLSSession::closeConnection() +int GnuTLSSession::closeConnection() { rv_ = gnutls_bye(sslSession_, GNUTLS_SHUT_WR); if(rv_ == GNUTLS_E_SUCCESS) { @@ -112,13 +117,13 @@ int TLSSession::closeConnection() } } -int TLSSession::checkDirection() +int GnuTLSSession::checkDirection() { int direction = gnutls_record_get_direction(sslSession_); return direction == 0 ? TLS_WANT_READ : TLS_WANT_WRITE; } -ssize_t TLSSession::writeData(const void* data, size_t len) +ssize_t GnuTLSSession::writeData(const void* data, size_t len) { while((rv_ = gnutls_record_send(sslSession_, data, len)) == GNUTLS_E_INTERRUPTED); @@ -133,7 +138,7 @@ ssize_t TLSSession::writeData(const void* data, size_t len) } } -ssize_t TLSSession::readData(void* data, size_t len) +ssize_t GnuTLSSession::readData(void* data, size_t len) { while((rv_ = gnutls_record_recv(sslSession_, data, len)) == GNUTLS_E_INTERRUPTED); @@ -148,7 +153,7 @@ ssize_t TLSSession::readData(void* data, size_t len) } } -int TLSSession::tlsConnect(const std::string& hostname, +int GnuTLSSession::tlsConnect(const std::string& hostname, std::string& handshakeErr) { handshakeErr = ""; @@ -160,7 +165,7 @@ int TLSSession::tlsConnect(const std::string& hostname, return TLS_ERR_ERROR; } } - if(tlsContext_->peerVerificationEnabled()) { + if(tlsContext_->getVerifyPeer()) { // verify peer unsigned int status; rv_ = gnutls_certificate_verify_peers2(sslSession_, &status); @@ -246,7 +251,7 @@ int TLSSession::tlsConnect(const std::string& hostname, return TLS_ERR_OK; } -int TLSSession::tlsAccept() +int GnuTLSSession::tlsAccept() { rv_ = gnutls_handshake(sslSession_); if(rv_ == GNUTLS_E_SUCCESS) { @@ -258,7 +263,7 @@ int TLSSession::tlsAccept() } } -std::string TLSSession::getLastErrorString() +std::string GnuTLSSession::getLastErrorString() { return gnutls_strerror(rv_); } diff --git a/src/LibgnutlsTLSSession.h b/src/LibgnutlsTLSSession.h index 7118ab21..48b7f121 100644 --- a/src/LibgnutlsTLSSession.h +++ b/src/LibgnutlsTLSSession.h @@ -39,31 +39,28 @@ #include -#include - -#include "TLSSessionConst.h" +#include "LibgnutlsTLSContext.h" +#include "TLSSession.h" #include "a2netcompat.h" namespace aria2 { -class TLSContext; - -class TLSSession { +class GnuTLSSession : public TLSSession { public: - TLSSession(TLSContext* tlsContext); - ~TLSSession(); - int init(sock_t sockfd); - int setSNIHostname(const std::string& hostname); - int closeConnection(); - int checkDirection(); - ssize_t writeData(const void* data, size_t len); - ssize_t readData(void* data, size_t len); - int tlsConnect(const std::string& hostname, std::string& handshakeErr); - int tlsAccept(); - std::string getLastErrorString(); + GnuTLSSession(GnuTLSContext* tlsContext); + ~GnuTLSSession(); + virtual int init(sock_t sockfd); + virtual int setSNIHostname(const std::string& hostname); + virtual int closeConnection(); + virtual int checkDirection(); + virtual ssize_t writeData(const void* data, size_t len); + virtual ssize_t readData(void* data, size_t len); + virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr); + virtual int tlsAccept(); + virtual std::string getLastErrorString(); private: gnutls_session_t sslSession_; - TLSContext* tlsContext_; + GnuTLSContext* tlsContext_; // Last error code from gnutls library functions int rv_; }; diff --git a/src/LibsslTLSContext.cc b/src/LibsslTLSContext.cc index 8a2b1757..c0507227 100644 --- a/src/LibsslTLSContext.cc +++ b/src/LibsslTLSContext.cc @@ -43,10 +43,15 @@ namespace aria2 { -TLSContext::TLSContext(TLSSessionSide side) +TLSContext* TLSContext::make(TLSSessionSide side) +{ + return new OpenSSLTLSContext(side); +} + +OpenSSLTLSContext::OpenSSLTLSContext(TLSSessionSide side) : sslCtx_(0), side_(side), - peerVerificationEnabled_(false) + verifyPeer_(true) { sslCtx_ = SSL_CTX_new(SSLv23_method()); if(sslCtx_) { @@ -70,22 +75,17 @@ TLSContext::TLSContext(TLSSessionSide side) #endif } -TLSContext::~TLSContext() +OpenSSLTLSContext::~OpenSSLTLSContext() { SSL_CTX_free(sslCtx_); } -bool TLSContext::good() const +bool OpenSSLTLSContext::good() const { return good_; } -bool TLSContext::bad() const -{ - return !good_; -} - -bool TLSContext::addCredentialFile(const std::string& certfile, +bool OpenSSLTLSContext::addCredentialFile(const std::string& certfile, const std::string& keyfile) { if(SSL_CTX_use_PrivateKey_file(sslCtx_, keyfile.c_str(), @@ -107,7 +107,7 @@ bool TLSContext::addCredentialFile(const std::string& certfile, return true; } -bool TLSContext::addSystemTrustedCACerts() +bool OpenSSLTLSContext::addSystemTrustedCACerts() { if(SSL_CTX_set_default_verify_paths(sslCtx_) != 1) { A2_LOG_INFO(fmt(MSG_LOADING_SYSTEM_TRUSTED_CA_CERTS_FAILED, @@ -119,7 +119,7 @@ bool TLSContext::addSystemTrustedCACerts() } } -bool TLSContext::addTrustedCACertFile(const std::string& certfile) +bool OpenSSLTLSContext::addTrustedCACertFile(const std::string& certfile) { if(SSL_CTX_load_verify_locations(sslCtx_, certfile.c_str(), 0) != 1) { A2_LOG_ERROR(fmt(MSG_LOADING_TRUSTED_CA_CERT_FAILED, @@ -132,14 +132,4 @@ bool TLSContext::addTrustedCACertFile(const std::string& certfile) } } -void TLSContext::enablePeerVerification() -{ - peerVerificationEnabled_ = true; -} - -void TLSContext::disablePeerVerification() -{ - peerVerificationEnabled_ = false; -} - } // namespace aria2 diff --git a/src/LibsslTLSContext.h b/src/LibsslTLSContext.h index 038d8990..00def36a 100644 --- a/src/LibsslTLSContext.h +++ b/src/LibsslTLSContext.h @@ -46,52 +46,43 @@ namespace aria2 { -class TLSContext { -private: - SSL_CTX* sslCtx_; - - TLSSessionSide side_; - - bool good_; - - bool peerVerificationEnabled_; +class OpenSSLTLSContext : public TLSContext { public: - TLSContext(TLSSessionSide side); + OpenSSLTLSContext(TLSSessionSide side); - ~TLSContext(); + ~OpenSSLTLSContext(); // private key `keyfile' must be decrypted. - bool addCredentialFile(const std::string& certfile, - const std::string& keyfile); + virtual bool addCredentialFile(const std::string& certfile, + const std::string& keyfile); - bool addSystemTrustedCACerts(); + virtual bool addSystemTrustedCACerts(); // certfile can contain multiple certificates. - bool addTrustedCACertFile(const std::string& certfile); + virtual bool addTrustedCACertFile(const std::string& certfile); - bool good() const; + virtual bool good() const; - bool bad() const; - - SSL_CTX* getSSLCtx() const - { - return sslCtx_; - } - - TLSSessionSide getSide() const - { + virtual TLSSessionSide getSide() const { return side_; } - void enablePeerVerification(); - - void disablePeerVerification(); - - bool peerVerificationEnabled() const - { - return peerVerificationEnabled_; + virtual bool getVerifyPeer() const { + return verifyPeer_; + } + virtual void setVerifyPeer(bool verify) { + verifyPeer_ = verify; } + SSL_CTX* getSSLCtx() const { + return sslCtx_; + } + +private: + SSL_CTX* sslCtx_; + TLSSessionSide side_; + bool good_; + bool verifyPeer_; }; } // namespace aria2 diff --git a/src/LibsslTLSSession.cc b/src/LibsslTLSSession.cc index 8cbcac5c..65b66b43 100644 --- a/src/LibsslTLSSession.cc +++ b/src/LibsslTLSSession.cc @@ -38,26 +38,31 @@ #include #include -#include "TLSContext.h" +#include "LogFactory.h" #include "util.h" #include "SocketCore.h" namespace aria2 { -TLSSession::TLSSession(TLSContext* tlsContext) +TLSSession* TLSSession::make(TLSContext* ctx) +{ + return new OpenSSLTLSSession(static_cast(ctx)); +} + +OpenSSLTLSSession::OpenSSLTLSSession(OpenSSLTLSContext* tlsContext) : ssl_(0), tlsContext_(tlsContext), rv_(1) {} -TLSSession::~TLSSession() +OpenSSLTLSSession::~OpenSSLTLSSession() { if(ssl_) { SSL_shutdown(ssl_); } } -int TLSSession::init(sock_t sockfd) +int OpenSSLTLSSession::init(sock_t sockfd) { ERR_clear_error(); ssl_ = SSL_new(tlsContext_->getSSLCtx()); @@ -71,7 +76,7 @@ int TLSSession::init(sock_t sockfd) return TLS_ERR_OK; } -int TLSSession::setSNIHostname(const std::string& hostname) +int OpenSSLTLSSession::setSNIHostname(const std::string& hostname) { #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME ERR_clear_error(); @@ -83,7 +88,7 @@ int TLSSession::setSNIHostname(const std::string& hostname) return TLS_ERR_OK; } -int TLSSession::closeConnection() +int OpenSSLTLSSession::closeConnection() { ERR_clear_error(); SSL_shutdown(ssl_); @@ -91,7 +96,7 @@ int TLSSession::closeConnection() return TLS_ERR_OK; } -int TLSSession::checkDirection() +int OpenSSLTLSSession::checkDirection() { int error = SSL_get_error(ssl_, rv_); if(error == SSL_ERROR_WANT_WRITE) { @@ -110,7 +115,7 @@ bool wouldblock(SSL* ssl, int rv) } } // namespace -ssize_t TLSSession::writeData(const void* data, size_t len) +ssize_t OpenSSLTLSSession::writeData(const void* data, size_t len) { ERR_clear_error(); rv_ = SSL_write(ssl_, data, len); @@ -127,7 +132,7 @@ ssize_t TLSSession::writeData(const void* data, size_t len) } } -ssize_t TLSSession::readData(void* data, size_t len) +ssize_t OpenSSLTLSSession::readData(void* data, size_t len) { ERR_clear_error(); rv_ = SSL_read(ssl_, data, len); @@ -144,7 +149,7 @@ ssize_t TLSSession::readData(void* data, size_t len) } } -int TLSSession::handshake() +int OpenSSLTLSSession::handshake() { ERR_clear_error(); if(tlsContext_->getSide() == TLS_CLIENT) { @@ -171,7 +176,7 @@ int TLSSession::handshake() return TLS_ERR_OK; } -int TLSSession::tlsConnect(const std::string& hostname, +int OpenSSLTLSSession::tlsConnect(const std::string& hostname, std::string& handshakeErr) { handshakeErr = ""; @@ -181,7 +186,7 @@ int TLSSession::tlsConnect(const std::string& hostname, return ret; } if(tlsContext_->getSide() == TLS_CLIENT && - tlsContext_->peerVerificationEnabled()) { + tlsContext_->getVerifyPeer()) { // verify peer X509* peerCert = SSL_get_peer_certificate(ssl_); if(!peerCert) { @@ -256,12 +261,12 @@ int TLSSession::tlsConnect(const std::string& hostname, return TLS_ERR_OK; } -int TLSSession::tlsAccept() +int OpenSSLTLSSession::tlsAccept() { return handshake(); } -std::string TLSSession::getLastErrorString() +std::string OpenSSLTLSSession::getLastErrorString() { if(rv_ <= 0) { int sslError = SSL_get_error(ssl_, rv_); diff --git a/src/LibsslTLSSession.h b/src/LibsslTLSSession.h index 6a162031..29412ab8 100644 --- a/src/LibsslTLSSession.h +++ b/src/LibsslTLSSession.h @@ -39,32 +39,29 @@ #include -#include - -#include "TLSSessionConst.h" +#include "LibsslTLSContext.h" +#include "TLSSession.h" #include "a2netcompat.h" namespace aria2 { -class TLSContext; - -class TLSSession { +class OpenSSLTLSSession : public TLSSession { public: - TLSSession(TLSContext* tlsContext); - ~TLSSession(); - int init(sock_t sockfd); - int setSNIHostname(const std::string& hostname); - int closeConnection(); - int checkDirection(); - ssize_t writeData(const void* data, size_t len); - ssize_t readData(void* data, size_t len); - int tlsConnect(const std::string& hostname, std::string& handshakeErr); - int tlsAccept(); - std::string getLastErrorString(); + OpenSSLTLSSession(OpenSSLTLSContext* tlsContext); + virtual ~OpenSSLTLSSession(); + virtual int init(sock_t sockfd); + virtual int setSNIHostname(const std::string& hostname); + virtual int closeConnection(); + virtual int checkDirection(); + virtual ssize_t writeData(const void* data, size_t len); + virtual ssize_t readData(void* data, size_t len); + virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr); + virtual int tlsAccept(); + virtual std::string getLastErrorString(); private: int handshake(); SSL* ssl_; - TLSContext* tlsContext_; + OpenSSLTLSContext* tlsContext_; // Last error code from openSSL library functions int rv_; }; diff --git a/src/Makefile.am b/src/Makefile.am index 22be425b..7a26a078 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -299,38 +299,53 @@ SRCS += EpollEventPoll.cc EpollEventPoll.h endif # HAVE_EPOLL if ENABLE_SSL -SRCS += TLSContext.h\ - TLSSession.h\ - TLSSessionConst.h +SRCS += TLSSession.h TLSSessionConst.h endif # ENABLE_SSL +if USE_APPLE_MD +SRCS += AppleMessageDigestImpl.cc AppleMessageDigestImpl.h +endif + +if HAVE_APPLETLS +SRCS += AppleTLSContext.cc AppleTLSContext.h \ + AppleTLSSession.cc AppleTLSSession.h +endif + if HAVE_LIBGNUTLS -SRCS += LibgnutlsTLSContext.cc LibgnutlsTLSContext.h\ - LibgnutlsTLSSession.cc LibgnutlsTLSSession.h +SRCS += LibgnutlsTLSContext.cc LibgnutlsTLSContext.h \ + LibgnutlsTLSSession.cc LibgnutlsTLSSession.h endif # HAVE_LIBGNUTLS if HAVE_LIBGCRYPT -SRCS += LibgcryptMessageDigestImpl.cc LibgcryptMessageDigestImpl.h\ - LibgcryptARC4Encryptor.cc LibgcryptARC4Encryptor.h\ - LibgcryptDHKeyExchange.cc LibgcryptDHKeyExchange.h +SRCS += LibgcryptARC4Encryptor.cc LibgcryptARC4Encryptor.h \ + LibgcryptDHKeyExchange.cc LibgcryptDHKeyExchange.h +if USE_LIBGCRYPT_MD +SRCS += LibgcryptMessageDigestImpl.cc LibgcryptMessageDigestImpl.h +endif endif # HAVE_LIBGCRYPT if HAVE_LIBNETTLE -SRCS += LibnettleMessageDigestImpl.cc LibnettleMessageDigestImpl.h\ - LibnettleARC4Encryptor.cc LibnettleARC4Encryptor.h +SRCS += LibnettleARC4Encryptor.cc LibnettleARC4Encryptor.h +if USE_LIBNETTLE_MD +SRCS += LibnettleMessageDigestImpl.cc LibnettleMessageDigestImpl.h +endif endif # HAVE_LIBNETTLE if HAVE_LIBGMP -SRCS += a2gmp.cc a2gmp.h\ - LibgmpDHKeyExchange.cc LibgmpDHKeyExchange.h +SRCS += a2gmp.cc a2gmp.h \ + LibgmpDHKeyExchange.cc LibgmpDHKeyExchange.h endif # HAVE_LIBGMP if HAVE_OPENSSL -SRCS += LibsslTLSContext.cc LibsslTLSContext.h\ - LibsslTLSSession.cc LibsslTLSSession.h\ - LibsslMessageDigestImpl.cc LibsslMessageDigestImpl.h\ - LibsslARC4Encryptor.cc LibsslARC4Encryptor.h\ - LibsslDHKeyExchange.cc LibsslDHKeyExchange.h +SRCS += LibsslARC4Encryptor.cc LibsslARC4Encryptor.h \ + LibsslDHKeyExchange.cc LibsslDHKeyExchange.h +if !HAVE_APPLETLS +SRCS += LibsslTLSContext.cc LibsslTLSContext.h \ + LibsslTLSSession.cc LibsslTLSSession.h +endif +if USE_OPENSSL_MD +SRCS += LibsslMessageDigestImpl.cc LibsslMessageDigestImpl.h +endif endif # HAVE_OPENSSL if HAVE_ZLIB diff --git a/src/MessageDigestImpl.h b/src/MessageDigestImpl.h index 0da4e41f..e7de69bb 100644 --- a/src/MessageDigestImpl.h +++ b/src/MessageDigestImpl.h @@ -35,12 +35,15 @@ #ifndef D_MESSAGE_DIGEST_IMPL_H #define D_MESSAGE_DIGEST_IMPL_H -#ifdef HAVE_LIBNETTLE + +#ifdef USE_APPLE_MD +# include "AppleMessageDigestImpl.h" +#elif defined(USE_LIBNETTLE_MD) # include "LibnettleMessageDigestImpl.h" -#elif HAVE_LIBGCRYPT +#elif defined(USE_LIBGCRYPT_MD) # include "LibgcryptMessageDigestImpl.h" -#elif HAVE_OPENSSL +#elif defined(USE_OPENSSL_MD) # include "LibsslMessageDigestImpl.h" -#endif // HAVE_OPENSSL +#endif #endif // D_MESSAGE_DIGEST_IMPL_H diff --git a/src/MultiUrlRequestInfo.cc b/src/MultiUrlRequestInfo.cc index 389e40ca..a8bb2342 100644 --- a/src/MultiUrlRequestInfo.cc +++ b/src/MultiUrlRequestInfo.cc @@ -145,7 +145,7 @@ error_code::Value MultiUrlRequestInfo::execute() !option_->blank(PREF_RPC_PRIVATE_KEY)) { // We set server TLS context to the SocketCore before creating // DownloadEngine instance. - SharedHandle svTlsContext(new TLSContext(TLS_SERVER)); + SharedHandle svTlsContext(TLSContext::make(TLS_SERVER)); svTlsContext->addCredentialFile(option_->get(PREF_RPC_CERTIFICATE), option_->get(PREF_RPC_PRIVATE_KEY)); SocketCore::setServerTLSContext(svTlsContext); @@ -194,7 +194,7 @@ error_code::Value MultiUrlRequestInfo::execute() e->setAuthConfigFactory(authConfigFactory); #ifdef ENABLE_SSL - SharedHandle clTlsContext(new TLSContext(TLS_CLIENT)); + SharedHandle clTlsContext(TLSContext::make(TLS_CLIENT)); if(!option_->blank(PREF_CERTIFICATE) && !option_->blank(PREF_PRIVATE_KEY)) { clTlsContext->addCredentialFile(option_->get(PREF_CERTIFICATE), @@ -211,9 +211,7 @@ error_code::Value MultiUrlRequestInfo::execute() A2_LOG_INFO(MSG_WARN_NO_CA_CERT); } } - if(option_->getAsBool(PREF_CHECK_CERTIFICATE)) { - clTlsContext->enablePeerVerification(); - } + clTlsContext->setVerifyPeer(option_->getAsBool(PREF_CHECK_CERTIFICATE)); SocketCore::setClientTLSContext(clTlsContext); #endif #ifdef HAVE_ARES_ADDR_NODE diff --git a/src/SocketCore.cc b/src/SocketCore.cc index 165d3bdd..b7ed9d85 100644 --- a/src/SocketCore.cc +++ b/src/SocketCore.cc @@ -819,7 +819,7 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname) wantWrite_ = false; switch(secure_) { case A2_TLS_NONE: - tlsSession_.reset(new TLSSession(tlsctx)); + tlsSession_.reset(TLSSession::make(tlsctx)); rv = tlsSession_->init(sockfd_); if(rv != TLS_ERR_OK) { std::string error = tlsSession_->getLastErrorString(); diff --git a/src/TLSContext.h b/src/TLSContext.h index 6ac1e727..64223ada 100644 --- a/src/TLSContext.h +++ b/src/TLSContext.h @@ -35,6 +35,8 @@ #ifndef D_TLS_CONTEXT_H #define D_TLS_CONTEXT_H +#include + #include "common.h" namespace aria2 { @@ -44,12 +46,27 @@ enum TLSSessionSide { TLS_SERVER }; +class TLSContext { +public: + static TLSContext* make(TLSSessionSide side); + virtual ~TLSContext() {} + + // private key `keyfile' must be decrypted. + virtual bool addCredentialFile(const std::string& certfile, + const std::string& keyfile) = 0; + + virtual bool addSystemTrustedCACerts() = 0; + + // certfile can contain multiple certificates. + virtual bool addTrustedCACertFile(const std::string& certfile) = 0; + + virtual bool good() const = 0; + + virtual TLSSessionSide getSide() const = 0; + virtual bool getVerifyPeer() const = 0; + virtual void setVerifyPeer(bool) = 0; +}; + } // namespace aria2 -#ifdef HAVE_OPENSSL -# include "LibsslTLSContext.h" -#elif HAVE_LIBGNUTLS -# include "LibgnutlsTLSContext.h" -#endif // HAVE_LIBGNUTLS - #endif // D_TLS_CONTEXT_H diff --git a/src/TLSSession.h b/src/TLSSession.h index 77c7c833..ec06a142 100644 --- a/src/TLSSession.h +++ b/src/TLSSession.h @@ -36,69 +36,86 @@ #define TLS_SESSION_H #include "common.h" +#include "a2netcompat.h" +#include "TLSContext.h" + +namespace aria2 { + +enum TLSDirection { + TLS_WANT_READ = 1, + TLS_WANT_WRITE +}; + +enum TLSErrorCode { + TLS_ERR_OK = 0, + TLS_ERR_ERROR = -1, + TLS_ERR_WOULDBLOCK = -2 +}; // To create another SSL/TLS backend, implement TLSSession class below. // -// class TLSSession { -// public: -// TLSSession(TLSContext* tlsContext); -// -// // MUST deallocate all resources -// ~TLSSession(); -// -// // Initializes SSL/TLS session. The |sockfd| is the underlying -// // tranport socket. This function returns TLS_ERR_OK if it -// // succeeds, or TLS_ERR_ERROR. -// int init(sock_t sockfd); -// -// // Sets |hostname| for TLS SNI extension. This is only meaningful for -// // client side session. This function returns TLS_ERR_OK if it -// // succeeds, or TLS_ERR_ERROR. -// int setSNIHostname(const std::string& hostname); -// -// // Closes the SSL/TLS session. Don't close underlying transport -// // socket. This function returns TLS_ERR_OK if it succeeds, or -// // TLS_ERR_ERROR. -// int closeConnection(); -// -// // Returns TLS_WANT_READ if SSL/TLS session needs more data from -// // remote endpoint to proceed, or TLS_WANT_WRITE if SSL/TLS session -// // needs to write more data to proceed. If SSL/TLS session needs -// // neither read nor write data at the moment, return value is -// // undefined. -// int checkDirection(); -// -// // Sends |data| with length |len|. This function returns the number -// // of bytes sent if it succeeds, or TLS_ERR_WOULDBLOCK if the -// // underlying tranport blocks, or TLS_ERR_ERROR. -// ssize_t writeData(const void* data, size_t len); -// -// // Receives data into |data| with length |len|. This function returns -// // the number of bytes received if it succeeds, or TLS_ERR_WOULDBLOCK -// // if the underlying tranport blocks, or TLS_ERR_ERROR. -// ssize_t readData(void* data, size_t len); -// -// // Performs client side handshake. The |hostname| is the hostname of -// // the remote endpoint and is used to verify its certificate. This -// // function returns TLS_ERR_OK if it succeeds, or TLS_ERR_WOULDBLOCK -// // if the underlying transport blocks, or TLS_ERR_ERROR. -// // When returning TLS_ERR_ERROR, provide certificate validation error -// // in |handshakeErr|. -// int tlsConnect(const std::string& hostname, std::string& handshakeErr); -// -// // Performs server side handshake. This function returns TLS_ERR_OK -// // if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport -// // blocks, or TLS_ERR_ERROR. -// int tlsAccept(); -// -// // Returns last error string -// std::string getLastErrorString(); -// }; +class TLSSession { +public: + static TLSSession* make(TLSContext* ctx); -#ifdef HAVE_OPENSSL -# include "LibsslTLSSession.h" -#elif defined HAVE_LIBGNUTLS -# include "LibgnutlsTLSSession.h" -#endif + // MUST deallocate all resources + virtual ~TLSSession() {} + + // Initializes SSL/TLS session. The |sockfd| is the underlying + // tranport socket. This function returns TLS_ERR_OK if it + // succeeds, or TLS_ERR_ERROR. + virtual int init(sock_t sockfd) = 0; + + // Sets |hostname| for TLS SNI extension. This is only meaningful for + // client side session. This function returns TLS_ERR_OK if it + // succeeds, or TLS_ERR_ERROR. + virtual int setSNIHostname(const std::string& hostname) = 0; + + // Closes the SSL/TLS session. Don't close underlying transport + // socket. This function returns TLS_ERR_OK if it succeeds, or + // TLS_ERR_ERROR. + virtual int closeConnection() = 0; + + // Returns TLS_WANT_READ if SSL/TLS session needs more data from + // remote endpoint to proceed, or TLS_WANT_WRITE if SSL/TLS session + // needs to write more data to proceed. If SSL/TLS session needs + // neither read nor write data at the moment, return value is + // undefined. + virtual int checkDirection() = 0; + + // Sends |data| with length |len|. This function returns the number + // of bytes sent if it succeeds, or TLS_ERR_WOULDBLOCK if the + // underlying tranport blocks, or TLS_ERR_ERROR. + virtual ssize_t writeData(const void* data, size_t len) = 0; + + // Receives data into |data| with length |len|. This function returns + // the number of bytes received if it succeeds, or TLS_ERR_WOULDBLOCK + // if the underlying tranport blocks, or TLS_ERR_ERROR. + virtual ssize_t readData(void* data, size_t len) = 0; + + // Performs client side handshake. The |hostname| is the hostname of + // the remote endpoint and is used to verify its certificate. This + // function returns TLS_ERR_OK if it succeeds, or TLS_ERR_WOULDBLOCK + // if the underlying transport blocks, or TLS_ERR_ERROR. + // When returning TLS_ERR_ERROR, provide certificate validation error + // in |handshakeErr|. + virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr) = 0; + + // Performs server side handshake. This function returns TLS_ERR_OK + // if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport + // blocks, or TLS_ERR_ERROR. + virtual int tlsAccept() = 0; + + // Returns last error string + virtual std::string getLastErrorString() = 0; + +protected: + TLSSession() {} +private: + TLSSession(const TLSSession&); + TLSSession& operator=(const TLSSession&); +}; + +} #endif // TLS_SESSION_H