mirror of https://github.com/aria2/aria2
				
				
				
			Implement WinTLS
							parent
							
								
									3f1d293ed1
								
							
						
					
					
						commit
						00dd83b461
					
				
							
								
								
									
										42
									
								
								configure.ac
								
								
								
								
							
							
						
						
									
										42
									
								
								configure.ac
								
								
								
								
							| 
						 | 
				
			
			@ -40,7 +40,7 @@ AC_DEFINE_UNQUOTED([TARGET], ["$target"], [Define target-type])
 | 
			
		|||
# Checks for arguments.
 | 
			
		||||
ARIA2_ARG_WITHOUT([libuv])
 | 
			
		||||
ARIA2_ARG_WITHOUT([appletls])
 | 
			
		||||
ARIA2_ARG_WITHOUT([wintls])
 | 
			
		||||
ARIA2_ARG_WITH([wintls])
 | 
			
		||||
ARIA2_ARG_WITHOUT([gnutls])
 | 
			
		||||
ARIA2_ARG_WITHOUT([libnettle])
 | 
			
		||||
ARIA2_ARG_WITHOUT([libgmp])
 | 
			
		||||
| 
						 | 
				
			
			@ -337,23 +337,39 @@ if test "x$with_appletls" = "xyes"; then
 | 
			
		|||
fi
 | 
			
		||||
 | 
			
		||||
if test "x$with_wintls" = "xyes"; then
 | 
			
		||||
  AC_SEARCH_LIBS([CryptAcquireContextW], [advapi32], [
 | 
			
		||||
                  AC_CHECK_HEADER([wincrypt.h], [have_wincrypt=yes], [have_wincrypt=no],
 | 
			
		||||
                      [[
 | 
			
		||||
  AC_HAVE_LIBRARY([crypt32],[have_wintls_libs=yes],[have_wintls_libs=no])
 | 
			
		||||
  AC_HAVE_LIBRARY([secur32],[have_wintls_libs=$have_wintls_libs],[have_wintls_libs=no])
 | 
			
		||||
  AC_HAVE_LIBRARY([advapi32],[have_wintls_libs=$have_wintls_libs],[have_wintls_libs=no])
 | 
			
		||||
  AC_CHECK_HEADER([wincrypt.h], [have_wintls_headers=yes], [have_wintls_headers=no], [[
 | 
			
		||||
#ifdef HAVE_WINDOWS_H
 | 
			
		||||
# include <windows.h>
 | 
			
		||||
#endif
 | 
			
		||||
                      ]])
 | 
			
		||||
                  break;
 | 
			
		||||
                  ], [have_wincrypt=no])
 | 
			
		||||
  if test "x$have_wincrypt" != "xyes"; then
 | 
			
		||||
  ]])
 | 
			
		||||
  AC_CHECK_HEADER([security.h], [have_wintls_headers=$have_wintls_headers], [have_wintls_headers=no], [[
 | 
			
		||||
#ifdef HAVE_WINDOWS_H
 | 
			
		||||
# include <windows.h>
 | 
			
		||||
#endif
 | 
			
		||||
#ifndef SECURITY_WIN32
 | 
			
		||||
#define SECURITY_WIN32 1
 | 
			
		||||
#endif
 | 
			
		||||
  ]])
 | 
			
		||||
 | 
			
		||||
  if test "x$have_wintls_libs" = "xyes" &&
 | 
			
		||||
     test "x$have_wintls_headers" = "xyes"; then
 | 
			
		||||
    AC_DEFINE([SECURITY_WIN32], [1], [Use security.h in WIN32 mode])
 | 
			
		||||
    LIBS="$LIBS -lcrypt32 -lsecur32 -ladvapi32"
 | 
			
		||||
    have_wintls=yes
 | 
			
		||||
  else
 | 
			
		||||
    have_wintls=no
 | 
			
		||||
  fi
 | 
			
		||||
  if test "x$have_wintls" != "xyes"; then
 | 
			
		||||
    if test "x$with_wintls_requested" = "xyes"; then
 | 
			
		||||
      ARIA2_DEP_NOT_MET([wintls])
 | 
			
		||||
    fi
 | 
			
		||||
  fi
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
if test "x$with_gnutls" = "xyes" && test "x$have_appletls" != "xyes"; then
 | 
			
		||||
