AppleTLS: Implement AppleTLS and Apple Message Digest

pull/61/head
Nils Maier 2013-04-05 05:55:57 +02:00
parent b292ae1305
commit 0bcbd947b4
21 changed files with 1156 additions and 287 deletions

View File

@ -24,6 +24,7 @@ esac
AC_DEFINE_UNQUOTED([TARGET], ["$target"], [Define target-type])
# Checks for arguments.
ARIA2_ARG_WITHOUT([appletls])
ARIA2_ARG_WITHOUT([gnutls])
ARIA2_ARG_WITHOUT([libnettle])
ARIA2_ARG_WITHOUT([libgmp])
@ -145,7 +146,28 @@ if test "x$with_sqlite3" = "xyes"; then
fi
fi
if test "x$with_gnutls" = "xyes"; then
case "$host" in
*darwin*)
have_osx="yes"
;;
esac
if test "x$with_appletls" = "xyes"; then
AC_MSG_CHECKING([whether to enable Mac OS X native SSL/TLS])
if test "x$have_osx" = "xyes"; then
AC_DEFINE([HAVE_APPLETLS], [1], [Define to 1 if you have Apple TLS])
LDFLAGS="$LDFLAGS -framework CoreFoundation -framework Security"
have_appletls="yes"
AC_MSG_RESULT(yes)
else
AC_MSG_RESULT(no)
if test "x$with_appletls_requested" = "xyes"; then
ARIA2_DEP_NOT_MET([appletls])
fi
fi
fi
if test "x$with_gnutls" = "xyes" && test "x$have_appletls" != "xyes"; then
# gnutls >= 2.8 doesn't have libgnutls-config anymore. We require
# 2.2.0 because we use gnutls_priority_set_direct()
PKG_CHECK_MODULES([LIBGNUTLS], [gnutls >= 2.2.0],
@ -163,7 +185,7 @@ if test "x$with_gnutls" = "xyes"; then
fi
fi
if test "x$with_openssl" = "xyes" && test "x$have_libgnutls" != "xyes"; then
if test "x$with_openssl" = "xyes" && test "x$have_appletls" != "xyes" && test "x$have_libgnutls" != "xyes"; then
PKG_CHECK_MODULES([OPENSSL], [openssl >= 0.9.8],
[have_openssl=yes], [have_openssl=no])
if test "x$have_openssl" = "xyes"; then
@ -235,8 +257,30 @@ if test "x$with_libcares" = "xyes"; then
fi
fi
use_md=""
if test "x$have_osx" == "xyes"; then
use_md="apple"
AC_DEFINE([USE_APPLE_MD], [1], [What message digest implementation to use])
else
if test "x$have_libnettle" = "xyes"; then
AC_DEFINE([USE_LIBNETTLE_MD], [1], [What message digest implementation to use])
use_md="libnettle"
else
if test "x$have_libgcrypt" = "xyes"; then
AC_DEFINE([USE_LIBGCRYPT_MD], [1], [What message digest implementation to use])
use_md="libgcrypt"
else
if test = "x$have_openssl" = "xyes"; then
AC_DEFINE([USE_OPENSSL_MD], [1], [What message digest implementation to use])
use_md="openssl"
fi
fi
fi
fi
# Define variables based on the result of the checks for libraries.
if test "x$have_libgnutls" = "xyes" || test "x$have_openssl" = "xyes"; then
if test "x$have_appletls" = "xyes" || test "x$have_libgnutls" = "xyes" || test "x$have_openssl" = "xyes"; then
have_ssl="yes"
AC_DEFINE([ENABLE_SSL], [1], [Define to 1 if ssl support is enabled.])
AM_CONDITIONAL([ENABLE_SSL], true)
AC_SUBST([ca_bundle])
@ -244,14 +288,20 @@ else
AM_CONDITIONAL([ENABLE_SSL], false)
fi
AM_CONDITIONAL([HAVE_OSX], [ test "x$have_osx" = "xyes" ])
AM_CONDITIONAL([HAVE_APPLETLS], [ test "x$have_appletls" = "xyes" ])
AM_CONDITIONAL([USE_APPLE_MD], [ test "x$use_md" = "xapple" ])
AM_CONDITIONAL([HAVE_LIBGNUTLS], [ test "x$have_libgnutls" = "xyes" ])
AM_CONDITIONAL([HAVE_LIBNETTLE], [ test "x$have_libnettle" = "xyes" ])
AM_CONDITIONAL([USE_LIBNETTLE_MD], [ test "x$use_md" = "xlibnettle"])
AM_CONDITIONAL([HAVE_LIBGMP], [ test "x$have_libgmp" = "xyes" ])
AM_CONDITIONAL([HAVE_LIBGCRYPT], [ test "x$have_libgcrypt" = "xyes" ])
AM_CONDITIONAL([USE_LIBGCRYPT_MD], [ test "x$use_md" = "xlibgcrypt"])
AM_CONDITIONAL([HAVE_OPENSSL], [ test "x$have_openssl" = "xyes" ])
AM_CONDITIONAL([USE_OPENSSL_MD], [ test "x$use_md" = "xopenssl"])
if test "x$have_libnettle" = "xyes" || test "x$have_libgcrypt" = "xyes" ||
test "x$have_openssl" = "xyes"; then
if test "x$use_md" != "x"; then
AC_DEFINE([ENABLE_MESSAGE_DIGEST], [1],
[Define to 1 if message digest support is enabled.])
AM_CONDITIONAL([ENABLE_MESSAGE_DIGEST], true)
@ -325,9 +375,9 @@ AM_CONDITIONAL([HAVE_SQLITE3], [test "x$have_sqlite3" = "xyes"])
AC_SEARCH_LIBS([clock_gettime], [rt])
case "$host" in
*solaris*)
AC_SEARCH_LIBS([getaddrinfo], [nsl socket])
;;
*solaris*)
AC_SEARCH_LIBS([getaddrinfo], [nsl socket])
;;
esac
# Checks for header files.
@ -670,6 +720,8 @@ echo "LDFLAGS: $LDFLAGS"
echo "LIBS: $LIBS"
echo "DEFS: $DEFS"
echo "SQLite3: $have_sqlite3"
echo "SSL Support: $have_ssl"
echo "AppleTLS: $have_appletls"
echo "GnuTLS: $have_libgnutls"
echo "OpenSSL: $have_openssl"
echo "CA Bundle: $ca_bundle"

View File

