mirror of https://github.com/aria2/aria2
Implement WinTLS
parent
3f1d293ed1
commit
00dd83b461
42
configure.ac
42
configure.ac
|
@ -40,7 +40,7 @@ AC_DEFINE_UNQUOTED([TARGET], ["$target"], [Define target-type])
|
|||
# Checks for arguments.
|
||||
ARIA2_ARG_WITHOUT([libuv])
|
||||
ARIA2_ARG_WITHOUT([appletls])
|
||||
ARIA2_ARG_WITHOUT([wintls])
|
||||
ARIA2_ARG_WITH([wintls])
|
||||
ARIA2_ARG_WITHOUT([gnutls])
|
||||
ARIA2_ARG_WITHOUT([libnettle])
|
||||
ARIA2_ARG_WITHOUT([libgmp])
|
||||
|
@ -337,23 +337,39 @@ if test "x$with_appletls" = "xyes"; then
|
|||
fi
|
||||
|
||||
if test "x$with_wintls" = "xyes"; then
|
||||
AC_SEARCH_LIBS([CryptAcquireContextW], [advapi32], [
|
||||
AC_CHECK_HEADER([wincrypt.h], [have_wincrypt=yes], [have_wincrypt=no],
|
||||
[[
|
||||
AC_HAVE_LIBRARY([crypt32],[have_wintls_libs=yes],[have_wintls_libs=no])
|
||||
AC_HAVE_LIBRARY([secur32],[have_wintls_libs=$have_wintls_libs],[have_wintls_libs=no])
|
||||
AC_HAVE_LIBRARY([advapi32],[have_wintls_libs=$have_wintls_libs],[have_wintls_libs=no])
|
||||
AC_CHECK_HEADER([wincrypt.h], [have_wintls_headers=yes], [have_wintls_headers=no], [[
|
||||
#ifdef HAVE_WINDOWS_H
|
||||
# include <windows.h>
|
||||
#endif
|
||||
]])
|
||||
break;
|
||||
], [have_wincrypt=no])
|
||||
if test "x$have_wincrypt" != "xyes"; then
|
||||
]])
|
||||
AC_CHECK_HEADER([security.h], [have_wintls_headers=$have_wintls_headers], [have_wintls_headers=no], [[
|
||||
#ifdef HAVE_WINDOWS_H
|
||||
# include <windows.h>
|
||||
#endif
|
||||
#ifndef SECURITY_WIN32
|
||||
#define SECURITY_WIN32 1
|
||||
#endif
|
||||
]])
|
||||
|
||||
if test "x$have_wintls_libs" = "xyes" &&
|
||||
test "x$have_wintls_headers" = "xyes"; then
|
||||
AC_DEFINE([SECURITY_WIN32], [1], [Use security.h in WIN32 mode])
|
||||
LIBS="$LIBS -lcrypt32 -lsecur32 -ladvapi32"
|
||||
have_wintls=yes
|
||||
else
|
||||
have_wintls=no
|
||||
fi
|
||||
if test "x$have_wintls" != "xyes"; then
|
||||
if test "x$with_wintls_requested" = "xyes"; then
|
||||
ARIA2_DEP_NOT_MET([wintls])
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
if test "x$with_gnutls" = "xyes" && test "x$have_appletls" != "xyes"; then
|
||||
if test "x$with_gnutls" = "xyes" && test "x$have_appletls" != "xyes" && test "x$have_wintls" != "xyes"; then
|
||||
# gnutls >= 2.8 doesn't have libgnutls-config anymore. We require
|
||||
# 2.2.0 because we use gnutls_priority_set_direct()
|
||||
PKG_CHECK_MODULES([LIBGNUTLS], [gnutls >= 2.2.0],
|
||||
|
@ -371,7 +387,7 @@ if test "x$with_gnutls" = "xyes" && test "x$have_appletls" != "xyes"; then
|
|||
fi
|
||||
fi
|
||||
|
||||
if test "x$with_openssl" = "xyes" && test "x$have_appletls" != "xyes" && test "x$have_libgnutls" != "xyes"; then
|
||||
if test "x$with_openssl" = "xyes" && test "x$have_appletls" != "xyes" && test "x$have_wintls" != "xyes" && test "x$have_libgnutls" != "xyes"; then
|
||||
PKG_CHECK_MODULES([OPENSSL], [openssl >= 0.9.8],
|
||||
[have_openssl=yes], [have_openssl=no])
|
||||
if test "x$have_openssl" = "xyes"; then
|
||||
|
@ -448,7 +464,7 @@ if test "x$have_appletls" == "xyes"; then
|
|||
use_md="apple"
|
||||
AC_DEFINE([USE_APPLE_MD], [1], [What message digest implementation to use])
|
||||
else
|
||||
if test "x$have_wincrypt" == "xyes"; then
|
||||
if test "x$have_wintls" == "xyes"; then
|
||||
use_md="windows"
|
||||
AC_DEFINE([USE_WINDOWS_MD], [1], [What message digest implementation to use])
|
||||
else
|
||||
|
@ -473,7 +489,7 @@ else
|
|||
fi
|
||||
|
||||
# Define variables based on the result of the checks for libraries.
|
||||
if test "x$have_appletls" = "xyes" || test "x$have_libgnutls" = "xyes" || test "x$have_openssl" = "xyes"; then
|
||||
if test "x$have_appletls" = "xyes" || test "x$have_wintls" == "xyes" || test "x$have_libgnutls" = "xyes" || test "x$have_openssl" = "xyes"; then
|
||||
have_ssl="yes"
|
||||
AC_DEFINE([ENABLE_SSL], [1], [Define to 1 if ssl support is enabled.])
|
||||
AM_CONDITIONAL([ENABLE_SSL], true)
|
||||
|
@ -485,6 +501,7 @@ fi
|
|||
|
||||
AM_CONDITIONAL([HAVE_OSX], [ test "x$have_osx" = "xyes" ])
|
||||
AM_CONDITIONAL([HAVE_APPLETLS], [ test "x$have_appletls" = "xyes" ])
|
||||
AM_CONDITIONAL([HAVE_WINTLS], [ test "x$have_wintls" = "xyes" ])
|
||||
AM_CONDITIONAL([USE_APPLE_MD], [ test "x$use_md" = "xapple" ])
|
||||
AM_CONDITIONAL([USE_WINDOWS_MD], [ test "x$use_md" = "xwindows" ])
|
||||
AM_CONDITIONAL([HAVE_LIBGNUTLS], [ test "x$have_libgnutls" = "xyes" ])
|
||||
|
@ -985,6 +1002,7 @@ echo "LibUV: $have_libuv"
|
|||
echo "SQLite3: $have_sqlite3"
|
||||
echo "SSL Support: $have_ssl"
|
||||
echo "AppleTLS: $have_appletls"
|
||||
echo "WinTLS: $have_wintls"
|
||||
echo "GnuTLS: $have_libgnutls"
|
||||
echo "OpenSSL: $have_openssl"
|
||||
echo "CA Bundle: $ca_bundle"
|
||||
|
|
|
@ -333,6 +333,11 @@ if USE_WINDOWS_MD
|
|||
SRCS += WinMessageDigestImpl.cc
|
||||
endif # USE_WINDOWS_MD
|
||||
|
||||
if HAVE_WINTLS
|
||||
SRCS += WinTLSContext.cc WinTLSContext.h \
|
||||
WinTLSSession.cc WinTLSSession.h
|
||||
endif # HAVE_WINTLS
|
||||
|
||||
if USE_INTERNAL_BIGNUM
|
||||
SRCS += InternalDHKeyExchange.cc InternalDHKeyExchange.h bignum.h
|
||||
endif
|
||||
|
|
|
@ -779,7 +779,7 @@ void SocketCore::readData(void* data, size_t& len)
|
|||
ret = tlsSession_->readData(data, len);
|
||||
if(ret < 0) {
|
||||
if(ret != TLS_ERR_WOULDBLOCK) {
|
||||
throw DL_RETRY_EX(fmt(EX_SOCKET_SEND,
|
||||
throw DL_RETRY_EX(fmt(EX_SOCKET_RECV,
|
||||
tlsSession_->getLastErrorString().c_str()));
|
||||
}
|
||||
if(tlsSession_->checkDirection() == TLS_WANT_READ) {
|
||||
|
@ -814,6 +814,7 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
|
|||
wantWrite_ = false;
|
||||
switch(secure_) {
|
||||
case A2_TLS_NONE:
|
||||
A2_LOG_DEBUG("Creating TLS session");
|
||||
tlsSession_.reset(TLSSession::make(tlsctx));
|
||||
rv = tlsSession_->init(sockfd_);
|
||||
if(rv != TLS_ERR_OK) {
|
||||
|
@ -835,6 +836,7 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
|
|||
secure_ = A2_TLS_HANDSHAKING;
|
||||
// Fall through
|
||||
case A2_TLS_HANDSHAKING:
|
||||
A2_LOG_DEBUG("TLS Handshaking");
|
||||
if(tlsctx->getSide() == TLS_CLIENT) {
|
||||
rv = tlsSession_->tlsConnect(hostname, handshakeError);
|
||||
} else {
|
||||
|
@ -857,6 +859,7 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
|
|||
}
|
||||
return false;
|
||||
default:
|
||||
A2_LOG_DEBUG("TLS else");
|
||||
break;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -0,0 +1,189 @@
|
|||
/* <!-- copyright */
|
||||
/*
|
||||
* aria2 - The high speed download utility
|
||||
*
|
||||
* Copyright (C) 2013 Nils Maier
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation; either version 2 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program; if not, write to the Free Software
|
||||
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
*
|
||||
* In addition, as a special exception, the copyright holders give
|
||||
* permission to link the code of portions of this program with the
|
||||
* OpenSSL library under certain conditions as described in each
|
||||
* individual source file, and distribute linked combinations
|
||||
* including the two.
|
||||
* You must obey the GNU General Public License in all respects
|
||||
* for all of the code used other than OpenSSL. If you modify
|
||||
* file(s) with this exception, you may extend this exception to your
|
||||
* version of the file(s), but you are not obligated to do so. If you
|
||||
* do not wish to do so, delete this exception statement from your
|
||||
* version. If you delete this exception statement from all source
|
||||
* files in the program, then also delete it here.
|
||||
*/
|
||||
/* copyright --> */
|
||||
|
||||
#include "WinTLSContext.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "BufferedFile.h"
|
||||
#include "LogFactory.h"
|
||||
#include "Logger.h"
|
||||
#include "fmt.h"
|
||||
#include "message.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace aria2 {
|
||||
|
||||
WinTLSContext::WinTLSContext(TLSSessionSide side)
|
||||
: side_(side), store_(0)
|
||||
{
|
||||
memset(&credentials_, 0, sizeof(credentials_));
|
||||
credentials_.dwVersion = SCHANNEL_CRED_VERSION;
|
||||
if (side_ == TLS_CLIENT) {
|
||||
credentials_.grbitEnabledProtocols =
|
||||
SP_PROT_SSL3_CLIENT |
|
||||
SP_PROT_TLS1_CLIENT |
|
||||
SP_PROT_TLS1_1_CLIENT |
|
||||
SP_PROT_TLS1_2_CLIENT;
|
||||
}
|
||||
else {
|
||||
credentials_.grbitEnabledProtocols =
|
||||
SP_PROT_SSL3_SERVER |
|
||||
SP_PROT_TLS1_SERVER |
|
||||
SP_PROT_TLS1_1_SERVER |
|
||||
SP_PROT_TLS1_2_SERVER;
|
||||
}
|
||||
credentials_.dwMinimumCipherStrength = 128; // bit
|
||||
|
||||
setVerifyPeer(side_ == TLS_CLIENT);
|
||||
}
|
||||
|
||||
TLSContext* TLSContext::make(TLSSessionSide side)
|
||||
{
|
||||
return new WinTLSContext(side);
|
||||
}
|
||||
|
||||
WinTLSContext::~WinTLSContext()
|
||||
{
|
||||
if (store_) {
|
||||
CertCloseStore(store_, 0);
|
||||
store_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool WinTLSContext::getVerifyPeer() const
|
||||
{
|
||||
return credentials_.dwFlags & SCH_CRED_AUTO_CRED_VALIDATION;
|
||||
}
|
||||
|
||||
void WinTLSContext::setVerifyPeer(bool verify)
|
||||
{
|
||||
if (side_ == TLS_CLIENT && verify) {
|
||||
credentials_.dwFlags =
|
||||
SCH_CRED_NO_DEFAULT_CREDS |
|
||||
SCH_CRED_AUTO_CRED_VALIDATION |
|
||||
SCH_CRED_REVOCATION_CHECK_CHAIN;
|
||||
}
|
||||
else {
|
||||
credentials_.dwFlags =
|
||||
SCH_CRED_NO_DEFAULT_CREDS |
|
||||
SCH_CRED_MANUAL_CRED_VALIDATION |
|
||||
SCH_CRED_IGNORE_NO_REVOCATION_CHECK |
|
||||
SCH_CRED_IGNORE_REVOCATION_OFFLINE |
|
||||
SCH_CRED_NO_SERVERNAME_CHECK;
|
||||
}
|
||||
|
||||
// Need to initialize cred_ early, because later on it will segfault deep
|
||||
// within AcquireCredentialsHandle for whatever reason.
|
||||
cred_.reset();
|
||||
getCredHandle();
|
||||
}
|
||||
|
||||
CredHandle* WinTLSContext::getCredHandle()
|
||||
{
|
||||
if (cred_) {
|
||||
return cred_.get();
|
||||
}
|
||||
|
||||
TimeStamp ts;
|
||||
cred_.reset(new CredHandle());
|
||||
SECURITY_STATUS status = ::AcquireCredentialsHandleW(
|
||||
nullptr,
|
||||
(SEC_WCHAR*)UNISP_NAME_W,
|
||||
side_ == TLS_CLIENT ? SECPKG_CRED_OUTBOUND : SECPKG_CRED_INBOUND,
|
||||
nullptr,
|
||||
&credentials_,
|
||||
nullptr,
|
||||
nullptr,
|
||||
cred_.get(),
|
||||
&ts);
|
||||
if (status != SEC_E_OK) {
|
||||
cred_.reset();
|
||||
throw DL_ABORT_EX("Failed to initialize WinTLS context handle");
|
||||
}
|
||||
return cred_.get();
|
||||
}
|
||||
|
||||
bool WinTLSContext::addCredentialFile(const std::string& certfile,
|
||||
const std::string& keyfile)
|
||||
{
|
||||
std::stringstream ss;
|
||||
BufferedFile(certfile.c_str(), "rb").transfer(ss);
|
||||
auto data = ss.str();
|
||||
CRYPT_DATA_BLOB blob = {
|
||||
(DWORD)data.length(),
|
||||
(BYTE*)data.c_str()
|
||||
};
|
||||
if (!PFXIsPFXBlob(&blob)) {
|
||||
A2_LOG_ERROR("Not a valid PKCS12 file");
|
||||
return false;
|
||||
}
|
||||
store_ = ::PFXImportCertStore(&blob, L"",
|
||||
CRYPT_EXPORTABLE | CRYPT_USER_KEYSET);
|
||||
if (!store_) {
|
||||
store_ = ::PFXImportCertStore(&blob, nullptr,
|
||||
CRYPT_EXPORTABLE | CRYPT_USER_KEYSET);
|
||||
}
|
||||
if (!store_) {
|
||||
A2_LOG_ERROR("Failed to import PKCS12 store");
|
||||
return false;
|
||||
}
|
||||
|
||||
const CERT_CONTEXT* ctx = ::CertEnumCertificatesInStore(store_, nullptr);
|
||||
if (!ctx) {
|
||||
A2_LOG_ERROR("Failed to read any certificates from the PKCS12 store");
|
||||
return false;
|
||||
}
|
||||
credentials_.cCreds = 1;
|
||||
credentials_.paCred = &ctx;
|
||||
|
||||
// Need to initialize cred_ early, because later on it will segfault deep
|
||||
// within AcquireCredentialsHandle for whatever reason.
|
||||
cred_.reset();
|
||||
getCredHandle();
|
||||
|
||||
CertFreeCertificateContext(ctx);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool WinTLSContext::addTrustedCACertFile(const std::string& certfile)
|
||||
{
|
||||
A2_LOG_INFO("TLS CA bundle files are not supported. "
|
||||
"The system trust store will be used.");
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace aria2
|
|
@ -0,0 +1,116 @@
|
|||
/* <!-- copyright */
|
||||
/*
|
||||
* aria2 - The high speed download utility
|
||||
*
|
||||
* Copyright (C) 2013 Nils Maier
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation; either version 2 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program; if not, write to the Free Software
|
||||
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
*
|
||||
* In addition, as a special exception, the copyright holders give
|
||||
* permission to link the code of portions of this program with the
|
||||
* OpenSSL library under certain conditions as described in each
|
||||
* individual source file, and distribute linked combinations
|
||||
* including the two.
|
||||
* You must obey the GNU General Public License in all respects
|
||||
* for all of the code used other than OpenSSL. If you modify
|
||||
* file(s) with this exception, you may extend this exception to your
|
||||
* version of the file(s), but you are not obligated to do so. If you
|
||||
* do not wish to do so, delete this exception statement from your
|
||||
* version. If you delete this exception statement from all source
|
||||
* files in the program, then also delete it here.
|
||||
*/
|
||||
/* copyright --> */
|
||||
|
||||
#ifndef D_WIN_TLS_CONTEXT_H
|
||||
#define D_WIN_TLS_CONTEXT_H
|
||||
|
||||
#include "common.h"
|
||||
#include "config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <windows.h>
|
||||
#include <security.h>
|
||||
#include <schnlsp.h>
|
||||
|
||||
#include "TLSContext.h"
|
||||
#include "DlAbortEx.h"
|
||||
|
||||
#ifndef SP_PROT_TLS1_1_CLIENT
|
||||
#define SP_PROT_TLS1_1_CLIENT 0x00000200
|
||||
#endif
|
||||
#ifndef SP_PROT_TLS1_1_SERVER
|
||||
#define SP_PROT_TLS1_1_SERVER 0x00000100
|
||||
#endif
|
||||
#ifndef SP_PROT_TLS1_2_CLIENT
|
||||
#define SP_PROT_TLS1_2_CLIENT 0x00000800
|
||||
#endif
|
||||
#ifndef SP_PROT_TLS1_2_SERVER
|
||||
#define SP_PROT_TLS1_2_SERVER 0x00000400
|
||||
#endif
|
||||
|
||||
namespace aria2 {
|
||||
|
||||
namespace wintls {
|
||||
struct cred_deleter{
|
||||
void operator()(CredHandle* handle) {
|
||||
if (handle) {
|
||||
FreeCredentialsHandle(handle);
|
||||
delete handle;
|
||||
}
|
||||
}
|
||||
};
|
||||
typedef std::unique_ptr<CredHandle, cred_deleter> CredPtr;
|
||||
} // namespace wintls
|
||||
|
||||
class WinTLSContext : public TLSContext {
|
||||
public:
|
||||
WinTLSContext(TLSSessionSide side);
|
||||
virtual ~WinTLSContext();
|
||||
|
||||
// private key `keyfile' must be decrypted.
|
||||
virtual bool addCredentialFile(const std::string& certfile,
|
||||
const std::string& keyfile) CXX11_OVERRIDE;
|
||||
|
||||
virtual bool addSystemTrustedCACerts() CXX11_OVERRIDE {
|
||||
return true;
|
||||
}
|
||||
|
||||
// certfile can contain multiple certificates.
|
||||
virtual bool addTrustedCACertFile(const std::string& certfile)
|
||||
CXX11_OVERRIDE;
|
||||
|
||||
virtual bool good() const CXX11_OVERRIDE {
|
||||
return true;
|
||||
}
|
||||
virtual TLSSessionSide getSide() const CXX11_OVERRIDE {
|
||||
return side_;
|
||||
}
|
||||
|
||||
virtual bool getVerifyPeer() const CXX11_OVERRIDE;
|
||||
virtual void setVerifyPeer(bool verify) CXX11_OVERRIDE;
|
||||
|
||||
CredHandle* getCredHandle();
|
||||
|
||||
private:
|
||||
TLSSessionSide side_;
|
||||
SCHANNEL_CRED credentials_;
|
||||
HCERTSTORE store_;
|
||||
wintls::CredPtr cred_;
|
||||
};
|
||||
|
||||
} // namespace aria2
|
||||
|
||||
#endif // D_LIBSSL_TLS_CONTEXT_H
|
|
@ -0,0 +1,816 @@
|
|||
/* <!-- copyright */
|
||||
/*
|
||||
* aria2 - The high speed download utility
|
||||
*
|
||||
* Copyright (C) 2013 Nils Maier
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation; either version 2 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program; if not, write to the Free Software
|
||||
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
*
|
||||
* In addition, as a special exception, the copyright holders give
|
||||
* permission to link the code of portions of this program with the
|
||||
* OpenSSL library under certain conditions as described in each
|
||||
* individual source file, and distribute linked combinations
|
||||
* including the two.
|
||||
* You must obey the GNU General Public License in all respects
|
||||
* for all of the code used other than OpenSSL. If you modify
|
||||
* file(s) with this exception, you may extend this exception to your
|
||||
* version of the file(s), but you are not obligated to do so. If you
|
||||
* do not wish to do so, delete this exception statement from your
|
||||
* version. If you delete this exception statement from all source
|
||||
* files in the program, then also delete it here.
|
||||
*/
|
||||
/* copyright --> */
|
||||
|
||||
#include "WinTLSSession.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "LogFactory.h"
|
||||
#include "a2functional.h"
|
||||
#include "fmt.h"
|
||||
#include "util.h"
|
||||
|
||||
#ifndef SECBUFFER_ALERT
|
||||
#define SECBUFFER_ALERT 17
|
||||
#endif
|
||||
|
||||
#ifndef SZ_ALG_MAX_SIZE
|
||||
#define SZ_ALG_MAX_SIZE 64
|
||||
#endif
|
||||
#ifndef SECPKGCONTEXT_CIPHERINFO_V1
|
||||
#define SECPKGCONTEXT_CIPHERINFO_V1 1
|
||||
#endif
|
||||
#ifndef SECPKG_ATTR_CIPHER_INFO
|
||||
#define SECPKG_ATTR_CIPHER_INFO 0x64
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
using namespace aria2;
|
||||
|
||||
struct WinSecPkgContext_CipherInfo {
|
||||
DWORD dwVersion;
|
||||
DWORD dwProtocol;
|
||||
DWORD dwCipherSuite;
|
||||
DWORD dwBaseCipherSuite;
|
||||
WCHAR szCipherSuite[SZ_ALG_MAX_SIZE];
|
||||
WCHAR szCipher[SZ_ALG_MAX_SIZE];
|
||||
DWORD dwCipherLen;
|
||||
DWORD dwCipherBlockLen; // in bytes
|
||||
WCHAR szHash[SZ_ALG_MAX_SIZE];
|
||||
DWORD dwHashLen;
|
||||
WCHAR szExchange[SZ_ALG_MAX_SIZE];
|
||||
DWORD dwMinExchangeLen;
|
||||
DWORD dwMaxExchangeLen;
|
||||
WCHAR szCertificate[SZ_ALG_MAX_SIZE];
|
||||
DWORD dwKeyType;
|
||||
};
|
||||
|
||||
static const ULONG kReqFlags = ISC_REQ_SEQUENCE_DETECT |
|
||||
ISC_REQ_REPLAY_DETECT |
|
||||
ISC_REQ_CONFIDENTIALITY |
|
||||
ISC_REQ_ALLOCATE_MEMORY |
|
||||
ISC_REQ_STREAM;
|
||||
static const ULONG kReqAFlags = ASC_REQ_SEQUENCE_DETECT |
|
||||
ASC_REQ_REPLAY_DETECT |
|
||||
ASC_REQ_CONFIDENTIALITY |
|
||||
ASC_REQ_EXTENDED_ERROR |
|
||||
ASC_REQ_ALLOCATE_MEMORY |
|
||||
ASC_REQ_STREAM;
|
||||
|
||||
class TLSBuffer : public ::SecBuffer {
|
||||
public:
|
||||
explicit TLSBuffer(ULONG type, ULONG size, void *data)
|
||||
{
|
||||
cbBuffer = size;
|
||||
BufferType = type;
|
||||
pvBuffer = data;
|
||||
}
|
||||
};
|
||||
|
||||
class TLSBufferDesc: public ::SecBufferDesc {
|
||||
public:
|
||||
explicit TLSBufferDesc(SecBuffer *arr, ULONG buffers)
|
||||
{
|
||||
ulVersion = SECBUFFER_VERSION;
|
||||
cBuffers = buffers;
|
||||
pBuffers = arr;
|
||||
}
|
||||
};
|
||||
|
||||
inline static std::string getCipherSuite(CtxtHandle *handle)
|
||||
{
|
||||
WinSecPkgContext_CipherInfo info = { SECPKGCONTEXT_CIPHERINFO_V1 };
|
||||
if (QueryContextAttributes(handle, SECPKG_ATTR_CIPHER_INFO, &info) ==
|
||||
SEC_E_OK) {
|
||||
return wCharToUtf8(info.szCipherSuite);
|
||||
}
|
||||
return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
namespace aria2 {
|
||||
|
||||
TLSSession* TLSSession::make(TLSContext* ctx)
|
||||
{
|
||||
return new WinTLSSession(static_cast<WinTLSContext*>(ctx));
|
||||
}
|
||||
|
||||
WinTLSSession::WinTLSSession(WinTLSContext* ctx)
|
||||
: sockfd_(0),
|
||||
side_(ctx->getSide()),
|
||||
cred_(ctx->getCredHandle()),
|
||||
writeBuffered_(0),
|
||||
state_(st_constructed),
|
||||
status_(SEC_E_OK)
|
||||
{
|
||||
memset(&handle_, 0, sizeof(handle_));
|
||||
}
|
||||
|
||||
WinTLSSession::~WinTLSSession()
|
||||
{
|
||||
::DeleteSecurityContext(&handle_);
|
||||
state_ = st_error;
|
||||
}
|
||||
|
||||
int WinTLSSession::init(sock_t sockfd)
|
||||
{
|
||||
if (state_ != st_constructed) {
|
||||
status_ = SEC_E_INVALID_HANDLE;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
sockfd_ = sockfd;
|
||||
state_ = st_initialized;
|
||||
|
||||
return TLS_ERR_OK;
|
||||
}
|
||||
|
||||
int WinTLSSession::setSNIHostname(const std::string& hostname)
|
||||
{
|
||||
if (state_ != st_initialized) {
|
||||
status_ = SEC_E_INVALID_HANDLE;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
hostname_ = hostname;
|
||||
return TLS_ERR_OK;
|
||||
}
|
||||
|
||||
int WinTLSSession::closeConnection()
|
||||
{
|
||||
if (state_ != st_connected || state_ != st_closing) {
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
|
||||
if (state_ == st_connected) {
|
||||
state_ = st_closing;
|
||||
|
||||
DWORD dwShut = SCHANNEL_SHUTDOWN;
|
||||
TLSBuffer shut(SECBUFFER_TOKEN, sizeof(dwShut), &dwShut);
|
||||
TLSBufferDesc shutDesc(&shut, 1);
|
||||
status_ = ::ApplyControlToken(&handle_, &shutDesc);
|
||||
if (status_ != SEC_E_OK) {
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
TLSBuffer ctx(SECBUFFER_EMPTY, 0, nullptr);
|
||||
TLSBufferDesc desc(&ctx, 1);
|
||||
ULONG flags = 0;
|
||||
if (side_ == TLS_CLIENT) {
|
||||
SEC_CHAR* host = hostname_.empty() ?
|
||||
nullptr :
|
||||
const_cast<SEC_CHAR*>(hostname_.c_str());
|
||||
status_ = ::InitializeSecurityContext(
|
||||
cred_,
|
||||
&handle_,
|
||||
host,
|
||||
kReqFlags,
|
||||
0,
|
||||
0,
|
||||
nullptr,
|
||||
0,
|
||||
&handle_,
|
||||
&desc,
|
||||
&flags,
|
||||
nullptr);
|
||||
}
|
||||
else {
|
||||
status_ = ::AcceptSecurityContext(
|
||||
cred_,
|
||||
&handle_,
|
||||
nullptr,
|
||||
kReqAFlags,
|
||||
0,
|
||||
&handle_,
|
||||
&desc,
|
||||
&flags,
|
||||
nullptr);
|
||||
}
|
||||
if (status_ == SEC_E_OK || status_== SEC_I_CONTEXT_EXPIRED) {
|
||||
size_t len = ctx.cbBuffer;
|
||||
ssize_t rv = writeData(ctx.pvBuffer, ctx.cbBuffer);
|
||||
::FreeContextBuffer(ctx.pvBuffer);
|
||||
if (rv == TLS_ERR_WOULDBLOCK) {
|
||||
return rv;
|
||||
}
|
||||
|
||||
// Alright data is sent or buffered
|
||||
if (rv - len != 0) {
|
||||
return TLS_ERR_WOULDBLOCK;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send remaining data.
|
||||
while (writeBuf_.size()) {
|
||||
int rv = writeData(nullptr, 0);
|
||||
if (rv == TLS_ERR_WOULDBLOCK) {
|
||||
return rv;
|
||||
}
|
||||
}
|
||||
|
||||
state_ = st_closed;
|
||||
return TLS_ERR_OK;
|
||||
}
|
||||
|
||||
int WinTLSSession::checkDirection()
|
||||
{
|
||||
if (state_ == st_handshake_write || state_ == st_handshake_write_last) {
|
||||
return TLS_WANT_WRITE;
|
||||
}
|
||||
if (state_ == st_handshake_read) {
|
||||
return TLS_WANT_READ;
|
||||
}
|
||||
if (readBuf_.size() || decBuf_.size()) {
|
||||
return TLS_WANT_READ;
|
||||
}
|
||||
if (writeBuf_.size()) {
|
||||
return TLS_WANT_WRITE;
|
||||
}
|
||||
return TLS_WANT_READ;
|
||||
}
|
||||
|
||||
ssize_t WinTLSSession::writeData(const void* data, size_t len)
|
||||
{
|
||||
if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
|
||||
state_ == st_handshake_read) {
|
||||
// Renegotiating
|
||||
std::string hn, err;
|
||||
auto connect = tlsConnect(hn, err);
|
||||
if (connect != TLS_ERR_OK) {
|
||||
return connect;
|
||||
}
|
||||
// Continue.
|
||||
}
|
||||
|
||||
if (state_ != st_connected && state_ != st_closing) {
|
||||
status_ = SEC_E_INVALID_HANDLE;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
|
||||
A2_LOG_DEBUG(fmt("WinTLS: Write request: %" PRIu64 " buffered: %" PRIu64,
|
||||
(uint64_t)len, (uint64_t)writeBuf_.size()));
|
||||
|
||||
// Write remaining buffered data, if any.
|
||||
size_t written = 0;
|
||||
while (writeBuf_.size()) {
|
||||
written = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
|
||||
errno = ::WSAGetLastError();
|
||||
if (written < 0 && errno == WSAEINTR) {
|
||||
continue;
|
||||
}
|
||||
if (written < 0 && errno == WSAEWOULDBLOCK) {
|
||||
return TLS_ERR_WOULDBLOCK;
|
||||
}
|
||||
if (written == 0) {
|
||||
return written;
|
||||
}
|
||||
if (written < 0) {
|
||||
status_ = SEC_E_INVALID_HANDLE;
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
writeBuf_.eat(written);
|
||||
}
|
||||
|
||||
if (len == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!streamSizes_) {
|
||||
streamSizes_.reset(new SecPkgContext_StreamSizes());
|
||||
status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
|
||||
streamSizes_.get());
|
||||
if (status_ != SEC_E_OK || !streamSizes_->cbMaximumMessage) {
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
size_t process = len;
|
||||
auto bytes = reinterpret_cast<const char*>(data);
|
||||
if (writeBuffered_) {
|
||||
// There was buffered data, hence we need to "remove" that data from the
|
||||
// incoming buffer to avoid writing it again
|
||||
if (len < writeBuffered_) {
|
||||
// We didn't get called with the same data again, obviously.
|
||||
status_ = SEC_E_INVALID_HANDLE;
|
||||
status_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
// just advance the buffer by writeBuffered_ bytes
|
||||
bytes += writeBuffered_;
|
||||
process -= writeBuffered_;
|
||||
writeBuffered_ = 0;
|
||||
}
|
||||
if (!process) {
|
||||
// The buffer contained the full remainder. At this point, the buffer has
|
||||
// been written, so the request is done in its entirety;
|
||||
return len;
|
||||
}
|
||||
|
||||
// Buffered data was already written ;)
|
||||
// If there was no buffered data, this will be len - len = 0.
|
||||
len = len - process;
|
||||
while (process) {
|
||||
// Set up an outgoing message, according to streamSizes_
|
||||
writeBuffered_ = std::min(process, (size_t)streamSizes_->cbMaximumMessage);
|
||||
size_t dl = streamSizes_->cbHeader + writeBuffered_ +
|
||||
streamSizes_->cbTrailer;
|
||||
auto buf = make_unique<char[]>(dl);
|
||||
TLSBuffer buffers[] = {
|
||||
TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_->cbHeader, buf.get()),
|
||||
TLSBuffer(SECBUFFER_DATA, writeBuffered_,
|
||||
buf.get() + streamSizes_->cbHeader),
|
||||
TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_->cbTrailer,
|
||||
buf.get() + streamSizes_->cbHeader + writeBuffered_),
|
||||
TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
|
||||
};
|
||||
TLSBufferDesc desc(buffers, 4);
|
||||
memcpy(buffers[1].pvBuffer, bytes, writeBuffered_);
|
||||
status_ = ::EncryptMessage(&handle_, 0, &desc, 0);
|
||||
if (status_ != SEC_E_OK) {
|
||||
A2_LOG_ERROR(fmt("WinTLS: Failed to encrypt a message! %s",
|
||||
getLastErrorString().c_str()));
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
|
||||
// EncryptMessage may have truncated the buffers.
|
||||
// Should rarely happen, if ever, except for the trailer.
|
||||
dl = buffers[0].cbBuffer;
|
||||
if (dl < streamSizes_->cbHeader) {
|
||||
// Move message.
|
||||
memmove(buf.get() + dl, buffers[1].pvBuffer, buffers[1].cbBuffer);
|
||||
}
|
||||
dl += buffers[1].cbBuffer;
|
||||
if (dl < streamSizes_->cbHeader + writeBuffered_) {
|
||||
// Move tailer.
|
||||
memmove(buf.get() + dl, buffers[2].pvBuffer, buffers[2].cbBuffer);
|
||||
}
|
||||
dl += buffers[2].cbBuffer;
|
||||
|
||||
// Write (or buffer) the message.
|
||||
char* p = buf.get();
|
||||
while (dl) {
|
||||
written = ::send(sockfd_, p, dl, 0);
|
||||
errno = ::WSAGetLastError();
|
||||
if (written < 0 && errno == WSAEINTR) {
|
||||
continue;
|
||||
}
|
||||
if (written < 0 && errno == WSAEWOULDBLOCK) {
|
||||
// Buffer the rest of the message...
|
||||
writeBuf_.write(p, dl);
|
||||
// and return...
|
||||
return len;
|
||||
}
|
||||
if (written == 0) {
|
||||
A2_LOG_ERROR("WinTLS: Connection closed while writing");
|
||||
status_ = SEC_E_INCOMPLETE_MESSAGE;
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
if (written < 0) {
|
||||
A2_LOG_ERROR("WinTLS: Connection error while writing");
|
||||
status_ = SEC_E_INCOMPLETE_MESSAGE;
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
dl -= written;
|
||||
p += written;
|
||||
}
|
||||
|
||||
len += writeBuffered_;
|
||||
bytes += writeBuffered_;
|
||||
process -= writeBuffered_;
|
||||
writeBuffered_ = 0;
|
||||
}
|
||||
|
||||
A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64 " buffered: %" PRIu64,
|
||||
(uint64_t)len, (uint64_t)writeBuf_.size()));
|
||||
if (!len) {
|
||||
return TLS_ERR_WOULDBLOCK;
|
||||
}
|
||||
return len;
|
||||
}
|
||||
|
||||
ssize_t WinTLSSession::readData(void* data, size_t len)
|
||||
{
|
||||
A2_LOG_DEBUG(fmt("WinTLS: Read request: %" PRIu64 " buffered: %" PRIu64,
|
||||
(uint64_t)len, (uint64_t)readBuf_.size()));
|
||||
if (len == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Can be filled from decBuffer entirely?
|
||||
if (decBuf_.size() >= len) {
|
||||
A2_LOG_DEBUG("WinTLS: Fullfilling req from buffer");
|
||||
memcpy(data, decBuf_.data(), len);
|
||||
decBuf_.eat(len);
|
||||
return len;
|
||||
}
|
||||
|
||||
if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
|
||||
state_ == st_handshake_read) {
|
||||
// Renegotiating
|
||||
std::string hn, err;
|
||||
auto connect = tlsConnect(hn, err);
|
||||
if (connect != TLS_ERR_OK) {
|
||||
return connect;
|
||||
}
|
||||
// Continue.
|
||||
}
|
||||
if (state_ != st_connected) {
|
||||
status_ = SEC_E_INVALID_HANDLE;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
|
||||
// Read as many bytes as available from the connection, up to len + 4k.
|
||||
readBuf_.resize(len + 4096);
|
||||
while (readBuf_.free()) {
|
||||
ssize_t read = ::recv(sockfd_, readBuf_.end(), readBuf_.free(), 0);
|
||||
errno = ::WSAGetLastError();
|
||||
if (read < 0 && errno == WSAEINTR) {
|
||||
continue;
|
||||
}
|
||||
if (read < 0 && errno == WSAEWOULDBLOCK) {
|
||||
break;
|
||||
}
|
||||
if (read == 0) {
|
||||
break;
|
||||
}
|
||||
if (read < 0) {
|
||||
status_ = SEC_E_INCOMPLETE_MESSAGE;
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
readBuf_.advance(read);
|
||||
}
|
||||
|
||||
// Try to decrypt as many messages as possible from the readBuf_.
|
||||
while (readBuf_.size()) {
|
||||
TLSBuffer bufs[] = {
|
||||
TLSBuffer(SECBUFFER_DATA, readBuf_.size(), readBuf_.data()),
|
||||
TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
|
||||
TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
|
||||
TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
|
||||
};
|
||||
TLSBufferDesc desc(bufs, 4);
|
||||
status_ = ::DecryptMessage(&handle_, &desc, 0, nullptr);
|
||||
if (status_ == SEC_E_INCOMPLETE_MESSAGE) {
|
||||
// Need to stop now, and wait for more bytes to arrive on the socket.
|
||||
break;
|
||||
}
|
||||
|
||||
if (status_ != SEC_E_OK && status_ != SEC_I_CONTEXT_EXPIRED &&
|
||||
status_ != SEC_I_RENEGOTIATE) {
|
||||
A2_LOG_ERROR(fmt("WinTLS: Failed to decrypt a message! %s",
|
||||
getLastErrorString().c_str()));
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
|
||||
// Decrypted message successfully.
|
||||
bool ate = false;
|
||||
for (auto& buf : bufs) {
|
||||
if (buf.BufferType == SECBUFFER_DATA && buf.cbBuffer > 0) {
|
||||
decBuf_.write(buf.pvBuffer, buf.cbBuffer);
|
||||
}
|
||||
else if (buf.BufferType == SECBUFFER_EXTRA && buf.cbBuffer > 0) {
|
||||
readBuf_.eat(readBuf_.size() - buf.cbBuffer);
|
||||
ate = true;
|
||||
}
|
||||
}
|
||||
if (!ate) {
|
||||
readBuf_.clear();
|
||||
}
|
||||
|
||||
if (status_ == SEC_I_RENEGOTIATE) {
|
||||
// Renegotiation basically means performing another handshake
|
||||
state_ = st_initialized;
|
||||
A2_LOG_INFO("WinTLS: Renegotiate");
|
||||
std::string hn, err;
|
||||
auto connect = tlsConnect(hn, err);
|
||||
if (connect == TLS_ERR_WOULDBLOCK) {
|
||||
break;
|
||||
}
|
||||
if (connect == TLS_ERR_ERROR) {
|
||||
return connect;
|
||||
}
|
||||
// Still good.
|
||||
}
|
||||
if (status_ == SEC_I_CONTEXT_EXPIRED) {
|
||||
// Connection is gone now, but the buffered bytes are still valid.
|
||||
A2_LOG_DEBUG("WinTLS: Connection closed!");
|
||||
closeConnection();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
len = std::min(decBuf_.size(), len);
|
||||
if (len == 0) {
|
||||
return TLS_ERR_WOULDBLOCK;
|
||||
}
|
||||
memcpy(data, decBuf_.data(), len);
|
||||
decBuf_.eat(len);
|
||||
return len;
|
||||
}
|
||||
|
||||
int WinTLSSession::tlsConnect(const std::string& hostname,
|
||||
std::string& handshakeErr)
|
||||
{
|
||||
// Handshaking will require sending multiple read/write exchanges until the
|
||||
// handshake is actually done. The client will first generate the initial
|
||||
// handshake message, then write that to the server, read the response
|
||||
// message, and write and/or read additional messages until the handshake is
|
||||
// either complete and successful, or something went wrong.
|
||||
// The server works analog to that.
|
||||
|
||||
A2_LOG_DEBUG("WinTLS: Starting/Resuming TLS Connect");
|
||||
ULONG flags = 0;
|
||||
|
||||
restart:
|
||||
|
||||
switch (state_) {
|
||||
default:
|
||||
A2_LOG_ERROR("WinTLS: Invalid state");
|
||||
status_ = SEC_E_INVALID_HANDLE;
|
||||
return TLS_ERR_ERROR;
|
||||
|
||||
case st_initialized: {
|
||||
if (side_ == TLS_SERVER) {
|
||||
goto read;
|
||||
}
|
||||
|
||||
if (!hostname.empty()) {
|
||||
setSNIHostname(hostname);
|
||||
}
|
||||
A2_LOG_DEBUG("WinTLS: Initializing handshake");
|
||||
TLSBuffer buf(SECBUFFER_EMPTY, 0, nullptr);
|
||||
TLSBufferDesc desc(&buf, 1);
|
||||
SEC_CHAR* host = hostname_.empty() ?
|
||||
nullptr :
|
||||
const_cast<SEC_CHAR*>(hostname_.c_str());
|
||||
status_ = ::InitializeSecurityContext(
|
||||
cred_,
|
||||
nullptr,
|
||||
host,
|
||||
kReqFlags,
|
||||
0,
|
||||
0,
|
||||
nullptr,
|
||||
0,
|
||||
&handle_,
|
||||
&desc,
|
||||
&flags,
|
||||
nullptr);
|
||||
if (status_ != SEC_I_CONTINUE_NEEDED) {
|
||||
// Has to be SEC_I_CONTINUE_NEEDED, as we did not actually send data
|
||||
// at this point.
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
|
||||
// Queue the initial message...
|
||||
writeBuf_.write(buf.pvBuffer, buf.cbBuffer);
|
||||
FreeContextBuffer(buf.pvBuffer);
|
||||
|
||||
// ... and start sending it
|
||||
state_ = st_handshake_write;
|
||||
}
|
||||
// Fall through
|
||||
|
||||
case st_handshake_write_last:
|
||||
case st_handshake_write: {
|
||||
A2_LOG_DEBUG("WinTLS: Writing handshake");
|
||||
|
||||
// Write the currently queued handshake message until all data is sent.
|
||||
while(writeBuf_.size()) {
|
||||
ssize_t writ = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
|
||||
errno = ::WSAGetLastError();
|
||||
if (writ < 0 && errno == WSAEINTR) {
|
||||
continue;
|
||||
}
|
||||
if (writ < 0 && errno == WSAEWOULDBLOCK) {
|
||||
return TLS_ERR_WOULDBLOCK;
|
||||
}
|
||||
if (writ <= 0) {
|
||||
status_ = SEC_E_INCOMPLETE_MESSAGE;
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
writeBuf_.eat(writ);
|
||||
}
|
||||
|
||||
if (state_ == st_handshake_write_last) {
|
||||
state_ = st_handshake_done;
|
||||
goto restart;
|
||||
}
|
||||
|
||||
// Have to read one or more response messages.
|
||||
state_ = st_handshake_read;
|
||||
}
|
||||
// Fall through
|
||||
|
||||
case st_handshake_read: {
|
||||
read:
|
||||
A2_LOG_DEBUG("WinTLS: Reading handshake...");
|
||||
|
||||
// All write buffered data is invalid at this point!
|
||||
writeBuf_.clear();
|
||||
|
||||
// Read as many bytes as possible, up to 4k new bytes.
|
||||
// We do not know how many bytes will arrive from the server at this
|
||||
// point.
|
||||
readBuf_.resize(readBuf_.size() + 4096);
|
||||
while (readBuf_.free()) {
|
||||
ssize_t read = ::recv(sockfd_, readBuf_.end(), readBuf_.free(), 0);
|
||||
errno = ::WSAGetLastError();
|
||||
if (read < 0 && errno == WSAEINTR) {
|
||||
continue;
|
||||
}
|
||||
if (read < 0 && errno == WSAEWOULDBLOCK) {
|
||||
break;
|
||||
}
|
||||
if (read <= 0) {
|
||||
status_ = SEC_E_INCOMPLETE_MESSAGE;
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
readBuf_.advance(read);
|
||||
break;
|
||||
}
|
||||
if (!readBuf_.size()) {
|
||||
return TLS_ERR_WOULDBLOCK;
|
||||
}
|
||||
|
||||
// Need to copy the data, as Schannel is free to mess with it. But we
|
||||
// might later need unmodified data from the original read buffer.
|
||||
auto bufcopy = make_unique<char[]>(readBuf_.size());
|
||||
memcpy(bufcopy.get(), readBuf_.data(), readBuf_.size());
|
||||
|
||||
// Set up buffers. inbufs will be the raw bytes the library has to decode.
|
||||
// outbufs will contain generated responses, if any.
|
||||
TLSBuffer inbufs[] = {
|
||||
TLSBuffer(SECBUFFER_TOKEN, readBuf_.size(), bufcopy.get()),
|
||||
TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
|
||||
};
|
||||
TLSBufferDesc indesc(inbufs, 2);
|
||||
TLSBuffer outbufs[] = {
|
||||
TLSBuffer(SECBUFFER_TOKEN, 0, nullptr),
|
||||
TLSBuffer(SECBUFFER_ALERT, 0, nullptr),
|
||||
};
|
||||
TLSBufferDesc outdesc(outbufs, 2);
|
||||
if (side_ == TLS_CLIENT) {
|
||||
SEC_CHAR* host = hostname_.empty() ?
|
||||
nullptr :
|
||||
const_cast<SEC_CHAR*>(hostname_.c_str());
|
||||
status_ = ::InitializeSecurityContext(
|
||||
cred_,
|
||||
&handle_,
|
||||
host,
|
||||
kReqFlags,
|
||||
0,
|
||||
0,
|
||||
&indesc,
|
||||
0,
|
||||
nullptr,
|
||||
&outdesc,
|
||||
&flags,
|
||||
nullptr);
|
||||
}
|
||||
else {
|
||||
status_ = ::AcceptSecurityContext(
|
||||
cred_,
|
||||
state_ == st_initialized ? nullptr : &handle_,
|
||||
&indesc,
|
||||
kReqAFlags,
|
||||
0,
|
||||
state_ == st_initialized ? &handle_ : nullptr,
|
||||
&outdesc,
|
||||
&flags,
|
||||
nullptr);
|
||||
}
|
||||
if (status_ == SEC_E_INCOMPLETE_MESSAGE) {
|
||||
// Not enough raw bytes read yet to decode a full message.
|
||||
return TLS_ERR_WOULDBLOCK;
|
||||
}
|
||||
if (status_ != SEC_E_OK && status_ != SEC_I_CONTINUE_NEEDED) {
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
|
||||
// Raw bytes where not entirely consumed, i.e. readBuf_ still contains
|
||||
// unprocessed data from the next message?
|
||||
if (inbufs[1].BufferType == SECBUFFER_EXTRA && inbufs[1].cbBuffer > 0) {
|
||||
readBuf_.eat(readBuf_.size() - inbufs[1].cbBuffer);
|
||||
}
|
||||
else {
|
||||
readBuf_.clear();
|
||||
}
|
||||
|
||||
// Check if the library produced a new outgoing message and queue it.
|
||||
for (auto& buf : outbufs) {
|
||||
if (buf.BufferType == SECBUFFER_TOKEN && buf.cbBuffer > 0) {
|
||||
writeBuf_.write(buf.pvBuffer, buf.cbBuffer);
|
||||
FreeContextBuffer(buf.pvBuffer);
|
||||
state_ = st_handshake_write;
|
||||
}
|
||||
}
|
||||
|
||||
// Need to read additional messages?
|
||||
if (status_ == SEC_I_CONTINUE_NEEDED) {
|
||||
A2_LOG_DEBUG("WinTLS: Continuing with handshake");
|
||||
goto restart;
|
||||
}
|
||||
|
||||
if (side_ == TLS_CLIENT && flags != kReqFlags) {
|
||||
A2_LOG_ERROR(fmt("WinTLS: Channel setup failed. Schannel provider did "
|
||||
"not fulfill requested flags. "
|
||||
"Excepted: %lu Actual: %lu",
|
||||
kReqFlags, flags));
|
||||
status_ = SEC_E_INTERNAL_ERROR;
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
|
||||
if (state_ == st_handshake_write) {
|
||||
A2_LOG_DEBUG("WinTLS: Continuing with handshake (last write)");
|
||||
state_ = st_handshake_write_last;
|
||||
goto restart;
|
||||
}
|
||||
}
|
||||
// Fall through
|
||||
|
||||
case st_handshake_done:
|
||||
// All ready now :D
|
||||
state_ = st_connected;
|
||||
A2_LOG_INFO(fmt("WinTLS: connected with: %s",
|
||||
getCipherSuite(&handle_).c_str()));
|
||||
return TLS_ERR_OK;
|
||||
}
|
||||
|
||||
A2_LOG_ERROR("WinTLS: Unreachable reached during tlsConnect! This is a bug!");
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
|
||||
int WinTLSSession::tlsAccept()
|
||||
{
|
||||
std::string host, err;
|
||||
return tlsConnect(host, err);
|
||||
}
|
||||
|
||||
std::string WinTLSSession::getLastErrorString()
|
||||
{
|
||||
std::stringstream ss;
|
||||
wchar_t* buf = nullptr;
|
||||
if (FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER |
|
||||
FORMAT_MESSAGE_FROM_SYSTEM |
|
||||
FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||
nullptr,
|
||||
status_,
|
||||
MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
|
||||
(LPWSTR)&buf,
|
||||
1024,
|
||||
nullptr) && buf) {
|
||||
ss << "Error: " << wCharToUtf8(buf);
|
||||
LocalFree(buf);
|
||||
}
|
||||
else {
|
||||
ss << "Error: " << std::hex << status_;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace aria2
|
|
@ -0,0 +1,194 @@
|
|||
/* <!-- copyright */
|
||||
/*
|
||||
* aria2 - The high speed download utility
|
||||
*
|
||||
* Copyright (C) 2013 Nils Maier
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation; either version 2 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program; if not, write to the Free Software
|
||||
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
*
|
||||
* In addition, as a special exception, the copyright holders give
|
||||
* permission to link the code of portions of this program with the
|
||||
* OpenSSL library under certain conditions as described in each
|
||||
* individual source file, and distribute linked combinations
|
||||
* including the two.
|
||||
* You must obey the GNU General Public License in all respects
|
||||
* for all of the code used other than OpenSSL. If you modify
|
||||
* file(s) with this exception, you may extend this exception to your
|
||||
* version of the file(s), but you are not obligated to do so. If you
|
||||
* do not wish to do so, delete this exception statement from your
|
||||
* version. If you delete this exception statement from all source
|
||||
* files in the program, then also delete it here.
|
||||
*/
|
||||
/* copyright --> */
|
||||
|
||||
#ifndef WIN_TLS_SESSION_H
|
||||
#define WIN_TLS_SESSION_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
#include "TLSSession.h"
|
||||
#include "WinTLSContext.h"
|
||||
|
||||
namespace aria2 {
|
||||
|
||||
namespace wintls {
|
||||
struct Buffer {
|
||||
private:
|
||||
size_t off_, free_, cap_;
|
||||
std::vector<char> buf_;
|
||||
|
||||
public:
|
||||
inline Buffer() : off_(0), free_(0), cap_(0) {}
|
||||
|
||||
inline size_t size() const {
|
||||
return off_;
|
||||
}
|
||||
inline size_t free() const {
|
||||
return free_;
|
||||
}
|
||||
inline void resize(size_t len) {
|
||||
if (cap_ >= len) {
|
||||
return;
|
||||
}
|
||||
buf_.resize(len);
|
||||
cap_ = buf_.size();
|
||||
free_ = cap_ - off_;
|
||||
}
|
||||
inline char* data() {
|
||||
return buf_.data();
|
||||
}
|
||||
inline char* end() {
|
||||
return buf_.data() + off_;
|
||||
}
|
||||
inline void eat(size_t len) {
|
||||
off_ -= len;
|
||||
if (off_) {
|
||||
memmove(buf_.data(), buf_.data() + len, off_);
|
||||
}
|
||||
free_ = cap_ - off_;
|
||||
}
|
||||
inline void clear() {
|
||||
eat(off_);
|
||||
}
|
||||
inline void advance(size_t len) {
|
||||
off_ += len;
|
||||
free_ = cap_ - off_;
|
||||
}
|
||||
inline void write(const void* data, size_t len) {
|
||||
if (!len) {
|
||||
return;
|
||||
}
|
||||
resize(off_ + len);
|
||||
memcpy(end(), data, len);
|
||||
advance(len);
|
||||
}
|
||||
};
|
||||
} // namespace wintls
|
||||
|
||||
class WinTLSSession : public TLSSession {
|
||||
enum state_t {
|
||||
st_constructed,
|
||||
st_initialized,
|
||||
st_handshake_write,
|
||||
st_handshake_write_last,
|
||||
st_handshake_read,
|
||||
st_handshake_done,
|
||||
st_connected,
|
||||
st_closing,
|
||||
st_closed,
|
||||
st_error
|
||||
};
|
||||
|
||||
public:
|
||||
WinTLSSession(WinTLSContext* ctx);
|
||||
|
||||
// MUST deallocate all resources
|
||||
virtual ~WinTLSSession();
|
||||
|
||||
// Initializes SSL/TLS session. The |sockfd| is the underlying
|
||||
// tranport socket. This function returns TLS_ERR_OK if it
|
||||
// succeeds, or TLS_ERR_ERROR.
|
||||
virtual int init(sock_t sockfd) CXX11_OVERRIDE;
|
||||
|
||||
// Sets |hostname| for TLS SNI extension. This is only meaningful for
|
||||
// client side session. This function returns TLS_ERR_OK if it
|
||||
// succeeds, or TLS_ERR_ERROR.
|
||||
virtual int setSNIHostname(const std::string& hostname) CXX11_OVERRIDE;
|
||||
|
||||
// Closes the SSL/TLS session. Don't close underlying transport
|
||||
// socket. This function returns TLS_ERR_OK if it succeeds, or
|
||||
// TLS_ERR_ERROR.
|
||||
virtual int closeConnection() CXX11_OVERRIDE;
|
||||
|
||||
// Returns TLS_WANT_READ if SSL/TLS session needs more data from
|
||||
// remote endpoint to proceed, or TLS_WANT_WRITE if SSL/TLS session
|
||||
// needs to write more data to proceed. If SSL/TLS session needs
|
||||
// neither read nor write data at the moment, return value is
|
||||
// undefined.
|
||||
virtual int checkDirection() CXX11_OVERRIDE;
|
||||
|
||||
// Sends |data| with length |len|. This function returns the number
|
||||
// of bytes sent if it succeeds, or TLS_ERR_WOULDBLOCK if the
|
||||
// underlying tranport blocks, or TLS_ERR_ERROR.
|
||||
virtual ssize_t writeData(const void* data, size_t len) CXX11_OVERRIDE;
|
||||
|
||||
// Receives data into |data| with length |len|. This function returns
|
||||
// the number of bytes received if it succeeds, or TLS_ERR_WOULDBLOCK
|
||||
// if the underlying tranport blocks, or TLS_ERR_ERROR.
|
||||
virtual ssize_t readData(void* data, size_t len) CXX11_OVERRIDE;
|
||||
|
||||
// Performs client side handshake. The |hostname| is the hostname of
|
||||
// the remote endpoint and is used to verify its certificate. This
|
||||
// function returns TLS_ERR_OK if it succeeds, or TLS_ERR_WOULDBLOCK
|
||||
// if the underlying transport blocks, or TLS_ERR_ERROR.
|
||||
// When returning TLS_ERR_ERROR, provide certificate validation error
|
||||
// in |handshakeErr|.
|
||||
virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr) CXX11_OVERRIDE;
|
||||
|
||||
// Performs server side handshake. This function returns TLS_ERR_OK
|
||||
// if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport
|
||||
// blocks, or TLS_ERR_ERROR.
|
||||
virtual int tlsAccept() CXX11_OVERRIDE;
|
||||
|
||||
// Returns last error string
|
||||
virtual std::string getLastErrorString() CXX11_OVERRIDE;
|
||||
|
||||
private:
|
||||
std::string hostname_;
|
||||
sock_t sockfd_;
|
||||
TLSSessionSide side_;
|
||||
CredHandle* cred_;
|
||||
CtxtHandle handle_;
|
||||
|
||||
// Buffer for already encrypted writes
|
||||
wintls::Buffer writeBuf_;
|
||||
// While the writeBuf_ holds encrypted messages, writeBuffered_ has the
|
||||
// corresponding size of unencrpted data used to procude the messages.
|
||||
size_t writeBuffered_;
|
||||
// Buffer for still encrypted reads
|
||||
wintls::Buffer readBuf_;
|
||||
// Buffer for already decrypted reads
|
||||
wintls::Buffer decBuf_;
|
||||
|
||||
state_t state_;
|
||||
|
||||
SECURITY_STATUS status_;
|
||||
std::unique_ptr<SecPkgContext_StreamSizes> streamSizes_;
|
||||
};
|
||||
|
||||
} // namespace aria2
|
||||
|
||||
#endif // TLS_SESSION_H
|
Loading…
Reference in New Issue