From 8580c98bce8e5b85c1cf7be150d9315cd663db93 Mon Sep 17 00:00:00 2001 From: Tatsuhiro Tsujikawa Date: Wed, 3 Apr 2013 02:24:41 +0900 Subject: [PATCH] Abstract TLS session implementation Now TLS session object is abstracted as TLSSession class. Currently, we have GNUTLS and OpenSSL implementations. --- src/LibgnutlsTLSSession.cc | 266 +++++++++++++++++++++ src/LibgnutlsTLSSession.h | 73 ++++++ src/LibsslTLSSession.cc | 299 +++++++++++++++++++++++ src/LibsslTLSSession.h | 74 ++++++ src/Makefile.am | 8 +- src/SocketCore.cc | 473 +++++-------------------------------- src/SocketCore.h | 37 +-- src/TLSSession.h | 104 ++++++++ src/TLSSessionConst.h | 55 +++++ 9 files changed, 948 insertions(+), 441 deletions(-) create mode 100644 src/LibgnutlsTLSSession.cc create mode 100644 src/LibgnutlsTLSSession.h create mode 100644 src/LibsslTLSSession.cc create mode 100644 src/LibsslTLSSession.h create mode 100644 src/TLSSession.h create mode 100644 src/TLSSessionConst.h diff --git a/src/LibgnutlsTLSSession.cc b/src/LibgnutlsTLSSession.cc new file mode 100644 index 00000000..c1d449f3 --- /dev/null +++ b/src/LibgnutlsTLSSession.cc @@ -0,0 +1,266 @@ +/* */ +#include "LibgnutlsTLSSession.h" + +#include + +#include "TLSContext.h" +#include "util.h" +#include "SocketCore.h" + +namespace aria2 { + +TLSSession::TLSSession(TLSContext* tlsContext) + : sslSession_(0), + tlsContext_(tlsContext), + rv_(0) +{} + +TLSSession::~TLSSession() +{ + if(sslSession_) { + gnutls_deinit(sslSession_); + } +} + +int TLSSession::init(sock_t sockfd) +{ + rv_ = gnutls_init(&sslSession_, + tlsContext_->getSide() == TLS_CLIENT ? + GNUTLS_CLIENT : GNUTLS_SERVER); + if(rv_ != GNUTLS_E_SUCCESS) { + return TLS_ERR_ERROR; + } + // It seems err is not error message, but the argument string + // which causes syntax error. + const char* err; + // For client side, disables TLS1.1 here because there are servers + // that don't understand TLS1.1. TODO Is this still necessary? + rv_ = gnutls_priority_set_direct(sslSession_, + tlsContext_->getSide() == TLS_CLIENT ? + "NORMAL:-VERS-TLS1.1" : + "NORMAL", + &err); + if(rv_ != GNUTLS_E_SUCCESS) { + return TLS_ERR_ERROR; + } + // put the x509 credentials to the current session + rv_ = gnutls_credentials_set(sslSession_, GNUTLS_CRD_CERTIFICATE, + tlsContext_->getCertCred()); + if(rv_ != GNUTLS_E_SUCCESS) { + return TLS_ERR_ERROR; + } + // TODO Consider to use gnutls_transport_set_int() for GNUTLS 3.1.9 + // or later + gnutls_transport_set_ptr(sslSession_, + (gnutls_transport_ptr_t)(ptrdiff_t)sockfd); + return TLS_ERR_OK; +} + +int TLSSession::setSNIHostname(const std::string& hostname) +{ + // TLS extensions: SNI + rv_ = gnutls_server_name_set(sslSession_, GNUTLS_NAME_DNS, + hostname.c_str(), hostname.size()); + if(rv_ != GNUTLS_E_SUCCESS) { + return TLS_ERR_ERROR; + } + return TLS_ERR_OK; +} + +int TLSSession::closeConnection() +{ + rv_ = gnutls_bye(sslSession_, GNUTLS_SHUT_WR); + if(rv_ == GNUTLS_E_SUCCESS) { + return TLS_ERR_OK; + } else if(rv_ == GNUTLS_E_AGAIN) { + return TLS_ERR_WOULDBLOCK; + } else { + return TLS_ERR_ERROR; + } +} + +int TLSSession::checkDirection() +{ + int direction = gnutls_record_get_direction(sslSession_); + return direction == 0 ? TLS_WANT_READ : TLS_WANT_WRITE; +} + +ssize_t TLSSession::writeData(const void* data, size_t len) +{ + while((rv_ = gnutls_record_send(sslSession_, data, len)) == + GNUTLS_E_INTERRUPTED); + if(rv_ >= 0) { + ssize_t ret = rv_; + rv_ = 0; + return ret; + } else if(rv_ == GNUTLS_E_AGAIN) { + return TLS_ERR_WOULDBLOCK; + } else { + return TLS_ERR_ERROR; + } +} + +ssize_t TLSSession::readData(void* data, size_t len) +{ + while((rv_ = gnutls_record_recv(sslSession_, data, len)) == + GNUTLS_E_INTERRUPTED); + if(rv_ >= 0) { + ssize_t ret = rv_; + rv_ = 0; + return ret; + } else if(rv_ == GNUTLS_E_AGAIN) { + return TLS_ERR_WOULDBLOCK; + } else { + return TLS_ERR_ERROR; + } +} + +int TLSSession::tlsConnect(const std::string& hostname, + std::string& handshakeErr) +{ + handshakeErr = ""; + rv_ = gnutls_handshake(sslSession_); + if(rv_ < 0) { + if(rv_ == GNUTLS_E_AGAIN) { + return TLS_ERR_WOULDBLOCK; + } else { + return TLS_ERR_ERROR; + } + } + if(tlsContext_->peerVerificationEnabled()) { + // verify peer + unsigned int status; + rv_ = gnutls_certificate_verify_peers2(sslSession_, &status); + if(rv_ != GNUTLS_E_SUCCESS) { + return TLS_ERR_ERROR; + } + if(status) { + handshakeErr = ""; + if(status & GNUTLS_CERT_INVALID) { + handshakeErr += " `not signed by known authorities or invalid'"; + } + if(status & GNUTLS_CERT_REVOKED) { + handshakeErr += " `revoked by its CA'"; + } + if(status & GNUTLS_CERT_SIGNER_NOT_FOUND) { + handshakeErr += " `issuer is not known'"; + } + // TODO should check GNUTLS_CERT_SIGNER_NOT_CA ? + if(status & GNUTLS_CERT_INSECURE_ALGORITHM) { + handshakeErr += " `insecure algorithm'"; + } + if(status & GNUTLS_CERT_NOT_ACTIVATED) { + handshakeErr += " `not activated yet'"; + } + if(status & GNUTLS_CERT_EXPIRED) { + handshakeErr += " `expired'"; + } + // TODO Add GNUTLS_CERT_SIGNATURE_FAILURE here + if(!handshakeErr.empty()) { + return TLS_ERR_ERROR; + } + } + // certificate type: only X509 is allowed. + if(gnutls_certificate_type_get(sslSession_) != GNUTLS_CRT_X509) { + handshakeErr = "certificate type must be X509"; + return TLS_ERR_ERROR; + } + unsigned int peerCertsLength; + const gnutls_datum_t* peerCerts; + peerCerts = gnutls_certificate_get_peers(sslSession_, &peerCertsLength); + if(!peerCerts || peerCertsLength == 0 ) { + handshakeErr = "certificate not found"; + return TLS_ERR_ERROR; + } + gnutls_x509_crt_t cert; + rv_ = gnutls_x509_crt_init(&cert); + if(rv_ != GNUTLS_E_SUCCESS) { + return TLS_ERR_ERROR; + } + auto_delete certDeleter(cert, gnutls_x509_crt_deinit); + rv_ = gnutls_x509_crt_import(cert, &peerCerts[0], GNUTLS_X509_FMT_DER); + if(rv_ != GNUTLS_E_SUCCESS) { + return TLS_ERR_ERROR; + } + std::string commonName; + std::vector dnsNames; + std::vector ipAddrs; + int ret = 0; + char altName[256]; + size_t altNameLen; + for(int i = 0; !(ret < 0); ++i) { + altNameLen = sizeof(altName); + ret = gnutls_x509_crt_get_subject_alt_name(cert, i, altName, + &altNameLen, 0); + if(ret == GNUTLS_SAN_DNSNAME) { + dnsNames.push_back(std::string(altName, altNameLen)); + } else if(ret == GNUTLS_SAN_IPADDRESS) { + ipAddrs.push_back(std::string(altName, altNameLen)); + } + } + altNameLen = sizeof(altName); + ret = gnutls_x509_crt_get_dn_by_oid(cert, + GNUTLS_OID_X520_COMMON_NAME, 0, 0, + altName, &altNameLen); + if(ret == 0) { + commonName.assign(altName, altNameLen); + } + if(!net::verifyHostname(hostname, dnsNames, ipAddrs, commonName)) { + handshakeErr = "hostname does not match"; + return TLS_ERR_ERROR; + } + } + return TLS_ERR_OK; +} + +int TLSSession::tlsAccept() +{ + rv_ = gnutls_handshake(sslSession_); + if(rv_ == GNUTLS_E_SUCCESS) { + return TLS_ERR_OK; + } else if(rv_ == GNUTLS_E_AGAIN) { + return TLS_ERR_WOULDBLOCK; + } else { + return TLS_ERR_ERROR; + } +} + +std::string TLSSession::getLastErrorString() +{ + return gnutls_strerror(rv_); +} + +} // namespace aria2 diff --git a/src/LibgnutlsTLSSession.h b/src/LibgnutlsTLSSession.h new file mode 100644 index 00000000..7118ab21 --- /dev/null +++ b/src/LibgnutlsTLSSession.h @@ -0,0 +1,73 @@ +/* */ +#ifndef LIBGNUTLS_TLS_SESSION_H +#define LIBGNUTLS_TLS_SESSION_H + +#include "common.h" + +#include + +#include + +#include "TLSSessionConst.h" +#include "a2netcompat.h" + +namespace aria2 { + +class TLSContext; + +class TLSSession { +public: + TLSSession(TLSContext* tlsContext); + ~TLSSession(); + int init(sock_t sockfd); + int setSNIHostname(const std::string& hostname); + int closeConnection(); + int checkDirection(); + ssize_t writeData(const void* data, size_t len); + ssize_t readData(void* data, size_t len); + int tlsConnect(const std::string& hostname, std::string& handshakeErr); + int tlsAccept(); + std::string getLastErrorString(); +private: + gnutls_session_t sslSession_; + TLSContext* tlsContext_; + // Last error code from gnutls library functions + int rv_; +}; + +} // namespace aria2 + +#endif // LIBGNUTLS_TLS_SESSION_H diff --git a/src/LibsslTLSSession.cc b/src/LibsslTLSSession.cc new file mode 100644 index 00000000..8cbcac5c --- /dev/null +++ b/src/LibsslTLSSession.cc @@ -0,0 +1,299 @@ +/* */ +#include "LibsslTLSSession.h" + +#include +#include +#include + +#include "TLSContext.h" +#include "util.h" +#include "SocketCore.h" + +namespace aria2 { + +TLSSession::TLSSession(TLSContext* tlsContext) + : ssl_(0), + tlsContext_(tlsContext), + rv_(1) +{} + +TLSSession::~TLSSession() +{ + if(ssl_) { + SSL_shutdown(ssl_); + } +} + +int TLSSession::init(sock_t sockfd) +{ + ERR_clear_error(); + ssl_ = SSL_new(tlsContext_->getSSLCtx()); + if(!ssl_) { + return TLS_ERR_ERROR; + } + rv_ = SSL_set_fd(ssl_, sockfd); + if(rv_ == 0) { + return TLS_ERR_ERROR; + } + return TLS_ERR_OK; +} + +int TLSSession::setSNIHostname(const std::string& hostname) +{ +#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME + ERR_clear_error(); + // TLS extensions: SNI. There is not documentation about the + // return code for this function (actually this is macro + // wrapping SSL_ctrl at the time of this writing). + SSL_set_tlsext_host_name(ssl_, hostname.c_str()); +#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME + return TLS_ERR_OK; +} + +int TLSSession::closeConnection() +{ + ERR_clear_error(); + SSL_shutdown(ssl_); + // TODO handle return value + return TLS_ERR_OK; +} + +int TLSSession::checkDirection() +{ + int error = SSL_get_error(ssl_, rv_); + if(error == SSL_ERROR_WANT_WRITE) { + return TLS_WANT_WRITE; + } else { + // TODO We ignore error other than SSL_ERR_WANT_READ here for now + return TLS_WANT_READ; + } +} + +namespace { +bool wouldblock(SSL* ssl, int rv) +{ + int error = SSL_get_error(ssl, rv); + return error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE; +} +} // namespace + +ssize_t TLSSession::writeData(const void* data, size_t len) +{ + ERR_clear_error(); + rv_ = SSL_write(ssl_, data, len); + if(rv_ <= 0) { + if(wouldblock(ssl_, rv_)) { + return TLS_ERR_WOULDBLOCK; + } else { + return TLS_ERR_ERROR; + } + } else { + ssize_t ret = rv_; + rv_ = 1; + return ret; + } +} + +ssize_t TLSSession::readData(void* data, size_t len) +{ + ERR_clear_error(); + rv_ = SSL_read(ssl_, data, len); + if(rv_ <= 0) { + if(wouldblock(ssl_, rv_)) { + return TLS_ERR_WOULDBLOCK; + } else { + return TLS_ERR_ERROR; + } + } else { + ssize_t ret = rv_; + rv_ = 1; + return ret; + } +} + +int TLSSession::handshake() +{ + ERR_clear_error(); + if(tlsContext_->getSide() == TLS_CLIENT) { + rv_ = SSL_connect(ssl_); + } else { + rv_ = SSL_accept(ssl_); + } + if(rv_ <= 0) { + int sslError = SSL_get_error(ssl_, rv_); + switch(sslError) { + case SSL_ERROR_NONE: + case SSL_ERROR_WANT_X509_LOOKUP: + case SSL_ERROR_ZERO_RETURN: + // TODO Now assume we are doing non-blocking. Then above 2 + // errors are OK. + break; + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + return TLS_ERR_WOULDBLOCK; + default: + return TLS_ERR_ERROR; + } + } + return TLS_ERR_OK; +} + +int TLSSession::tlsConnect(const std::string& hostname, + std::string& handshakeErr) +{ + handshakeErr = ""; + int ret; + ret = handshake(); + if(ret != TLS_ERR_OK) { + return ret; + } + if(tlsContext_->getSide() == TLS_CLIENT && + tlsContext_->peerVerificationEnabled()) { + // verify peer + X509* peerCert = SSL_get_peer_certificate(ssl_); + if(!peerCert) { + handshakeErr = "certificate not found"; + return TLS_ERR_ERROR; + } + auto_delete certDeleter(peerCert, X509_free); + long verifyResult = SSL_get_verify_result(ssl_); + if(verifyResult != X509_V_OK) { + handshakeErr = X509_verify_cert_error_string(verifyResult); + return TLS_ERR_ERROR; + } + std::string commonName; + std::vector dnsNames; + std::vector ipAddrs; + GENERAL_NAMES* altNames; + altNames = reinterpret_cast + (X509_get_ext_d2i(peerCert, NID_subject_alt_name, NULL, NULL)); + if(altNames) { + auto_delete altNamesDeleter + (altNames, GENERAL_NAMES_free); + size_t n = sk_GENERAL_NAME_num(altNames); + for(size_t i = 0; i < n; ++i) { + const GENERAL_NAME* altName = sk_GENERAL_NAME_value(altNames, i); + if(altName->type == GEN_DNS) { + const char* name = + reinterpret_cast(ASN1_STRING_data(altName->d.ia5)); + if(!name) { + continue; + } + size_t len = ASN1_STRING_length(altName->d.ia5); + dnsNames.push_back(std::string(name, len)); + } else if(altName->type == GEN_IPADD) { + const unsigned char* ipAddr = altName->d.iPAddress->data; + if(!ipAddr) { + continue; + } + size_t len = altName->d.iPAddress->length; + ipAddrs.push_back(std::string(reinterpret_cast(ipAddr), + len)); + } + } + } + X509_NAME* subjectName = X509_get_subject_name(peerCert); + if(!subjectName) { + handshakeErr = "could not get X509 name object from the certificate."; + return TLS_ERR_ERROR; + } + int lastpos = -1; + while(1) { + lastpos = X509_NAME_get_index_by_NID(subjectName, NID_commonName, + lastpos); + if(lastpos == -1) { + break; + } + X509_NAME_ENTRY* entry = X509_NAME_get_entry(subjectName, lastpos); + unsigned char* out; + int outlen = ASN1_STRING_to_UTF8(&out, + X509_NAME_ENTRY_get_data(entry)); + if(outlen < 0) { + continue; + } + commonName.assign(&out[0], &out[outlen]); + OPENSSL_free(out); + break; + } + if(!net::verifyHostname(hostname, dnsNames, ipAddrs, commonName)) { + handshakeErr = "hostname does not match"; + return TLS_ERR_ERROR; + } + } + return TLS_ERR_OK; +} + +int TLSSession::tlsAccept() +{ + return handshake(); +} + +std::string TLSSession::getLastErrorString() +{ + if(rv_ <= 0) { + int sslError = SSL_get_error(ssl_, rv_); + switch(sslError) { + case SSL_ERROR_NONE: + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_X509_LOOKUP: + case SSL_ERROR_ZERO_RETURN: + return ""; + case SSL_ERROR_SYSCALL: { + int err = ERR_get_error(); + if(err == 0) { + if(rv_ == 0) { + return "EOF was received"; + } else if(rv_ == -1) { + return "SSL I/O error"; + } else { + return "unknown syscall error"; + } + } else { + return ERR_error_string(err, 0); + } + } + case SSL_ERROR_SSL: + return "protocol error"; + default: + return "unknown error"; + } + } else { + return ""; + } +} + +} // namespace aria2 diff --git a/src/LibsslTLSSession.h b/src/LibsslTLSSession.h new file mode 100644 index 00000000..6a162031 --- /dev/null +++ b/src/LibsslTLSSession.h @@ -0,0 +1,74 @@ +/* */ +#ifndef LIBSSL_TLS_SESSION_H +#define LIBSSL_TLS_SESSION_H + +#include "common.h" + +#include + +#include + +#include "TLSSessionConst.h" +#include "a2netcompat.h" + +namespace aria2 { + +class TLSContext; + +class TLSSession { +public: + TLSSession(TLSContext* tlsContext); + ~TLSSession(); + int init(sock_t sockfd); + int setSNIHostname(const std::string& hostname); + int closeConnection(); + int checkDirection(); + ssize_t writeData(const void* data, size_t len); + ssize_t readData(void* data, size_t len); + int tlsConnect(const std::string& hostname, std::string& handshakeErr); + int tlsAccept(); + std::string getLastErrorString(); +private: + int handshake(); + SSL* ssl_; + TLSContext* tlsContext_; + // Last error code from openSSL library functions + int rv_; +}; + +} // namespace aria2 + +#endif // LIBSSL_TLS_SESSION_H diff --git a/src/Makefile.am b/src/Makefile.am index 87c61cc7..22be425b 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -299,11 +299,14 @@ SRCS += EpollEventPoll.cc EpollEventPoll.h endif # HAVE_EPOLL if ENABLE_SSL -SRCS += TLSContext.h +SRCS += TLSContext.h\ + TLSSession.h\ + TLSSessionConst.h endif # ENABLE_SSL if HAVE_LIBGNUTLS -SRCS += LibgnutlsTLSContext.cc LibgnutlsTLSContext.h +SRCS += LibgnutlsTLSContext.cc LibgnutlsTLSContext.h\ + LibgnutlsTLSSession.cc LibgnutlsTLSSession.h endif # HAVE_LIBGNUTLS if HAVE_LIBGCRYPT @@ -324,6 +327,7 @@ endif # HAVE_LIBGMP if HAVE_OPENSSL SRCS += LibsslTLSContext.cc LibsslTLSContext.h\ + LibsslTLSSession.cc LibsslTLSSession.h\ LibsslMessageDigestImpl.cc LibsslMessageDigestImpl.h\ LibsslARC4Encryptor.cc LibsslARC4Encryptor.h\ LibsslDHKeyExchange.cc LibsslDHKeyExchange.h diff --git a/src/SocketCore.cc b/src/SocketCore.cc index 88fdf581..165d3bdd 100644 --- a/src/SocketCore.cc +++ b/src/SocketCore.cc @@ -46,15 +46,6 @@ #include #include -#ifdef HAVE_OPENSSL -# include -# include -#endif // HAVE_OPENSSL - -#ifdef HAVE_LIBGNUTLS -# include -#endif // HAVE_LIBGNUTLS - #include "message.h" #include "DlRetryEx.h" #include "DlAbortEx.h" @@ -66,6 +57,7 @@ #include "A2STR.h" #ifdef ENABLE_SSL # include "TLSContext.h" +# include "TLSSession.h" #endif // ENABLE_SSL namespace aria2 { @@ -179,14 +171,6 @@ void SocketCore::init() wantRead_ = false; wantWrite_ = false; - -#ifdef HAVE_OPENSSL - // for SSL - ssl = NULL; -#endif // HAVE_OPENSSL -#ifdef HAVE_LIBGNUTLS - sslSession_ = 0; -#endif //HAVE_LIBGNUTLS } SocketCore::~SocketCore() { @@ -586,33 +570,15 @@ void SocketCore::setBlockingMode() void SocketCore::closeConnection() { -#ifdef HAVE_OPENSSL - // for SSL - if(secure_) { - SSL_shutdown(ssl); + if(tlsSession_) { + tlsSession_->closeConnection(); + tlsSession_.reset(); } -#endif // HAVE_OPENSSL -#ifdef HAVE_LIBGNUTLS - if(secure_) { - gnutls_bye(sslSession_, GNUTLS_SHUT_WR); - } -#endif // HAVE_LIBGNUTLS if(sockfd_ != (sock_t) -1) { shutdown(sockfd_, SHUT_WR); CLOSE(sockfd_); sockfd_ = -1; } -#ifdef HAVE_OPENSSL - // for SSL - if(secure_) { - SSL_free(ssl); - } -#endif // HAVE_OPENSSL -#ifdef HAVE_LIBGNUTLS - if(secure_) { - gnutls_deinit(sslSession_); - } -#endif // HAVE_LIBGNUTLS } #ifndef __MINGW32__ @@ -716,34 +682,6 @@ bool SocketCore::isReadable(time_t timeout) #endif // !HAVE_POLL } -#ifdef HAVE_OPENSSL -int SocketCore::sslHandleEAGAIN(int ret) -{ - int error = SSL_get_error(ssl, ret); - if(error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) { - ret = 0; - if(error == SSL_ERROR_WANT_READ) { - wantRead_ = true; - } else { - wantWrite_ = true; - } - } - return ret; -} -#endif // HAVE_OPENSSL - -#ifdef HAVE_LIBGNUTLS -void SocketCore::gnutlsRecordCheckDirection() -{ - int direction = gnutls_record_get_direction(sslSession_); - if(direction == 0) { - wantRead_ = true; - } else { // if(direction == 1) { - wantWrite_ = true; - } -} -#endif // HAVE_LIBGNUTLS - ssize_t SocketCore::writeVector(a2iovec *iov, size_t iovcnt) { ssize_t ret = 0; @@ -805,29 +743,21 @@ ssize_t SocketCore::writeData(const void* data, size_t len) } } } else { -#ifdef HAVE_OPENSSL - ERR_clear_error(); - ret = SSL_write(ssl, data, len); + ret = tlsSession_->writeData(data, len); if(ret < 0) { - ret = sslHandleEAGAIN(ret); + if(ret == TLS_ERR_WOULDBLOCK) { + if(tlsSession_->checkDirection() == TLS_WANT_READ) { + wantRead_ = true; + } else { + wantWrite_ = true; + } + ret = 0; + } else { + throw DL_RETRY_EX(fmt(EX_SOCKET_SEND, + tlsSession_->getLastErrorString().c_str())); + } } - if(ret < 0) { - throw DL_RETRY_EX - (fmt(EX_SOCKET_SEND, ERR_error_string(ERR_get_error(), 0))); - } -#endif // HAVE_OPENSSL -#ifdef HAVE_LIBGNUTLS - while((ret = gnutls_record_send(sslSession_, data, len)) == - GNUTLS_E_INTERRUPTED); - if(ret == GNUTLS_E_AGAIN) { - gnutlsRecordCheckDirection(); - ret = 0; - } else if(ret < 0) { - throw DL_RETRY_EX(fmt(EX_SOCKET_SEND, gnutls_strerror(ret))); - } -#endif // HAVE_LIBGNUTLS } - return ret; } @@ -851,31 +781,21 @@ void SocketCore::readData(void* data, size_t& len) } } } else { -#ifdef HAVE_OPENSSL - // for SSL - // TODO handling len == 0 case required - ERR_clear_error(); - ret = SSL_read(ssl, data, len); + ret = tlsSession_->readData(data, len); if(ret < 0) { - ret = sslHandleEAGAIN(ret); + if(ret == TLS_ERR_WOULDBLOCK) { + if(tlsSession_->checkDirection() == TLS_WANT_READ) { + wantRead_ = true; + } else { + wantWrite_ = true; + } + ret = 0; + } else { + throw DL_RETRY_EX(fmt(EX_SOCKET_SEND, + tlsSession_->getLastErrorString().c_str())); + } } - if(ret < 0) { - throw DL_RETRY_EX - (fmt(EX_SOCKET_RECV, ERR_error_string(ERR_get_error(), 0))); - } -#endif // HAVE_OPENSSL -#ifdef HAVE_LIBGNUTLS - while((ret = gnutls_record_recv(sslSession_, data, len)) == - GNUTLS_E_INTERRUPTED); - if(ret == GNUTLS_E_AGAIN) { - gnutlsRecordCheckDirection(); - ret = 0; - } else if(ret < 0) { - throw DL_RETRY_EX(fmt(EX_SOCKET_RECV, gnutls_strerror(ret))); - } -#endif // HAVE_LIBGNUTLS } - len = ret; } @@ -893,324 +813,57 @@ bool SocketCore::tlsConnect(const std::string& hostname) bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname) { + int rv = 0; + std::string handshakeError; wantRead_ = false; wantWrite_ = false; -#ifdef HAVE_OPENSSL switch(secure_) { case A2_TLS_NONE: - ssl = SSL_new(tlsctx->getSSLCtx()); - if(!ssl) { - throw DL_ABORT_EX - (fmt(EX_SSL_INIT_FAILURE, ERR_error_string(ERR_get_error(), 0))); - } - if(SSL_set_fd(ssl, sockfd_) == 0) { - throw DL_ABORT_EX - (fmt(EX_SSL_INIT_FAILURE, ERR_error_string(ERR_get_error(), 0))); - } - // Fall through -#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME - if(tlsctx->getSide() == TLS_CLIENT && !util::isNumericHost(hostname)) { - // TLS extensions: SNI. There is not documentation about the - // return code for this function (actually this is macro - // wrapping SSL_ctrl at the time of this writing). - SSL_set_tlsext_host_name(ssl, hostname.c_str()); - } -#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME - secure_ = A2_TLS_HANDSHAKING; - // Fall through - case A2_TLS_HANDSHAKING: { - ERR_clear_error(); - int e; - if(tlsctx->getSide() == TLS_CLIENT) { - e = SSL_connect(ssl); - } else { - e = SSL_accept(ssl); - } - - if (e <= 0) { - int ssl_error = SSL_get_error(ssl, e); - switch(ssl_error) { - case SSL_ERROR_NONE: - break; - case SSL_ERROR_WANT_READ: - wantRead_ = true; - return false; - case SSL_ERROR_WANT_WRITE: - wantWrite_ = true; - return false; - case SSL_ERROR_WANT_X509_LOOKUP: - case SSL_ERROR_ZERO_RETURN: - if (blocking_) { - throw DL_ABORT_EX(fmt(EX_SSL_CONNECT_ERROR, ssl_error)); - } - break; - - case SSL_ERROR_SYSCALL: { - int sslErr = ERR_get_error(); - if(sslErr == 0) { - if(e == 0) { - throw DL_ABORT_EX("Got EOF in SSL handshake"); - } else if(e == -1) { - throw DL_ABORT_EX(fmt("SSL I/O error: %s", strerror(errno))); - } else { - throw DL_ABORT_EX(EX_SSL_IO_ERROR); - } - } else { - throw DL_ABORT_EX(fmt("SSL I/O error: %s", - ERR_error_string(sslErr, 0))); - } - } - case SSL_ERROR_SSL: - throw DL_ABORT_EX(EX_SSL_PROTOCOL_ERROR); - - default: - throw DL_ABORT_EX(fmt(EX_SSL_UNKNOWN_ERROR, ssl_error)); - } + tlsSession_.reset(new TLSSession(tlsctx)); + rv = tlsSession_->init(sockfd_); + if(rv != TLS_ERR_OK) { + std::string error = tlsSession_->getLastErrorString(); + tlsSession_.reset(); + throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, error.c_str())); } + // Check hostname is not numeric and it includes ".". Setting + // "localhost" will produce TLS alert with GNUTLS. if(tlsctx->getSide() == TLS_CLIENT && - tlsctx->peerVerificationEnabled()) { - // verify peer - X509* peerCert = SSL_get_peer_certificate(ssl); - if(!peerCert) { - throw DL_ABORT_EX(MSG_NO_CERT_FOUND); - } - auto_delete certDeleter(peerCert, X509_free); - - long verifyResult = SSL_get_verify_result(ssl); - if(verifyResult != X509_V_OK) { - throw DL_ABORT_EX - (fmt(MSG_CERT_VERIFICATION_FAILED, - X509_verify_cert_error_string(verifyResult))); - } - std::string commonName; - std::vector dnsNames; - std::vector ipAddrs; - GENERAL_NAMES* altNames; - altNames = reinterpret_cast - (X509_get_ext_d2i(peerCert, NID_subject_alt_name, NULL, NULL)); - if(altNames) { - auto_delete altNamesDeleter - (altNames, GENERAL_NAMES_free); - size_t n = sk_GENERAL_NAME_num(altNames); - for(size_t i = 0; i < n; ++i) { - const GENERAL_NAME* altName = sk_GENERAL_NAME_value(altNames, i); - if(altName->type == GEN_DNS) { - const char* name = - reinterpret_cast(ASN1_STRING_data(altName->d.ia5)); - if(!name) { - continue; - } - size_t len = ASN1_STRING_length(altName->d.ia5); - dnsNames.push_back(std::string(name, len)); - } else if(altName->type == GEN_IPADD) { - const unsigned char* ipAddr = altName->d.iPAddress->data; - if(!ipAddr) { - continue; - } - size_t len = altName->d.iPAddress->length; - ipAddrs.push_back(std::string(reinterpret_cast(ipAddr), - len)); - } - } - } - X509_NAME* subjectName = X509_get_subject_name(peerCert); - if(!subjectName) { - throw DL_ABORT_EX - ("Could not get X509 name object from the certificate."); - } - int lastpos = -1; - while(1) { - lastpos = X509_NAME_get_index_by_NID(subjectName, NID_commonName, - lastpos); - if(lastpos == -1) { - break; - } - X509_NAME_ENTRY* entry = X509_NAME_get_entry(subjectName, lastpos); - unsigned char* out; - int outlen = ASN1_STRING_to_UTF8(&out, - X509_NAME_ENTRY_get_data(entry)); - if(outlen < 0) { - continue; - } - commonName.assign(&out[0], &out[outlen]); - OPENSSL_free(out); - break; - } - if(!net::verifyHostname(hostname, dnsNames, ipAddrs, commonName)) { - throw DL_ABORT_EX(MSG_HOSTNAME_NOT_MATCH); - } - } - secure_ = A2_TLS_CONNECTED; - break; - } - default: - break; - } -#endif // HAVE_OPENSSL -#ifdef HAVE_LIBGNUTLS - switch(secure_) { - case A2_TLS_NONE: - int r; - gnutls_init(&sslSession_, - tlsctx->getSide() == TLS_CLIENT ? - GNUTLS_CLIENT : GNUTLS_SERVER); - // It seems err is not error message, but the argument string - // which causes syntax error. - const char* err; - // For client side, disables TLS1.1 here because there are servers - // that don't understand TLS1.1. TODO Is this still necessary? - r = gnutls_priority_set_direct(sslSession_, - tlsctx->getSide() == TLS_CLIENT ? - "NORMAL:-VERS-TLS1.1" : - "NORMAL", - &err); - if(r != GNUTLS_E_SUCCESS) { - throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, gnutls_strerror(r))); - } - // put the x509 credentials to the current session - gnutls_credentials_set(sslSession_, GNUTLS_CRD_CERTIFICATE, - tlsctx->getCertCred()); - gnutls_transport_set_ptr(sslSession_, (gnutls_transport_ptr_t)sockfd_); - if(tlsctx->getSide() == TLS_CLIENT) { - // Check hostname is not numeric and it includes ".". Setting - // "localhost" will produce TLS alert. - if(!util::isNumericHost(hostname) && - hostname.find(".") != std::string::npos) { - // TLS extensions: SNI - int ret = gnutls_server_name_set(sslSession_, GNUTLS_NAME_DNS, - hostname.c_str(), hostname.size()); - if(ret < 0) { - A2_LOG_WARN(fmt - ("Setting hostname in SNI extension failed. Cause: %s", - gnutls_strerror(ret))); - } + !util::isNumericHost(hostname) && + hostname.find(".") != std::string::npos) { + rv = tlsSession_->setSNIHostname(hostname); + if(rv != TLS_ERR_OK) { + throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, + tlsSession_->getLastErrorString().c_str())); } } secure_ = A2_TLS_HANDSHAKING; // Fall through - case A2_TLS_HANDSHAKING: { - int ret = gnutls_handshake(sslSession_); - if(ret == GNUTLS_E_AGAIN) { - gnutlsRecordCheckDirection(); + case A2_TLS_HANDSHAKING: + if(tlsctx->getSide() == TLS_CLIENT) { + rv = tlsSession_->tlsConnect(hostname, handshakeError); + } else { + rv = tlsSession_->tlsAccept(); + } + if(rv == TLS_ERR_OK) { + secure_ = A2_TLS_CONNECTED; + } else if(rv == TLS_ERR_WOULDBLOCK) { + if(tlsSession_->checkDirection() == TLS_WANT_READ) { + wantRead_ = true; + } else { + wantWrite_ = true; + } return false; - } else if(ret < 0) { - throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, gnutls_strerror(ret))); + } else { + throw DL_ABORT_EX(fmt("SSL/TLS handshake failure: %s", + handshakeError.empty() ? + tlsSession_->getLastErrorString().c_str() : + handshakeError.c_str())); } - - if(tlsctx->getSide() == TLS_CLIENT && tlsctx->peerVerificationEnabled()) { - // verify peer - unsigned int status; - ret = gnutls_certificate_verify_peers2(sslSession_, &status); - if(ret < 0) { - throw DL_ABORT_EX - (fmt("gnutls_certificate_verify_peer2() failed. Cause: %s", - gnutls_strerror(ret))); - } - if(status) { - std::string errors; - if(status & GNUTLS_CERT_INVALID) { - errors += " `not signed by known authorities or invalid'"; - } - if(status & GNUTLS_CERT_REVOKED) { - errors += " `revoked by its CA'"; - } - if(status & GNUTLS_CERT_SIGNER_NOT_FOUND) { - errors += " `issuer is not known'"; - } - // TODO should check GNUTLS_CERT_SIGNER_NOT_CA ? - if(status & GNUTLS_CERT_INSECURE_ALGORITHM) { - errors += " `insecure algorithm'"; - } - if(status & GNUTLS_CERT_NOT_ACTIVATED) { - errors += " `not activated yet'"; - } - if(status & GNUTLS_CERT_EXPIRED) { - errors += " `expired'"; - } - // TODO Add GNUTLS_CERT_SIGNATURE_FAILURE here - if(!errors.empty()) { - throw DL_ABORT_EX(fmt(MSG_CERT_VERIFICATION_FAILED, errors.c_str())); - } - } - // certificate type: only X509 is allowed. - if(gnutls_certificate_type_get(sslSession_) != GNUTLS_CRT_X509) { - throw DL_ABORT_EX("Certificate type is not X509."); - } - - unsigned int peerCertsLength; - const gnutls_datum_t* peerCerts = gnutls_certificate_get_peers - (sslSession_, &peerCertsLength); - if(!peerCerts || peerCertsLength == 0 ) { - throw DL_ABORT_EX(MSG_NO_CERT_FOUND); - } - Time now; - for(unsigned int i = 0; i < peerCertsLength; ++i) { - gnutls_x509_crt_t cert; - ret = gnutls_x509_crt_init(&cert); - if(ret < 0) { - throw DL_ABORT_EX - (fmt("gnutls_x509_crt_init() failed. Cause: %s", - gnutls_strerror(ret))); - } - auto_delete certDeleter - (cert, gnutls_x509_crt_deinit); - ret = gnutls_x509_crt_import(cert, &peerCerts[i], GNUTLS_X509_FMT_DER); - if(ret < 0) { - throw DL_ABORT_EX - (fmt("gnutls_x509_crt_import() failed. Cause: %s", - gnutls_strerror(ret))); - } - if(i == 0) { - std::string commonName; - std::vector dnsNames; - std::vector ipAddrs; - int ret = 0; - char altName[256]; - size_t altNameLen; - for(int j = 0; !(ret < 0); ++j) { - altNameLen = sizeof(altName); - ret = gnutls_x509_crt_get_subject_alt_name(cert, j, altName, - &altNameLen, 0); - if(ret == GNUTLS_SAN_DNSNAME) { - dnsNames.push_back(std::string(altName, altNameLen)); - } else if(ret == GNUTLS_SAN_IPADDRESS) { - ipAddrs.push_back(std::string(altName, altNameLen)); - } - } - altNameLen = sizeof(altName); - ret = gnutls_x509_crt_get_dn_by_oid(cert, - GNUTLS_OID_X520_COMMON_NAME, 0, 0, - altName, &altNameLen); - if(ret == 0) { - commonName.assign(altName, altNameLen); - } - if(!net::verifyHostname(hostname, dnsNames, ipAddrs, commonName)) { - throw DL_ABORT_EX(MSG_HOSTNAME_NOT_MATCH); - } - } - time_t activationTime = gnutls_x509_crt_get_activation_time(cert); - if(activationTime == -1) { - throw DL_ABORT_EX("Could not get activation time from certificate."); - } - if(now.getTime() < activationTime) { - throw DL_ABORT_EX("Certificate is not activated yet."); - } - time_t expirationTime = gnutls_x509_crt_get_expiration_time(cert); - if(expirationTime == -1) { - throw DL_ABORT_EX("Could not get expiration time from certificate."); - } - if(expirationTime < now.getTime()) { - throw DL_ABORT_EX("Certificate has expired."); - } - } - } - secure_ = A2_TLS_CONNECTED; break; - } default: break; } -#endif // HAVE_LIBGNUTLS return true; } diff --git a/src/SocketCore.h b/src/SocketCore.h index acabe77c..5d27fe2e 100644 --- a/src/SocketCore.h +++ b/src/SocketCore.h @@ -43,16 +43,6 @@ #include #include "a2netcompat.h" - -#ifdef HAVE_OPENSSL -// for SSL -# include -# include -#endif // HAVE_OPENSSL -#ifdef HAVE_LIBGNUTLS -# include -#endif // HAVE_LIBGNUTLS - #include "SharedHandle.h" #include "a2io.h" #include "a2netcompat.h" @@ -62,6 +52,7 @@ namespace aria2 { #ifdef ENABLE_SSL class TLSContext; +class TLSSession; #endif // ENABLE_SSL class SocketCore { @@ -89,27 +80,9 @@ private: static SharedHandle clTlsContext_; // TLS context for server side static SharedHandle svTlsContext_; -#endif // ENABLE_SSL -#ifdef HAVE_OPENSSL - // for SSL - SSL* ssl; + SharedHandle tlsSession_; - int sslHandleEAGAIN(int ret); -#endif // HAVE_OPENSSL -#ifdef HAVE_LIBGNUTLS - gnutls_session_t sslSession_; - - void gnutlsRecordCheckDirection(); -#endif // HAVE_LIBGNUTLS - - void init(); - - void bind(const struct sockaddr* addr, socklen_t addrlen); - - void setSockOpt(int level, int optname, void* optval, socklen_t optlen); - -#ifdef ENABLE_SSL /** * Makes this socket secure. The connection must be established * before calling this method. @@ -119,6 +92,12 @@ private: bool tlsHandshake(TLSContext* tlsctx, const std::string& hostname); #endif // ENABLE_SSL + void init(); + + void bind(const struct sockaddr* addr, socklen_t addrlen); + + void setSockOpt(int level, int optname, void* optval, socklen_t optlen); + SocketCore(sock_t sockfd, int sockType); public: SocketCore(int sockType = SOCK_STREAM); diff --git a/src/TLSSession.h b/src/TLSSession.h new file mode 100644 index 00000000..77c7c833 --- /dev/null +++ b/src/TLSSession.h @@ -0,0 +1,104 @@ +/* */ +#ifndef TLS_SESSION_H +#define TLS_SESSION_H + +#include "common.h" + +// To create another SSL/TLS backend, implement TLSSession class below. +// +// class TLSSession { +// public: +// TLSSession(TLSContext* tlsContext); +// +// // MUST deallocate all resources +// ~TLSSession(); +// +// // Initializes SSL/TLS session. The |sockfd| is the underlying +// // tranport socket. This function returns TLS_ERR_OK if it +// // succeeds, or TLS_ERR_ERROR. +// int init(sock_t sockfd); +// +// // Sets |hostname| for TLS SNI extension. This is only meaningful for +// // client side session. This function returns TLS_ERR_OK if it +// // succeeds, or TLS_ERR_ERROR. +// int setSNIHostname(const std::string& hostname); +// +// // Closes the SSL/TLS session. Don't close underlying transport +// // socket. This function returns TLS_ERR_OK if it succeeds, or +// // TLS_ERR_ERROR. +// int closeConnection(); +// +// // Returns TLS_WANT_READ if SSL/TLS session needs more data from +// // remote endpoint to proceed, or TLS_WANT_WRITE if SSL/TLS session +// // needs to write more data to proceed. If SSL/TLS session needs +// // neither read nor write data at the moment, return value is +// // undefined. +// int checkDirection(); +// +// // Sends |data| with length |len|. This function returns the number +// // of bytes sent if it succeeds, or TLS_ERR_WOULDBLOCK if the +// // underlying tranport blocks, or TLS_ERR_ERROR. +// ssize_t writeData(const void* data, size_t len); +// +// // Receives data into |data| with length |len|. This function returns +// // the number of bytes received if it succeeds, or TLS_ERR_WOULDBLOCK +// // if the underlying tranport blocks, or TLS_ERR_ERROR. +// ssize_t readData(void* data, size_t len); +// +// // Performs client side handshake. The |hostname| is the hostname of +// // the remote endpoint and is used to verify its certificate. This +// // function returns TLS_ERR_OK if it succeeds, or TLS_ERR_WOULDBLOCK +// // if the underlying transport blocks, or TLS_ERR_ERROR. +// // When returning TLS_ERR_ERROR, provide certificate validation error +// // in |handshakeErr|. +// int tlsConnect(const std::string& hostname, std::string& handshakeErr); +// +// // Performs server side handshake. This function returns TLS_ERR_OK +// // if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport +// // blocks, or TLS_ERR_ERROR. +// int tlsAccept(); +// +// // Returns last error string +// std::string getLastErrorString(); +// }; + +#ifdef HAVE_OPENSSL +# include "LibsslTLSSession.h" +#elif defined HAVE_LIBGNUTLS +# include "LibgnutlsTLSSession.h" +#endif + +#endif // TLS_SESSION_H diff --git a/src/TLSSessionConst.h b/src/TLSSessionConst.h new file mode 100644 index 00000000..f3b8422d --- /dev/null +++ b/src/TLSSessionConst.h @@ -0,0 +1,55 @@ +/* */ +#ifndef TLS_SESSION_CONST_H +#define TLS_SESSION_CONST_H + +#include "common.h" + +namespace aria2 { + +enum TLSDirection { + TLS_WANT_READ = 1, + TLS_WANT_WRITE +}; + +enum TLSErrorCode { + TLS_ERR_OK = 0, + TLS_ERR_ERROR = -1, + TLS_ERR_WOULDBLOCK = -2 +}; + +} // namespace aria2 + +#endif // TLS_SESSION_CONST_H