diff --git a/src/SocketCore.cc b/src/SocketCore.cc index 17feeec8..a625b019 100644 --- a/src/SocketCore.cc +++ b/src/SocketCore.cc @@ -121,6 +121,19 @@ std::string errorMsg(int errNum) } } // namespace +namespace { +enum TlsState { + // TLS object is not initialized. + A2_TLS_NONE = 0, + // TLS object is initialized. Ready for handshake. + A2_TLS_INITIALIZED = 1, + // TLS object is now handshaking. + A2_TLS_HANDSHAKING = 2, + // TLS object is now connected. + A2_TLS_CONNECTED = 3 +}; +} // namespace + int SocketCore::protocolFamily_ = AF_UNSPEC; std::vector > @@ -152,7 +165,7 @@ SocketCore::SocketCore(sock_t sockfd, int sockType) void SocketCore::init() { blocking_ = true; - secure_ = 0; + secure_ = A2_TLS_NONE; wantRead_ = false; wantWrite_ = false; @@ -837,16 +850,28 @@ void SocketCore::prepareSecureConnection() tlsContext_->getCertCred()); gnutls_transport_set_ptr(sslSession_, (gnutls_transport_ptr_t)sockfd_); #endif // HAVE_LIBGNUTLS - secure_ = 1; + secure_ = A2_TLS_INITIALIZED; } } bool SocketCore::initiateSecureConnection(const std::string& hostname) { - if(secure_ == 1) { - wantRead_ = false; - wantWrite_ = false; + wantRead_ = false; + wantWrite_ = false; #ifdef HAVE_OPENSSL + switch(secure_) { + case A2_TLS_INITIALIZED: + secure_ = A2_TLS_HANDSHAKING; +#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME + if(!util::isNumericHost(hostname)) { + // TLS extensions: SNI. There is not documentation about the + // return code for this function (actually this is macro + // wrapping SSL_ctrl at the time of this writing). + SSL_set_tlsext_host_name(ssl, hostname.c_str()); + } +#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME + // Fall through + case A2_TLS_HANDSHAKING: { ERR_clear_error(); int e = SSL_connect(ssl); @@ -950,8 +975,28 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname) throw DL_ABORT_EX(MSG_HOSTNAME_NOT_MATCH); } } + secure_ = A2_TLS_CONNECTED; + break; + } + default: + break; + } #endif // HAVE_OPENSSL #ifdef HAVE_LIBGNUTLS + switch(secure_) { + case A2_TLS_INITIALIZED: + secure_ = A2_TLS_HANDSHAKING; + if(!util::isNumericHost(hostname)) { + // TLS extensions: SNI + int ret = gnutls_server_name_set(sslSession_, GNUTLS_NAME_DNS, + hostname.c_str(), hostname.size()); + if(ret < 0) { + A2_LOG_WARN(fmt("Setting hostname in SNI extension failed. Cause: %s", + gnutls_strerror(ret))); + } + } + // Fall through + case A2_TLS_HANDSHAKING: { int ret = gnutls_handshake(sslSession_); if(ret == GNUTLS_E_AGAIN) { gnutlsRecordCheckDirection(); @@ -1056,12 +1101,14 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname) } } } -#endif // HAVE_LIBGNUTLS - secure_ = 2; - return true; - } else { - return true; + secure_ = A2_TLS_CONNECTED; + break; } + default: + break; + } +#endif // HAVE_LIBGNUTLS + return true; } ssize_t SocketCore::writeData(const char* data, size_t len,