@ -0,0 +1,153 @@
/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2013 Nils Maier
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#include "AppleMessageDigestImpl.h"
#include <CommonCrypto/CommonDigest.h>
#include "array_fun.h"
#include "HashFuncEntry.h"
namespace aria2 {
template<size_t dlen,
typename ctx_t,
int (*init_fn)(ctx_t*),
int (*update_fn)(ctx_t*, const void*, CC_LONG),
int(*final_fn)(unsigned char*, ctx_t*)>
class MessageDigestBase : public MessageDigestImpl {
public:
MessageDigestBase() { reset(); }
virtual size_t getDigestLength() const {
return dlen;
}
virtual void reset() {
init_fn(&ctx_);
}
virtual void update(const void* data, size_t length) {
while (length) {
CC_LONG l = std::min(length, (size_t)std::numeric_limits<uint32_t>::max());
update_fn(&ctx_, data, l);
length -= l;
}
}
virtual void digest(unsigned char* md) {
final_fn(md, &ctx_);
}
private:
ctx_t ctx_;
};
typedef MessageDigestBase<CC_MD5_DIGEST_LENGTH,
CC_MD5_CTX,
CC_MD5_Init,
CC_MD5_Update,
CC_MD5_Final>
MessageDigestMD5;
typedef MessageDigestBase<CC_SHA1_DIGEST_LENGTH,
CC_SHA1_CTX,
CC_SHA1_Init,
CC_SHA1_Update,
CC_SHA1_Final>
MessageDigestSHA1;
typedef MessageDigestBase<CC_SHA224_DIGEST_LENGTH,
CC_SHA256_CTX,
CC_SHA224_Init,
CC_SHA224_Update,
CC_SHA224_Final>
MessageDigestSHA224;
typedef MessageDigestBase<CC_SHA256_DIGEST_LENGTH,
CC_SHA256_CTX,
CC_SHA256_Init,
CC_SHA256_Update,
CC_SHA256_Final>
MessageDigestSHA256;
typedef MessageDigestBase<CC_SHA384_DIGEST_LENGTH,
CC_SHA512_CTX,
CC_SHA384_Init,
CC_SHA384_Update,
CC_SHA384_Final>
MessageDigestSHA384;
typedef MessageDigestBase<CC_SHA512_DIGEST_LENGTH,
CC_SHA512_CTX,
CC_SHA512_Init,
CC_SHA512_Update,
CC_SHA512_Final>
MessageDigestSHA512;
SharedHandle<MessageDigestImpl> MessageDigestImpl::sha1()
{
return SharedHandle<MessageDigestImpl>(new MessageDigestSHA1());
}
SharedHandle<MessageDigestImpl> MessageDigestImpl::create
(const std::string& hashType)
{
if (hashType == "sha-1") {
return SharedHandle<MessageDigestImpl>(new MessageDigestSHA1());
}
if (hashType == "sha-224") {
return SharedHandle<MessageDigestImpl>(new MessageDigestSHA224());
}
if (hashType == "sha-256") {
return SharedHandle<MessageDigestImpl>(new MessageDigestSHA256());
}
if (hashType == "sha-384") {
return SharedHandle<MessageDigestImpl>(new MessageDigestSHA384());
}
if (hashType == "sha-512") {
return SharedHandle<MessageDigestImpl>(new MessageDigestSHA512());
}
if (hashType == "md5") {
return SharedHandle<MessageDigestImpl>(new MessageDigestMD5());
}
return SharedHandle<MessageDigestImpl>();
}
bool MessageDigestImpl::supports(const std::string& hashType)
{
return hashType == "sha-1" || hashType == "sha-224" || hashType == "sha-256" || hashType == "sha-384" || hashType == "sha-512" || hashType == "md5";
}
size_t MessageDigestImpl::getDigestLength(const std::string& hashType)
{
SharedHandle<MessageDigestImpl> impl = create(hashType);
if (!impl) {
return 0;
}
return impl->getDigestLength();
}
} // namespace aria2

View File

@ -0,0 +1,71 @@
/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2013 Nils Maier
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#ifndef D_APPLE_MESSAGE_DIGEST_IMPL_H
#define D_APPLE_MESSAGE_DIGEST_IMPL_H
#include "common.h"
#include <string>
#include "SharedHandle.h"
namespace aria2 {
class MessageDigestImpl {
public:
static SharedHandle<MessageDigestImpl> sha1();
static SharedHandle<MessageDigestImpl> create(const std::string& hashType);
static bool supports(const std::string& hashType);
static size_t getDigestLength(const std::string& hashType);
public:
virtual size_t getDigestLength() const = 0;
virtual void reset() = 0;
virtual void update(const void* data, size_t length) = 0;
virtual void digest(unsigned char* md) = 0;
protected:
MessageDigestImpl() {}
private:
MessageDigestImpl(const MessageDigestImpl&);
MessageDigestImpl& operator=(const MessageDigestImpl&);
};
} // namespace aria2
#endif // D_APPLE_MESSAGE_DIGEST_IMPL_H

View File

@ -2,7 +2,7 @@
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2013 Tatsuhiro Tsujikawa
* Copyright (C) 2013 Nils Maier
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
@ -32,24 +32,31 @@
* files in the program, then also delete it here.
*/
/* copyright --> */
#ifndef TLS_SESSION_CONST_H
#define TLS_SESSION_CONST_H
#include "AppleTLSContext.h"
#include "common.h"
#include "LogFactory.h"
#include "Logger.h"
#include "fmt.h"
#include "message.h"
namespace aria2 {
enum TLSDirection {
TLS_WANT_READ = 1,
TLS_WANT_WRITE
};
TLSContext* TLSContext::make(TLSSessionSide side) {
return new AppleTLSContext(side);
}
bool AppleTLSContext::addCredentialFile(const std::string& certfile,
const std::string& keyfile)
{
A2_LOG_WARN("TLS credential files are not supported. Use the KeyChain to manage your certificates.");
return false;
}
bool AppleTLSContext::addTrustedCACertFile(const std::string& certfile)
{
A2_LOG_WARN("TLS CA bundle files are not supported. Use the KeyChain to manage your certificates.");
return false;
}
enum TLSErrorCode {
TLS_ERR_OK = 0,
TLS_ERR_ERROR = -1,
TLS_ERR_WOULDBLOCK = -2
};
} // namespace aria2
#endif // TLS_SESSION_CONST_H

90
src/AppleTLSContext.h Normal file
View File

@ -0,0 +1,90 @@
/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2013 Nils Maier
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#ifndef D_APPLE_TLS_CONTEXT_H
#define D_APPLE_TLS_CONTEXT_H
#include "common.h"
#include <string>
#include <Security/Security.h>
#include <Security/SecureTransport.h>
#include "TLSContext.h"
#include "DlAbortEx.h"
namespace aria2 {
class AppleTLSContext : public TLSContext {
public:
AppleTLSContext(TLSSessionSide side)
: side_(side),
verifyPeer_(true)
{}
virtual ~AppleTLSContext() {}
// private key `keyfile' must be decrypted.
virtual bool addCredentialFile(const std::string& certfile,
const std::string& keyfile);
virtual bool addSystemTrustedCACerts() {
return true;
}
// certfile can contain multiple certificates.
virtual bool addTrustedCACertFile(const std::string& certfile);
virtual bool good() const {
return true;
}
virtual TLSSessionSide getSide() const {
return side_;
}
virtual bool getVerifyPeer() const {
return verifyPeer_;
}
virtual void setVerifyPeer(bool verify) {
verifyPeer_ = verify;
}
private:
TLSSessionSide side_;
bool verifyPeer_;
};
} // namespace aria2
#endif // D_LIBSSL_TLS_CONTEXT_H

354
src/AppleTLSSession.cc Normal file
View File

