From 90515dfa50c676df37f440893de6e344fbfd2cdc Mon Sep 17 00:00:00 2001 From: Tatsuhiro Tsujikawa Date: Sun, 30 Sep 2012 00:50:14 +0900 Subject: [PATCH] RPC over SSL/TLS transport To enable RPC over SSL/TLS, specify server certificate and private key using --rpc-certificate and --rpc-private-key options and enable --rpc-secure option. After the encryption is enabled, use https and wss scheme to access RPC server. --- src/AbstractHttpServerResponseCommand.cc | 49 ++++++-- src/AbstractHttpServerResponseCommand.h | 8 +- src/DownloadEngineFactory.cc | 7 +- src/HttpListenCommand.cc | 8 +- src/HttpListenCommand.h | 7 +- src/HttpRequestCommand.cc | 3 +- src/HttpServer.cc | 27 +++-- src/HttpServer.h | 14 +++ src/HttpServerBodyCommand.cc | 23 +++- src/HttpServerBodyCommand.h | 7 +- src/HttpServerCommand.cc | 41 ++++++- src/HttpServerCommand.h | 9 +- src/LibgnutlsTLSContext.cc | 14 ++- src/LibgnutlsTLSContext.h | 14 ++- src/LibsslTLSContext.cc | 19 +-- src/LibsslTLSContext.h | 16 ++- src/MultiUrlRequestInfo.cc | 37 ++++-- src/OptionHandlerFactory.cc | 27 +++++ src/SocketCore.cc | 148 ++++++++++++++--------- src/SocketCore.h | 37 ++++-- src/TLSContext.h | 9 ++ src/WebSocketInteractionCommand.cc | 7 +- src/prefs.cc | 6 + src/prefs.h | 6 + src/usage_text.h | 18 +++ 25 files changed, 419 insertions(+), 142 deletions(-) diff --git a/src/AbstractHttpServerResponseCommand.cc b/src/AbstractHttpServerResponseCommand.cc index 7b89a254..5d9d8c11 100644 --- a/src/AbstractHttpServerResponseCommand.cc +++ b/src/AbstractHttpServerResponseCommand.cc @@ -55,15 +55,44 @@ AbstractHttpServerResponseCommand::AbstractHttpServerResponseCommand : Command(cuid), e_(e), socket_(socket), - httpServer_(httpServer) + httpServer_(httpServer), + readCheck_(false), + writeCheck_(true) { - setStatus(Command::STATUS_ONESHOT_REALTIME); + setStatus(Command::STATUS_ONESHOT_REALTIME); e_->addSocketForWriteCheck(socket_, this); } AbstractHttpServerResponseCommand::~AbstractHttpServerResponseCommand() { - e_->deleteSocketForWriteCheck(socket_, this); + if(readCheck_) { + e_->deleteSocketForReadCheck(socket_, this); + } + if(writeCheck_) { + e_->deleteSocketForWriteCheck(socket_, this); + } +} + +void AbstractHttpServerResponseCommand::updateReadWriteCheck() +{ + if(httpServer_->wantRead()) { + if(!readCheck_) { + readCheck_ = true; + e_->addSocketForReadCheck(socket_, this); + } + } else if(readCheck_) { + readCheck_ = false; + e_->deleteSocketForReadCheck(socket_, this); + } + if(httpServer_->wantWrite()) { + if(!writeCheck_) { + writeCheck_ = true; + e_->addSocketForWriteCheck(socket_, this); + } + } else if(writeCheck_) { + writeCheck_ = false; + e_->deleteSocketForWriteCheck(socket_, this); + } } bool AbstractHttpServerResponseCommand::execute() @@ -72,26 +101,30 @@ bool AbstractHttpServerResponseCommand::execute() return true; } try { - httpServer_->sendResponse(); + ssize_t len = httpServer_->sendResponse(); + if(len > 0) { + timeoutTimer_ = global::wallclock(); + } } catch(RecoverableException& e) { A2_LOG_INFO_EX - (fmt("CUID#%" PRId64 " - Error occurred while transmitting response body.", + (fmt("CUID#%"PRId64" - Error occurred while transmitting response body.", getCuid()), e); return true; } if(httpServer_->sendBufferIsEmpty()) { - A2_LOG_INFO(fmt("CUID#%" PRId64 " - HttpServer: all response transmitted.", + A2_LOG_INFO(fmt("CUID#%"PRId64" - HttpServer: all response transmitted.", getCuid())); afterSend(httpServer_, e_); return true; } else { - if(timeoutTimer_.difference(global::wallclock()) >= 10) { - A2_LOG_INFO(fmt("CUID#%" PRId64 " - HttpServer: Timeout while trasmitting" + if(timeoutTimer_.difference(global::wallclock()) >= 30) { + A2_LOG_INFO(fmt("CUID#%"PRId64" - HttpServer: Timeout while trasmitting" " response.", getCuid())); return true; } else { + updateReadWriteCheck(); e_->addCommand(this); return false; } diff --git a/src/AbstractHttpServerResponseCommand.h b/src/AbstractHttpServerResponseCommand.h index 7dbc71dd..652d6b05 100644 --- a/src/AbstractHttpServerResponseCommand.h +++ b/src/AbstractHttpServerResponseCommand.h @@ -51,6 +51,10 @@ private: SharedHandle socket_; SharedHandle httpServer_; Timer timeoutTimer_; + bool readCheck_; + bool writeCheck_; + + void updateReadWriteCheck(); protected: DownloadEngine* getDownloadEngine() { @@ -66,10 +70,10 @@ public: const SharedHandle& socket); virtual ~AbstractHttpServerResponseCommand(); - + virtual bool execute(); }; -} // namespace aria2 +} // namespace aria2 #endif // D_ABSTRACT_HTTP_SERVER_RESPONSE_COMMAND_H diff --git a/src/DownloadEngineFactory.cc b/src/DownloadEngineFactory.cc index 7cdb8e3e..0aba1494 100644 --- a/src/DownloadEngineFactory.cc +++ b/src/DownloadEngineFactory.cc @@ -74,6 +74,7 @@ #include "DlAbortEx.h" #include "FileAllocationEntry.h" #include "HttpListenCommand.h" +#include "LogFactory.h" namespace aria2 { @@ -170,11 +171,15 @@ DownloadEngineFactory::newDownloadEngine } if(op->getAsBool(PREF_ENABLE_RPC)) { bool ok = false; + bool secure = op->getAsBool(PREF_RPC_SECURE); + if(secure) { + A2_LOG_NOTICE("RPC transport will be encrypted."); + } static int families[] = { AF_INET, AF_INET6 }; size_t familiesLength = op->getAsBool(PREF_DISABLE_IPV6)?1:2; for(size_t i = 0; i < familiesLength; ++i) { HttpListenCommand* httpListenCommand = - new HttpListenCommand(e->newCUID(), e.get(), families[i]); + new HttpListenCommand(e->newCUID(), e.get(), families[i], secure); if(httpListenCommand->bindPort(op->getAsInt(PREF_RPC_LISTEN_PORT))){ e->addCommand(httpListenCommand); ok = true; diff --git a/src/HttpListenCommand.cc b/src/HttpListenCommand.cc index be650a1d..32c41ae1 100644 --- a/src/HttpListenCommand.cc +++ b/src/HttpListenCommand.cc @@ -50,10 +50,12 @@ namespace aria2 { -HttpListenCommand::HttpListenCommand(cuid_t cuid, DownloadEngine* e, int family) +HttpListenCommand::HttpListenCommand(cuid_t cuid, DownloadEngine* e, + int family, bool secure) : Command(cuid), e_(e), - family_(family) + family_(family), + secure_(secure) {} HttpListenCommand::~HttpListenCommand() @@ -80,7 +82,7 @@ bool HttpListenCommand::execute() peerInfo.first.c_str(), peerInfo.second)); HttpServerCommand* c = - new HttpServerCommand(e_->newCUID(), e_, socket); + new HttpServerCommand(e_->newCUID(), e_, socket, secure_); e_->setNoWait(true); e_->addCommand(c); } diff --git a/src/HttpListenCommand.h b/src/HttpListenCommand.h index 19a21848..c58700e5 100644 --- a/src/HttpListenCommand.h +++ b/src/HttpListenCommand.h @@ -48,16 +48,17 @@ private: DownloadEngine* e_; int family_; SharedHandle serverSocket_; + bool secure_; public: - HttpListenCommand(cuid_t cuid, DownloadEngine* e, int family); + HttpListenCommand(cuid_t cuid, DownloadEngine* e, int family, bool secure); virtual ~HttpListenCommand(); - + virtual bool execute(); bool bindPort(uint16_t port); }; -} // namespace aria2 +} // namespace aria2 #endif // D_HTTP_LISTEN_COMMAND_H diff --git a/src/HttpRequestCommand.cc b/src/HttpRequestCommand.cc index d96ead50..0a3f0375 100644 --- a/src/HttpRequestCommand.cc +++ b/src/HttpRequestCommand.cc @@ -123,8 +123,7 @@ createHttpRequest(const SharedHandle& req, bool HttpRequestCommand::executeInternal() { //socket->setBlockingMode(); if(getRequest()->getProtocol() == "https") { - getSocket()->prepareSecureConnection(); - if(!getSocket()->initiateSecureConnection(getRequest()->getHost())) { + if(!getSocket()->tlsConnect(getRequest()->getHost())) { setReadCheckSocketIf(getSocket(), getSocket()->wantRead()); setWriteCheckSocketIf(getSocket(), getSocket()->wantWrite()); getDownloadEngine()->addCommand(this); diff --git a/src/HttpServer.cc b/src/HttpServer.cc index afb1ab45..bfabb08e 100644 --- a/src/HttpServer.cc +++ b/src/HttpServer.cc @@ -148,13 +148,16 @@ SharedHandle HttpServer::receiveRequest() if(setupResponseRecv() < 0) { A2_LOG_INFO("Request path is invaild. Ignore the request body."); } - if(!util::parseLLIntNoThrow(lastContentLength_, - lastRequestHeader_-> - find(HttpHeader::CONTENT_LENGTH)) || - lastContentLength_ < 0) { - throw DL_ABORT_EX(fmt("Invalid Content-Length=%s", - lastRequestHeader_-> - find(HttpHeader::CONTENT_LENGTH).c_str())); + const std::string& contentLengthHdr = lastRequestHeader_-> + find(HttpHeader::CONTENT_LENGTH); + if(!contentLengthHdr.empty()) { + if(!util::parseLLIntNoThrow(lastContentLength_, contentLengthHdr) || + lastContentLength_ < 0) { + throw DL_ABORT_EX(fmt("Invalid Content-Length=%s", + contentLengthHdr.c_str())); + } + } else { + lastContentLength_ = 0; } headerProcessor_->clear(); @@ -386,4 +389,14 @@ bool HttpServer::supportsPersistentConnection() const lastRequestHeader_ && lastRequestHeader_->isKeepAlive(); } +bool HttpServer::wantRead() const +{ + return socket_->wantRead(); +} + +bool HttpServer::wantWrite() const +{ + return socket_->wantWrite(); +} + } // namespace aria2 diff --git a/src/HttpServer.h b/src/HttpServer.h index c5b01530..602458e2 100644 --- a/src/HttpServer.h +++ b/src/HttpServer.h @@ -82,6 +82,7 @@ private: std::string password_; bool acceptsGZip_; std::string allowOrigin_; + bool secure_; public: HttpServer(const SharedHandle& socket, DownloadEngine* e); @@ -178,6 +179,19 @@ public: { return lastRequestHeader_; } + + void setSecure(bool f) + { + secure_ = f; + } + + bool getSecure() const + { + return secure_; + } + + bool wantRead() const; + bool wantWrite() const; }; } // namespace aria2 diff --git a/src/HttpServerBodyCommand.cc b/src/HttpServerBodyCommand.cc index d986ef4d..23230d00 100644 --- a/src/HttpServerBodyCommand.cc +++ b/src/HttpServerBodyCommand.cc @@ -74,7 +74,8 @@ HttpServerBodyCommand::HttpServerBodyCommand : Command(cuid), e_(e), socket_(socket), - httpServer_(httpServer) + httpServer_(httpServer), + writeCheck_(false) { // To handle Content-Length == 0 case setStatus(Command::STATUS_ONESHOT_REALTIME); @@ -87,6 +88,9 @@ HttpServerBodyCommand::HttpServerBodyCommand HttpServerBodyCommand::~HttpServerBodyCommand() { e_->deleteSocketForReadCheck(socket_, this); + if(writeCheck_) { + e_->deleteSocketForWriteCheck(socket_, this); + } } namespace { @@ -144,6 +148,19 @@ void HttpServerBodyCommand::addHttpServerResponseCommand() e_->setNoWait(true); } +void HttpServerBodyCommand::updateWriteCheck() +{ + if(httpServer_->wantWrite()) { + if(!writeCheck_) { + writeCheck_ = true; + e_->addSocketForWriteCheck(socket_, this); + } + } else if(writeCheck_) { + writeCheck_ = false; + e_->deleteSocketForWriteCheck(socket_, this); + } +} + bool HttpServerBodyCommand::execute() { if(e_->getRequestGroupMan()->downloadFinished() || e_->isHaltRequested()) { @@ -151,6 +168,7 @@ bool HttpServerBodyCommand::execute() } try { if(socket_->isReadable(0) || + (writeCheck_ && socket_->isWritable(0)) || !httpServer_->getSocketRecvBuffer()->bufferEmpty() || httpServer_->getContentLength() == 0) { timeoutTimer_ = global::wallclock(); @@ -290,9 +308,10 @@ bool HttpServerBodyCommand::execute() return true; } } else { + updateWriteCheck(); e_->addCommand(this); return false; - } + } } else { if(timeoutTimer_.difference(global::wallclock()) >= 30) { A2_LOG_INFO("HTTP request body timeout."); diff --git a/src/HttpServerBodyCommand.h b/src/HttpServerBodyCommand.h index 30ed7851..46c86919 100644 --- a/src/HttpServerBodyCommand.h +++ b/src/HttpServerBodyCommand.h @@ -53,6 +53,8 @@ private: SharedHandle socket_; SharedHandle httpServer_; Timer timeoutTimer_; + bool writeCheck_; + void sendJsonRpcErrorResponse (const std::string& httpStatus, int code, @@ -66,6 +68,7 @@ private: (const std::vector& results, const std::string& callback); void addHttpServerResponseCommand(); + void updateWriteCheck(); public: HttpServerBodyCommand(cuid_t cuid, const SharedHandle& httpServer, @@ -73,10 +76,10 @@ public: const SharedHandle& socket); virtual ~HttpServerBodyCommand(); - + virtual bool execute(); }; -} // namespace aria2 +} // namespace aria2 #endif // D_HTTP_SERVER_BODY_COMMAND_H diff --git a/src/HttpServerCommand.cc b/src/HttpServerCommand.cc index ca0b9071..195b040b 100644 --- a/src/HttpServerCommand.cc +++ b/src/HttpServerCommand.cc @@ -64,14 +64,17 @@ namespace aria2 { HttpServerCommand::HttpServerCommand (cuid_t cuid, DownloadEngine* e, - const SharedHandle& socket) + const SharedHandle& socket, + bool secure) : Command(cuid), e_(e), socket_(socket), - httpServer_(new HttpServer(socket, e)) + httpServer_(new HttpServer(socket, e)), + writeCheck_(false) { setStatus(Command::STATUS_ONESHOT_REALTIME); e_->addSocketForReadCheck(socket_, this); + httpServer_->setSecure(secure); httpServer_->setUsernamePassword(e_->getOption()->get(PREF_RPC_USER), e_->getOption()->get(PREF_RPC_PASSWD)); if(e_->getOption()->getAsBool(PREF_RPC_ALLOW_ORIGIN_ALL)) { @@ -93,7 +96,8 @@ HttpServerCommand::HttpServerCommand : Command(cuid), e_(e), socket_(socket), - httpServer_(httpServer) + httpServer_(httpServer), + writeCheck_(false) { e_->addSocketForReadCheck(socket_, this); checkSocketRecvBuffer(); @@ -102,6 +106,9 @@ HttpServerCommand::HttpServerCommand HttpServerCommand::~HttpServerCommand() { e_->deleteSocketForReadCheck(socket_, this); + if(writeCheck_) { + e_->deleteSocketForWriteCheck(socket_, this); + } } void HttpServerCommand::checkSocketRecvBuffer() @@ -147,6 +154,19 @@ int websocketHandshake(const SharedHandle& header) #endif // ENABLE_WEBSOCKET +void HttpServerCommand::updateWriteCheck() +{ + if(httpServer_->wantWrite()) { + if(!writeCheck_) { + writeCheck_ = true; + e_->addSocketForWriteCheck(socket_, this); + } + } else if(writeCheck_) { + writeCheck_ = false; + e_->deleteSocketForWriteCheck(socket_, this); + } +} + bool HttpServerCommand::execute() { if(e_->getRequestGroupMan()->downloadFinished() || e_->isHaltRequested()) { @@ -154,13 +174,24 @@ bool HttpServerCommand::execute() } try { if(socket_->isReadable(0) || + (writeCheck_ && socket_->isWritable(0)) || !httpServer_->getSocketRecvBuffer()->bufferEmpty()) { timeoutTimer_ = global::wallclock(); + + if(httpServer_->getSecure()) { + // tlsAccept() just returns true if handshake has already + // finished. + if(!socket_->tlsAccept()) { + updateWriteCheck(); + e_->addCommand(this); + return false; + } + } + SharedHandle header; - header = httpServer_->receiveRequest(); - if(!header) { + updateWriteCheck(); e_->addCommand(this); return false; } diff --git a/src/HttpServerCommand.h b/src/HttpServerCommand.h index e54d5bbc..e8bf4cd0 100644 --- a/src/HttpServerCommand.h +++ b/src/HttpServerCommand.h @@ -51,11 +51,14 @@ private: SharedHandle socket_; SharedHandle httpServer_; Timer timeoutTimer_; + bool writeCheck_; void checkSocketRecvBuffer(); + void updateWriteCheck(); public: HttpServerCommand(cuid_t cuid, DownloadEngine* e, - const SharedHandle& socket); + const SharedHandle& socket, + bool secure); HttpServerCommand(cuid_t cuid, const SharedHandle& httpServer, @@ -63,10 +66,10 @@ public: const SharedHandle& socket); virtual ~HttpServerCommand(); - + virtual bool execute(); }; -} // namespace aria2 +} // namespace aria2 #endif // D_HTTP_SERVER_COMMAND_H diff --git a/src/LibgnutlsTLSContext.cc b/src/LibgnutlsTLSContext.cc index 9eefb342..9ca505b9 100644 --- a/src/LibgnutlsTLSContext.cc +++ b/src/LibgnutlsTLSContext.cc @@ -45,8 +45,9 @@ namespace aria2 { -TLSContext::TLSContext() +TLSContext::TLSContext(TLSSessionSide side) : certCred_(0), + side_(side), peerVerificationEnabled_(false) { int r = gnutls_certificate_allocate_credentials(&certCred_); @@ -79,19 +80,20 @@ bool TLSContext::bad() const return !good_; } -bool TLSContext::addClientKeyFile(const std::string& certfile, - const std::string& keyfile) +bool TLSContext::addCredentialFile(const std::string& certfile, + const std::string& keyfile) { int ret = gnutls_certificate_set_x509_key_file(certCred_, certfile.c_str(), keyfile.c_str(), GNUTLS_X509_FMT_PEM); if(ret == GNUTLS_E_SUCCESS) { - A2_LOG_INFO(fmt("Client Key File(cert=%s, key=%s) were successfully added.", - certfile.c_str(), keyfile.c_str())); + A2_LOG_INFO(fmt + ("Credential files(cert=%s, key=%s) were successfully added.", + certfile.c_str(), keyfile.c_str())); return true; } else { - A2_LOG_ERROR(fmt("Failed to load client certificate from %s and" + A2_LOG_ERROR(fmt("Failed to load certificate from %s and" " private key from %s. Cause: %s", certfile.c_str(), keyfile.c_str(), gnutls_strerror(ret))); diff --git a/src/LibgnutlsTLSContext.h b/src/LibgnutlsTLSContext.h index 9e0d5fa8..4fc49fbc 100644 --- a/src/LibgnutlsTLSContext.h +++ b/src/LibgnutlsTLSContext.h @@ -41,6 +41,7 @@ #include +#include "TLSContext.h" #include "DlAbortEx.h" namespace aria2 { @@ -49,17 +50,19 @@ class TLSContext { private: gnutls_certificate_credentials_t certCred_; + TLSSessionSide side_; + bool good_; bool peerVerificationEnabled_; public: - TLSContext(); + TLSContext(TLSSessionSide side); ~TLSContext(); // private key `keyfile' must be decrypted. - bool addClientKeyFile(const std::string& certfile, - const std::string& keyfile); + bool addCredentialFile(const std::string& certfile, + const std::string& keyfile); bool addSystemTrustedCACerts(); @@ -72,6 +75,11 @@ public: gnutls_certificate_credentials_t getCertCred() const; + TLSSessionSide getSide() const + { + return side_; + } + void enablePeerVerification(); void disablePeerVerification(); diff --git a/src/LibsslTLSContext.cc b/src/LibsslTLSContext.cc index 8feaf075..1561e4a5 100644 --- a/src/LibsslTLSContext.cc +++ b/src/LibsslTLSContext.cc @@ -43,11 +43,12 @@ namespace aria2 { -TLSContext::TLSContext() +TLSContext::TLSContext(TLSSessionSide side) : sslCtx_(0), + side_(side), peerVerificationEnabled_(false) { - sslCtx_ = SSL_CTX_new(SSLv23_client_method()); + sslCtx_ = SSL_CTX_new(SSLv23_method()); if(sslCtx_) { good_ = true; } else { @@ -55,15 +56,15 @@ TLSContext::TLSContext() A2_LOG_ERROR(fmt("SSL_CTX_new() failed. Cause: %s", ERR_error_string(ERR_get_error(), 0))); } - /* Disable SSLv2 and enable all workarounds for buggy servers */ + // Disable SSLv2 and enable all workarounds for buggy servers SSL_CTX_set_options(sslCtx_, SSL_OP_ALL|SSL_OP_NO_SSLv2| SSL_OP_NO_COMPRESSION); SSL_CTX_set_mode(sslCtx_, SSL_MODE_AUTO_RETRY); + SSL_CTX_set_mode(sslCtx_, SSL_MODE_ENABLE_PARTIAL_WRITE); #ifdef SSL_MODE_RELEASE_BUFFERS /* keep memory usage low */ SSL_CTX_set_mode(sslCtx_, SSL_MODE_RELEASE_BUFFERS); #endif - } TLSContext::~TLSContext() @@ -81,23 +82,23 @@ bool TLSContext::bad() const return !good_; } -bool TLSContext::addClientKeyFile(const std::string& certfile, - const std::string& keyfile) +bool TLSContext::addCredentialFile(const std::string& certfile, + const std::string& keyfile) { if(SSL_CTX_use_PrivateKey_file(sslCtx_, keyfile.c_str(), SSL_FILETYPE_PEM) != 1) { - A2_LOG_ERROR(fmt("Failed to load client private key from %s. Cause: %s", + A2_LOG_ERROR(fmt("Failed to load private key from %s. Cause: %s", keyfile.c_str(), ERR_error_string(ERR_get_error(), 0))); return false; } if(SSL_CTX_use_certificate_chain_file(sslCtx_, certfile.c_str()) != 1) { - A2_LOG_ERROR(fmt("Failed to load client certificate from %s. Cause: %s", + A2_LOG_ERROR(fmt("Failed to load certificate from %s. Cause: %s", certfile.c_str(), ERR_error_string(ERR_get_error(), 0))); return false; } - A2_LOG_INFO(fmt("Client Key File(cert=%s, key=%s) were successfully added.", + A2_LOG_INFO(fmt("Credential files(cert=%s, key=%s) were successfully added.", certfile.c_str(), keyfile.c_str())); return true; diff --git a/src/LibsslTLSContext.h b/src/LibsslTLSContext.h index 98aedbd9..038d8990 100644 --- a/src/LibsslTLSContext.h +++ b/src/LibsslTLSContext.h @@ -41,6 +41,7 @@ # include +#include "TLSContext.h" #include "DlAbortEx.h" namespace aria2 { @@ -49,17 +50,19 @@ class TLSContext { private: SSL_CTX* sslCtx_; + TLSSessionSide side_; + bool good_; bool peerVerificationEnabled_; public: - TLSContext(); + TLSContext(TLSSessionSide side); ~TLSContext(); // private key `keyfile' must be decrypted. - bool addClientKeyFile(const std::string& certfile, - const std::string& keyfile); + bool addCredentialFile(const std::string& certfile, + const std::string& keyfile); bool addSystemTrustedCACerts(); @@ -74,7 +77,12 @@ public: { return sslCtx_; } - + + TLSSessionSide getSide() const + { + return side_; + } + void enablePeerVerification(); void disablePeerVerification(); diff --git a/src/MultiUrlRequestInfo.cc b/src/MultiUrlRequestInfo.cc index d0528c73..284a08a7 100644 --- a/src/MultiUrlRequestInfo.cc +++ b/src/MultiUrlRequestInfo.cc @@ -137,6 +137,24 @@ error_code::Value MultiUrlRequestInfo::execute() Notifier notifier(wsSessionMan); SingletonHolder::instance(¬ifier); +#ifdef ENABLE_SSL + if(option_->getAsBool(PREF_ENABLE_RPC) && + option_->getAsBool(PREF_RPC_SECURE)) { + if(!option_->blank(PREF_RPC_CERTIFICATE) && + !option_->blank(PREF_RPC_PRIVATE_KEY)) { + // We set server TLS context to the SocketCore before creating + // DownloadEngine instance. + SharedHandle svTlsContext(new TLSContext(TLS_SERVER)); + svTlsContext->addCredentialFile(option_->get(PREF_RPC_CERTIFICATE), + option_->get(PREF_RPC_PRIVATE_KEY)); + SocketCore::setServerTLSContext(svTlsContext); + } else { + throw DL_ABORT_EX("Specify --rpc-certificate and --rpc-private-key " + "options in order to use secure RPC."); + } + } +#endif // ENABLE_SSL + SharedHandle e = DownloadEngineFactory().newDownloadEngine(option_.get(), requestGroups_); @@ -173,26 +191,27 @@ error_code::Value MultiUrlRequestInfo::execute() e->setAuthConfigFactory(authConfigFactory); #ifdef ENABLE_SSL - SharedHandle tlsContext(new TLSContext()); + SharedHandle clTlsContext(new TLSContext(TLS_CLIENT)); if(!option_->blank(PREF_CERTIFICATE) && !option_->blank(PREF_PRIVATE_KEY)) { - tlsContext->addClientKeyFile(option_->get(PREF_CERTIFICATE), - option_->get(PREF_PRIVATE_KEY)); + clTlsContext->addCredentialFile(option_->get(PREF_CERTIFICATE), + option_->get(PREF_PRIVATE_KEY)); } if(!option_->blank(PREF_CA_CERTIFICATE)) { - if(!tlsContext->addTrustedCACertFile(option_->get(PREF_CA_CERTIFICATE))) { + if(!clTlsContext->addTrustedCACertFile + (option_->get(PREF_CA_CERTIFICATE))) { A2_LOG_INFO(MSG_WARN_NO_CA_CERT); } } else if(option_->getAsBool(PREF_CHECK_CERTIFICATE)) { - if(!tlsContext->addSystemTrustedCACerts()) { + if(!clTlsContext->addSystemTrustedCACerts()) { A2_LOG_INFO(MSG_WARN_NO_CA_CERT); } } if(option_->getAsBool(PREF_CHECK_CERTIFICATE)) { - tlsContext->enablePeerVerification(); + clTlsContext->enablePeerVerification(); } - SocketCore::setTLSContext(tlsContext); + SocketCore::setClientTLSContext(clTlsContext); #endif #ifdef HAVE_ARES_ADDR_NODE ares_addr_node* asyncDNSServers = @@ -219,9 +238,9 @@ error_code::Value MultiUrlRequestInfo::execute() #endif // SIGHUP util::setGlobalSignalHandler(SIGINT, handler, 0); util::setGlobalSignalHandler(SIGTERM, handler, 0); - + e->run(); - + if(!option_->blank(PREF_SAVE_COOKIES)) { e->getCookieStorage()->saveNsFormat(option_->get(PREF_SAVE_COOKIES)); } diff --git a/src/OptionHandlerFactory.cc b/src/OptionHandlerFactory.cc index 3d440741..562f947b 100644 --- a/src/OptionHandlerFactory.cc +++ b/src/OptionHandlerFactory.cc @@ -747,6 +747,15 @@ std::vector OptionHandlerFactory::createOptionHandlers() op->addTag(TAG_RPC); handlers.push_back(op); } + { + OptionHandler* op(new LocalFilePathOptionHandler + (PREF_RPC_CERTIFICATE, + TEXT_RPC_CERTIFICATE, + NO_DEFAULT_VALUE, + false)); + op->addTag(TAG_RPC); + handlers.push_back(op); + } { OptionHandler* op(new BooleanOptionHandler (PREF_RPC_LISTEN_ALL, @@ -774,6 +783,24 @@ std::vector OptionHandlerFactory::createOptionHandlers() op->addTag(TAG_RPC); handlers.push_back(op); } + { + OptionHandler* op(new LocalFilePathOptionHandler + (PREF_RPC_PRIVATE_KEY, + TEXT_RPC_PRIVATE_KEY, + NO_DEFAULT_VALUE, + false)); + op->addTag(TAG_RPC); + handlers.push_back(op); + } + { + OptionHandler* op(new BooleanOptionHandler + (PREF_RPC_SECURE, + TEXT_RPC_SECURE, + A2_V_FALSE, + OptionHandler::OPT_ARG)); + op->addTag(TAG_RPC); + handlers.push_back(op); + } { OptionHandler* op(new DefaultOptionHandler (PREF_RPC_USER, diff --git a/src/SocketCore.cc b/src/SocketCore.cc index 969e15b3..658071ff 100644 --- a/src/SocketCore.cc +++ b/src/SocketCore.cc @@ -125,8 +125,6 @@ namespace { enum TlsState { // TLS object is not initialized. A2_TLS_NONE = 0, - // TLS object is initialized. Ready for handshake. - A2_TLS_INITIALIZED = 1, // TLS object is now handshaking. A2_TLS_HANDSHAKING = 2, // TLS object is now connected. @@ -140,11 +138,19 @@ std::vector > SocketCore::bindAddrs_; #ifdef ENABLE_SSL -SharedHandle SocketCore::tlsContext_; +SharedHandle SocketCore::clTlsContext_; +SharedHandle SocketCore::svTlsContext_; -void SocketCore::setTLSContext(const SharedHandle& tlsContext) +void SocketCore::setClientTLSContext +(const SharedHandle& tlsContext) { - tlsContext_ = tlsContext; + clTlsContext_ = tlsContext; +} + +void SocketCore::setServerTLSContext +(const SharedHandle& tlsContext) +{ + svTlsContext_ = tlsContext; } #endif // ENABLE_SSL @@ -818,12 +824,24 @@ void SocketCore::readData(char* data, size_t& len) len = ret; } -void SocketCore::prepareSecureConnection() +bool SocketCore::tlsAccept() { - if(!secure_) { + return tlsHandshake(svTlsContext_.get(), A2STR::NIL); +} + +bool SocketCore::tlsConnect(const std::string& hostname) +{ + return tlsHandshake(clTlsContext_.get(), hostname); +} + +bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname) +{ + wantRead_ = false; + wantWrite_ = false; #ifdef HAVE_OPENSSL - // for SSL - ssl = SSL_new(tlsContext_->getSSLCtx()); + switch(secure_) { + case A2_TLS_NONE: + ssl = SSL_new(tlsctx->getSSLCtx()); if(!ssl) { throw DL_ABORT_EX (fmt(EX_SSL_INIT_FAILURE, ERR_error_string(ERR_get_error(), 0))); @@ -832,48 +850,25 @@ void SocketCore::prepareSecureConnection() throw DL_ABORT_EX (fmt(EX_SSL_INIT_FAILURE, ERR_error_string(ERR_get_error(), 0))); } -#endif // HAVE_OPENSSL -#ifdef HAVE_LIBGNUTLS - int r; - gnutls_init(&sslSession_, GNUTLS_CLIENT); - // It seems err is not error message, but the argument string - // which causes syntax error. - const char* err; - // Disables TLS1.1 here because there are servers that don't - // understand TLS1.1. - r = gnutls_priority_set_direct(sslSession_, "NORMAL:!VERS-TLS1.1", &err); - if(r != GNUTLS_E_SUCCESS) { - throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, gnutls_strerror(r))); - } - // put the x509 credentials to the current session - gnutls_credentials_set(sslSession_, GNUTLS_CRD_CERTIFICATE, - tlsContext_->getCertCred()); - gnutls_transport_set_ptr(sslSession_, (gnutls_transport_ptr_t)sockfd_); -#endif // HAVE_LIBGNUTLS - secure_ = A2_TLS_INITIALIZED; - } -} - -bool SocketCore::initiateSecureConnection(const std::string& hostname) -{ - wantRead_ = false; - wantWrite_ = false; -#ifdef HAVE_OPENSSL - switch(secure_) { - case A2_TLS_INITIALIZED: - secure_ = A2_TLS_HANDSHAKING; + // Fall through #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME - if(!util::isNumericHost(hostname)) { + if(tlsctx->getSide() == TLS_CLIENT && !util::isNumericHost(hostname)) { // TLS extensions: SNI. There is not documentation about the // return code for this function (actually this is macro // wrapping SSL_ctrl at the time of this writing). SSL_set_tlsext_host_name(ssl, hostname.c_str()); } #endif // SSL_CTRL_SET_TLSEXT_HOSTNAME + secure_ = A2_TLS_HANDSHAKING; // Fall through case A2_TLS_HANDSHAKING: { ERR_clear_error(); - int e = SSL_connect(ssl); + int e; + if(tlsctx->getSide() == TLS_CLIENT) { + e = SSL_connect(ssl); + } else { + e = SSL_accept(ssl); + } if (e <= 0) { int ssl_error = SSL_get_error(ssl, e); @@ -893,9 +888,21 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname) } break; - case SSL_ERROR_SYSCALL: - throw DL_ABORT_EX(EX_SSL_IO_ERROR); - + case SSL_ERROR_SYSCALL: { + int sslErr = ERR_get_error(); + if(sslErr == 0) { + if(e == 0) { + throw DL_ABORT_EX("Got EOF in SSL handshake"); + } else if(e == -1) { + throw DL_ABORT_EX(fmt("SSL I/O error: %s", strerror(errno))); + } else { + throw DL_ABORT_EX(EX_SSL_IO_ERROR); + } + } else { + throw DL_ABORT_EX(fmt("SSL I/O error: %s", + ERR_error_string(sslErr, 0))); + } + } case SSL_ERROR_SSL: throw DL_ABORT_EX(EX_SSL_PROTOCOL_ERROR); @@ -903,7 +910,8 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname) throw DL_ABORT_EX(fmt(EX_SSL_UNKNOWN_ERROR, ssl_error)); } } - if(tlsContext_->peerVerificationEnabled()) { + if(tlsctx->getSide() == TLS_CLIENT && + tlsctx->peerVerificationEnabled()) { // verify peer X509* peerCert = SSL_get_peer_certificate(ssl); if(!peerCert) { @@ -984,20 +992,44 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname) #endif // HAVE_OPENSSL #ifdef HAVE_LIBGNUTLS switch(secure_) { - case A2_TLS_INITIALIZED: - secure_ = A2_TLS_HANDSHAKING; - // Check hostname is not numeric and it includes ".". Setting - // "localhost" will produce TLS alert. - if(!util::isNumericHost(hostname) && - hostname.find(".") != std::string::npos) { - // TLS extensions: SNI - int ret = gnutls_server_name_set(sslSession_, GNUTLS_NAME_DNS, - hostname.c_str(), hostname.size()); - if(ret < 0) { - A2_LOG_WARN(fmt("Setting hostname in SNI extension failed. Cause: %s", - gnutls_strerror(ret))); + case A2_TLS_NONE: + int r; + gnutls_init(&sslSession_, + tlsctx->getSide() == TLS_CLIENT ? + GNUTLS_CLIENT : GNUTLS_SERVER); + // It seems err is not error message, but the argument string + // which causes syntax error. + const char* err; + // For client side, disables TLS1.1 here because there are servers + // that don't understand TLS1.1. TODO Is this still necessary? + r = gnutls_priority_set_direct(sslSession_, + tlsctx->getSide() == TLS_CLIENT ? + "NORMAL:-VERS-TLS1.1" : + "NORMAL", + &err); + if(r != GNUTLS_E_SUCCESS) { + throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, gnutls_strerror(r))); + } + // put the x509 credentials to the current session + gnutls_credentials_set(sslSession_, GNUTLS_CRD_CERTIFICATE, + tlsctx->getCertCred()); + gnutls_transport_set_ptr(sslSession_, (gnutls_transport_ptr_t)sockfd_); + if(tlsctx->getSide() == TLS_CLIENT) { + // Check hostname is not numeric and it includes ".". Setting + // "localhost" will produce TLS alert. + if(!util::isNumericHost(hostname) && + hostname.find(".") != std::string::npos) { + // TLS extensions: SNI + int ret = gnutls_server_name_set(sslSession_, GNUTLS_NAME_DNS, + hostname.c_str(), hostname.size()); + if(ret < 0) { + A2_LOG_WARN(fmt + ("Setting hostname in SNI extension failed. Cause: %s", + gnutls_strerror(ret))); + } } } + secure_ = A2_TLS_HANDSHAKING; // Fall through case A2_TLS_HANDSHAKING: { int ret = gnutls_handshake(sslSession_); @@ -1008,7 +1040,7 @@ bool SocketCore::initiateSecureConnection(const std::string& hostname) throw DL_ABORT_EX(fmt(EX_SSL_INIT_FAILURE, gnutls_strerror(ret))); } - if(tlsContext_->peerVerificationEnabled()) { + if(tlsctx->getSide() == TLS_CLIENT && tlsctx->peerVerificationEnabled()) { // verify peer unsigned int status; ret = gnutls_certificate_verify_peers2(sslSession_, &status); diff --git a/src/SocketCore.h b/src/SocketCore.h index bbf2fe39..e33449f8 100644 --- a/src/SocketCore.h +++ b/src/SocketCore.h @@ -85,7 +85,10 @@ private: bool wantWrite_; #if ENABLE_SSL - static SharedHandle tlsContext_; + // TLS context for client side + static SharedHandle clTlsContext_; + // TLS context for server side + static SharedHandle svTlsContext_; #endif // ENABLE_SSL #ifdef HAVE_OPENSSL @@ -106,6 +109,15 @@ private: void setSockOpt(int level, int optname, void* optval, socklen_t optlen); + /** + * Makes this socket secure. + * If the system has not OpenSSL, then this method do nothing. + * connection must be established before calling this method. + * + * If you are going to verify peer's certificate, hostname must be supplied. + */ + bool tlsHandshake(TLSContext* tlsctx, const std::string& hostname); + SocketCore(sock_t sockfd, int sockType); public: SocketCore(int sockType = SOCK_STREAM); @@ -124,7 +136,7 @@ public: void joinMulticastGroup (const std::string& multicastAddr, uint16_t multicastPort, const std::string& localAddr); - + // Enables TCP_NODELAY socket option if f == true. void setTcpNodelay(bool f); @@ -293,16 +305,16 @@ public: return readDataFrom(reinterpret_cast(data), len, sender); } - /** - * Makes this socket secure. - * If the system has not OpenSSL, then this method do nothing. - * connection must be established before calling this method. - * - * If you are going to verify peer's certificate, hostname must be supplied. - */ - bool initiateSecureConnection(const std::string& hostname=""); + // Performs TLS server side handshake. If handshake is completed, + // returns true. If handshake has not been done yet, returns false. + bool tlsAccept(); - void prepareSecureConnection(); + // Performs TLS client side handshake. If handshake is completed, + // returns true. If handshake has not been done yet, returns false. + // + // If you are going to verify peer's certificate, hostname must be + // supplied. + bool tlsConnect(const std::string& hostname); bool operator==(const SocketCore& s) { return sockfd_ == s.sockfd_; @@ -332,7 +344,8 @@ public: bool wantWrite() const; #ifdef ENABLE_SSL - static void setTLSContext(const SharedHandle& tlsContext); + static void setClientTLSContext(const SharedHandle& tlsContext); + static void setServerTLSContext(const SharedHandle& tlsContext); #endif // ENABLE_SSL static void setProtocolFamily(int protocolFamily) diff --git a/src/TLSContext.h b/src/TLSContext.h index f3832b4f..6ac1e727 100644 --- a/src/TLSContext.h +++ b/src/TLSContext.h @@ -37,6 +37,15 @@ #include "common.h" +namespace aria2 { + +enum TLSSessionSide { + TLS_CLIENT, + TLS_SERVER +}; + +} // namespace aria2 + #ifdef HAVE_OPENSSL # include "LibsslTLSContext.h" #elif HAVE_LIBGNUTLS diff --git a/src/WebSocketInteractionCommand.cc b/src/WebSocketInteractionCommand.cc index 4d7f26f8..9dccc3c3 100644 --- a/src/WebSocketInteractionCommand.cc +++ b/src/WebSocketInteractionCommand.cc @@ -73,7 +73,7 @@ WebSocketInteractionCommand::~WebSocketInteractionCommand() void WebSocketInteractionCommand::updateWriteCheck() { - if(wsSession_->wantWrite()) { + if(socket_->wantWrite() || wsSession_->wantWrite()) { if(!writeCheck_) { writeCheck_ = true; e_->addSocketForWriteCheck(socket_, this); @@ -91,9 +91,10 @@ bool WebSocketInteractionCommand::execute() } if(wsSession_->onReadEvent() == -1 || wsSession_->onWriteEvent() == -1) { if(wsSession_->closeSent() || wsSession_->closeReceived()) { - A2_LOG_INFO(fmt("CUID#%" PRId64 " - WebSocket session terminated.", getCuid())); + A2_LOG_INFO(fmt("CUID#%"PRId64" - WebSocket session terminated.", + getCuid())); } else { - A2_LOG_INFO(fmt("CUID#%" PRId64 " - WebSocket session terminated" + A2_LOG_INFO(fmt("CUID#%"PRId64" - WebSocket session terminated" " (Possibly due to EOF).", getCuid())); } return true; diff --git a/src/prefs.cc b/src/prefs.cc index 822a23bf..160cd3c3 100644 --- a/src/prefs.cc +++ b/src/prefs.cc @@ -270,6 +270,12 @@ const Pref* PREF_RPC_MAX_REQUEST_SIZE = makePref("rpc-max-request-size"); const Pref* PREF_RPC_LISTEN_ALL = makePref("rpc-listen-all"); // value: true | false const Pref* PREF_RPC_ALLOW_ORIGIN_ALL = makePref("rpc-allow-origin-all"); +// value: string that your file system recognizes as a file name. +const Pref* PREF_RPC_CERTIFICATE = makePref("rpc-certificate"); +// value: string that your file system recognizes as a file name. +const Pref* PREF_RPC_PRIVATE_KEY = makePref("rpc-private-key"); +// value: true | false +const Pref* PREF_RPC_SECURE = makePref("rpc-secure"); // value: true | false const Pref* PREF_DRY_RUN = makePref("dry-run"); // value: true | false diff --git a/src/prefs.h b/src/prefs.h index 9d9f8462..7339b566 100644 --- a/src/prefs.h +++ b/src/prefs.h @@ -213,6 +213,12 @@ extern const Pref* PREF_RPC_MAX_REQUEST_SIZE; extern const Pref* PREF_RPC_LISTEN_ALL; // value: true | false extern const Pref* PREF_RPC_ALLOW_ORIGIN_ALL; +// value: string that your file system recognizes as a file name. +extern const Pref* PREF_RPC_CERTIFICATE; +// value: string that your file system recognizes as a file name. +extern const Pref* PREF_RPC_PRIVATE_KEY; +// value: true | false +extern const Pref* PREF_RPC_SECURE; // value: true | false extern const Pref* PREF_DRY_RUN; // value: true | false diff --git a/src/usage_text.h b/src/usage_text.h index de3f5a7e..513259df 100644 --- a/src/usage_text.h +++ b/src/usage_text.h @@ -880,3 +880,21 @@ " your disk.") #define TEXT_ENABLE_MMAP \ _(" --enable-mmap[=true|false] Map files into memory.") +#define TEXT_RPC_CERTIFICATE \ + _(" --rpc-certificate=FILE Use the certificate in FILE for RPC server.\n" \ + " The certificate must be in PEM format.\n" \ + " Use --rpc-private-key option to specify the\n" \ + " private key. Use --rpc-secure option to enable\n" \ + " encryption.") +#define TEXT_RPC_PRIVATE_KEY \ + _(" --rpc-private-key=FILE Use the private key in FILE for RPC server.\n" \ + " The private key must be decrypted and in PEM\n" \ + " format. Use --rpc-secure option to enable\n" \ + " encryption. See also --rpc-certificate option.") +#define TEXT_RPC_SECURE \ + _(" --rpc-secure[=true|false] RPC transport will be encrypted by SSL/TLS.\n" \ + " The RPC clients must use https scheme to access\n" \ + " the server. For WebSocket client, use wss\n" \ + " scheme. Use --rpc-certificate and\n" \ + " --rpc-private-key options to specify the\n" \ + " server certificate and private key.")