diff --git a/src/LibgnutlsTLSContext.h b/src/LibgnutlsTLSContext.h index 281f2c5e..ef2b005c 100644 --- a/src/LibgnutlsTLSContext.h +++ b/src/LibgnutlsTLSContext.h @@ -53,7 +53,7 @@ public: // private key `keyfile' must be decrypted. virtual bool addCredentialFile(const std::string& certfile, const std::string& keyfile) CXX11_OVERRIDE; - virtual bool addP12CredentialFile(const std::string& p12file); + bool addP12CredentialFile(const std::string& p12file); virtual bool addSystemTrustedCACerts() CXX11_OVERRIDE; diff --git a/src/LibsslTLSContext.cc b/src/LibsslTLSContext.cc index e358b9cc..f9c293db 100644 --- a/src/LibsslTLSContext.cc +++ b/src/LibsslTLSContext.cc @@ -52,11 +52,31 @@ namespace { if (b) BIO_free(b); } }; + using bio_t = std::unique_ptr; struct p12_deleter { void operator()(PKCS12 *p) { if (p) PKCS12_free(p); } }; + using p12_t = std::unique_ptr; + struct pkey_deleter { + void operator()(EVP_PKEY *x) { + if (x) EVP_PKEY_free(x); + } + }; + using pkey_t = std::unique_ptr; + struct x509_deleter { + void operator()(X509 *x) { + if (x) X509_free(x); + } + }; + using x509_t = std::unique_ptr; + struct x509_sk_deleter { + void operator()(STACK_OF(X509) *x) { + if (x) sk_X509_pop_free(x, X509_free); + } + }; + using x509_sk_t = std::unique_ptr; } // namespace namespace aria2 { @@ -138,14 +158,13 @@ bool OpenSSLTLSContext::addP12CredentialFile(const std::string& p12file) auto data = ss.str(); void *ptr = const_cast(data.c_str()); size_t len = data.length(); - std::unique_ptr bio(BIO_new_mem_buf(ptr, len)); - A2_LOG_DEBUG(fmt("p12 size: %" PRIu64, len)); + bio_t bio(BIO_new_mem_buf(ptr, len)); if (!bio) { A2_LOG_ERROR("Failed to open p12 file: no memory"); return false; } - std::unique_ptr p12(d2i_PKCS12_bio(bio.get(), nullptr)); + p12_t p12(d2i_PKCS12_bio(bio.get(), nullptr)); if (!p12) { A2_LOG_ERROR(fmt("Failed to open p12 file: %s", ERR_error_string(ERR_get_error(), nullptr))); @@ -160,44 +179,33 @@ bool OpenSSLTLSContext::addP12CredentialFile(const std::string& p12file) return false; } - bool rv = false; - if (pkey && cert) { - rv = SSL_CTX_use_PrivateKey(sslCtx_, pkey); - if (!rv) { - A2_LOG_ERROR(fmt("Failed to use p12 file pkey: %s", - ERR_error_string(ERR_get_error(), nullptr))); - } - if (rv) { - rv = SSL_CTX_use_certificate(sslCtx_, cert); - if (!rv) { - A2_LOG_ERROR(fmt("Failed to use p12 file cert: %s", - ERR_error_string(ERR_get_error(), nullptr))); - } - } - if (rv && ca && sk_X509_num(ca)) { - rv = SSL_CTX_add_extra_chain_cert(sslCtx_, ca); - if (!rv) { - A2_LOG_ERROR(fmt("Failed to use p12 file chain: %s", - ERR_error_string(ERR_get_error(), nullptr))); - } - } - } - else { + pkey_t pkey_holder(pkey); + x509_t cert_holder(cert); + x509_sk_t ca_holder(ca); + + if (!pkey || !cert) { A2_LOG_ERROR(fmt("Failed to use p12 file: no pkey or cert %s", ERR_error_string(ERR_get_error(), nullptr))); + return false; + } + if (!SSL_CTX_use_PrivateKey(sslCtx_, pkey)) { + A2_LOG_ERROR(fmt("Failed to use p12 file pkey: %s", + ERR_error_string(ERR_get_error(), nullptr))); + return false; + } + if (!SSL_CTX_use_certificate(sslCtx_, cert)) { + A2_LOG_ERROR(fmt("Failed to use p12 file cert: %s", + ERR_error_string(ERR_get_error(), nullptr))); + return false; + } + if (ca && sk_X509_num(ca) && !SSL_CTX_add_extra_chain_cert(sslCtx_, ca)) { + A2_LOG_ERROR(fmt("Failed to use p12 file chain: %s", + ERR_error_string(ERR_get_error(), nullptr))); + return false; } - if (pkey) EVP_PKEY_free(pkey); - if (cert) X509_free(cert); - if (ca) sk_X509_pop_free(ca, X509_free); - if (!rv) { - A2_LOG_ERROR(fmt("Failed to use p12 file: %s", - ERR_error_string(ERR_get_error(), nullptr))); - } - else { - A2_LOG_INFO("Using certificate and key from p12 file"); - } - return rv; + A2_LOG_INFO("Using certificate and key from p12 file"); + return true; } bool OpenSSLTLSContext::addSystemTrustedCACerts() diff --git a/src/LibsslTLSContext.h b/src/LibsslTLSContext.h index 326f6a63..65582366 100644 --- a/src/LibsslTLSContext.h +++ b/src/LibsslTLSContext.h @@ -55,7 +55,7 @@ public: // private key `keyfile' must be decrypted. virtual bool addCredentialFile(const std::string& certfile, const std::string& keyfile) CXX11_OVERRIDE; - virtual bool addP12CredentialFile(const std::string& p12file); + bool addP12CredentialFile(const std::string& p12file); virtual bool addSystemTrustedCACerts() CXX11_OVERRIDE;