@ -0,0 +1,354 @@
/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2013 Nils Maier
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#include "AppleTLSSession.h"
#include <CoreFoundation/CoreFoundation.h>
#include "fmt.h"
#include "LogFactory.h"
#define ioErr -36
#define paramErr -50
#define errSSLServerAuthCompleted -9841
namespace {
static const SSLProtocol kTLSProtocol11_h = (SSLProtocol)(kSSLProtocolAll + 1);
static const SSLProtocol kTLSProtocol12_h = (SSLProtocol)(kSSLProtocolAll + 2);
}
namespace aria2 {
TLSSession* TLSSession::make(TLSContext* ctx)
{
return new AppleTLSSession(static_cast<AppleTLSContext*>(ctx));
}
AppleTLSSession::AppleTLSSession(AppleTLSContext* ctx)
: ctx_(ctx),
sslCtx_(0),
sockfd_(0),
state_(st_constructed),
lastError_(noErr),
writeBuffered_(0)
{
lastError_ = SSLNewContext(ctx->getSide() == TLS_SERVER, &sslCtx_) == noErr;
if (lastError_ == noErr) {
state_ = st_error;
return;
}
#if defined(__MAC_10_8)
(void)SSLSetProtocolVersionMin(sslCtx_, kSSLProtocol3);
(void)SSLSetProtocolVersionMax(sslCtx_, kTLSProtocol12);
#else
(void)SSLSetProtocolVersionEnabled(sslCtx_, kSSLProtocolAll, false);
(void)SSLSetProtocolVersionEnabled(sslCtx_, kSSLProtocol3, true);
(void)SSLSetProtocolVersionEnabled(sslCtx_, kTLSProtocol1, true);
(void)SSLSetProtocolVersionEnabled(sslCtx_, kTLSProtocol11_h, true);
(void)SSLSetProtocolVersionEnabled(sslCtx_, kTLSProtocol12_h, true);
#endif
(void)SSLSetEnableCertVerify(sslCtx_, ctx->getVerifyPeer());
}
AppleTLSSession::~AppleTLSSession()
{
closeConnection();
if (sslCtx_) {
SSLDisposeContext(sslCtx_);
sslCtx_ = 0;
}
state_ = st_error;
}
int AppleTLSSession::init(sock_t sockfd)
{
if (state_ != st_constructed) {
lastError_ = noErr;
return TLS_ERR_ERROR;
}
lastError_ = SSLSetIOFuncs(sslCtx_, SocketRead, SocketWrite);
if (lastError_ != noErr) {
state_ = st_error;
return TLS_ERR_ERROR;
}
lastError_ = SSLSetConnection(sslCtx_, this);
if (lastError_ != noErr) {
state_ = st_error;
return TLS_ERR_ERROR;
}
sockfd_ = sockfd;
state_ = st_initialized;
return TLS_ERR_OK;
}
int AppleTLSSession::setSNIHostname(const std::string& hostname)
{
if (state_ != st_initialized) {
lastError_ = noErr;
return TLS_ERR_ERROR;
}
lastError_ = SSLSetPeerDomainName(sslCtx_, hostname.c_str(), hostname.length());
return (lastError_ != noErr) ? TLS_ERR_ERROR : TLS_ERR_OK;
}
int AppleTLSSession::closeConnection()
{
if (state_ != st_connected) {
lastError_ = noErr;
return TLS_ERR_ERROR;
}
lastError_ = SSLClose(sslCtx_);
state_ = st_closed;
return lastError_ == noErr ? TLS_ERR_OK : TLS_ERR_ERROR;
}
int AppleTLSSession::checkDirection() {
if (writeBuffered_) {
return TLS_WANT_WRITE;
}
if (state_ == st_connected) {
size_t buffered;
lastError_ = SSLGetBufferedReadSize(sslCtx_, &buffered);
if (lastError_ == noErr && buffered) {
return TLS_WANT_READ;
}
}
return 0;
}
ssize_t AppleTLSSession::writeData(const void* data, size_t len)
{
if (state_ != st_connected) {
lastError_ = noErr;
return TLS_ERR_ERROR;
}
size_t processed = 0;
if (writeBuffered_) {
lastError_ = SSLWrite(sslCtx_, 0, 0, &processed);
switch (lastError_) {
case noErr:
processed = writeBuffered_;
writeBuffered_ = 0;
return processed;
case errSSLWouldBlock:
return TLS_ERR_WOULDBLOCK;
case errSSLClosedGraceful:
case errSSLClosedNoNotify:
closeConnection();
return TLS_ERR_ERROR;
default:
closeConnection();
state_ = st_error;
return TLS_ERR_ERROR;
}
}
lastError_ = SSLWrite(sslCtx_, data, len, &processed);
switch (lastError_) {
case noErr:
return processed;
case errSSLWouldBlock:
writeBuffered_ = len;
return TLS_ERR_WOULDBLOCK;
case errSSLClosedGraceful:
case errSSLClosedNoNotify:
closeConnection();
return TLS_ERR_ERROR;
default:
closeConnection();
state_ = st_error;
return TLS_ERR_ERROR;
}
}
OSStatus AppleTLSSession::sockWrite(const void* data, size_t* len)
{
size_t remain = *len;
const uint8_t *buffer = static_cast<const uint8_t*>(data);
*len = 0;
while (remain) {
ssize_t w = write(sockfd_, buffer, remain);
if (w <= 0) {
switch (errno) {
case EAGAIN:
return errSSLWouldBlock;
default:
return errSSLClosedAbort;
}
}
remain -= w;
buffer += w;
*len += w;
}
return noErr;
}
ssize_t AppleTLSSession::readData(void* data, size_t len)
{
if (state_ != st_connected) {
lastError_ = noErr;
return TLS_ERR_ERROR;
}
size_t processed = 0;
lastError_ = SSLRead(sslCtx_, data, len, &processed);
switch (lastError_) {
case noErr:
return processed;
case errSSLWouldBlock:
if (processed) {
return processed;
}
return TLS_ERR_WOULDBLOCK;
case errSSLClosedGraceful:
case errSSLClosedNoNotify:
closeConnection();
return TLS_ERR_ERROR;
default:
closeConnection();
state_ = st_error;
return TLS_ERR_ERROR;
}
}
OSStatus AppleTLSSession::sockRead(void* data, size_t* len)
{
size_t remain = *len;
uint8_t *buffer = static_cast<uint8_t*>(data);
*len = 0;
while (remain) {
ssize_t r = read(sockfd_, buffer, remain);
if (r == 0) {
return errSSLClosedGraceful;
}
if (r < 0) {
switch (errno) {
case ENOENT:
return errSSLClosedGraceful;
case ECONNRESET:
return errSSLClosedAbort;
case EAGAIN:
return errSSLWouldBlock;
default:
return errSSLClosedAbort;
}
}
remain -= r;
buffer += r;
*len += r;
}
return noErr;
}
int AppleTLSSession::tlsConnect(const std::string& hostname, std::string& handshakeErr)
{
if (state_ != st_initialized) {
return TLS_ERR_ERROR;
}
if (!hostname.empty()) {
setSNIHostname(hostname);
}
lastError_ = SSLHandshake(sslCtx_);
switch (lastError_) {
case noErr:
state_ = st_connected;
return TLS_ERR_OK;
case errSSLWouldBlock:
return TLS_ERR_WOULDBLOCK;
case errSSLServerAuthCompleted:
return tlsConnect(hostname, handshakeErr);
default:
handshakeErr = getLastErrorString();
return TLS_ERR_ERROR;
}
}
int AppleTLSSession::tlsAccept()
{
std::string hostname, err;
return tlsConnect(hostname, err);
}
std::string AppleTLSSession::getLastErrorString()
{
switch (lastError_) {
case errSSLProtocol:
return "Protocol error";
case errSSLNegotiation:
return "No common cipher suites";
case errSSLFatalAlert:
return "Received fatal alert";
case errSSLSessionNotFound:
return "Unknown session";
case errSSLClosedGraceful:
return "Closed gracefully";
case errSSLClosedAbort:
return "Connection aborted";
case errSSLXCertChainInvalid:
return "Invalid certificate chain";
case errSSLBadCert:
return "Invalid certificate format";
case errSSLCrypto:
return "Cryptographic error";
case paramErr:
case errSSLInternal:
return "Internal SSL error";
case errSSLUnknownRootCert:
return "Self-signed certificate";
case errSSLNoRootCert:
return "No root certificate";
case errSSLCertExpired:
return "Certificate expired";
case errSSLCertNotYetValid:
return "Certificate not yet valid";
case errSSLClosedNoNotify:
return "Closed without notification";
case errSSLBufferOverflow:
return "Buffer not large enough";
case errSSLBadCipherSuite:
return "Bad cipher suite";
case errSSLPeerUnexpectedMsg:
return "Unexpected peer message";
case errSSLPeerBadRecordMac:
return "Bad MAC";
case errSSLPeerDecryptionFail:
return "Decryption failure";
case errSSLHostNameMismatch:
return "Invalid hostname";
case errSSLConnectionRefused:
return "Connection refused";
default:
return fmt("Unspecified error %d", lastError_);
}
}
}