if test "x$with_gnutls" = "xyes" && test "x$have_appletls" != "xyes" && test "x$have_wintls" != "xyes"; then
 | 
			
		||||
  # gnutls >= 2.8 doesn't have libgnutls-config anymore. We require
 | 
			
		||||
  # 2.2.0 because we use gnutls_priority_set_direct()
 | 
			
		||||
  PKG_CHECK_MODULES([LIBGNUTLS], [gnutls >= 2.2.0],
 | 
			
		||||
| 
						 | 
				
			
			@ -371,7 +387,7 @@ if test "x$with_gnutls" = "xyes" && test "x$have_appletls" != "xyes"; then
 | 
			
		|||
  fi
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
if test "x$with_openssl" = "xyes" && test "x$have_appletls" != "xyes" && test "x$have_libgnutls" != "xyes"; then
 | 
			
		||||
if test "x$with_openssl" = "xyes" && test "x$have_appletls" != "xyes" && test "x$have_wintls" != "xyes" && test "x$have_libgnutls" != "xyes"; then
 | 
			
		||||
  PKG_CHECK_MODULES([OPENSSL], [openssl >= 0.9.8],
 | 
			
		||||
                    [have_openssl=yes], [have_openssl=no])
 | 
			
		||||
  if test "x$have_openssl" = "xyes"; then
 | 
			
		||||
| 
						 | 
				
			
			@ -448,7 +464,7 @@ if test "x$have_appletls" == "xyes"; then
 | 
			
		|||
  use_md="apple"
 | 
			
		||||
  AC_DEFINE([USE_APPLE_MD], [1], [What message digest implementation to use])
 | 
			
		||||
else
 | 
			
		||||
  if test "x$have_wincrypt" == "xyes"; then
 | 
			
		||||
  if test "x$have_wintls" == "xyes"; then
 | 
			
		||||
    use_md="windows"
 | 
			
		||||
    AC_DEFINE([USE_WINDOWS_MD], [1], [What message digest implementation to use])
 | 
			
		||||
  else
 | 
			
		||||
| 
						 | 
				
			
			@ -473,7 +489,7 @@ else
 | 
			
		|||
fi
 | 
			
		||||
 | 
			
		||||
# Define variables based on the result of the checks for libraries.
 | 
			
		||||
if test "x$have_appletls" = "xyes" || test "x$have_libgnutls" = "xyes" || test "x$have_openssl" = "xyes"; then
 | 
			
		||||
if test "x$have_appletls" = "xyes" || test "x$have_wintls" == "xyes" || test "x$have_libgnutls" = "xyes" || test "x$have_openssl" = "xyes"; then
 | 
			
		||||
  have_ssl="yes"
 | 
			
		||||
  AC_DEFINE([ENABLE_SSL], [1], [Define to 1 if ssl support is enabled.])
 | 
			
		||||
  AM_CONDITIONAL([ENABLE_SSL], true)
 | 
			
		||||
| 
						 | 
				
			
			@ -485,6 +501,7 @@ fi
 | 
			
		|||
 | 
			
		||||
AM_CONDITIONAL([HAVE_OSX], [ test "x$have_osx" = "xyes" ])
 | 
			
		||||
AM_CONDITIONAL([HAVE_APPLETLS], [ test "x$have_appletls" = "xyes" ])
 | 
			
		||||
AM_CONDITIONAL([HAVE_WINTLS], [ test "x$have_wintls" = "xyes" ])
 | 
			
		||||
AM_CONDITIONAL([USE_APPLE_MD], [ test "x$use_md" = "xapple" ])
 | 
			
		||||
AM_CONDITIONAL([USE_WINDOWS_MD], [ test "x$use_md" = "xwindows" ])
 | 
			
		||||
AM_CONDITIONAL([HAVE_LIBGNUTLS], [ test "x$have_libgnutls" = "xyes" ])
 | 
			
		||||
| 
						 | 
				
			
			@ -985,6 +1002,7 @@ echo "LibUV:          $have_libuv"
 | 
			
		|||
echo "SQLite3:        $have_sqlite3"
 | 
			
		||||
echo "SSL Support:    $have_ssl"
 | 
			
		||||
echo "AppleTLS:       $have_appletls"
 | 
			
		||||
echo "WinTLS:         $have_wintls"
 | 
			
		||||
echo "GnuTLS:         $have_libgnutls"
 | 
			
		||||
echo "OpenSSL:        $have_openssl"
 | 
			
		||||
echo "CA Bundle:      $ca_bundle"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -333,6 +333,11 @@ if USE_WINDOWS_MD
 | 
			
		|||
SRCS += WinMessageDigestImpl.cc
 | 
			
		||||
endif # USE_WINDOWS_MD
 | 
			
		||||
 | 
			
		||||
if HAVE_WINTLS
 | 
			
		||||
SRCS += WinTLSContext.cc WinTLSContext.h \
 | 
			
		||||
	WinTLSSession.cc WinTLSSession.h
 | 
			
		||||
endif # HAVE_WINTLS
 | 
			
		||||
 | 
			
		||||
if USE_INTERNAL_BIGNUM
 | 
			
		||||
SRCS += InternalDHKeyExchange.cc InternalDHKeyExchange.h bignum.h
 | 
			
		||||
endif
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -779,7 +779,7 @@ void SocketCore::readData(void* data, size_t& len)
 | 
			
		|||
    ret = tlsSession_->readData(data, len);
 | 
			
		||||
    if(ret < 0) {
 | 
			
		||||
      if(ret != TLS_ERR_WOULDBLOCK) {
 | 
			
		||||
        throw DL_RETRY_EX(fmt(EX_SOCKET_SEND,
 | 
			
		||||
        throw DL_RETRY_EX(fmt(EX_SOCKET_RECV,
 | 
			
		||||
                              tlsSession_->getLastErrorString().c_str()));
 | 
			
		||||
      }
 | 
			
		||||
      if(tlsSession_->checkDirection() == TLS_WANT_READ) {
 | 
			
		||||
| 
						 | 
				
			
			@ -814,6 +814,7 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
 | 
			
		|||
  wantWrite_ = false;
 | 
			
		||||
  switch(secure_) {
 | 
			
		||||
  case A2_TLS_NONE:
 | 
			
		||||
    A2_LOG_DEBUG("Creating TLS session");
 | 
			
		||||
    tlsSession_.reset(TLSSession::make(tlsctx));
 | 
			
		||||
    rv = tlsSession_->init(sockfd_);
 | 
			
		||||
    if(rv != TLS_ERR_OK) {
 | 
			
		||||
| 
						 | 
				
			
			@ -835,6 +836,7 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
 | 
			
		|||
    secure_ = A2_TLS_HANDSHAKING;
 | 
			
		||||
    // Fall through
 | 
			
		||||
  case A2_TLS_HANDSHAKING:
 | 
			
		||||
    A2_LOG_DEBUG("TLS Handshaking");
 | 
			
		||||
    if(tlsctx->getSide() == TLS_CLIENT) {
 | 
			
		||||
      rv = tlsSession_->tlsConnect(hostname, handshakeError);
 | 
			
		||||
    } else {
 | 
			
		||||
| 
						 | 
				
			
			@ -857,6 +859,7 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
 | 
			
		|||
    }
 | 
			
		||||
    return false;
 | 
			
		||||
  default:
 | 
			
		||||
    A2_LOG_DEBUG("TLS else");
 | 
			
		||||
    break;
 | 
			
		||||
  }
 | 
			
		||||
  return true;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,189 @@
 | 
			
		|||
/* <!-- copyright */
 | 
			
		||||
/*
 | 
			
		||||
 * aria2 - The high speed download utility
 | 
			
		||||
 *
 | 
			
		||||
 * Copyright (C) 2013 Nils Maier
 | 
			
		||||
 *
 | 
			
		||||
 * This program is free software; you can redistribute it and/or modify
 | 
			
		||||
 * it under the terms of the GNU General Public License as published by
 | 
			
		||||
 * the Free Software Foundation; either version 2 of the License, or
 | 
			
		||||
 * (at your option) any later version.
 | 
			
		||||
 *
 | 
			
		||||
 * This program is distributed in the hope that it will be useful,
 | 
			
		||||
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
 * GNU General Public License for more details.
 | 
			
		||||
 *
 | 
			
		||||
 * You should have received a copy of the GNU General Public License
 | 
			
		||||
 * along with this program; if not, write to the Free Software
 | 
			
		||||
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 | 
			
		||||
 *
 | 
			
		||||
 * In addition, as a special exception, the copyright holders give
 | 
			
		||||
 * permission to link the code of portions of this program with the
 | 
			
		||||
 * OpenSSL library under certain conditions as described in each
 | 
			
		||||
 * individual source file, and distribute linked combinations
 | 
			
		||||
 * including the two.
 | 
			
		||||
 * You must obey the GNU General Public License in all respects
 | 
			
		||||
 * for all of the code used other than OpenSSL.  If you modify
 | 
			
		||||
 * file(s) with this exception, you may extend this exception to your
 | 
			
		||||
 * version of the file(s), but you are not obligated to do so.  If you
 | 
			
		||||
 * do not wish to do so, delete this exception statement from your
 | 
			
		||||
 * version.  If you delete this exception statement from all source
 | 
			
		||||
 * files in the program, then also delete it here.
 | 
			
		||||
 */
 | 
			
		||||
/* copyright --> */
 | 
			
		||||
 | 
			
		||||
#include "WinTLSContext.h"
 | 
			
		||||
 | 
			
		||||
#include <sstream>
 | 
			
		||||
 | 
			
		||||
#include "BufferedFile.h"
 | 
			
		||||
#include "LogFactory.h"
 | 
			
		||||
#include "Logger.h"
 | 
			
		||||
#include "fmt.h"
 | 
			
		||||
#include "message.h"
 | 
			
		||||
#include "util.h"
 | 
			
		||||
 | 
			
		||||
namespace aria2 {
 | 
			
		||||
 | 
			
		||||
WinTLSContext::WinTLSContext(TLSSessionSide side)
 | 
			
		||||
  : side_(side), store_(0)
 | 
			
		||||
{
 | 
			
		||||
  memset(&credentials_, 0, sizeof(credentials_));
 | 
			
		||||
  credentials_.dwVersion = SCHANNEL_CRED_VERSION;
 | 
			
		||||
  if (side_ == TLS_CLIENT) {
 | 
			
		||||
    credentials_.grbitEnabledProtocols =
 | 
			
		||||
      SP_PROT_SSL3_CLIENT |
 | 
			
		||||
      SP_PROT_TLS1_CLIENT |
 | 
			
		||||
      SP_PROT_TLS1_1_CLIENT |
 | 
			
		||||
      SP_PROT_TLS1_2_CLIENT;
 | 
			
		||||
  }
 | 
			
		||||
  else {
 | 
			
		||||
    credentials_.grbitEnabledProtocols =
 | 
			
		||||
      SP_PROT_SSL3_SERVER |
 | 
			
		||||
      SP_PROT_TLS1_SERVER |
 | 
			
		||||
      SP_PROT_TLS1_1_SERVER |
 | 
			
		||||
      SP_PROT_TLS1_2_SERVER;
 | 
			
		||||
  }
 | 
			
		||||
  credentials_.dwMinimumCipherStrength = 128; // bit
 | 
			
		||||
 | 
			
		||||
  setVerifyPeer(side_ == TLS_CLIENT);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TLSContext* TLSContext::make(TLSSessionSide side)
 | 
			
		||||
{
 | 
			
		||||
  return new WinTLSContext(side);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
WinTLSContext::~WinTLSContext()
 | 
			
		||||
{
 | 
			
		||||
  if (store_) {
 | 
			
		||||
    CertCloseStore(store_, 0);
 | 
			
		||||
    store_ = 0;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool WinTLSContext::getVerifyPeer() const
 | 
			
		||||
{
 | 
			
		||||
  return credentials_.dwFlags & SCH_CRED_AUTO_CRED_VALIDATION;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void WinTLSContext::setVerifyPeer(bool verify)
 | 
			
		||||
{
 | 
			
		||||
  if (side_ == TLS_CLIENT && verify) {
 | 
			
		||||
    credentials_.dwFlags =
 | 
			
		||||
      SCH_CRED_NO_DEFAULT_CREDS |
 | 
			
		||||
      SCH_CRED_AUTO_CRED_VALIDATION |
 | 
			
		||||
      SCH_CRED_REVOCATION_CHECK_CHAIN;
 | 
			
		||||
  }
 | 
			
		||||
  else {
 | 
			
		||||
    credentials_.dwFlags =
 | 
			
		||||
      SCH_CRED_NO_DEFAULT_CREDS |
 | 
			
		||||
      SCH_CRED_MANUAL_CRED_VALIDATION |
 | 
			
		||||
      SCH_CRED_IGNORE_NO_REVOCATION_CHECK |
 | 
			
		||||
      SCH_CRED_IGNORE_REVOCATION_OFFLINE |
 | 
			
		||||
      SCH_CRED_NO_SERVERNAME_CHECK;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Need to initialize cred_ early, because later on it will segfault deep
 | 
			
		||||
  // within AcquireCredentialsHandle for whatever reason.
 | 
			
		||||
  cred_.reset();
 | 
			
		||||
  getCredHandle();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
CredHandle* WinTLSContext::getCredHandle()
 | 
			
		||||
{
 | 
			
		||||
  if (cred_) {
 | 
			
		||||
    return cred_.get();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TimeStamp ts;
 | 
			
		||||
  cred_.reset(new CredHandle());
 | 
			
		||||
  SECURITY_STATUS status = ::AcquireCredentialsHandleW(
 | 
			
		||||
      nullptr,
 | 
			
		||||
      (SEC_WCHAR*)UNISP_NAME_W,
 | 
			
		||||
      side_ == TLS_CLIENT ? SECPKG_CRED_OUTBOUND : SECPKG_CRED_INBOUND,
 | 
			
		||||
      nullptr,
 | 
			
		||||
      &credentials_,
 | 
			
		||||
      nullptr,
 | 
			
		||||
      nullptr,
 | 
			
		||||
      cred_.get(),
 | 
			
		||||
      &ts);
 | 
			
		||||
  if (status != SEC_E_OK) {
 | 
			
		||||
    cred_.reset();
 | 
			
		||||
    throw DL_ABORT_EX("Failed to initialize WinTLS context handle");
 | 
			
		||||
  }
 | 
			
		||||
  return cred_.get();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool WinTLSContext::addCredentialFile(const std::string& certfile,
 | 
			
		||||
                                        const std::string& keyfile)
 | 
			
		||||
{
 | 
			
		||||
  std::stringstream ss;
 | 
			
		||||
  BufferedFile(certfile.c_str(), "rb").transfer(ss);
 | 
			
		||||
  auto data = ss.str();
 | 
			
		||||
  CRYPT_DATA_BLOB blob = {
 | 
			
		||||
    (DWORD)data.length(),
 | 
			
		||||
    (BYTE*)data.c_str()
 | 
			
		||||
  };
 | 
			
		||||
  if (!PFXIsPFXBlob(&blob)) {
 | 
			
		||||
    A2_LOG_ERROR("Not a valid PKCS12 file");
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  store_ = ::PFXImportCertStore(&blob, L"",
 | 
			
		||||
                                CRYPT_EXPORTABLE | CRYPT_USER_KEYSET);
 | 
			
		||||
  if (!store_) {
 | 
			
		||||
    store_ = ::PFXImportCertStore(&blob, nullptr,
 | 
			
		||||
                                  CRYPT_EXPORTABLE | CRYPT_USER_KEYSET);
 | 
			
		||||
  }
 | 
			
		||||
  if (!store_) {
 | 
			
		||||
    A2_LOG_ERROR("Failed to import PKCS12 store");
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const CERT_CONTEXT* ctx = ::CertEnumCertificatesInStore(store_, nullptr);
 | 
			
		||||
  if (!ctx) {
 | 
			
		||||
    A2_LOG_ERROR("Failed to read any certificates from the PKCS12 store");
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  credentials_.cCreds = 1;
 | 
			
		||||
  credentials_.paCred = &ctx;
 | 
			
		||||
 | 
			
		||||
  // Need to initialize cred_ early, because later on it will segfault deep
 | 
			
		||||
  // within AcquireCredentialsHandle for whatever reason.
 | 
			
		||||
  cred_.reset();
 | 
			
		||||
  getCredHandle();
 | 
			
		||||
 | 
			
		||||
  CertFreeCertificateContext(ctx);
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool WinTLSContext::addTrustedCACertFile(const std::string& certfile)
 | 
			
		||||
{
 | 
			
		||||
  A2_LOG_INFO("TLS CA bundle files are not supported. "
 | 
			
		||||
              "The system trust store will be used.");
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace aria2
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,116 @@
 | 
			
		|||
/* <!-- copyright */
 | 
			
		||||
/*
 | 
			
		||||
 * aria2 - The high speed download utility
 | 
			
		||||
 *
 | 
			
		||||
 * Copyright (C) 2013 Nils Maier
 | 
			
		||||
 *
 | 
			
		||||
 * This program is free software; you can redistribute it and/or modify
 | 
			
		||||
 * it under the terms of the GNU General Public License as published by
 | 
			
		||||
 * the Free Software Foundation; either version 2 of the License, or
 | 
			
		||||
 * (at your option) any later version.
 | 
			
		||||
 *
 | 
			
		||||
 * This program is distributed in the hope that it will be useful,
 | 
			
		||||
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
 * GNU General Public License for more details.
 | 
			
		||||
 *
 | 
			
		||||
 * You should have received a copy of the GNU General Public License
 | 
			
		||||
 * along with this program; if not, write to the Free Software
 | 
			
		||||
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 | 
			
		||||
 *
 | 
			
		||||
 * In addition, as a special exception, the copyright holders give
 | 
			
		||||
 * permission to link the code of portions of this program with the
 | 
			
		||||
 * OpenSSL library under certain conditions as described in each
 | 
			
		||||
 * individual source file, and distribute linked combinations
 | 
			
		||||
 * including the two.
 | 
			
		||||
 * You must obey the GNU General Public License in all respects
 | 
			
		||||
 * for all of the code used other than OpenSSL.  If you modify
 | 
			
		||||
 * file(s) with this exception, you may extend this exception to your
 | 
			
		||||
 * version of the file(s), but you are not obligated to do so.  If you
 | 
			
		||||
 * do not wish to do so, delete this exception statement from your
 | 
			
		||||
 * version.  If you delete this exception statement from all source
 | 
			
		||||
 * files in the program, then also delete it here.
 | 
			
		||||
 */
 | 
			
		||||
/* copyright --> */
 | 
			
		||||
 | 
			
		||||
#ifndef D_WIN_TLS_CONTEXT_H
 | 
			
		||||
#define D_WIN_TLS_CONTEXT_H
 | 
			
		||||
 | 
			
		||||
#include "common.h"
 | 
			
		||||
#include "config.h"
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
#include <windows.h>
 | 
			
		||||
#include <security.h>
 | 
			
		||||
#include <schnlsp.h>
 | 
			
		||||
 | 
			
		||||
#include "TLSContext.h"
 | 
			
		||||
#include "DlAbortEx.h"
 | 
			
		||||
 | 
			
		||||
#ifndef SP_PROT_TLS1_1_CLIENT
 | 
			
		||||
#define SP_PROT_TLS1_1_CLIENT 0x00000200
 | 
			
		||||
#endif
 | 
			
		||||
#ifndef SP_PROT_TLS1_1_SERVER
 | 
			
		||||
#define SP_PROT_TLS1_1_SERVER 0x00000100
 | 
			
		||||
#endif
 | 
			
		||||
#ifndef SP_PROT_TLS1_2_CLIENT
 | 
			
		||||
#define SP_PROT_TLS1_2_CLIENT 0x00000800
 | 
			
		||||
#endif
 | 
			
		||||
#ifndef SP_PROT_TLS1_2_SERVER
 | 
			
		||||
#define SP_PROT_TLS1_2_SERVER 0x00000400
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace aria2 {
 | 
			
		||||
 | 
			
		||||
namespace wintls {
 | 
			
		||||
  struct cred_deleter{
 | 
			
		||||
    void operator()(CredHandle* handle) {
 | 
			
		||||
      if (handle) {
 | 
			
		||||
        FreeCredentialsHandle(handle);
 | 
			
		||||
        delete handle;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
  typedef std::unique_ptr<CredHandle, cred_deleter> CredPtr;
 | 
			
		||||
} // namespace wintls
 | 
			
		||||
 | 
			
		||||
class WinTLSContext : public TLSContext {
 | 
			
		||||
public:
 | 
			
		||||
  WinTLSContext(TLSSessionSide side);
 | 
			
		||||
  virtual ~WinTLSContext();
 | 
			
		||||
 | 
			
		||||
  // private key `keyfile' must be decrypted.
 | 
			
		||||
  virtual bool addCredentialFile(const std::string& certfile,
 | 
			
		||||
                                 const std::string& keyfile) CXX11_OVERRIDE;
 | 
			
		||||
 | 
			
		||||
  virtual bool addSystemTrustedCACerts() CXX11_OVERRIDE {
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // certfile can contain multiple certificates.
 | 
			
		||||
  virtual bool addTrustedCACertFile(const std::string& certfile)
 | 
			
		||||
    CXX11_OVERRIDE;
 | 
			
		||||
 | 
			
		||||
  virtual bool good() const CXX11_OVERRIDE {
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
  virtual TLSSessionSide getSide() const CXX11_OVERRIDE {
 | 
			
		||||
    return side_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual bool getVerifyPeer() const CXX11_OVERRIDE;
 | 
			
		||||
  virtual void setVerifyPeer(bool verify) CXX11_OVERRIDE;
 | 
			
		||||
 | 
			
		||||
  CredHandle* getCredHandle();
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  TLSSessionSide side_;
 | 
			
		||||
  SCHANNEL_CRED credentials_;
 | 
			
		||||
  HCERTSTORE store_;
 | 
			
		||||
  wintls::CredPtr cred_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace aria2
 | 
			
		||||
 | 
			
		||||
#endif // D_LIBSSL_TLS_CONTEXT_H
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,816 @@
 | 
			
		|||
/* <!-- copyright */
 | 
			
		||||
/*
 | 
			
		||||
 * aria2 - The high speed download utility
 | 
			
		||||
 *
 | 
			
		||||
 * Copyright (C) 2013 Nils Maier
 | 
			
		||||
 *
 | 
			
		||||
 * This program is free software; you can redistribute it and/or modify
 | 
			
		||||
 * it under the terms of the GNU General Public License as published by
 | 
			
		||||
 * the Free Software Foundation; either version 2 of the License, or
 | 
			
		||||
 * (at your option) any later version.
 | 
			
		||||
 *
 | 
			
		||||
 * This program is distributed in the hope that it will be useful,
 | 
			
		||||
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
 * GNU General Public License for more details.
 | 
			
		||||
 *
 | 
			
		||||
 * You should have received a copy of the GNU General Public License
 | 
			
		||||
 * along with this program; if not, write to the Free Software
 | 
			
		||||
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 | 
			
		||||
 *
 | 
			
		||||
 * In addition, as a special exception, the copyright holders give
 | 
			
		||||
 * permission to link the code of portions of this program with the
 | 
			
		||||
 * OpenSSL library under certain conditions as described in each
 | 
			
		||||
 * individual source file, and distribute linked combinations
 | 
			
		||||
 * including the two.
 | 
			
		||||
 * You must obey the GNU General Public License in all respects
 | 
			
		||||
 * for all of the code used other than OpenSSL.  If you modify
 | 
			
		||||
 * file(s) with this exception, you may extend this exception to your
 | 
			
		||||
 * version of the file(s), but you are not obligated to do so.  If you
 | 
			
		||||
 * do not wish to do so, delete this exception statement from your
 | 
			
		||||
 * version.  If you delete this exception statement from all source
 | 
			
		||||
 * files in the program, then also delete it here.
 | 
			
		||||
 */
 | 
			
		||||
/* copyright --> */
 | 
			
		||||
 | 
			
		||||
#include "WinTLSSession.h"
 | 
			
		||||
 | 
			
		||||
#include <sstream>
 | 
			
		||||
 | 
			
		||||
#include "LogFactory.h"
 | 
			
		||||
#include "a2functional.h"
 | 
			
		||||
#include "fmt.h"
 | 
			
		||||
#include "util.h"
 | 
			
		||||
 | 
			
		||||
#ifndef SECBUFFER_ALERT
 | 
			
		||||
#define SECBUFFER_ALERT 17
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef SZ_ALG_MAX_SIZE
 | 
			
		||||
#define SZ_ALG_MAX_SIZE 64
 | 
			
		||||
#endif
 | 
			
		||||
#ifndef SECPKGCONTEXT_CIPHERINFO_V1
 | 
			
		||||
#define SECPKGCONTEXT_CIPHERINFO_V1 1
 | 
			
		||||
#endif
 | 
			
		||||
#ifndef SECPKG_ATTR_CIPHER_INFO
 | 
			
		||||
#define SECPKG_ATTR_CIPHER_INFO  0x64
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
  using namespace aria2;
 | 
			
		||||
 | 
			
		||||
  struct WinSecPkgContext_CipherInfo {
 | 
			
		||||
      DWORD dwVersion;
 | 
			
		||||
      DWORD dwProtocol;
 | 
			
		||||
      DWORD dwCipherSuite;
 | 
			
		||||
      DWORD dwBaseCipherSuite;
 | 
			
		||||
      WCHAR szCipherSuite[SZ_ALG_MAX_SIZE];
 | 
			
		||||
      WCHAR szCipher[SZ_ALG_MAX_SIZE];
 | 
			
		||||
      DWORD dwCipherLen;
 | 
			
		||||
      DWORD dwCipherBlockLen;    // in bytes
 | 
			
		||||
      WCHAR szHash[SZ_ALG_MAX_SIZE];
 | 
			
		||||
      DWORD dwHashLen;
 | 
			
		||||
      WCHAR szExchange[SZ_ALG_MAX_SIZE];
 | 
			
		||||
      DWORD dwMinExchangeLen;
 | 
			
		||||
      DWORD dwMaxExchangeLen;
 | 
			
		||||
      WCHAR szCertificate[SZ_ALG_MAX_SIZE];
 | 
			
		||||
      DWORD dwKeyType;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  static const ULONG kReqFlags = ISC_REQ_SEQUENCE_DETECT |
 | 
			
		||||
                                 ISC_REQ_REPLAY_DETECT |
 | 
			
		||||
                                 ISC_REQ_CONFIDENTIALITY |
 | 
			
		||||
                                 ISC_REQ_ALLOCATE_MEMORY |
 | 
			
		||||
                                 ISC_REQ_STREAM;
 | 
			
		||||
  static const ULONG kReqAFlags = ASC_REQ_SEQUENCE_DETECT |
 | 
			
		||||
                                  ASC_REQ_REPLAY_DETECT |
 | 
			
		||||
                                  ASC_REQ_CONFIDENTIALITY |
 | 
			
		||||
                                  ASC_REQ_EXTENDED_ERROR |
 | 
			
		||||
                                  ASC_REQ_ALLOCATE_MEMORY |
 | 
			
		||||
                                  ASC_REQ_STREAM;
 | 
			
		||||
 | 
			
		||||
  class TLSBuffer : public ::SecBuffer {
 | 
			
		||||
  public:
 | 
			
		||||
    explicit TLSBuffer(ULONG type, ULONG size, void *data)
 | 
			
		||||
    {
 | 
			
		||||
      cbBuffer = size;
 | 
			
		||||
      BufferType = type;
 | 
			
		||||
      pvBuffer = data;
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  class TLSBufferDesc: public ::SecBufferDesc {
 | 
			
		||||
  public:
 | 
			
		||||
    explicit TLSBufferDesc(SecBuffer *arr, ULONG buffers)
 | 
			
		||||
    {
 | 
			
		||||
      ulVersion = SECBUFFER_VERSION;
 | 
			
		||||
      cBuffers = buffers;
 | 
			
		||||
      pBuffers = arr;
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  inline static std::string getCipherSuite(CtxtHandle *handle)
 | 
			
		||||
  {
 | 
			
		||||
    WinSecPkgContext_CipherInfo info = { SECPKGCONTEXT_CIPHERINFO_V1 };
 | 
			
		||||
    if (QueryContextAttributes(handle, SECPKG_ATTR_CIPHER_INFO, &info) ==
 | 
			
		||||
        SEC_E_OK) {
 | 
			
		||||
      return wCharToUtf8(info.szCipherSuite);
 | 
			
		||||
    }
 | 
			
		||||
    return "Unknown";
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace aria2 {
 | 
			
		||||
 | 
			
		||||
TLSSession* TLSSession::make(TLSContext* ctx)
 | 
			
		||||
{
 | 
			
		||||
  return new WinTLSSession(static_cast<WinTLSContext*>(ctx));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
WinTLSSession::WinTLSSession(WinTLSContext* ctx)
 | 
			
		||||
  : sockfd_(0),
 | 
			
		||||
    side_(ctx->getSide()),
 | 
			
		||||
    cred_(ctx->getCredHandle()),
 | 
			
		||||
    writeBuffered_(0),
 | 
			
		||||
    state_(st_constructed),
 | 
			
		||||
    status_(SEC_E_OK)
 | 
			
		||||
{
 | 
			
		||||
  memset(&handle_, 0, sizeof(handle_));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
WinTLSSession::~WinTLSSession()
 | 
			
		||||
{
 | 
			
		||||
  ::DeleteSecurityContext(&handle_);
 | 
			
		||||
  state_ = st_error;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int WinTLSSession::init(sock_t sockfd)
 | 
			
		||||
{
 | 
			
		||||
  if (state_ != st_constructed) {
 | 
			
		||||
    status_ = SEC_E_INVALID_HANDLE;
 | 
			
		||||
    return TLS_ERR_ERROR;
 | 
			
		||||
  }
 | 
			
		||||
  sockfd_ = sockfd;
 | 
			
		||||
  state_ = st_initialized;
 | 
			
		||||
 | 
			
		||||
  return TLS_ERR_OK;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int WinTLSSession::setSNIHostname(const std::string& hostname)
 | 
			
		||||
{
 | 
			
		||||
  if (state_ != st_initialized) {
 | 
			
		||||
    status_ = SEC_E_INVALID_HANDLE;
 | 
			
		||||
    return TLS_ERR_ERROR;
 | 
			
		||||
  }
 | 
			
		||||
  hostname_ = hostname;
 | 
			
		||||
  return TLS_ERR_OK;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int WinTLSSession::closeConnection()
 | 
			
		||||
{
 | 
			
		||||
  if (state_ != st_connected || state_ != st_closing) {
 | 
			
		||||
    return TLS_ERR_ERROR;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (state_ == st_connected) {
 | 
			
		||||
    state_ = st_closing;
 | 
			
		||||
 | 
			
		||||
    DWORD dwShut = SCHANNEL_SHUTDOWN;
 | 
			
		||||
    TLSBuffer shut(SECBUFFER_TOKEN, sizeof(dwShut), &dwShut);
 | 
			
		||||
    TLSBufferDesc shutDesc(&shut, 1);
 | 
			
		||||
    status_ = ::ApplyControlToken(&handle_, &shutDesc);
 | 
			
		||||
    if (status_ != SEC_E_OK) {
 | 
			
		||||
      state_ = st_error;
 | 
			
		||||
      return TLS_ERR_ERROR;
 | 
			
		||||
    }
 | 
			
		||||
    TLSBuffer ctx(SECBUFFER_EMPTY, 0, nullptr);
 | 
			
		||||
    TLSBufferDesc desc(&ctx, 1);
 | 
			
		||||
    ULONG flags = 0;
 | 
			
		||||
    if (side_ == TLS_CLIENT) {
 | 
			
		||||
      SEC_CHAR* host = hostname_.empty() ?
 | 
			
		||||
        nullptr :
 | 
			
		||||
        const_cast<SEC_CHAR*>(hostname_.c_str());
 | 
			
		||||
      status_ = ::InitializeSecurityContext(
 | 
			
		||||
          cred_,
 | 
			
		||||
          &handle_,
 | 
			
		||||
          host,
 | 
			
		||||
          kReqFlags,
 | 
			
		||||
          0,
 | 
			
		||||
          0,
 | 
			
		||||
          nullptr,
 | 
			
		||||
          0,
 | 
			
		||||
          &handle_,
 | 
			
		||||
          &desc,
 | 
			
		||||
          &flags,
 | 
			
		||||
          nullptr);
 | 
			
		||||
    }
 | 
			
		||||
    else {
 | 
			
		||||
      status_ = ::AcceptSecurityContext(
 | 
			
		||||
          cred_,
 | 
			
		||||
          &handle_,
 | 
			
		||||
          nullptr,
 | 
			
		||||
          kReqAFlags,
 | 
			
		||||
          0,
 | 
			
		||||
          &handle_,
 | 
			
		||||
          &desc,
 | 
			
		||||
          &flags,
 | 
			
		||||
          nullptr);
 | 
			
		||||
    }
 | 
			
		||||
    if (status_ == SEC_E_OK || status_== SEC_I_CONTEXT_EXPIRED) {
 | 
			
		||||
      size_t len = ctx.cbBuffer;
 | 
			
		||||
      ssize_t rv = writeData(ctx.pvBuffer, ctx.cbBuffer);
 | 
			
		||||
      ::FreeContextBuffer(ctx.pvBuffer);
 | 
			
		||||
      if (rv == TLS_ERR_WOULDBLOCK) {
 | 
			
		||||
        return rv;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Alright data is sent or buffered
 | 
			
		||||
      if (rv - len != 0) {
 | 
			
		||||
        return TLS_ERR_WOULDBLOCK;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Send remaining data.
 | 
			
		||||
  while (writeBuf_.size()) {
 | 
			
		||||
    int rv = writeData(nullptr, 0);
 | 
			
		||||
    if (rv == TLS_ERR_WOULDBLOCK) {
 | 
			
		||||
      return rv;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  state_ = st_closed;
 | 
			
		||||
  return TLS_ERR_OK;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int WinTLSSession::checkDirection()
 | 
			
		||||
{
 | 
			
		||||
  if (state_ == st_handshake_write || state_  == st_handshake_write_last) {
 | 
			
		||||
    return TLS_WANT_WRITE;
 | 
			
		||||
  }
 | 
			
		||||
  if (state_ == st_handshake_read) {
 | 
			
		||||
    return TLS_WANT_READ;
 | 
			
		||||
  }
 | 
			
		||||
  if (readBuf_.size() || decBuf_.size()) {
 | 
			
		||||
    return TLS_WANT_READ;
 | 
			
		||||
  }
 | 
			
		||||
  if (writeBuf_.size()) {
 | 
			
		||||
    return TLS_WANT_WRITE;
 | 
			
		||||
  }
 | 
			
		||||
  return TLS_WANT_READ;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ssize_t WinTLSSession::writeData(const void* data, size_t len)
 | 
			
		||||
{
 | 
			
		||||
  if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
 | 
			
		||||
      state_ == st_handshake_read) {
 | 
			
		||||
    // Renegotiating
 | 
			
		||||
    std::string hn, err;
 | 
			
		||||
    auto connect = tlsConnect(hn, err);
 | 
			
		||||
    if (connect != TLS_ERR_OK) {
 | 
			
		||||
      return connect;
 | 
			
		||||
    }
 | 
			
		||||
    // Continue.
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (state_ != st_connected && state_ != st_closing) {
 | 
			
		||||
    status_ = SEC_E_INVALID_HANDLE;
 | 
			
		||||
    return TLS_ERR_ERROR;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  A2_LOG_DEBUG(fmt("WinTLS: Write request: %" PRIu64 " buffered: %" PRIu64,
 | 
			
		||||
                   (uint64_t)len, (uint64_t)writeBuf_.size()));
 | 
			
		||||
 | 
			
		||||
  // Write remaining buffered data, if any.
 | 
			
		||||
  size_t written = 0;
 | 
			
		||||
  while (writeBuf_.size()) {
 | 
			
		||||
    written = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
 | 
			
		||||
    errno = ::WSAGetLastError();
 | 
			
		||||
    if (written < 0 && errno == WSAEINTR) {
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
    if (written < 0 && errno == WSAEWOULDBLOCK) {
 | 
			
		||||
      return TLS_ERR_WOULDBLOCK;
 | 
			
		||||
    }
 | 
			
		||||
    if (written == 0) {
 | 
			
		||||
      return written;
 | 
			
		||||
    }
 | 
			
		||||
    if (written < 0) {
 | 
			
		||||
      status_ = SEC_E_INVALID_HANDLE;
 | 
			
		||||
      state_ = st_error;
 | 
			
		||||
      return TLS_ERR_ERROR;
 | 
			
		||||
    }
 | 
			
		||||
    writeBuf_.eat(written);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (len == 0) {
 | 
			
		||||
    return 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (!streamSizes_) {
 | 
			
		||||
    streamSizes_.reset(new SecPkgContext_StreamSizes());
 | 
			
		||||
    status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
 | 
			
		||||
                                       streamSizes_.get());
 | 
			
		||||
    if (status_ != SEC_E_OK || !streamSizes_->cbMaximumMessage) {
 | 
			
		||||
      state_ = st_error;
 | 
			
		||||
      return TLS_ERR_ERROR;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  size_t process = len;
 | 
			
		||||
  auto bytes = reinterpret_cast<const char*>(data);
 | 
			
		||||
  if (writeBuffered_) {
 | 
			
		||||
    // There was buffered data, hence we need to "remove" that data from the
 | 
			
		||||
    // incoming buffer to avoid writing it again
 | 
			
		||||
    if (len < writeBuffered_) {
 | 
			
		||||
      // We didn't get called with the same data again, obviously.
 | 
			
		||||
      status_ = SEC_E_INVALID_HANDLE;
 | 
			
		||||
      status_ = st_error;
 | 
			
		||||
      return TLS_ERR_ERROR;
 | 
			
		||||
    }
 | 
			
		||||
    // just advance the buffer by writeBuffered_ bytes
 | 
			
		||||
    bytes += writeBuffered_;
 | 
			
		||||
    process -= writeBuffered_;
 | 
			
		||||
    writeBuffered_ = 0;
 | 
			
		||||
  }
 | 
			
		||||
  if (!process) {
 | 
			
		||||
    // The buffer contained the full remainder. At this point, the buffer has
 | 
			
		||||
    // been written, so the request is done in its entirety;
 | 
			
		||||
    return len;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Buffered data was already written ;)
 | 
			
		||||
  // If there was no buffered data, this will be len - len = 0.
 | 
			
		||||
  len = len - process;
 | 
			
		||||
  while (process) {
 | 
			
		||||
    // Set up an outgoing message, according to streamSizes_
 | 
			
		||||
    writeBuffered_ = std::min(process, (size_t)streamSizes_->cbMaximumMessage);
 | 
			
		||||
    size_t dl = streamSizes_->cbHeader + writeBuffered_ +
 | 
			
		||||
                streamSizes_->cbTrailer;
 | 
			
		||||
    auto buf = make_unique<char[]>(dl);
 | 
			
		||||
    TLSBuffer buffers[] = {
 | 
			
		||||
      TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_->cbHeader, buf.get()),
 | 
			
		||||
      TLSBuffer(SECBUFFER_DATA, writeBuffered_,
 | 
			
		||||
                buf.get() + streamSizes_->cbHeader),
 | 
			
		||||
      TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_->cbTrailer,
 | 
			
		||||
                buf.get() + streamSizes_->cbHeader + writeBuffered_),
 | 
			
		||||
      TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
 | 
			
		||||
    };
 | 
			
		||||
    TLSBufferDesc desc(buffers, 4);
 | 
			
		||||
    memcpy(buffers[1].pvBuffer, bytes, writeBuffered_);
 | 
			
		||||
    status_ = ::EncryptMessage(&handle_, 0, &desc, 0);
 | 
			
		||||
    if (status_ != SEC_E_OK) {
 | 
			
		||||
      A2_LOG_ERROR(fmt("WinTLS: Failed to encrypt a message! %s",
 | 
			
		||||
                       getLastErrorString().c_str()));
 | 
			
		||||
      state_ = st_error;
 | 
			
		||||
      return TLS_ERR_ERROR;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // EncryptMessage may have truncated the buffers.
 | 
			
		||||
    // Should rarely happen, if ever, except for the trailer.
 | 
			
		||||
    dl = buffers[0].cbBuffer;
 | 
			
		||||
    if (dl < streamSizes_->cbHeader) {
 | 
			
		||||
      // Move message.
 | 
			
		||||
      memmove(buf.get() + dl, buffers[1].pvBuffer, buffers[1].cbBuffer);
 | 
			
		||||
    }
 | 
			
		||||
    dl += buffers[1].cbBuffer;
 | 
			
		||||
    if (dl < streamSizes_->cbHeader + writeBuffered_) {
 | 
			
		||||
      // Move tailer.
 | 
			
		||||
      memmove(buf.get() + dl, buffers[2].pvBuffer, buffers[2].cbBuffer);
 | 
			
		||||
    }
 | 
			
		||||
    dl += buffers[2].cbBuffer;
 | 
			
		||||
 | 
			
		||||
    // Write (or buffer) the message.
 | 
			
		||||
    char* p = buf.get();
 | 
			
		||||
    while (dl) {
 | 
			
		||||
      written = ::send(sockfd_, p, dl, 0);
 | 
			
		||||
      errno = ::WSAGetLastError();
 | 
			
		||||
      if (written < 0 && errno == WSAEINTR) {
 | 
			
		||||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
      if (written < 0 && errno == WSAEWOULDBLOCK) {
 | 
			
		||||
        // Buffer the rest of the message...
 | 
			
		||||
        writeBuf_.write(p, dl);
 | 
			
		||||
        // and return...
 | 
			
		||||
        return len;
 | 
			
		||||
      }
 | 
			
		||||
      if (written == 0) {
 | 
			
		||||
        A2_LOG_ERROR("WinTLS: Connection closed while writing");
 | 
			
		||||
        status_ = SEC_E_INCOMPLETE_MESSAGE;
 | 
			
		||||
        state_ = st_error;
 | 
			
		||||
        return TLS_ERR_ERROR;
 | 
			
		||||
      }
 | 
			
		||||
      if (written < 0) {
 | 
			
		||||
        A2_LOG_ERROR("WinTLS: Connection error while writing");
 | 
			
		||||
        status_ = SEC_E_INCOMPLETE_MESSAGE;
 | 
			
		||||
        state_ = st_error;
 | 
			
		||||
        return TLS_ERR_ERROR;
 | 
			
		||||
      }
 | 
			
		||||
      dl -= written;
 | 
			
		||||
      p += written;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    len += writeBuffered_;
 | 
			
		||||
    bytes += writeBuffered_;
 | 
			
		||||
    process -= writeBuffered_;
 | 
			
		||||
    writeBuffered_ = 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64 " buffered: %" PRIu64,
 | 
			
		||||
                   (uint64_t)len, (uint64_t)writeBuf_.size()));
 | 
			
		||||
  if (!len) {
 | 
			
		||||
    return TLS_ERR_WOULDBLOCK;
 | 
			
		||||
  }
 | 
			
		||||
  return len;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ssize_t WinTLSSession::readData(void* data, size_t len)
 | 
			
		||||
{
 | 
			
		||||
  A2_LOG_DEBUG(fmt("WinTLS: Read request: %" PRIu64 " buffered: %" PRIu64,
 | 
			
		||||
                   (uint64_t)len, (uint64_t)readBuf_.size()));
 | 
			
		||||
  if (len == 0) {
 | 
			
		||||
    return 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Can be filled from decBuffer entirely?
 | 
			
		||||
  if (decBuf_.size() >= len) {
 | 
			
		||||
    A2_LOG_DEBUG("WinTLS: Fullfilling req from buffer");
 | 
			
		||||
    memcpy(data, decBuf_.data(), len);
 | 
			
		||||
    decBuf_.eat(len);
 | 
			
		||||
    return len;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
 | 
			
		||||
      state_ == st_handshake_read) {
 | 
			
		||||
    // Renegotiating
 | 
			
		||||
    std::string hn, err;
 | 
			
		||||
    auto connect = tlsConnect(hn, err);
 | 
			
		||||
    if (connect != TLS_ERR_OK) {
 | 
			
		||||
      return connect;
 | 
			
		||||
    }
 | 
			
		||||
    // Continue.
 | 
			
		||||
  }
 | 
			
		||||
  if (state_ != st_connected) {
 | 
			
		||||
    status_ = SEC_E_INVALID_HANDLE;
 | 
			
		||||
    return TLS_ERR_ERROR;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Read as many bytes as available from the connection, up to len + 4k.
 | 
			
		||||
  readBuf_.resize(len + 4096);
 | 
			
		||||
  while (readBuf_.free()) {
 | 
			
		||||
    ssize_t read = ::recv(sockfd_, readBuf_.end(), readBuf_.free(), 0);
 | 
			
		||||
    errno = ::WSAGetLastError();
 | 
			
		||||
    if (read < 0 && errno == WSAEINTR) {
 | 
			
		||||
      continue;
 | 
			
		||||
    }
 | 
			
		||||
    if (read < 0 && errno == WSAEWOULDBLOCK) {
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    if (read == 0) {
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    if (read < 0) {
 | 
			
		||||
      status_ = SEC_E_INCOMPLETE_MESSAGE;
 | 
			
		||||
      state_ = st_error;
 | 
			
		||||
      return TLS_ERR_ERROR;
 | 
			
		||||
    }
 | 
			
		||||
    readBuf_.advance(read);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Try to decrypt as many messages as possible from the readBuf_.
 | 
			
		||||
  while (readBuf_.size()) {
 | 
			
		||||
    TLSBuffer bufs[] = {
 | 
			
		||||
      TLSBuffer(SECBUFFER_DATA, readBuf_.size(), readBuf_.data()),
 | 
			
		||||
      TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
 | 
			
		||||
      TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
 | 
			
		||||
      TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
 | 
			
		||||
    };
 | 
			
		||||
    TLSBufferDesc desc(bufs, 4);
 | 
			
		||||
    status_ = ::DecryptMessage(&handle_, &desc, 0, nullptr);
 | 
			
		||||
    if (status_ == SEC_E_INCOMPLETE_MESSAGE) {
 | 
			
		||||
      // Need to stop now, and wait for more bytes to arrive on the socket.
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (status_ != SEC_E_OK && status_ != SEC_I_CONTEXT_EXPIRED &&
 | 
			
		||||
        status_ != SEC_I_RENEGOTIATE) {
 | 
			
		||||
      A2_LOG_ERROR(fmt("WinTLS: Failed to decrypt a message! %s",
 | 
			
		||||
                       getLastErrorString().c_str()));
 | 
			
		||||
      state_ = st_error;
 | 
			
		||||
      return TLS_ERR_ERROR;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Decrypted message successfully.
 | 
			
		||||
    bool ate = false;
 | 
			
		||||
    for (auto& buf : bufs) {
 | 
			
		||||
      if (buf.BufferType == SECBUFFER_DATA && buf.cbBuffer > 0) {
 | 
			
		||||
        decBuf_.write(buf.pvBuffer, buf.cbBuffer);
 | 
			
		||||
      }
 | 
			
		||||
      else if (buf.BufferType == SECBUFFER_EXTRA && buf.cbBuffer > 0) {
 | 
			
		||||
        readBuf_.eat(readBuf_.size() - buf.cbBuffer);
 | 
			
		||||
        ate = true;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    if (!ate) {
 | 
			
		||||
      readBuf_.clear();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (status_ == SEC_I_RENEGOTIATE) {
 | 
			
		||||
      // Renegotiation basically means performing another handshake
 | 
			
		||||
      state_ = st_initialized;
 | 
			
		||||
      A2_LOG_INFO("WinTLS: Renegotiate");
 | 
			
		||||
      std::string hn, err;
 | 
			
		||||
      auto connect = tlsConnect(hn, err);
 | 
			
		||||
      if (connect == TLS_ERR_WOULDBLOCK) {
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      if (connect == TLS_ERR_ERROR) {
 | 
			
		||||
        return connect;
 | 
			
		||||
      }
 | 
			
		||||
      // Still good.
 | 
			
		||||
    }
 | 
			
		||||
    if (status_ == SEC_I_CONTEXT_EXPIRED) {
 | 
			
		||||
      // Connection is gone now, but the buffered bytes are still valid.
 | 
			
		||||
      A2_LOG_DEBUG("WinTLS: Connection closed!");
 | 
			
		||||
      closeConnection();
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  len = std::min(decBuf_.size(), len);
 | 
			
		||||
  if (len == 0) {
 | 
			
		||||
    return TLS_ERR_WOULDBLOCK;
 | 
			
		||||
  }
 | 
			
		||||
  memcpy(data, decBuf_.data(), len);
 | 
			
		||||
  decBuf_.eat(len);
 | 
			
		||||
  return len;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int WinTLSSession::tlsConnect(const std::string& hostname,
 | 
			
		||||
                              std::string& handshakeErr)
 | 
			
		||||
{
 | 
			
		||||
  // Handshaking will require sending multiple read/write exchanges until the
 | 
			
		||||
  // handshake is actually done. The client will first generate the initial
 | 
			
		||||
  // handshake message, then write that to the server, read the response
 | 
			
		||||
  // message, and write and/or read additional messages until the handshake is
 | 
			
		||||
  // either complete and successful, or something went wrong.
 | 
			
		||||
  // The server works analog to that.
 | 
			
		||||
 | 
			
		||||
  A2_LOG_DEBUG("WinTLS: Starting/Resuming TLS Connect");
 | 
			
		||||
  ULONG flags = 0;
 | 
			
		||||
 | 
			
		||||
restart:
 | 
			
		||||
 | 
			
		||||
  switch (state_) {
 | 
			
		||||
    default:
 | 
			
		||||
      A2_LOG_ERROR("WinTLS: Invalid state");
 | 
			
		||||
      status_ = SEC_E_INVALID_HANDLE;
 | 
			
		||||
      return TLS_ERR_ERROR;
 | 
			
		||||
 | 
			
		||||
    case st_initialized: {
 | 
			
		||||
      if (side_ == TLS_SERVER) {
 | 
			
		||||
        goto read;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      if (!hostname.empty()) {
 | 
			
		||||
        setSNIHostname(hostname);
 | 
			
		||||
      }
 | 
			
		||||
      A2_LOG_DEBUG("WinTLS: Initializing handshake");
 | 
			
		||||
      TLSBuffer buf(SECBUFFER_EMPTY, 0, nullptr);
 | 
			
		||||
      TLSBufferDesc desc(&buf, 1);
 | 
			
		||||
      SEC_CHAR* host = hostname_.empty() ?
 | 
			
		||||
        nullptr :
 | 
			
		||||
        const_cast<SEC_CHAR*>(hostname_.c_str());
 | 
			
		||||
      status_ = ::InitializeSecurityContext(
 | 
			
		||||
          cred_,
 | 
			
		||||
          nullptr,
 | 
			
		||||
          host,
 | 
			
		||||
          kReqFlags,
 | 
			
		||||
          0,
 | 
			
		||||
          0,
 | 
			
		||||
          nullptr,
 | 
			
		||||
          0,
 | 
			
		||||
          &handle_,
 | 
			
		||||
          &desc,
 | 
			
		||||
          &flags,
 | 
			
		||||
          nullptr);
 | 
			
		||||
      if (status_ != SEC_I_CONTINUE_NEEDED) {
 | 
			
		||||
        // Has to be SEC_I_CONTINUE_NEEDED, as we did not actually send data
 | 
			
		||||
        // at this point.
 | 
			
		||||
        state_ = st_error;
 | 
			
		||||
        return TLS_ERR_ERROR;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Queue the initial message...
 | 
			
		||||
      writeBuf_.write(buf.pvBuffer, buf.cbBuffer);
 | 
			
		||||
      FreeContextBuffer(buf.pvBuffer);
 | 
			
		||||
 | 
			
		||||
      // ... and start sending it
 | 
			
		||||
      state_ = st_handshake_write;
 | 
			
		||||
    }
 | 
			
		||||
    // Fall through
 | 
			
		||||
 | 
			
		||||
    case st_handshake_write_last:
 | 
			
		||||
    case st_handshake_write: {
 | 
			
		||||
      A2_LOG_DEBUG("WinTLS: Writing handshake");
 | 
			
		||||
 | 
			
		||||
      // Write the currently queued handshake message until all data is sent.
 | 
			
		||||
      while(writeBuf_.size()) {
 | 
			
		||||
        ssize_t writ = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
 | 
			
		||||
        errno = ::WSAGetLastError();
 | 
			
		||||
        if (writ < 0 && errno == WSAEINTR) {
 | 
			
		||||
          continue;
 | 
			
		||||
        }
 | 
			
		||||
        if (writ < 0 && errno == WSAEWOULDBLOCK) {
 | 
			
		||||
          return TLS_ERR_WOULDBLOCK;
 | 
			
		||||
        }
 | 
			
		||||
        if (writ <= 0) {
 | 
			
		||||
          status_ = SEC_E_INCOMPLETE_MESSAGE;
 | 
			
		||||
          state_ = st_error;
 | 
			
		||||
          return TLS_ERR_ERROR;
 | 
			
		||||
        }
 | 
			
		||||
        writeBuf_.eat(writ);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      if (state_ == st_handshake_write_last) {
 | 
			
		||||
        state_ = st_handshake_done;
 | 
			
		||||
        goto restart;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Have to read one or more response messages.
 | 
			
		||||
      state_ = st_handshake_read;
 | 
			
		||||
    }
 | 
			
		||||
    // Fall through
 | 
			
		||||
 | 
			
		||||
    case st_handshake_read: {
 | 
			
		||||
read:
 | 
			
		||||
      A2_LOG_DEBUG("WinTLS: Reading handshake...");
 | 
			
		||||
 | 
			
		||||
      // All write buffered data is invalid at this point!
 | 
			
		||||
      writeBuf_.clear();
 | 
			
		||||
 | 
			
		||||
      // Read as many bytes as possible, up to 4k new bytes.
 | 
			
		||||
      // We do not know how many bytes will arrive from the server at this
 | 
			
		||||
      // point.
 | 
			
		||||
      readBuf_.resize(readBuf_.size() + 4096);
 | 
			
		||||
      while (readBuf_.free()) {
 | 
			
		||||
        ssize_t read = ::recv(sockfd_, readBuf_.end(), readBuf_.free(), 0);
 | 
			
		||||
        errno = ::WSAGetLastError();
 | 
			
		||||
        if (read < 0 && errno == WSAEINTR) {
 | 
			
		||||
          continue;
 | 
			
		||||
        }
 | 
			
		||||
        if (read < 0 && errno == WSAEWOULDBLOCK) {
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        if (read <= 0) {
 | 
			
		||||
          status_ = SEC_E_INCOMPLETE_MESSAGE;
 | 
			
		||||
          state_ = st_error;
 | 
			
		||||
          return TLS_ERR_ERROR;
 | 
			
		||||
        }
 | 
			
		||||
        readBuf_.advance(read);
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      if (!readBuf_.size()) {
 | 
			
		||||
        return TLS_ERR_WOULDBLOCK;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Need to copy the data, as Schannel is free to mess with it. But we
 | 
			
		||||
      // might later need unmodified data from the original read buffer.
 | 
			
		||||
      auto bufcopy = make_unique<char[]>(readBuf_.size());
 | 
			
		||||
      memcpy(bufcopy.get(), readBuf_.data(), readBuf_.size());
 | 
			
		||||
 | 
			
		||||
      // Set up buffers. inbufs will be the raw bytes the library has to decode.
 | 
			
		||||
      // outbufs will contain generated responses, if any.
 | 
			
		||||
      TLSBuffer inbufs[] = {
 | 
			
		||||
        TLSBuffer(SECBUFFER_TOKEN, readBuf_.size(), bufcopy.get()),
 | 
			
		||||
        TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
 | 
			
		||||
      };
 | 
			
		||||
      TLSBufferDesc indesc(inbufs, 2);
 | 
			
		||||
      TLSBuffer outbufs[] = {
 | 
			
		||||
        TLSBuffer(SECBUFFER_TOKEN, 0, nullptr),
 | 
			
		||||
        TLSBuffer(SECBUFFER_ALERT, 0, nullptr),
 | 
			
		||||
      };
 | 
			
		||||
      TLSBufferDesc outdesc(outbufs, 2);
 | 
			
		||||
      if (side_ == TLS_CLIENT) {
 | 
			
		||||
        SEC_CHAR* host = hostname_.empty() ?
 | 
			
		||||
          nullptr :
 | 
			
		||||
          const_cast<SEC_CHAR*>(hostname_.c_str());
 | 
			
		||||
        status_ = ::InitializeSecurityContext(
 | 
			
		||||
            cred_,
 | 
			
		||||
            &handle_,
 | 
			
		||||
            host,
 | 
			
		||||
            kReqFlags,
 | 
			
		||||
            0,
 | 
			
		||||
            0,
 | 
			
		||||
            &indesc,
 | 
			
		||||
            0,
 | 
			
		||||
            nullptr,
 | 
			
		||||
            &outdesc,
 | 
			
		||||
            &flags,
 | 
			
		||||
            nullptr);
 | 
			
		||||
      }
 | 
			
		||||
      else {
 | 
			
		||||
        status_ = ::AcceptSecurityContext(
 | 
			
		||||
            cred_,
 | 
			
		||||
            state_ == st_initialized ? nullptr : &handle_,
 | 
			
		||||
            &indesc,
 | 
			
		||||
            kReqAFlags,
 | 
			
		||||
            0,
 | 
			
		||||
            state_ == st_initialized ? &handle_ : nullptr,
 | 
			
		||||
            &outdesc,
 | 
			
		||||
            &flags,
 | 
			
		||||
            nullptr);
 | 
			
		||||
      }
 | 
			
		||||
      if (status_ == SEC_E_INCOMPLETE_MESSAGE) {
 | 
			
		||||
        // Not enough raw bytes read yet to decode a full message.
 | 
			
		||||
        return TLS_ERR_WOULDBLOCK;
 | 
			
		||||
      }
 | 
			
		||||
      if (status_ != SEC_E_OK && status_ != SEC_I_CONTINUE_NEEDED) {
 | 
			
		||||
        state_ = st_error;
 | 
			
		||||
        return TLS_ERR_ERROR;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Raw bytes where not entirely consumed, i.e. readBuf_ still contains
 | 
			
		||||
      // unprocessed data from the next message?
 | 
			
		||||
      if (inbufs[1].BufferType == SECBUFFER_EXTRA && inbufs[1].cbBuffer > 0) {
 | 
			
		||||
        readBuf_.eat(readBuf_.size() - inbufs[1].cbBuffer);
 | 
			
		||||
      }
 | 
			
		||||
      else {
 | 
			
		||||
        readBuf_.clear();
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Check if the library produced a new outgoing message and queue it.
 | 
			
		||||
      for (auto& buf : outbufs) {
 | 
			
		||||
        if (buf.BufferType == SECBUFFER_TOKEN && buf.cbBuffer > 0) {
 | 
			
		||||
          writeBuf_.write(buf.pvBuffer, buf.cbBuffer);
 | 
			
		||||
          FreeContextBuffer(buf.pvBuffer);
 | 
			
		||||
          state_ = st_handshake_write;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Need to read additional messages?
 | 
			
		||||
      if (status_ == SEC_I_CONTINUE_NEEDED) {
 | 
			
		||||
        A2_LOG_DEBUG("WinTLS: Continuing with handshake");
 | 
			
		||||
        goto restart;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      if (side_ == TLS_CLIENT && flags != kReqFlags) {
 | 
			
		||||
        A2_LOG_ERROR(fmt("WinTLS: Channel setup failed. Schannel provider did "
 | 
			
		||||
                         "not fulfill requested flags. "
 | 
			
		||||
                         "Excepted: %lu Actual: %lu",
 | 
			
		||||
                         kReqFlags, flags));
 | 
			
		||||
        status_ = SEC_E_INTERNAL_ERROR;
 | 
			
		||||
        state_ = st_error;
 | 
			
		||||
        return TLS_ERR_ERROR;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      if (state_ == st_handshake_write) {
 | 
			
		||||
        A2_LOG_DEBUG("WinTLS: Continuing with handshake (last write)");
 | 
			
		||||
        state_ = st_handshake_write_last;
 | 
			
		||||
        goto restart;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    // Fall through
 | 
			
		||||
 | 
			
		||||
    case st_handshake_done:
 | 
			
		||||
      // All ready now :D
 | 
			
		||||
      state_ = st_connected;
 | 
			
		||||
      A2_LOG_INFO(fmt("WinTLS: connected with: %s",
 | 
			
		||||
                      getCipherSuite(&handle_).c_str()));
 | 
			
		||||
      return TLS_ERR_OK;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  A2_LOG_ERROR("WinTLS: Unreachable reached during tlsConnect! This is a bug!");
 | 
			
		||||
  state_ = st_error;
 | 
			
		||||
  return TLS_ERR_ERROR;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int WinTLSSession::tlsAccept()
 | 
			
		||||
{
 | 
			
		||||
  std::string host, err;
 | 
			
		||||
  return tlsConnect(host, err);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::string WinTLSSession::getLastErrorString()
 | 
			
		||||
{
 | 
			
		||||
  std::stringstream ss;
 | 
			
		||||
  wchar_t* buf = nullptr;
 | 
			
		||||
  if (FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER |
 | 
			
		||||
                    FORMAT_MESSAGE_FROM_SYSTEM |
 | 
			
		||||
                    FORMAT_MESSAGE_IGNORE_INSERTS,
 | 
			
		||||
                    nullptr,
 | 
			
		||||
                    status_,
 | 
			
		||||
                    MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
 | 
			
		||||
                    (LPWSTR)&buf,
 | 
			
		||||
                    1024,
 | 
			
		||||
                    nullptr) && buf) {
 | 
			
		||||
    ss << "Error: " << wCharToUtf8(buf);
 | 
			
		||||
    LocalFree(buf);
 | 
			
		||||
  }
 | 
			
		||||
  else {
 | 
			
		||||
    ss << "Error: " << std::hex << status_;
 | 
			
		||||
  }
 | 
			
		||||
  return ss.str();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace aria2
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,194 @@
 | 
			
		|||
/* <!-- copyright */
 | 
			
		||||
/*
 | 
			
		||||
 * aria2 - The high speed download utility
 | 
			
		||||
 *
 | 
			
		||||
 * Copyright (C) 2013 Nils Maier
 | 
			
		||||
 *
 | 
			
		||||
 * This program is free software; you can redistribute it and/or modify
 | 
			
		||||
 * it under the terms of the GNU General Public License as published by
 | 
			
		||||
 * the Free Software Foundation; either version 2 of the License, or
 | 
			
		||||
 * (at your option) any later version.
 | 
			
		||||
 *
 | 
			
		||||
 * This program is distributed in the hope that it will be useful,
 | 
			
		||||
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
 * GNU General Public License for more details.
 | 
			
		||||
 *
 | 
			
		||||
 * You should have received a copy of the GNU General Public License
 | 
			
		||||
 * along with this program; if not, write to the Free Software
 | 
			
		||||
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 | 
			
		||||
 *
 | 
			
		||||
 * In addition, as a special exception, the copyright holders give
 | 
			
		||||
 * permission to link the code of portions of this program with the
 | 
			
		||||
 * OpenSSL library under certain conditions as described in each
 | 
			
		||||
 * individual source file, and distribute linked combinations
 | 
			
		||||
 * including the two.
 | 
			
		||||
 * You must obey the GNU General Public License in all respects
 | 
			
		||||
 * for all of the code used other than OpenSSL.  If you modify
 | 
			
		||||
 * file(s) with this exception, you may extend this exception to your
 | 
			
		||||
 * version of the file(s), but you are not obligated to do so.  If you
 | 
			
		||||
 * do not wish to do so, delete this exception statement from your
 | 
			
		||||
 * version.  If you delete this exception statement from all source
 | 
			
		||||
 * files in the program, then also delete it here.
 | 
			
		||||
 */
 | 
			
		||||
/* copyright --> */
 | 
			
		||||
 | 
			
		||||
#ifndef WIN_TLS_SESSION_H
 | 
			
		||||
#define WIN_TLS_SESSION_H
 | 
			
		||||
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "common.h"
 | 
			
		||||
#include "TLSSession.h"
 | 
			
		||||
#include "WinTLSContext.h"
 | 
			
		||||
 | 
			
		||||
namespace aria2 {
 | 
			
		||||
 | 
			
		||||
namespace wintls {
 | 
			
		||||
  struct Buffer {
 | 
			
		||||
  private:
 | 
			
		||||
    size_t off_, free_, cap_;
 | 
			
		||||
    std::vector<char> buf_;
 | 
			
		||||
 | 
			
		||||
  public:
 | 
			
		||||
    inline Buffer() : off_(0), free_(0), cap_(0) {}
 | 
			
		||||
 | 
			
		||||
    inline size_t size() const {
 | 
			
		||||
      return off_;
 | 
			
		||||
    }
 | 
			
		||||
    inline size_t free() const {
 | 
			
		||||
      return free_;
 | 
			
		||||
    }
 | 
			
		||||
    inline void resize(size_t len) {
 | 
			
		||||
      if (cap_ >= len) {
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      buf_.resize(len);
 | 
			
		||||
      cap_ = buf_.size();
 | 
			
		||||
      free_ = cap_ - off_;
 | 
			
		||||
    }
 | 
			
		||||
    inline char* data() {
 | 
			
		||||
      return buf_.data();
 | 
			
		||||
    }
 | 
			
		||||
    inline char* end() {
 | 
			
		||||
      return buf_.data() + off_;
 | 
			
		||||
    }
 | 
			
		||||
    inline void eat(size_t len) {
 | 
			
		||||
      off_ -= len;
 | 
			
		||||
      if (off_) {
 | 
			
		||||
        memmove(buf_.data(), buf_.data() + len, off_);
 | 
			
		||||
      }
 | 
			
		||||
      free_ = cap_ - off_;
 | 
			
		||||
    }
 | 
			
		||||
    inline void clear() {
 | 
			
		||||
      eat(off_);
 | 
			
		||||
    }
 | 
			
		||||
    inline void advance(size_t len) {
 | 
			
		||||
      off_ += len;
 | 
			
		||||
      free_ = cap_ - off_;
 | 
			
		||||
    }
 | 
			
		||||
    inline void write(const void* data, size_t len) {
 | 
			
		||||
      if (!len) {
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      resize(off_ + len);
 | 
			
		||||
      memcpy(end(), data, len);
 | 
			
		||||
      advance(len);
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
} // namespace wintls
 | 
			
		||||
 | 
			
		||||
class WinTLSSession : public TLSSession {
 | 
			
		||||
  enum state_t {
 | 
			
		||||
    st_constructed,
 | 
			
		||||
    st_initialized,
 | 
			
		||||
    st_handshake_write,
 | 
			
		||||
    st_handshake_write_last,
 | 
			
		||||
    st_handshake_read,
 | 
			
		||||
    st_handshake_done,
 | 
			
		||||
    st_connected,
 | 
			
		||||
    st_closing,
 | 
			
		||||
    st_closed,
 | 
			
		||||
    st_error
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
public:
 | 
			
		||||
  WinTLSSession(WinTLSContext* ctx);
 | 
			
		||||
 | 
			
		||||
  // MUST deallocate all resources
 | 
			
		||||
  virtual ~WinTLSSession();
 | 
			
		||||
 | 
			
		||||
  // Initializes SSL/TLS session. The |sockfd| is the underlying
 | 
			
		||||
  // tranport socket. This function returns TLS_ERR_OK if it
 | 
			
		||||
  // succeeds, or TLS_ERR_ERROR.
 | 
			
		||||
  virtual int init(sock_t sockfd) CXX11_OVERRIDE;
 | 
			
		||||
 | 
			
		||||
  // Sets |hostname| for TLS SNI extension. This is only meaningful for
 | 
			
		||||
  // client side session. This function returns TLS_ERR_OK if it
 | 
			
		||||
  // succeeds, or TLS_ERR_ERROR.
 | 
			
		||||
  virtual int setSNIHostname(const std::string& hostname) CXX11_OVERRIDE;
 | 
			
		||||
 | 
			
		||||
  // Closes the SSL/TLS session. Don't close underlying transport
 | 
			
		||||
  // socket. This function returns TLS_ERR_OK if it succeeds, or
 | 
			
		||||
  // TLS_ERR_ERROR.
 | 
			
		||||
  virtual int closeConnection() CXX11_OVERRIDE;
 | 
			
		||||
 | 
			
		||||
  // Returns TLS_WANT_READ if SSL/TLS session needs more data from
 | 
			
		||||
  // remote endpoint to proceed, or TLS_WANT_WRITE if SSL/TLS session
 | 
			
		||||
  // needs to write more data to proceed. If SSL/TLS session needs
 | 
			
		||||
  // neither read nor write data at the moment, return value is
 | 
			
		||||
  // undefined.
 | 
			
		||||
  virtual int checkDirection() CXX11_OVERRIDE;
 | 
			
		||||
 | 
			
		||||
  // Sends |data| with length |len|. This function returns the number
 | 
			
		||||
  // of bytes sent if it succeeds, or TLS_ERR_WOULDBLOCK if the
 | 
			
		||||
  // underlying tranport blocks, or TLS_ERR_ERROR.
 | 
			
		||||
  virtual ssize_t writeData(const void* data, size_t len) CXX11_OVERRIDE;
 | 
			
		||||
 | 
			
		||||
  // Receives data into |data| with length |len|. This function returns
 | 
			
		||||
  // the number of bytes received if it succeeds, or TLS_ERR_WOULDBLOCK
 | 
			
		||||
  // if the underlying tranport blocks, or TLS_ERR_ERROR.
 | 
			
		||||
  virtual ssize_t readData(void* data, size_t len) CXX11_OVERRIDE;
 | 
			
		||||
 | 
			
		||||
  // Performs client side handshake. The |hostname| is the hostname of
 | 
			
		||||
  // the remote endpoint and is used to verify its certificate. This
 | 
			
		||||
  // function returns TLS_ERR_OK if it succeeds, or TLS_ERR_WOULDBLOCK
 | 
			
		||||
  // if the underlying transport blocks, or TLS_ERR_ERROR.
 | 
			
		||||
  // When returning TLS_ERR_ERROR, provide certificate validation error
 | 
			
		||||
  // in |handshakeErr|.
 | 
			
		||||
  virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr) CXX11_OVERRIDE;
 | 
			
		||||
 | 
			
		||||
  // Performs server side handshake. This function returns TLS_ERR_OK
 | 
			
		||||
  // if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport
 | 
			
		||||
  // blocks, or TLS_ERR_ERROR.
 | 
			
		||||
  virtual int tlsAccept() CXX11_OVERRIDE;
 | 
			
		||||
 | 
			
		||||
  // Returns last error string
 | 
			
		||||
  virtual std::string getLastErrorString() CXX11_OVERRIDE;
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  std::string hostname_;
 | 
			
		||||
  sock_t sockfd_;
 | 
			
		||||
  TLSSessionSide side_;
 | 
			
		||||
  CredHandle* cred_;
 | 
			
		||||
  CtxtHandle handle_;
 | 
			
		||||
 | 
			
		||||
  // Buffer for already encrypted writes
 | 
			
		||||
  wintls::Buffer writeBuf_;
 | 
			
		||||
  // While the writeBuf_ holds encrypted messages, writeBuffered_ has the
 | 
			
		||||
  // corresponding size of unencrpted data used to procude the messages.
 | 
			
		||||
  size_t writeBuffered_;
 | 
			
		||||
  // Buffer for still encrypted reads
 | 
			
		||||
  wintls::Buffer readBuf_;
 | 
			
		||||
  // Buffer for already decrypted reads
 | 
			
		||||
  wintls::Buffer decBuf_;
 | 
			
		||||
 | 
			
		||||
  state_t state_;
 | 
			
		||||
 | 
			
		||||
  SECURITY_STATUS status_;
 | 
			
		||||
  std::unique_ptr<SecPkgContext_StreamSizes> streamSizes_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace aria2
 | 
			
		||||
 | 
			
		||||
#endif // TLS_SESSION_H
 | 
			
		||||
		Loading…
	
		Reference in New Issue