RPC over SSL/TLS transport

To enable RPC over SSL/TLS, specify server certificate and private key
using --rpc-certificate and --rpc-private-key options and enable
--rpc-secure option.  After the encryption is enabled, use https and
wss scheme to access RPC server.
pull/28/head
Tatsuhiro Tsujikawa 2012-09-30 00:50:14 +09:00
parent 4b94ede268
commit 90515dfa50
25 changed files with 419 additions and 142 deletions

View File

@ -55,7 +55,9 @@ AbstractHttpServerResponseCommand::AbstractHttpServerResponseCommand
: Command(cuid),
e_(e),
socket_(socket),
httpServer_(httpServer)
httpServer_(httpServer),
readCheck_(false),
writeCheck_(true)
{
setStatus(Command::STATUS_ONESHOT_REALTIME);
e_->addSocketForWriteCheck(socket_, this);
@ -63,8 +65,35 @@ AbstractHttpServerResponseCommand::AbstractHttpServerResponseCommand
AbstractHttpServerResponseCommand::~AbstractHttpServerResponseCommand()
{
if(readCheck_) {
e_->deleteSocketForReadCheck(socket_, this);
}
if(writeCheck_) {
e_->deleteSocketForWriteCheck(socket_, this);
}
}
void AbstractHttpServerResponseCommand::updateReadWriteCheck()
{
if(httpServer_->wantRead()) {
if(!readCheck_) {
readCheck_ = true;
e_->addSocketForReadCheck(socket_, this);
}
} else if(readCheck_) {
readCheck_ = false;
e_->deleteSocketForReadCheck(socket_, this);
}
if(httpServer_->wantWrite()) {
if(!writeCheck_) {
writeCheck_ = true;
e_->addSocketForWriteCheck(socket_, this);
}
} else if(writeCheck_) {
writeCheck_ = false;
e_->deleteSocketForWriteCheck(socket_, this);
}
}
bool AbstractHttpServerResponseCommand::execute()
{
@ -72,7 +101,10 @@ bool AbstractHttpServerResponseCommand::execute()
return true;
}
try {
httpServer_->sendResponse();
ssize_t len = httpServer_->sendResponse();
if(len > 0) {
timeoutTimer_ = global::wallclock();
}
} catch(RecoverableException& e) {
A2_LOG_INFO_EX
(fmt("CUID#%"PRId64" - Error occurred while transmitting response body.",
@ -86,12 +118,13 @@ bool AbstractHttpServerResponseCommand::execute()
afterSend(httpServer_, e_);
return true;
} else {
if(timeoutTimer_.difference(global::wallclock()) >= 10) {
if(timeoutTimer_.difference(global::wallclock()) >= 30) {
A2_LOG_INFO(fmt("CUID#%"PRId64" - HttpServer: Timeout while trasmitting"
" response.",
getCuid()));
return true;
} else {
updateReadWriteCheck();
e_->addCommand(this);
return false;
}

View File

@ -51,6 +51,10 @@ private:
SharedHandle<SocketCore> socket_;
SharedHandle<HttpServer> httpServer_;
Timer timeoutTimer_;
bool readCheck_;
bool writeCheck_;
void updateReadWriteCheck();
protected:
DownloadEngine* getDownloadEngine()
{

View File

@ -74,6 +74,7 @@
#include "DlAbortEx.h"
#include "FileAllocationEntry.h"
#include "HttpListenCommand.h"
#include "LogFactory.h"
namespace aria2 {
@ -170,11 +171,15 @@ DownloadEngineFactory::newDownloadEngine
}
if(op->getAsBool(PREF_ENABLE_RPC)) {
bool ok = false;
bool secure = op->getAsBool(PREF_RPC_SECURE);
if(secure) {
A2_LOG_NOTICE("RPC transport will be encrypted.");
}
static int families[] = { AF_INET, AF_INET6 };
size_t familiesLength = op->getAsBool(PREF_DISABLE_IPV6)?1:2;
for(size_t i = 0; i < familiesLength; ++i) {
HttpListenCommand* httpListenCommand =
new HttpListenCommand(e->newCUID(), e.get(), families[i]);
new HttpListenCommand(e->newCUID(), e.get(), families[i], secure);
if(httpListenCommand->bindPort(op->getAsInt(PREF_RPC_LISTEN_PORT))){
e->addCommand(httpListenCommand);
ok = true;

View File

@ -50,10 +50,12 @@
namespace aria2 {
HttpListenCommand::HttpListenCommand(cuid_t cuid, DownloadEngine* e, int family)
HttpListenCommand::HttpListenCommand(cuid_t cuid, DownloadEngine* e,
int family, bool secure)
: Command(cuid),
e_(e),
family_(family)
family_(family),
secure_(secure)
{}
HttpListenCommand::~HttpListenCommand()
@ -80,7 +82,7 @@ bool HttpListenCommand::execute()
peerInfo.first.c_str(), peerInfo.second));
HttpServerCommand* c =
new HttpServerCommand(e_->newCUID(), e_, socket);
new HttpServerCommand(e_->newCUID(), e_, socket, secure_);
e_->setNoWait(true);
e_->addCommand(c);
}

View File

@ -48,8 +48,9 @@ private:
DownloadEngine* e_;
int family_;
SharedHandle<SocketCore> serverSocket_;
bool secure_;
public:
HttpListenCommand(cuid_t cuid, DownloadEngine* e, int family);
HttpListenCommand(cuid_t cuid, DownloadEngine* e, int family, bool secure);
virtual ~HttpListenCommand();

View File

@ -123,8 +123,7 @@ createHttpRequest(const SharedHandle<Request>& req,
bool HttpRequestCommand::executeInternal() {
//socket->setBlockingMode();
if(getRequest()->getProtocol() == "https") {
getSocket()->prepareSecureConnection();
if(!getSocket()->initiateSecureConnection(getRequest()->getHost())) {
if(!getSocket()->tlsConnect(getRequest()->getHost())) {
setReadCheckSocketIf(getSocket(), getSocket()->wantRead());
setWriteCheckSocketIf(getSocket(), getSocket()->wantWrite());
getDownloadEngine()->addCommand(this);

View File

@ -148,13 +148,16 @@ SharedHandle<HttpHeader> HttpServer::receiveRequest()
if(setupResponseRecv() < 0) {
A2_LOG_INFO("Request path is invaild. Ignore the request body.");
}
if(!util::parseLLIntNoThrow(lastContentLength_,
lastRequestHeader_->
find(HttpHeader::CONTENT_LENGTH)) ||
const std::string& contentLengthHdr = lastRequestHeader_->
find(HttpHeader::CONTENT_LENGTH);
if(!contentLengthHdr.empty()) {
if(!util::parseLLIntNoThrow(lastContentLength_, contentLengthHdr) ||
lastContentLength_ < 0) {
throw DL_ABORT_EX(fmt("Invalid Content-Length=%s",
lastRequestHeader_->
find(HttpHeader::CONTENT_LENGTH).c_str()));
contentLengthHdr.c_str()));
}
} else {
lastContentLength_ = 0;
}
headerProcessor_->clear();
@ -386,4 +389,14 @@ bool HttpServer::supportsPersistentConnection() const
lastRequestHeader_ && lastRequestHeader_->isKeepAlive();
}
bool HttpServer::wantRead() const
{
return socket_->wantRead();
}
bool HttpServer::wantWrite() const
{
return socket_->wantWrite();
}
} // namespace aria2

View File

@ -82,6 +82,7 @@ private:
std::string password_;
bool acceptsGZip_;
std::string allowOrigin_;
bool secure_;
public:
HttpServer(const SharedHandle<SocketCore>& socket, DownloadEngine* e);
@ -178,6 +179,19 @@ public:
{
return lastRequestHeader_;
}
void setSecure(bool f)
{
secure_ = f;
}
bool getSecure() const
{
return secure_;
}
bool wantRead() const;
bool wantWrite() const;
};
} // namespace aria2

View File

@ -74,7 +74,8 @@ HttpServerBodyCommand::HttpServerBodyCommand
: Command(cuid),
e_(e),
socket_(socket),
httpServer_(httpServer)
httpServer_(httpServer),
writeCheck_(false)
{
// To handle Content-Length == 0 case
setStatus(Command::STATUS_ONESHOT_REALTIME);
@ -87,6 +88,9 @@ HttpServerBodyCommand::HttpServerBodyCommand
HttpServerBodyCommand::~HttpServerBodyCommand()
{
e_->deleteSocketForReadCheck(socket_, this);
if(writeCheck_) {
e_->deleteSocketForWriteCheck(socket_, this);
}
}
namespace {
@ -144,6 +148,19 @@ void HttpServerBodyCommand::addHttpServerResponseCommand()
e_->setNoWait(true);
}
void HttpServerBodyCommand::updateWriteCheck()
{
if(httpServer_->wantWrite()) {
if(!writeCheck_) {
writeCheck_ = true;
e_->addSocketForWriteCheck(socket_, this);
}
} else if(writeCheck_) {
writeCheck_ = false;
e_->deleteSocketForWriteCheck(socket_, this);
}
}
bool HttpServerBodyCommand::execute()
{
if(e_->getRequestGroupMan()->downloadFinished() || e_->isHaltRequested()) {
@ -151,6 +168,7 @@ bool HttpServerBodyCommand::execute()
}
try {
if(socket_->isReadable(0) ||
(writeCheck_ && socket_->isWritable(0)) ||
!httpServer_->getSocketRecvBuffer()->bufferEmpty() ||
httpServer_->getContentLength() == 0) {
timeoutTimer_ = global::wallclock();
@ -290,6 +308,7 @@ bool HttpServerBodyCommand::execute()
return true;
}
} else {
updateWriteCheck();
e_->addCommand(this);
return false;
}

View File

@ -53,6 +53,8 @@ private:
SharedHandle<SocketCore> socket_;
SharedHandle<HttpServer> httpServer_;
Timer timeoutTimer_;
bool writeCheck_;
void sendJsonRpcErrorResponse
(const std::string& httpStatus,
int code,
@ -66,6 +68,7 @@ private:
(const std::vector<rpc::RpcResponse>& results,
const std::string& callback);
void addHttpServerResponseCommand();
void updateWriteCheck();
public:
HttpServerBodyCommand(cuid_t cuid,
const SharedHandle<HttpServer>& httpServer,

View File

@ -64,14 +64,17 @@ namespace aria2 {
HttpServerCommand::HttpServerCommand
(cuid_t cuid,
DownloadEngine* e,
const SharedHandle<SocketCore>& socket)
const SharedHandle<SocketCore>& socket,
bool secure)
: Command(cuid),
e_(e),
socket_(socket),
httpServer_(new HttpServer(socket, e))
httpServer_(new HttpServer(socket, e)),
writeCheck_(false)
{
setStatus(Command::STATUS_ONESHOT_REALTIME);
e_->addSocketForReadCheck(socket_, this);
httpServer_->setSecure(secure);
httpServer_->setUsernamePassword(e_->getOption()->get(PREF_RPC_USER),
e_->getOption()->get(PREF_RPC_PASSWD));
if(e_->getOption()->getAsBool(PREF_RPC_ALLOW_ORIGIN_ALL)) {
@ -93,7 +96,8 @@ HttpServerCommand::HttpServerCommand
: Command(cuid),
e_(e),
socket_(socket),
httpServer_(httpServer)
httpServer_(httpServer),
writeCheck_(false)
{
e_->addSocketForReadCheck(socket_, this);
checkSocketRecvBuffer();
@ -102,6 +106,9 @@ HttpServerCommand::HttpServerCommand
HttpServerCommand::~HttpServerCommand()
{
e_->deleteSocketForReadCheck(socket_, this);
if(writeCheck_) {
e_->deleteSocketForWriteCheck(socket_, this);
}
}
void HttpServerCommand::checkSocketRecvBuffer()
@ -147,6 +154,19 @@ int websocketHandshake(const SharedHandle<HttpHeader>& header)
#endif // ENABLE_WEBSOCKET
void HttpServerCommand::updateWriteCheck()
{
if(httpServer_->wantWrite()) {
if(!writeCheck_) {
writeCheck_ = true;
e_->addSocketForWriteCheck(socket_, this);
}
} else if(writeCheck_) {
writeCheck_ = false;
e_->deleteSocketForWriteCheck(socket_, this);
}
}
bool HttpServerCommand::execute()
{
if(e_->getRequestGroupMan()->downloadFinished() || e_->isHaltRequested()) {
@ -154,13 +174,24 @@ bool HttpServerCommand::execute()
}
try {
if(socket_->isReadable(0) ||
(writeCheck_ && socket_->isWritable(0)) ||
!httpServer_->getSocketRecvBuffer()->bufferEmpty()) {
timeoutTimer_ = global::wallclock();
if(httpServer_->getSecure()) {
// tlsAccept() just returns true if handshake has already
// finished.
if(!socket_->tlsAccept()) {
updateWriteCheck();
e_->addCommand(this);
return false;
}
}
SharedHandle<HttpHeader> header;
header = httpServer_->receiveRequest();
if(!header) {
updateWriteCheck();
e_->addCommand(this);
return false;
}

View File

@ -51,11 +51,14 @@ private:
SharedHandle<SocketCore> socket_;
SharedHandle<HttpServer> httpServer_;
Timer timeoutTimer_;
bool writeCheck_;
void checkSocketRecvBuffer();
void updateWriteCheck();
public:
HttpServerCommand(cuid_t cuid, DownloadEngine* e,
const SharedHandle<SocketCore>& socket);
const SharedHandle<SocketCore>& socket,
bool secure);
HttpServerCommand(cuid_t cuid,
const SharedHandle<HttpServer>& httpServer,

View File

@ -45,8 +45,9 @@
namespace aria2 {
TLSContext::TLSContext()
TLSContext::TLSContext(TLSSessionSide side)
: certCred_(0),
side_(side),
peerVerificationEnabled_(false)
{
int r = gnutls_certificate_allocate_credentials(&certCred_);
@ -79,7 +80,7 @@ bool TLSContext::bad() const
return !good_;
}
bool TLSContext::addClientKeyFile(const std::string& certfile,
bool TLSContext::addCredentialFile(const std::string& certfile,
const std::string& keyfile)
{
int ret = gnutls_certificate_set_x509_key_file(certCred_,
@ -87,11 +88,12 @@ bool TLSContext::addClientKeyFile(const std::string& certfile,
keyfile.c_str(),
GNUTLS_X509_FMT_PEM);
if(ret == GNUTLS_E_SUCCESS) {
A2_LOG_INFO(fmt("Client Key File(cert=%s, key=%s) were successfully added.",
A2_LOG_INFO(fmt
("Credential files(cert=%s, key=%s) were successfully added.",
certfile.c_str(), keyfile.c_str()));
return true;
} else {
A2_LOG_ERROR(fmt("Failed to load client certificate from %s and"
A2_LOG_ERROR(fmt("Failed to load certificate from %s and"
" private key from %s. Cause: %s",
certfile.c_str(), keyfile.c_str(),
gnutls_strerror(ret)));

View File

@ -41,6 +41,7 @@
#include <gnutls/gnutls.h>
#include "TLSContext.h"
#include "DlAbortEx.h"
namespace aria2 {
@ -49,16 +50,18 @@ class TLSContext {
private:
gnutls_certificate_credentials_t certCred_;
TLSSessionSide side_;
bool good_;
bool peerVerificationEnabled_;
public:
TLSContext();
TLSContext(TLSSessionSide side);
~TLSContext();
// private key `keyfile' must be decrypted.
bool addClientKeyFile(const std::string& certfile,
bool addCredentialFile(const std::string& certfile,
const std::string& keyfile);
bool addSystemTrustedCACerts();
@ -72,6 +75,11 @@ public:
gnutls_certificate_credentials_t getCertCred() const;
TLSSessionSide getSide() const
{
return side_;
}
void enablePeerVerification();
void disablePeerVerification();

View File

@ -43,11 +43,12 @@
namespace aria2 {
TLSContext::TLSContext()
TLSContext::TLSContext(TLSSessionSide side)
: sslCtx_(0),
side_(side),
peerVerificationEnabled_(false)
{
sslCtx_ = SSL_CTX_new(SSLv23_client_method());
sslCtx_ = SSL_CTX_new(SSLv23_method());
if(sslCtx_) {
good_ = true;
} else {
@ -55,15 +56,15 @@ TLSContext::TLSContext()
A2_LOG_ERROR(fmt("SSL_CTX_new() failed. Cause: %s",
ERR_error_string(ERR_get_error(), 0)));
}
/* Disable SSLv2 and enable all workarounds for buggy servers */
// Disable SSLv2 and enable all workarounds for buggy servers
SSL_CTX_set_options(sslCtx_, SSL_OP_ALL|SSL_OP_NO_SSLv2|
SSL_OP_NO_COMPRESSION);
SSL_CTX_set_mode(sslCtx_, SSL_MODE_AUTO_RETRY);
SSL_CTX_set_mode(sslCtx_, SSL_MODE_ENABLE_PARTIAL_WRITE);
#ifdef SSL_MODE_RELEASE_BUFFERS
/* keep memory usage low */
SSL_CTX_set_mode(sslCtx_, SSL_MODE_RELEASE_BUFFERS);
#endif
}
TLSContext::~TLSContext()
@ -81,23 +82,23 @@ bool TLSContext::bad() const
return !good_;
}
bool TLSContext::addClientKeyFile(const std::string& certfile,
bool TLSContext::addCredentialFile(const std::string& certfile,
const std::string& keyfile)
{
if(SSL_CTX_use_PrivateKey_file(sslCtx_, keyfile.c_str(),
SSL_FILETYPE_PEM) != 1) {
A2_LOG_ERROR(fmt("Failed to load client private key from %s. Cause: %s",
A2_LOG_ERROR(fmt("Failed to load private key from %s. Cause: %s",
keyfile.c_str(),
ERR_error_string(ERR_get_error(), 0)));
return false;
}
if(SSL_CTX_use_certificate_chain_file(sslCtx_, certfile.c_str()) != 1) {
A2_LOG_ERROR(fmt("Failed to load client certificate from %s. Cause: %s",
A2_LOG_ERROR(fmt("Failed to load certificate from %s. Cause: %s",
certfile.c_str(),
ERR_error_string(ERR_get_error(), 0)));
return false;
}
A2_LOG_INFO(fmt("Client Key File(cert=%s, key=%s) were successfully added.",
A2_LOG_INFO(fmt("Credential files(cert=%s, key=%s) were successfully added.",
certfile.c_str(),
keyfile.c_str()));
return true;

View File

@ -41,6 +41,7 @@
# include <openssl/ssl.h>
#include "TLSContext.h"
#include "DlAbortEx.h"
namespace aria2 {
@ -49,16 +50,18 @@ class TLSContext {
private:
SSL_CTX* sslCtx_;
TLSSessionSide side_;
bool good_;
bool peerVerificationEnabled_;
public:
TLSContext();
TLSContext(TLSSessionSide side);
~TLSContext();
// private key `keyfile' must be decrypted.
bool addClientKeyFile(const std::string& certfile,
bool addCredentialFile(const std::string& certfile,
const std::string& keyfile);
bool addSystemTrustedCACerts();
@ -75,6 +78,11 @@ public:
return sslCtx_;
}
TLSSessionSide getSide() const
{
return side_;
}
void enablePeerVerification();
void disablePeerVerification();

View File

@ -137,6 +137,24 @@ error_code::Value MultiUrlRequestInfo::execute()
Notifier notifier(wsSessionMan);
SingletonHolder<Notifier>::instance(&notifier);
#ifdef ENABLE_SSL
if(option_->getAsBool(PREF_ENABLE_RPC) &&
option_->getAsBool(PREF_RPC_SECURE)) {
if(!option_->blank(PREF_RPC_CERTIFICATE) &&
!option_->blank(PREF_RPC_PRIVATE_KEY)) {
// We set server TLS context to the SocketCore before creating
// DownloadEngine instance.
SharedHandle<TLSContext> svTlsContext(new TLSContext(TLS_SERVER));
svTlsContext->addCredentialFile(option_->get(PREF_RPC_CERTIFICATE),
option_->get(PREF_RPC_PRIVATE_KEY));
SocketCore::setServerTLSContext(svTlsContext);
} else {
throw DL_ABORT_EX("Specify --rpc-certificate and --rpc-private-key "
"options in order to use secure RPC.");
}
}
#endif // ENABLE_SSL
SharedHandle<DownloadEngine> e =
DownloadEngineFactory().newDownloadEngine(option_.get(), requestGroups_);
@ -173,26 +191,27 @@ error_code::Value MultiUrlRequestInfo::execute()
e->setAuthConfigFactory(authConfigFactory);
#ifdef ENABLE_SSL
SharedHandle<TLSContext> tlsContext(new TLSContext());
SharedHandle<TLSContext> clTlsContext(new TLSContext(TLS_CLIENT));
if(!option_->blank(PREF_CERTIFICATE) &&
!option_->blank(PREF_PRIVATE_KEY)) {
tlsContext->addClientKeyFile(option_->get(PREF_CERTIFICATE),
clTlsContext->addCredentialFile(option_->get(PREF_CERTIFICATE),
option_->get(PREF_PRIVATE_KEY));
}
if(!option_->blank(PREF_CA_CERTIFICATE)) {
if(!tlsContext->addTrustedCACertFile(option_->get(PREF_CA_CERTIFICATE))) {
if(!clTlsContext->addTrustedCACertFile
(option_->get(PREF_CA_CERTIFICATE))) {
A2_LOG_INFO(MSG_WARN_NO_CA_CERT);
}
} else if(option_->getAsBool(PREF_CHECK_CERTIFICATE)) {
if(!tlsContext->addSystemTrustedCACerts()) {
if(!clTlsContext->addSystemTrustedCACerts()) {
A2_LOG_INFO(MSG_WARN_NO_CA_CERT);
}
}
if(option_->getAsBool(PREF_CHECK_CERTIFICATE)) {
tlsContext->enablePeerVerification();
clTlsContext->enablePeerVerification();
}
SocketCore::setTLSContext(tlsContext);
SocketCore::setClientTLSContext(clTlsContext);
#endif
#ifdef HAVE_ARES_ADDR_NODE
ares_addr_node* asyncDNSServers =

View File

@ -747,6 +747,15 @@ std::vector<OptionHandler*> OptionHandlerFactory::createOptionHandlers()
op->addTag(TAG_RPC);
handlers.push_back(op);
}
{
OptionHandler* op(new LocalFilePathOptionHandler
(PREF_RPC_CERTIFICATE,
TEXT_RPC_CERTIFICATE,
NO_DEFAULT_VALUE,
false));
op->addTag(TAG_RPC);
handlers.push_back(op);
}
{
OptionHandler* op(new BooleanOptionHandler
(PREF_RPC_LISTEN_ALL,
@ -774,6 +783,24 @@ std::vector<OptionHandler*> OptionHandlerFactory::createOptionHandlers()
op->addTag(TAG_RPC);
handlers.push_back(op);
}
{
OptionHandler* op(new LocalFilePathOptionHandler
(PREF_RPC_PRIVATE_KEY,
TEXT_RPC_PRIVATE_KEY,
NO_DEFAULT_VALUE,
false));
op->addTag(TAG_RPC);
handlers.push_back(op);
}
{
OptionHandler* op(new BooleanOptionHandler
(PREF_RPC_SECURE,
TEXT_RPC_SECURE,
A2_V_FALSE,
OptionHandler::OPT_ARG));
op->addTag(TAG_RPC);
handlers.push_back(op);
}
{
OptionHandler* op(new DefaultOptionHandler
(PREF_RPC_USER,

View File

@ -125,8 +125,6 @@ namespace {
enum TlsState {
// TLS object is not initialized.
A2_TLS_NONE = 0,
// TLS object is initialized. Ready for handshake.
A2_TLS_INITIALIZED = 1,
// TLS object is now handshaking.
A2_TLS_HANDSHAKING = 2,
// TLS object is now connected.
@ -140,11 +138,19 @@ std::vector<std::pair<sockaddr_union, socklen_t> >
SocketCore::bindAddrs_;
#ifdef ENABLE_SSL
SharedHandle<TLSContext> SocketCore::tlsContext_;
SharedHandle<TLSContext> SocketCore::clTlsContext_;
SharedHandle<TLSContext> SocketCore::svTlsContext_;
void SocketCore::setTLSContext(const SharedHandle<TLSContext>& tlsContext)
void SocketCore::setClientTLSContext
(const SharedHandle<TLSContext>& tlsContext)
{
tlsContext_ = tlsContext;
clTlsContext_ = tlsContext;
}
void SocketCore::setServerTLSContext
(const SharedHandle<TLSContext>& tlsContext)
{
svTlsContext_ = tlsContext;
}
#endif // ENABLE_SSL
@ -818,12 +824,24 @@ void SocketCore::readData(char* data, size_t& len)
len = ret;
}
void SocketCore::prepareSecureConnection()
bool SocketCore::tlsAccept()
{
if(!secure_) {
return tlsHandshake(svTlsContext_.get(), A2STR::NIL);
}
bool SocketCore::tlsConnect(const std::string& hostname)
{
return tlsHandshake(clTlsContext_.get(), hostname);
}
bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
{
wantRead_ = false;
wantWrite_ = false;
#ifdef HAVE_OPENSSL
// for SSL
ssl = SSL_new(tlsContext_->getSSLCtx());
switch(secure_) {
case A2_TLS_NONE:
ssl = SSL_new(tlsctx->getSSLCtx());
if(!ssl) {
throw DL_ABORT_EX
(fmt(EX_SSL_INIT_FAILURE, ERR_error_string(ERR_get_error(), 0)));
@ -832,48 +850,25 @@ void SocketCore::prepareSecureConnection()
throw DL_ABORT_EX
(fmt(EX_SSL_INIT_FAILURE, ERR_error_string(ERR_get_error(), 0)));
}
#endif // HAVE_OPENSSL
#ifdef HAVE_LIBGNUTLS
int r;
gnutls_init(&sslSession_, GNUTLS_CLIENT);
// It seems err is not error message, but the argument string
// which causes syntax error.
const char* err;
// Disables TLS1.1 here because there are servers that don't
// understand TLS1.1.
r = gnutls_priority_set_direct(sslSession_, "NORMAL:!VERS-TLS1.1", &err);
if(r != GNUTLS_E_SUCCESS) {
throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, gnutls_strerror(r)));
}
// put the x509 credentials to the current session
gnutls_credentials_set(sslSession_, GNUTLS_CRD_CERTIFICATE,
tlsContext_->getCertCred());
gnutls_transport_set_ptr(sslSession_, (gnutls_transport_ptr_t)sockfd_);
#endif // HAVE_LIBGNUTLS
secure_ = A2_TLS_INITIALIZED;
}
}
bool SocketCore::initiateSecureConnection(const std::string& hostname)
{
wantRead_ = false;
wantWrite_ = false;
#ifdef HAVE_OPENSSL
switch(secure_) {
case A2_TLS_INITIALIZED:
secure_ = A2_TLS_HANDSHAKING;
// Fall through
#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
if(!util::isNumericHost(hostname)) {
if(tlsctx->getSide() == TLS_CLIENT && !util::isNumericHost(hostname)) {
// TLS extensions: SNI. There is not documentation about the
// return code for this function (actually this is macro
// wrapping SSL_ctrl at the time of this writing).
SSL_set_tlsext_host_name(ssl, hostname.c_str());
}
#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME
secure_ = A2_TLS_HANDSHAKING;
// Fall through
case A2_TLS_HANDSHAKING: {
ERR_clear_error();
int e = SSL_connect(ssl);
int e;
if(tlsctx->getSide() == TLS_CLIENT) {
e = SSL_connect(ssl);
} else {
e = SSL_accept(ssl);
}
if (e <= 0) {
int ssl_error = SSL_get_error(ssl, e);
@ -893,9 +888,21 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname)
}
break;
case SSL_ERROR_SYSCALL:
case SSL_ERROR_SYSCALL: {
int sslErr = ERR_get_error();
if(sslErr == 0) {
if(e == 0) {
throw DL_ABORT_EX("Got EOF in SSL handshake");
} else if(e == -1) {
throw DL_ABORT_EX(fmt("SSL I/O error: %s", strerror(errno)));
} else {
throw DL_ABORT_EX(EX_SSL_IO_ERROR);
}
} else {
throw DL_ABORT_EX(fmt("SSL I/O error: %s",
ERR_error_string(sslErr, 0)));
}
}
case SSL_ERROR_SSL:
throw DL_ABORT_EX(EX_SSL_PROTOCOL_ERROR);
@ -903,7 +910,8 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname)
throw DL_ABORT_EX(fmt(EX_SSL_UNKNOWN_ERROR, ssl_error));
}
}
if(tlsContext_->peerVerificationEnabled()) {
if(tlsctx->getSide() == TLS_CLIENT &&
tlsctx->peerVerificationEnabled()) {
// verify peer
X509* peerCert = SSL_get_peer_certificate(ssl);
if(!peerCert) {
@ -984,8 +992,29 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname)
#endif // HAVE_OPENSSL
#ifdef HAVE_LIBGNUTLS
switch(secure_) {
case A2_TLS_INITIALIZED:
secure_ = A2_TLS_HANDSHAKING;
case A2_TLS_NONE:
int r;
gnutls_init(&sslSession_,
tlsctx->getSide() == TLS_CLIENT ?
GNUTLS_CLIENT : GNUTLS_SERVER);
// It seems err is not error message, but the argument string
// which causes syntax error.
const char* err;
// For client side, disables TLS1.1 here because there are servers
// that don't understand TLS1.1. TODO Is this still necessary?
r = gnutls_priority_set_direct(sslSession_,
tlsctx->getSide() == TLS_CLIENT ?
"NORMAL:-VERS-TLS1.1" :
"NORMAL",
&err);
if(r != GNUTLS_E_SUCCESS) {
throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, gnutls_strerror(r)));
}
// put the x509 credentials to the current session
gnutls_credentials_set(sslSession_, GNUTLS_CRD_CERTIFICATE,
tlsctx->getCertCred());
gnutls_transport_set_ptr(sslSession_, (gnutls_transport_ptr_t)sockfd_);
if(tlsctx->getSide() == TLS_CLIENT) {
// Check hostname is not numeric and it includes ".". Setting
// "localhost" will produce TLS alert.
if(!util::isNumericHost(hostname) &&
@ -994,10 +1023,13 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname)
int ret = gnutls_server_name_set(sslSession_, GNUTLS_NAME_DNS,
hostname.c_str(), hostname.size());
if(ret < 0) {
A2_LOG_WARN(fmt("Setting hostname in SNI extension failed. Cause: %s",
A2_LOG_WARN(fmt
("Setting hostname in SNI extension failed. Cause: %s",
gnutls_strerror(ret)));
}
}
}
secure_ = A2_TLS_HANDSHAKING;
// Fall through
case A2_TLS_HANDSHAKING: {
int ret = gnutls_handshake(sslSession_);
@ -1008,7 +1040,7 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname)
throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, gnutls_strerror(ret)));
}
if(tlsContext_->peerVerificationEnabled()) {
if(tlsctx->getSide() == TLS_CLIENT && tlsctx->peerVerificationEnabled()) {
// verify peer
unsigned int status;
ret = gnutls_certificate_verify_peers2(sslSession_, &status);

View File

@ -85,7 +85,10 @@ private:
bool wantWrite_;
#if ENABLE_SSL
static SharedHandle<TLSContext> tlsContext_;
// TLS context for client side
static SharedHandle<TLSContext> clTlsContext_;
// TLS context for server side
static SharedHandle<TLSContext> svTlsContext_;
#endif // ENABLE_SSL
#ifdef HAVE_OPENSSL
@ -106,6 +109,15 @@ private:
void setSockOpt(int level, int optname, void* optval, socklen_t optlen);
/**
* Makes this socket secure.
* If the system has not OpenSSL, then this method do nothing.
* connection must be established before calling this method.
*
* If you are going to verify peer's certificate, hostname must be supplied.
*/
bool tlsHandshake(TLSContext* tlsctx, const std::string& hostname);
SocketCore(sock_t sockfd, int sockType);
public:
SocketCore(int sockType = SOCK_STREAM);
@ -293,16 +305,16 @@ public:
return readDataFrom(reinterpret_cast<char*>(data), len, sender);
}
/**
* Makes this socket secure.
* If the system has not OpenSSL, then this method do nothing.
* connection must be established before calling this method.
*
* If you are going to verify peer's certificate, hostname must be supplied.
*/
bool initiateSecureConnection(const std::string& hostname="");
// Performs TLS server side handshake. If handshake is completed,
// returns true. If handshake has not been done yet, returns false.
bool tlsAccept();
void prepareSecureConnection();
// Performs TLS client side handshake. If handshake is completed,
// returns true. If handshake has not been done yet, returns false.
//
// If you are going to verify peer's certificate, hostname must be
// supplied.
bool tlsConnect(const std::string& hostname);
bool operator==(const SocketCore& s) {
return sockfd_ == s.sockfd_;
@ -332,7 +344,8 @@ public:
bool wantWrite() const;
#ifdef ENABLE_SSL
static void setTLSContext(const SharedHandle<TLSContext>& tlsContext);
static void setClientTLSContext(const SharedHandle<TLSContext>& tlsContext);
static void setServerTLSContext(const SharedHandle<TLSContext>& tlsContext);
#endif // ENABLE_SSL
static void setProtocolFamily(int protocolFamily)

View File

@ -37,6 +37,15 @@
#include "common.h"
namespace aria2 {
enum TLSSessionSide {
TLS_CLIENT,
TLS_SERVER
};
} // namespace aria2
#ifdef HAVE_OPENSSL
# include "LibsslTLSContext.h"
#elif HAVE_LIBGNUTLS

View File

@ -73,7 +73,7 @@ WebSocketInteractionCommand::~WebSocketInteractionCommand()
void WebSocketInteractionCommand::updateWriteCheck()
{
if(wsSession_->wantWrite()) {
if(socket_->wantWrite() || wsSession_->wantWrite()) {
if(!writeCheck_) {
writeCheck_ = true;
e_->addSocketForWriteCheck(socket_, this);
@ -91,7 +91,8 @@ bool WebSocketInteractionCommand::execute()
}
if(wsSession_->onReadEvent() == -1 || wsSession_->onWriteEvent() == -1) {
if(wsSession_->closeSent() || wsSession_->closeReceived()) {
A2_LOG_INFO(fmt("CUID#%" PRId64 " - WebSocket session terminated.", getCuid()));
A2_LOG_INFO(fmt("CUID#%"PRId64" - WebSocket session terminated.",
getCuid()));
} else {
A2_LOG_INFO(fmt("CUID#%"PRId64" - WebSocket session terminated"
" (Possibly due to EOF).", getCuid()));

View File

@ -270,6 +270,12 @@ const Pref* PREF_RPC_MAX_REQUEST_SIZE = makePref("rpc-max-request-size");
const Pref* PREF_RPC_LISTEN_ALL = makePref("rpc-listen-all");
// value: true | false
const Pref* PREF_RPC_ALLOW_ORIGIN_ALL = makePref("rpc-allow-origin-all");
// value: string that your file system recognizes as a file name.
const Pref* PREF_RPC_CERTIFICATE = makePref("rpc-certificate");
// value: string that your file system recognizes as a file name.
const Pref* PREF_RPC_PRIVATE_KEY = makePref("rpc-private-key");
// value: true | false
const Pref* PREF_RPC_SECURE = makePref("rpc-secure");
// value: true | false
const Pref* PREF_DRY_RUN = makePref("dry-run");
// value: true | false

View File

@ -213,6 +213,12 @@ extern const Pref* PREF_RPC_MAX_REQUEST_SIZE;
extern const Pref* PREF_RPC_LISTEN_ALL;
// value: true | false
extern const Pref* PREF_RPC_ALLOW_ORIGIN_ALL;
// value: string that your file system recognizes as a file name.
extern const Pref* PREF_RPC_CERTIFICATE;
// value: string that your file system recognizes as a file name.
extern const Pref* PREF_RPC_PRIVATE_KEY;
// value: true | false
extern const Pref* PREF_RPC_SECURE;
// value: true | false
extern const Pref* PREF_DRY_RUN;
// value: true | false

View File

@ -880,3 +880,21 @@
" your disk.")
#define TEXT_ENABLE_MMAP \
_(" --enable-mmap[=true|false] Map files into memory.")
#define TEXT_RPC_CERTIFICATE \
_(" --rpc-certificate=FILE Use the certificate in FILE for RPC server.\n" \
" The certificate must be in PEM format.\n" \
" Use --rpc-private-key option to specify the\n" \
" private key. Use --rpc-secure option to enable\n" \
" encryption.")
#define TEXT_RPC_PRIVATE_KEY \
_(" --rpc-private-key=FILE Use the private key in FILE for RPC server.\n" \
" The private key must be decrypted and in PEM\n" \
" format. Use --rpc-secure option to enable\n" \
" encryption. See also --rpc-certificate option.")
#define TEXT_RPC_SECURE \
_(" --rpc-secure[=true|false] RPC transport will be encrypted by SSL/TLS.\n" \
" The RPC clients must use https scheme to access\n" \
" the server. For WebSocket client, use wss\n" \
" scheme. Use --rpc-certificate and\n" \
" --rpc-private-key options to specify the\n" \
" server certificate and private key.")