127
src/AppleTLSSession.h Normal file
View File

@ -0,0 +1,127 @@
/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2013 Nils Maier
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#ifndef APPLE_TLS_SESSION_H
#define APPLE_TLS_SESSION_H
#include "common.h"
#include "TLSSession.h"
#include "AppleTLSContext.h"
namespace aria2 {
class AppleTLSSession : public TLSSession {
enum state_t {
st_constructed,
st_initialized,
st_connected,
st_closed,
st_error
};
public:
AppleTLSSession(AppleTLSContext* ctx);
// MUST deallocate all resources
virtual ~AppleTLSSession();
// Initializes SSL/TLS session. The |sockfd| is the underlying
// tranport socket. This function returns TLS_ERR_OK if it
// succeeds, or TLS_ERR_ERROR.
virtual int init(sock_t sockfd);
// Sets |hostname| for TLS SNI extension. This is only meaningful for
// client side session. This function returns TLS_ERR_OK if it
// succeeds, or TLS_ERR_ERROR.
virtual int setSNIHostname(const std::string& hostname);
// Closes the SSL/TLS session. Don't close underlying transport
// socket. This function returns TLS_ERR_OK if it succeeds, or
// TLS_ERR_ERROR.
virtual int closeConnection();
// Returns TLS_WANT_READ if SSL/TLS session needs more data from
// remote endpoint to proceed, or TLS_WANT_WRITE if SSL/TLS session
// needs to write more data to proceed. If SSL/TLS session needs
// neither read nor write data at the moment, return value is
// undefined.
virtual int checkDirection();
// Sends |data| with length |len|. This function returns the number
// of bytes sent if it succeeds, or TLS_ERR_WOULDBLOCK if the
// underlying tranport blocks, or TLS_ERR_ERROR.
virtual ssize_t writeData(const void* data, size_t len);
// Receives data into |data| with length |len|. This function returns
// the number of bytes received if it succeeds, or TLS_ERR_WOULDBLOCK
// if the underlying tranport blocks, or TLS_ERR_ERROR.
virtual ssize_t readData(void* data, size_t len);
// Performs client side handshake. The |hostname| is the hostname of
// the remote endpoint and is used to verify its certificate. This
// function returns TLS_ERR_OK if it succeeds, or TLS_ERR_WOULDBLOCK
// if the underlying transport blocks, or TLS_ERR_ERROR.
// When returning TLS_ERR_ERROR, provide certificate validation error
// in |handshakeErr|.
virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr);
// Performs server side handshake. This function returns TLS_ERR_OK
// if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport
// blocks, or TLS_ERR_ERROR.
virtual int tlsAccept();
// Returns last error string
virtual std::string getLastErrorString();
private:
static OSStatus SocketWrite(SSLConnectionRef conn, const void* data, size_t* len) {
return ((AppleTLSSession*)conn)->sockWrite(data, len);
}
static OSStatus SocketRead(SSLConnectionRef conn, void* data, size_t* len) {
return ((AppleTLSSession*)conn)->sockRead(data, len);
}
AppleTLSContext *ctx_;
SSLContextRef sslCtx_;
sock_t sockfd_;
state_t state_;
OSStatus lastError_;
size_t writeBuffered_;
OSStatus sockWrite(const void* data, size_t* len);
OSStatus sockRead(void* data, size_t* len);
};
}
#endif // TLS_SESSION_H

View File

