/* */ #include "LibsslTLSSession.h" #include #include #include #include "TLSContext.h" #include "util.h" #include "SocketCore.h" namespace aria2 { TLSSession::TLSSession(TLSContext* tlsContext) : ssl_(0), tlsContext_(tlsContext), rv_(1) {} TLSSession::~TLSSession() { if(ssl_) { SSL_shutdown(ssl_); } } int TLSSession::init(sock_t sockfd) { ERR_clear_error(); ssl_ = SSL_new(tlsContext_->getSSLCtx()); if(!ssl_) { return TLS_ERR_ERROR; } rv_ = SSL_set_fd(ssl_, sockfd); if(rv_ == 0) { return TLS_ERR_ERROR; } return TLS_ERR_OK; } int TLSSession::setSNIHostname(const std::string& hostname) { #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME ERR_clear_error(); // TLS extensions: SNI. There is not documentation about the // return code for this function (actually this is macro // wrapping SSL_ctrl at the time of this writing). SSL_set_tlsext_host_name(ssl_, hostname.c_str()); #endif // SSL_CTRL_SET_TLSEXT_HOSTNAME return TLS_ERR_OK; } int TLSSession::closeConnection() { ERR_clear_error(); SSL_shutdown(ssl_); // TODO handle return value return TLS_ERR_OK; } int TLSSession::checkDirection() { int error = SSL_get_error(ssl_, rv_); if(error == SSL_ERROR_WANT_WRITE) { return TLS_WANT_WRITE; } else { // TODO We ignore error other than SSL_ERR_WANT_READ here for now return TLS_WANT_READ; } } namespace { bool wouldblock(SSL* ssl, int rv) { int error = SSL_get_error(ssl, rv); return error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE; } } // namespace ssize_t TLSSession::writeData(const void* data, size_t len) { ERR_clear_error(); rv_ = SSL_write(ssl_, data, len); if(rv_ <= 0) { if(wouldblock(ssl_, rv_)) { return TLS_ERR_WOULDBLOCK; } else { return TLS_ERR_ERROR; } } else { ssize_t ret = rv_; rv_ = 1; return ret; } } ssize_t TLSSession::readData(void* data, size_t len) { ERR_clear_error(); rv_ = SSL_read(ssl_, data, len); if(rv_ <= 0) { if(wouldblock(ssl_, rv_)) { return TLS_ERR_WOULDBLOCK; } else { return TLS_ERR_ERROR; } } else { ssize_t ret = rv_; rv_ = 1; return ret; } } int TLSSession::handshake() { ERR_clear_error(); if(tlsContext_->getSide() == TLS_CLIENT) { rv_ = SSL_connect(ssl_); } else { rv_ = SSL_accept(ssl_); } if(rv_ <= 0) { int sslError = SSL_get_error(ssl_, rv_); switch(sslError) { case SSL_ERROR_NONE: case SSL_ERROR_WANT_X509_LOOKUP: case SSL_ERROR_ZERO_RETURN: // TODO Now assume we are doing non-blocking. Then above 2 // errors are OK. break; case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_WRITE: return TLS_ERR_WOULDBLOCK; default: return TLS_ERR_ERROR; } } return TLS_ERR_OK; } int TLSSession::tlsConnect(const std::string& hostname, std::string& handshakeErr) { handshakeErr = ""; int ret; ret = handshake(); if(ret != TLS_ERR_OK) { return ret; } if(tlsContext_->getSide() == TLS_CLIENT && tlsContext_->peerVerificationEnabled()) { // verify peer X509* peerCert = SSL_get_peer_certificate(ssl_); if(!peerCert) { handshakeErr = "certificate not found"; return TLS_ERR_ERROR; } auto_delete certDeleter(peerCert, X509_free); long verifyResult = SSL_get_verify_result(ssl_); if(verifyResult != X509_V_OK) { handshakeErr = X509_verify_cert_error_string(verifyResult); return TLS_ERR_ERROR; } std::string commonName; std::vector dnsNames; std::vector ipAddrs; GENERAL_NAMES* altNames; altNames = reinterpret_cast (X509_get_ext_d2i(peerCert, NID_subject_alt_name, NULL, NULL)); if(altNames) { auto_delete altNamesDeleter (altNames, GENERAL_NAMES_free); size_t n = sk_GENERAL_NAME_num(altNames); for(size_t i = 0; i < n; ++i) { const GENERAL_NAME* altName = sk_GENERAL_NAME_value(altNames, i); if(altName->type == GEN_DNS) { const char* name = reinterpret_cast(ASN1_STRING_data(altName->d.ia5)); if(!name) { continue; } size_t len = ASN1_STRING_length(altName->d.ia5); dnsNames.push_back(std::string(name, len)); } else if(altName->type == GEN_IPADD) { const unsigned char* ipAddr = altName->d.iPAddress->data; if(!ipAddr) { continue; } size_t len = altName->d.iPAddress->length; ipAddrs.push_back(std::string(reinterpret_cast(ipAddr), len)); } } } X509_NAME* subjectName = X509_get_subject_name(peerCert); if(!subjectName) { handshakeErr = "could not get X509 name object from the certificate."; return TLS_ERR_ERROR; } int lastpos = -1; while(1) { lastpos = X509_NAME_get_index_by_NID(subjectName, NID_commonName, lastpos); if(lastpos == -1) { break; } X509_NAME_ENTRY* entry = X509_NAME_get_entry(subjectName, lastpos); unsigned char* out; int outlen = ASN1_STRING_to_UTF8(&out, X509_NAME_ENTRY_get_data(entry)); if(outlen < 0) { continue; } commonName.assign(&out[0], &out[outlen]); OPENSSL_free(out); break; } if(!net::verifyHostname(hostname, dnsNames, ipAddrs, commonName)) { handshakeErr = "hostname does not match"; return TLS_ERR_ERROR; } } return TLS_ERR_OK; } int TLSSession::tlsAccept() { return handshake(); } std::string TLSSession::getLastErrorString() { if(rv_ <= 0) { int sslError = SSL_get_error(ssl_, rv_); switch(sslError) { case SSL_ERROR_NONE: case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_X509_LOOKUP: case SSL_ERROR_ZERO_RETURN: return ""; case SSL_ERROR_SYSCALL: { int err = ERR_get_error(); if(err == 0) { if(rv_ == 0) { return "EOF was received"; } else if(rv_ == -1) { return "SSL I/O error"; } else { return "unknown syscall error"; } } else { return ERR_error_string(err, 0); } } case SSL_ERROR_SSL: return "protocol error"; default: return "unknown error"; } } else { return ""; } } } // namespace aria2