@ -45,10 +45,15 @@
namespace aria2 {
TLSContext::TLSContext(TLSSessionSide side)
TLSContext* TLSContext::make(TLSSessionSide side)
{
return new GnuTLSContext(side);
}
GnuTLSContext::GnuTLSContext(TLSSessionSide side)
: certCred_(0),
side_(side),
peerVerificationEnabled_(false)
verifyPeer_(true)
{
int r = gnutls_certificate_allocate_credentials(&certCred_);
if(r == GNUTLS_E_SUCCESS) {
@ -63,24 +68,19 @@ TLSContext::TLSContext(TLSSessionSide side)
}
}
TLSContext::~TLSContext()
GnuTLSContext::~GnuTLSContext()
{
if(certCred_) {
gnutls_certificate_free_credentials(certCred_);
}
}
bool TLSContext::good() const
bool GnuTLSContext::good() const
{
return good_;
}
bool TLSContext::bad() const
{
return !good_;
}
bool TLSContext::addCredentialFile(const std::string& certfile,
bool GnuTLSContext::addCredentialFile(const std::string& certfile,
const std::string& keyfile)
{
int ret = gnutls_certificate_set_x509_key_file(certCred_,
@ -101,7 +101,7 @@ bool TLSContext::addCredentialFile(const std::string& certfile,
}
}
bool TLSContext::addSystemTrustedCACerts()
bool GnuTLSContext::addSystemTrustedCACerts()
{
#ifdef HAVE_GNUTLS_CERTIFICATE_SET_X509_SYSTEM_TRUST
int ret = gnutls_certificate_set_x509_system_trust(certCred_);
@ -114,11 +114,12 @@ bool TLSContext::addSystemTrustedCACerts()
return true;
}
#else
A2_LOG_WARN("System certificates not supported");
return false;
#endif
}
bool TLSContext::addTrustedCACertFile(const std::string& certfile)
bool GnuTLSContext::addTrustedCACertFile(const std::string& certfile)
{
int ret = gnutls_certificate_set_x509_trust_file(certCred_,
certfile.c_str(),
@ -133,24 +134,9 @@ bool TLSContext::addTrustedCACertFile(const std::string& certfile)
}
}
gnutls_certificate_credentials_t TLSContext::getCertCred() const
gnutls_certificate_credentials_t GnuTLSContext::getCertCred() const
{
return certCred_;
}
void TLSContext::enablePeerVerification()
{
peerVerificationEnabled_ = true;
}
void TLSContext::disablePeerVerification()
{
peerVerificationEnabled_ = false;
}
bool TLSContext::peerVerificationEnabled() const
{
return peerVerificationEnabled_;
}
} // namespace aria2

View File

@ -37,8 +37,6 @@
#include "common.h"
#include <string>
#include <gnutls/gnutls.h>
#include "TLSContext.h"
@ -46,45 +44,41 @@
namespace aria2 {
class TLSContext {
private:
gnutls_certificate_credentials_t certCred_;
TLSSessionSide side_;
bool good_;
bool peerVerificationEnabled_;
class GnuTLSContext : public TLSContext {
public:
TLSContext(TLSSessionSide side);
GnuTLSContext(TLSSessionSide side);
~TLSContext();
virtual ~GnuTLSContext();
// private key `keyfile' must be decrypted.
bool addCredentialFile(const std::string& certfile,
const std::string& keyfile);
virtual bool addCredentialFile(const std::string& certfile,
const std::string& keyfile);
bool addSystemTrustedCACerts();
virtual bool addSystemTrustedCACerts();
// certfile can contain multiple certificates.
bool addTrustedCACertFile(const std::string& certfile);
virtual bool addTrustedCACertFile(const std::string& certfile);
bool good() const;
virtual bool good() const;
bool bad() const;
gnutls_certificate_credentials_t getCertCred() const;
TLSSessionSide getSide() const
{
virtual TLSSessionSide getSide() const {
return side_;
}
void enablePeerVerification();
virtual bool getVerifyPeer() const {
return verifyPeer_;
}
virtual void setVerifyPeer(bool verify) {
verifyPeer_ = verify;
}
void disablePeerVerification();
gnutls_certificate_credentials_t getCertCred() const;
bool peerVerificationEnabled() const;
private:
gnutls_certificate_credentials_t certCred_;
TLSSessionSide side_;
bool good_;
bool verifyPeer_;
};
} // namespace aria2

View File

@ -42,20 +42,25 @@
namespace aria2 {
TLSSession::TLSSession(TLSContext* tlsContext)
TLSSession* TLSSession::make(TLSContext* ctx)
{
return new GnuTLSSession(static_cast<GnuTLSContext*>(ctx));
}
GnuTLSSession::GnuTLSSession(GnuTLSContext* tlsContext)
: sslSession_(0),
tlsContext_(tlsContext),
rv_(0)
{}
TLSSession::~TLSSession()
GnuTLSSession::~GnuTLSSession()
{
if(sslSession_) {
gnutls_deinit(sslSession_);
}
}
int TLSSession::init(sock_t sockfd)
int GnuTLSSession::init(sock_t sockfd)
{
rv_ = gnutls_init(&sslSession_,
tlsContext_->getSide() == TLS_CLIENT ?
@ -89,7 +94,7 @@ int TLSSession::init(sock_t sockfd)
return TLS_ERR_OK;
}
int TLSSession::setSNIHostname(const std::string& hostname)
int GnuTLSSession::setSNIHostname(const std::string& hostname)
{
// TLS extensions: SNI
rv_ = gnutls_server_name_set(sslSession_, GNUTLS_NAME_DNS,
@ -100,7 +105,7 @@ int TLSSession::setSNIHostname(const std::string& hostname)
return TLS_ERR_OK;
}
int TLSSession::closeConnection()
int GnuTLSSession::closeConnection()
{
rv_ = gnutls_bye(sslSession_, GNUTLS_SHUT_WR);
if(rv_ == GNUTLS_E_SUCCESS) {
@ -112,13 +117,13 @@ int TLSSession::closeConnection()
}
}
int TLSSession::checkDirection()
int GnuTLSSession::checkDirection()
{
int direction = gnutls_record_get_direction(sslSession_);
return direction == 0 ? TLS_WANT_READ : TLS_WANT_WRITE;
}
ssize_t TLSSession::writeData(const void* data, size_t len)
ssize_t GnuTLSSession::writeData(const void* data, size_t len)
{
while((rv_ = gnutls_record_send(sslSession_, data, len)) ==
GNUTLS_E_INTERRUPTED);
@ -133,7 +138,7 @@ ssize_t TLSSession::writeData(const void* data, size_t len)
}
}
ssize_t TLSSession::readData(void* data, size_t len)
ssize_t GnuTLSSession::readData(void* data, size_t len)
{
while((rv_ = gnutls_record_recv(sslSession_, data, len)) ==
GNUTLS_E_INTERRUPTED);
@ -148,7 +153,7 @@ ssize_t TLSSession::readData(void* data, size_t len)
}
}
int TLSSession::tlsConnect(const std::string& hostname,
int GnuTLSSession::tlsConnect(const std::string& hostname,
std::string& handshakeErr)
{
handshakeErr = "";
@ -160,7 +165,7 @@ int TLSSession::tlsConnect(const std::string& hostname,
return TLS_ERR_ERROR;
}
}
if(tlsContext_->peerVerificationEnabled()) {
if(tlsContext_->getVerifyPeer()) {
// verify peer
unsigned int status;
rv_ = gnutls_certificate_verify_peers2(sslSession_, &status);
@ -246,7 +251,7 @@ int TLSSession::tlsConnect(const std::string& hostname,
return TLS_ERR_OK;
}
int TLSSession::tlsAccept()
int GnuTLSSession::tlsAccept()
{
rv_ = gnutls_handshake(sslSession_);
if(rv_ == GNUTLS_E_SUCCESS) {
@ -258,7 +263,7 @@ int TLSSession::tlsAccept()
}
}
std::string TLSSession::getLastErrorString()
std::string GnuTLSSession::getLastErrorString()
{
return gnutls_strerror(rv_);
}

View File

@ -39,31 +39,28 @@
#include <gnutls/gnutls.h>
#include <string>
#include "TLSSessionConst.h"
#include "LibgnutlsTLSContext.h"
#include "TLSSession.h"
#include "a2netcompat.h"
namespace aria2 {
class TLSContext;
class TLSSession {
class GnuTLSSession : public TLSSession {
public:
TLSSession(TLSContext* tlsContext);
~TLSSession();
int init(sock_t sockfd);
int setSNIHostname(const std::string& hostname);
int closeConnection();
int checkDirection();
ssize_t writeData(const void* data, size_t len);
ssize_t readData(void* data, size_t len);
int tlsConnect(const std::string& hostname, std::string& handshakeErr);
int tlsAccept();
std::string getLastErrorString();
GnuTLSSession(GnuTLSContext* tlsContext);
~GnuTLSSession();
virtual int init(sock_t sockfd);
virtual int setSNIHostname(const std::string& hostname);
virtual int closeConnection();
virtual int checkDirection();
virtual ssize_t writeData(const void* data, size_t len);
virtual ssize_t readData(void* data, size_t len);
virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr);
virtual int tlsAccept();
virtual std::string getLastErrorString();
private:
gnutls_session_t sslSession_;
TLSContext* tlsContext_;
GnuTLSContext* tlsContext_;
// Last error code from gnutls library functions
int rv_;
};

View File

@ -43,10 +43,15 @@
namespace aria2 {
TLSContext::TLSContext(TLSSessionSide side)
TLSContext* TLSContext::make(TLSSessionSide side)
{
return new OpenSSLTLSContext(side);
}
OpenSSLTLSContext::OpenSSLTLSContext(TLSSessionSide side)
: sslCtx_(0),
side_(side),
peerVerificationEnabled_(false)
verifyPeer_(true)
{
sslCtx_ = SSL_CTX_new(SSLv23_method());
if(sslCtx_) {
@ -70,22 +75,17 @@ TLSContext::TLSContext(TLSSessionSide side)
#endif
}
TLSContext::~TLSContext()
OpenSSLTLSContext::~OpenSSLTLSContext()
{
SSL_CTX_free(sslCtx_);
}
bool TLSContext::good() const
bool OpenSSLTLSContext::good() const
{
return good_;
}
bool TLSContext::bad() const
{
return !good_;
}
bool TLSContext::addCredentialFile(const std::string& certfile,
bool OpenSSLTLSContext::addCredentialFile(const std::string& certfile,
const std::string& keyfile)
{
if(SSL_CTX_use_PrivateKey_file(sslCtx_, keyfile.c_str(),
@ -107,7 +107,7 @@ bool TLSContext::addCredentialFile(const std::string& certfile,
return true;
}
bool TLSContext::addSystemTrustedCACerts()
bool OpenSSLTLSContext::addSystemTrustedCACerts()
{
if(SSL_CTX_set_default_verify_paths(sslCtx_) != 1) {
A2_LOG_INFO(fmt(MSG_LOADING_SYSTEM_TRUSTED_CA_CERTS_FAILED,
@ -119,7 +119,7 @@ bool TLSContext::addSystemTrustedCACerts()
}
}
bool TLSContext::addTrustedCACertFile(const std::string& certfile)
bool OpenSSLTLSContext::addTrustedCACertFile(const std::string& certfile)
{
if(SSL_CTX_load_verify_locations(sslCtx_, certfile.c_str(), 0) != 1) {
A2_LOG_ERROR(fmt(MSG_LOADING_TRUSTED_CA_CERT_FAILED,
@ -132,14 +132,4 @@ bool TLSContext::addTrustedCACertFile(const std::string& certfile)
}
}
void TLSContext::enablePeerVerification()
{
peerVerificationEnabled_ = true;
}
void TLSContext::disablePeerVerification()
{
peerVerificationEnabled_ = false;
}
} // namespace aria2

View File

@ -46,52 +46,43 @@
namespace aria2 {
class TLSContext {
private:
SSL_CTX* sslCtx_;
TLSSessionSide side_;
bool good_;
bool peerVerificationEnabled_;
class OpenSSLTLSContext : public TLSContext {
public:
TLSContext(TLSSessionSide side);
OpenSSLTLSContext(TLSSessionSide side);
~TLSContext();
~OpenSSLTLSContext();
// private key `keyfile' must be decrypted.
bool addCredentialFile(const std::string& certfile,
const std::string& keyfile);
virtual bool addCredentialFile(const std::string& certfile,
const std::string& keyfile);
bool addSystemTrustedCACerts();
virtual bool addSystemTrustedCACerts();
// certfile can contain multiple certificates.
bool addTrustedCACertFile(const std::string& certfile);
virtual bool addTrustedCACertFile(const std::string& certfile);
bool good() const;
virtual bool good() const;
bool bad() const;
SSL_CTX* getSSLCtx() const
{
return sslCtx_;
}
TLSSessionSide getSide() const
{
virtual TLSSessionSide getSide() const {
return side_;
}
void enablePeerVerification();
void disablePeerVerification();
bool peerVerificationEnabled() const
{
return peerVerificationEnabled_;
virtual bool getVerifyPeer() const {
return verifyPeer_;
}
virtual void setVerifyPeer(bool verify) {
verifyPeer_ = verify;
}
SSL_CTX* getSSLCtx() const {
return sslCtx_;
}
private:
SSL_CTX* sslCtx_;
TLSSessionSide side_;
bool good_;
bool verifyPeer_;
};
} // namespace aria2

View File

@ -38,26 +38,31 @@
#include <openssl/x509.h>
#include <openssl/x509v3.h>
#include "TLSContext.h"
#include "LogFactory.h"
#include "util.h"
#include "SocketCore.h"
namespace aria2 {
TLSSession::TLSSession(TLSContext* tlsContext)
TLSSession* TLSSession::make(TLSContext* ctx)
{
return new OpenSSLTLSSession(static_cast<OpenSSLTLSContext*>(ctx));
}
OpenSSLTLSSession::OpenSSLTLSSession(OpenSSLTLSContext* tlsContext)
: ssl_(0),
tlsContext_(tlsContext),
rv_(1)
{}
TLSSession::~TLSSession()
OpenSSLTLSSession::~OpenSSLTLSSession()
{
if(ssl_) {
SSL_shutdown(ssl_);
}
}
int TLSSession::init(sock_t sockfd)
int OpenSSLTLSSession::init(sock_t sockfd)
{
ERR_clear_error();
ssl_ = SSL_new(tlsContext_->getSSLCtx());
@ -71,7 +76,7 @@ int TLSSession::init(sock_t sockfd)
return TLS_ERR_OK;
}
int TLSSession::setSNIHostname(const std::string& hostname)
int OpenSSLTLSSession::setSNIHostname(const std::string& hostname)
{
#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
ERR_clear_error();
@ -83,7 +88,7 @@ int TLSSession::setSNIHostname(const std::string& hostname)
return TLS_ERR_OK;
}
int TLSSession::closeConnection()
int OpenSSLTLSSession::closeConnection()
{
ERR_clear_error();
SSL_shutdown(ssl_);
@ -91,7 +96,7 @@ int TLSSession::closeConnection()
return TLS_ERR_OK;
}
int TLSSession::checkDirection()
int OpenSSLTLSSession::checkDirection()
{
int error = SSL_get_error(ssl_, rv_);
if(error == SSL_ERROR_WANT_WRITE) {
@ -110,7 +115,7 @@ bool wouldblock(SSL* ssl, int rv)
}
} // namespace
ssize_t TLSSession::writeData(const void* data, size_t len)
ssize_t OpenSSLTLSSession::writeData(const void* data, size_t len)
{
ERR_clear_error();
rv_ = SSL_write(ssl_, data, len);
@ -127,7 +132,7 @@ ssize_t TLSSession::writeData(const void* data, size_t len)
}
}
ssize_t TLSSession::readData(void* data, size_t len)
ssize_t OpenSSLTLSSession::readData(void* data, size_t len)
{
ERR_clear_error();
rv_ = SSL_read(ssl_, data, len);
@ -144,7 +149,7 @@ ssize_t TLSSession::readData(void* data, size_t len)
}
}
int TLSSession::handshake()
int OpenSSLTLSSession::handshake()
{
ERR_clear_error();
if(tlsContext_->getSide() == TLS_CLIENT) {
@ -171,7 +176,7 @@ int TLSSession::handshake()
return TLS_ERR_OK;
}
int TLSSession::tlsConnect(const std::string& hostname,
int OpenSSLTLSSession::tlsConnect(const std::string& hostname,
std::string& handshakeErr)
{
handshakeErr = "";
@ -181,7 +186,7 @@ int TLSSession::tlsConnect(const std::string& hostname,
return ret;
}
if(tlsContext_->getSide() == TLS_CLIENT &&
tlsContext_->peerVerificationEnabled()) {
tlsContext_->getVerifyPeer()) {
// verify peer
X509* peerCert = SSL_get_peer_certificate(ssl_);
if(!peerCert) {
@ -256,12 +261,12 @@ int TLSSession::tlsConnect(const std::string& hostname,
return TLS_ERR_OK;
}
int TLSSession::tlsAccept()
int OpenSSLTLSSession::tlsAccept()
{
return handshake();
}
std::string TLSSession::getLastErrorString()
std::string OpenSSLTLSSession::getLastErrorString()
{
if(rv_ <= 0) {
int sslError = SSL_get_error(ssl_, rv_);

View File

@ -39,32 +39,29 @@
#include <openssl/ssl.h>
#include <string>
#include "TLSSessionConst.h"
#include "LibsslTLSContext.h"
#include "TLSSession.h"
#include "a2netcompat.h"
namespace aria2 {
class TLSContext;
class TLSSession {
class OpenSSLTLSSession : public TLSSession {
public:
TLSSession(TLSContext* tlsContext);
~TLSSession();
int init(sock_t sockfd);
int setSNIHostname(const std::string& hostname);
int closeConnection();
int checkDirection();
ssize_t writeData(const void* data, size_t len);
ssize_t readData(void* data, size_t len);
int tlsConnect(const std::string& hostname, std::string& handshakeErr);
int tlsAccept();
std::string getLastErrorString();
OpenSSLTLSSession(OpenSSLTLSContext* tlsContext);
virtual ~OpenSSLTLSSession();
virtual int init(sock_t sockfd);
virtual int setSNIHostname(const std::string& hostname);
virtual int closeConnection();
virtual int checkDirection();
virtual ssize_t writeData(const void* data, size_t len);
virtual ssize_t readData(void* data, size_t len);
virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr);
virtual int tlsAccept();
virtual std::string getLastErrorString();
private:
int handshake();
SSL* ssl_;
TLSContext* tlsContext_;
OpenSSLTLSContext* tlsContext_;
// Last error code from openSSL library functions
int rv_;
};

View File

@ -299,38 +299,53 @@ SRCS += EpollEventPoll.cc EpollEventPoll.h
endif # HAVE_EPOLL
if ENABLE_SSL
SRCS += TLSContext.h\
TLSSession.h\
TLSSessionConst.h
SRCS += TLSSession.h TLSSessionConst.h
endif # ENABLE_SSL
if USE_APPLE_MD
SRCS += AppleMessageDigestImpl.cc AppleMessageDigestImpl.h
endif
if HAVE_APPLETLS
SRCS += AppleTLSContext.cc AppleTLSContext.h \
AppleTLSSession.cc AppleTLSSession.h
endif
if HAVE_LIBGNUTLS
SRCS += LibgnutlsTLSContext.cc LibgnutlsTLSContext.h\
LibgnutlsTLSSession.cc LibgnutlsTLSSession.h
SRCS += LibgnutlsTLSContext.cc LibgnutlsTLSContext.h \
LibgnutlsTLSSession.cc LibgnutlsTLSSession.h
endif # HAVE_LIBGNUTLS
if HAVE_LIBGCRYPT
SRCS += LibgcryptMessageDigestImpl.cc LibgcryptMessageDigestImpl.h\
LibgcryptARC4Encryptor.cc LibgcryptARC4Encryptor.h\
LibgcryptDHKeyExchange.cc LibgcryptDHKeyExchange.h
SRCS += LibgcryptARC4Encryptor.cc LibgcryptARC4Encryptor.h \
LibgcryptDHKeyExchange.cc LibgcryptDHKeyExchange.h
if USE_LIBGCRYPT_MD
SRCS += LibgcryptMessageDigestImpl.cc LibgcryptMessageDigestImpl.h
endif
endif # HAVE_LIBGCRYPT
if HAVE_LIBNETTLE
SRCS += LibnettleMessageDigestImpl.cc LibnettleMessageDigestImpl.h\
LibnettleARC4Encryptor.cc LibnettleARC4Encryptor.h
SRCS += LibnettleARC4Encryptor.cc LibnettleARC4Encryptor.h
if USE_LIBNETTLE_MD
SRCS += LibnettleMessageDigestImpl.cc LibnettleMessageDigestImpl.h
endif
endif # HAVE_LIBNETTLE
if HAVE_LIBGMP
SRCS += a2gmp.cc a2gmp.h\
LibgmpDHKeyExchange.cc LibgmpDHKeyExchange.h
SRCS += a2gmp.cc a2gmp.h \
LibgmpDHKeyExchange.cc LibgmpDHKeyExchange.h
endif # HAVE_LIBGMP
if HAVE_OPENSSL
SRCS += LibsslTLSContext.cc LibsslTLSContext.h\
LibsslTLSSession.cc LibsslTLSSession.h\
LibsslMessageDigestImpl.cc LibsslMessageDigestImpl.h\
LibsslARC4Encryptor.cc LibsslARC4Encryptor.h\
LibsslDHKeyExchange.cc LibsslDHKeyExchange.h
SRCS += LibsslARC4Encryptor.cc LibsslARC4Encryptor.h \
LibsslDHKeyExchange.cc LibsslDHKeyExchange.h
if !HAVE_APPLETLS
SRCS += LibsslTLSContext.cc LibsslTLSContext.h \
LibsslTLSSession.cc LibsslTLSSession.h
endif
if USE_OPENSSL_MD
SRCS += LibsslMessageDigestImpl.cc LibsslMessageDigestImpl.h
endif
endif # HAVE_OPENSSL
if HAVE_ZLIB

View File

@ -35,12 +35,15 @@
#ifndef D_MESSAGE_DIGEST_IMPL_H
#define D_MESSAGE_DIGEST_IMPL_H
#ifdef HAVE_LIBNETTLE
#ifdef USE_APPLE_MD
# include "AppleMessageDigestImpl.h"
#elif defined(USE_LIBNETTLE_MD)
# include "LibnettleMessageDigestImpl.h"
#elif HAVE_LIBGCRYPT
#elif defined(USE_LIBGCRYPT_MD)
# include "LibgcryptMessageDigestImpl.h"
#elif HAVE_OPENSSL
#elif defined(USE_OPENSSL_MD)
# include "LibsslMessageDigestImpl.h"
#endif // HAVE_OPENSSL
#endif
#endif // D_MESSAGE_DIGEST_IMPL_H

View File

@ -145,7 +145,7 @@ error_code::Value MultiUrlRequestInfo::execute()
!option_->blank(PREF_RPC_PRIVATE_KEY)) {
// We set server TLS context to the SocketCore before creating
// DownloadEngine instance.
SharedHandle<TLSContext> svTlsContext(new TLSContext(TLS_SERVER));
SharedHandle<TLSContext> svTlsContext(TLSContext::make(TLS_SERVER));
svTlsContext->addCredentialFile(option_->get(PREF_RPC_CERTIFICATE),
option_->get(PREF_RPC_PRIVATE_KEY));
SocketCore::setServerTLSContext(svTlsContext);
@ -194,7 +194,7 @@ error_code::Value MultiUrlRequestInfo::execute()
e->setAuthConfigFactory(authConfigFactory);
#ifdef ENABLE_SSL
SharedHandle<TLSContext> clTlsContext(new TLSContext(TLS_CLIENT));
SharedHandle<TLSContext> clTlsContext(TLSContext::make(TLS_CLIENT));
if(!option_->blank(PREF_CERTIFICATE) &&
!option_->blank(PREF_PRIVATE_KEY)) {
clTlsContext->addCredentialFile(option_->get(PREF_CERTIFICATE),
@ -211,9 +211,7 @@ error_code::Value MultiUrlRequestInfo::execute()
A2_LOG_INFO(MSG_WARN_NO_CA_CERT);
}
}
if(option_->getAsBool(PREF_CHECK_CERTIFICATE)) {
clTlsContext->enablePeerVerification();
}
clTlsContext->setVerifyPeer(option_->getAsBool(PREF_CHECK_CERTIFICATE));
SocketCore::setClientTLSContext(clTlsContext);
#endif
#ifdef HAVE_ARES_ADDR_NODE

View File

@ -819,7 +819,7 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
wantWrite_ = false;
switch(secure_) {
case A2_TLS_NONE:
tlsSession_.reset(new TLSSession(tlsctx));
tlsSession_.reset(TLSSession::make(tlsctx));
rv = tlsSession_->init(sockfd_);
if(rv != TLS_ERR_OK) {
std::string error = tlsSession_->getLastErrorString();

View File

@ -35,6 +35,8 @@
#ifndef D_TLS_CONTEXT_H
#define D_TLS_CONTEXT_H
#include <string>
#include "common.h"
namespace aria2 {
@ -44,12 +46,27 @@ enum TLSSessionSide {
TLS_SERVER
};
class TLSContext {
public:
static TLSContext* make(TLSSessionSide side);
virtual ~TLSContext() {}
// private key `keyfile' must be decrypted.
virtual bool addCredentialFile(const std::string& certfile,
const std::string& keyfile) = 0;
virtual bool addSystemTrustedCACerts() = 0;
// certfile can contain multiple certificates.
virtual bool addTrustedCACertFile(const std::string& certfile) = 0;
virtual bool good() const = 0;
virtual TLSSessionSide getSide() const = 0;
virtual bool getVerifyPeer() const = 0;
virtual void setVerifyPeer(bool) = 0;
};
} // namespace aria2
#ifdef HAVE_OPENSSL
# include "LibsslTLSContext.h"
#elif HAVE_LIBGNUTLS
# include "LibgnutlsTLSContext.h"
#endif // HAVE_LIBGNUTLS
#endif // D_TLS_CONTEXT_H

View File

@ -36,69 +36,86 @@
#define TLS_SESSION_H
#include "common.h"
#include "a2netcompat.h"
#include "TLSContext.h"
namespace aria2 {
enum TLSDirection {
TLS_WANT_READ = 1,
TLS_WANT_WRITE
};
enum TLSErrorCode {
TLS_ERR_OK = 0,
TLS_ERR_ERROR = -1,
TLS_ERR_WOULDBLOCK = -2
};
// To create another SSL/TLS backend, implement TLSSession class below.
//
// class TLSSession {
// public:
// TLSSession(TLSContext* tlsContext);
//
// // MUST deallocate all resources
// ~TLSSession();
//
// // Initializes SSL/TLS session. The |sockfd| is the underlying
// // tranport socket. This function returns TLS_ERR_OK if it
// // succeeds, or TLS_ERR_ERROR.
// int init(sock_t sockfd);
//
// // Sets |hostname| for TLS SNI extension. This is only meaningful for
// // client side session. This function returns TLS_ERR_OK if it
// // succeeds, or TLS_ERR_ERROR.
// int setSNIHostname(const std::string& hostname);
//
// // Closes the SSL/TLS session. Don't close underlying transport
// // socket. This function returns TLS_ERR_OK if it succeeds, or
// // TLS_ERR_ERROR.
// int closeConnection();
//
// // Returns TLS_WANT_READ if SSL/TLS session needs more data from
// // remote endpoint to proceed, or TLS_WANT_WRITE if SSL/TLS session
// // needs to write more data to proceed. If SSL/TLS session needs
// // neither read nor write data at the moment, return value is
// // undefined.
// int checkDirection();
//
// // Sends |data| with length |len|. This function returns the number
// // of bytes sent if it succeeds, or TLS_ERR_WOULDBLOCK if the
// // underlying tranport blocks, or TLS_ERR_ERROR.
// ssize_t writeData(const void* data, size_t len);
//
// // Receives data into |data| with length |len|. This function returns
// // the number of bytes received if it succeeds, or TLS_ERR_WOULDBLOCK
// // if the underlying tranport blocks, or TLS_ERR_ERROR.
// ssize_t readData(void* data, size_t len);
//
// // Performs client side handshake. The |hostname| is the hostname of
// // the remote endpoint and is used to verify its certificate. This
// // function returns TLS_ERR_OK if it succeeds, or TLS_ERR_WOULDBLOCK
// // if the underlying transport blocks, or TLS_ERR_ERROR.
// // When returning TLS_ERR_ERROR, provide certificate validation error
// // in |handshakeErr|.
// int tlsConnect(const std::string& hostname, std::string& handshakeErr);
//
// // Performs server side handshake. This function returns TLS_ERR_OK
// // if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport
// // blocks, or TLS_ERR_ERROR.
// int tlsAccept();
//
// // Returns last error string
// std::string getLastErrorString();
// };
class TLSSession {
public:
static TLSSession* make(TLSContext* ctx);
#ifdef HAVE_OPENSSL
# include "LibsslTLSSession.h"
#elif defined HAVE_LIBGNUTLS
# include "LibgnutlsTLSSession.h"
#endif
// MUST deallocate all resources
virtual ~TLSSession() {}
// Initializes SSL/TLS session. The |sockfd| is the underlying
// tranport socket. This function returns TLS_ERR_OK if it
// succeeds, or TLS_ERR_ERROR.
virtual int init(sock_t sockfd) = 0;
// Sets |hostname| for TLS SNI extension. This is only meaningful for
// client side session. This function returns TLS_ERR_OK if it
// succeeds, or TLS_ERR_ERROR.
virtual int setSNIHostname(const std::string& hostname) = 0;
// Closes the SSL/TLS session. Don't close underlying transport
// socket. This function returns TLS_ERR_OK if it succeeds, or
// TLS_ERR_ERROR.
virtual int closeConnection() = 0;
// Returns TLS_WANT_READ if SSL/TLS session needs more data from
// remote endpoint to proceed, or TLS_WANT_WRITE if SSL/TLS session
// needs to write more data to proceed. If SSL/TLS session needs
// neither read nor write data at the moment, return value is
// undefined.
virtual int checkDirection() = 0;
// Sends |data| with length |len|. This function returns the number
// of bytes sent if it succeeds, or TLS_ERR_WOULDBLOCK if the
// underlying tranport blocks, or TLS_ERR_ERROR.
virtual ssize_t writeData(const void* data, size_t len) = 0;
// Receives data into |data| with length |len|. This function returns
// the number of bytes received if it succeeds, or TLS_ERR_WOULDBLOCK
// if the underlying tranport blocks, or TLS_ERR_ERROR.
virtual ssize_t readData(void* data, size_t len) = 0;
// Performs client side handshake. The |hostname| is the hostname of
// the remote endpoint and is used to verify its certificate. This
// function returns TLS_ERR_OK if it succeeds, or TLS_ERR_WOULDBLOCK
// if the underlying transport blocks, or TLS_ERR_ERROR.
// When returning TLS_ERR_ERROR, provide certificate validation error
// in |handshakeErr|.
virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr) = 0;
// Performs server side handshake. This function returns TLS_ERR_OK
// if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport
// blocks, or TLS_ERR_ERROR.
virtual int tlsAccept() = 0;
// Returns last error string
virtual std::string getLastErrorString() = 0;
protected:
TLSSession() {}
private:
TLSSession(const TLSSession&);
TLSSession& operator=(const TLSSession&);
};
}
#endif // TLS_SESSION_H