diff --git a/README.rst b/README.rst index 72e86bd4..4bd4e61c 100644 --- a/README.rst +++ b/README.rst @@ -33,10 +33,10 @@ Features Here is a list of features: * Command-line interface -* Download files through HTTP(S)/FTP/BitTorrent +* Download files through HTTP(S)/FTP/SFTP/BitTorrent * Segmented downloading -* Metalink version 4 (RFC 5854) support(HTTP/FTP/BitTorrent) -* Metalink version 3.0 support(HTTP/FTP/BitTorrent) +* Metalink version 4 (RFC 5854) support(HTTP/FTP/SFTP/BitTorrent) +* Metalink version 3.0 support(HTTP/FTP/SFTP/BitTorrent) * Metalink/HTTP (RFC 6249) support * HTTP/1.1 implementation * HTTP Proxy support @@ -54,7 +54,7 @@ Here is a list of features: * Save Cookies in the Mozilla/Firefox (1.x/2.x)/Netscape format. * Custom HTTP Header support * Persistent Connections support -* FTP through HTTP Proxy +* FTP/SFTP through HTTP Proxy * Download/Upload speed throttling * BitTorrent extensions: Fast extension, DHT, PEX, MSE/PSE, Multi-Tracker, UDP tracker @@ -98,6 +98,7 @@ Dependency features dependency ======================== ======================================== HTTPS OSX or GnuTLS or OpenSSL or Windows +SFTP libssh2 BitTorrent None. Optional: libnettle+libgmp or libgcrypt or OpenSSL (see note) Metalink libxml2 or Expat. @@ -183,6 +184,7 @@ distribution you use): * libgnutls-dev (Required for HTTPS, BitTorrent, Checksum support) * nettle-dev (Required for BitTorrent, Checksum support) * libgmp-dev (Required for BitTorrent) +* libssh2-1-dev (Required for SFTP support) * libc-ares-dev (Required for async DNS support) * libxml2-dev (Required for Metalink support) * zlib1g-dev (Required for gzip, deflate decoding support in HTTP) @@ -454,9 +456,9 @@ Other things should be noted Metalink -------- -The current implementation supports HTTP(S)/FTP/BitTorrent. The other -P2P protocols are ignored. Both Metalink4 (RFC 5854) and Metalink -version 3.0 documents are supported. +The current implementation supports HTTP(S)/FTP/SFTP/BitTorrent. The +other P2P protocols are ignored. Both Metalink4 (RFC 5854) and +Metalink version 3.0 documents are supported. For checksum verification, md5, sha-1, sha-224, sha-256, sha-384 and sha-512 are supported. If multiple hash algorithms are provided, aria2 @@ -502,9 +504,10 @@ which location you prefer, you can use ``--metalink-location`` option. netrc ----- -netrc support is enabled by default for HTTP(S)/FTP. To disable netrc -support, specify -n command-line option. Your .netrc file should have -correct permissions(600). + +netrc support is enabled by default for HTTP(S)/FTP/SFTP. To disable +netrc support, specify -n command-line option. Your .netrc file +should have correct permissions(600). WebSocket --------- diff --git a/configure.ac b/configure.ac index 58b71dad..4461139a 100644 --- a/configure.ac +++ b/configure.ac @@ -58,6 +58,7 @@ ARIA2_ARG_WITHOUT([libcares]) ARIA2_ARG_WITHOUT([libz]) ARIA2_ARG_WITH([tcmalloc]) ARIA2_ARG_WITH([jemalloc]) +ARIA2_ARG_WITHOUT([libssh2]) ARIA2_ARG_DISABLE([ssl]) ARIA2_ARG_DISABLE([bittorrent]) @@ -298,6 +299,20 @@ if test "x$with_sqlite3" = "xyes"; then fi fi +if test "x$with_libssh2" = "xyes"; then + PKG_CHECK_MODULES([LIBSSH2], [libssh2], [have_libssh2=yes], [have_libssh2=no]) + if test "x$have_libssh2" = "xyes"; then + AC_DEFINE([HAVE_LIBSSH2], [1], [Define to 1 if you have libssh2.]) + LIBS="$LIBSSH2_LIBS $LIBS" + CPPFLAGS="$LIBSSH2_CFLAGS $CPPFLAGS" + else + AC_MSG_WARN([$LIBSSH2_PKG_ERRORS]) + if test "x$with_libssh2_requested" = "yes"; then + ARIA2_DEP_NOT_MET([libssh2]) + fi + fi +fi + case "$host" in *darwin*) have_osx="yes" @@ -613,6 +628,9 @@ AM_CONDITIONAL([HAVE_ZLIB], [test "x$have_zlib" = "xyes"]) # Set conditional for sqlite3 AM_CONDITIONAL([HAVE_SQLITE3], [test "x$have_sqlite3" = "xyes"]) +# Set conditional for libssh2 +AM_CONDITIONAL([HAVE_LIBSSH2], [test "x$have_libssh2" = "xyes"]) + AC_SEARCH_LIBS([clock_gettime], [rt]) case "$host" in @@ -1062,6 +1080,7 @@ echo "LibXML2: $have_libxml2" echo "LibExpat: $have_libexpat" echo "LibCares: $have_libcares" echo "Zlib: $have_zlib" +echo "Libssh2: $have_libssh2" echo "Epoll: $have_epoll" echo "Bittorrent: $enable_bittorrent" echo "Metalink: $enable_metalink" diff --git a/src/AbstractCommand.cc b/src/AbstractCommand.cc index 2df0141b..ba1297bc 100644 --- a/src/AbstractCommand.cc +++ b/src/AbstractCommand.cc @@ -677,7 +677,7 @@ std::string getProxyUri(const std::string& protocol, const Option* option) option); } - if (protocol == "ftp") { + if (protocol == "ftp" || protocol == "sftp") { return getProxyOptionFor (PREF_FTP_PROXY, PREF_FTP_PROXY_USER, PREF_FTP_PROXY_PASSWD, option); } @@ -883,7 +883,8 @@ bool AbstractCommand::checkIfConnectionEstablished const std::string& AbstractCommand::resolveProxyMethod(const std::string& protocol) const { - if (getOption()->get(PREF_PROXY_METHOD) == V_TUNNEL || protocol == "https") { + if (getOption()->get(PREF_PROXY_METHOD) == V_TUNNEL || protocol == "https" || + protocol == "sftp") { return V_TUNNEL; } return V_GET; diff --git a/src/AuthConfigFactory.cc b/src/AuthConfigFactory.cc index e8d55b04..2f3c0573 100644 --- a/src/AuthConfigFactory.cc +++ b/src/AuthConfigFactory.cc @@ -87,7 +87,8 @@ AuthConfigFactory::createAuthConfig createHttpAuthResolver(op)->resolveAuthConfig(request->getHost()); } } - } else if(request->getProtocol() == "ftp") { + } else if(request->getProtocol() == "ftp" || + request->getProtocol() == "sftp") { if(!request->getUsername().empty()) { if(request->hasPassword()) { return AuthConfig::create(request->getUsername(), diff --git a/src/DownloadCommand.cc b/src/DownloadCommand.cc index 5be8d551..20337611 100644 --- a/src/DownloadCommand.cc +++ b/src/DownloadCommand.cc @@ -134,6 +134,7 @@ bool DownloadCommand::executeInternal() { || getRequestGroup()->doesDownloadSpeedExceed()) { addCommandSelf(); disableReadCheckSocket(); + disableWriteCheckSocket(); return false; } setReadCheckSocket(getSocket()); @@ -195,7 +196,8 @@ bool DownloadCommand::executeInternal() { // Note that GrowSegment::complete() always returns false. if(sinkFilterOnly_) { if(segment->complete() || - segment->getPositionToWrite() == getFileEntry()->getLastOffset()) { + (getFileEntry()->getLength() != 0 && + segment->getPositionToWrite() == getFileEntry()->getLastOffset())) { segmentPartComplete = true; } else if(segment->getLength() == 0 && eof) { segmentPartComplete = true; @@ -275,13 +277,17 @@ bool DownloadCommand::executeInternal() { return prepareForNextSegment(); } else { checkLowestDownloadSpeed(); - setWriteCheckSocketIf(getSocket(), getSocket()->wantWrite()); + setWriteCheckSocketIf(getSocket(), shouldEnableWriteCheck()); checkSocketRecvBuffer(); addCommandSelf(); return false; } } +bool DownloadCommand::shouldEnableWriteCheck() { + return getSocket()->wantWrite(); +} + void DownloadCommand::checkLowestDownloadSpeed() const { if(lowestDownloadSpeedLimit_ > 0 && diff --git a/src/DownloadCommand.h b/src/DownloadCommand.h index 3efaf5a0..485cdc74 100644 --- a/src/DownloadCommand.h +++ b/src/DownloadCommand.h @@ -76,6 +76,12 @@ protected: // This is file local offset virtual int64_t getRequestEndOffset() const = 0; + + // Returns true if socket should be monitored for writing. The + // default implementation is return the return value of + // getSocket()->wantWrite(). + virtual bool shouldEnableWriteCheck(); + public: DownloadCommand(cuid_t cuid, const std::shared_ptr& req, diff --git a/src/FeatureConfig.cc b/src/FeatureConfig.cc index 1d2d30bb..1b5b183a 100644 --- a/src/FeatureConfig.cc +++ b/src/FeatureConfig.cc @@ -67,7 +67,9 @@ #ifdef HAVE_SYS_UTSNAME_H # include #endif // HAVE_SYS_UTSNAME_H - +#ifdef HAVE_LIBSSH2 +# include +#endif // HAVE_LIBSSH2 #include "util.h" namespace aria2 { @@ -80,6 +82,8 @@ uint16_t getDefaultPort(const std::string& protocol) return 443; } else if(protocol == "ftp") { return 21; + } else if(protocol == "sftp") { + return 22; } else { return 0; } @@ -166,6 +170,14 @@ const char* strSupportedFeature(int feature) #endif // !ENABLE_XML_RPC break; + case(FEATURE_SFTP): +#ifdef HAVE_LIBSSH2 + return "SFTP"; +#else // !HAVE_LIBSSH2 + return nullptr; +#endif // !HAVE_LIBSSH2 + break; + default: return nullptr; } @@ -221,6 +233,11 @@ std::string usedLibs() #ifdef HAVE_LIBCARES res += "c-ares/" ARES_VERSION_STR " "; #endif // HAVE_LIBCARES + +#ifdef HAVE_LIBSSH2 + res += "libssh2/" LIBSSH2_VERSION " "; +#endif // HAVE_LIBSSH2 + if(!res.empty()) { res.erase(res.length()-1); } diff --git a/src/FeatureConfig.h b/src/FeatureConfig.h index 9c0a2798..214dca25 100644 --- a/src/FeatureConfig.h +++ b/src/FeatureConfig.h @@ -53,6 +53,7 @@ enum FeatureType { FEATURE_MESSAGE_DIGEST, FEATURE_METALINK, FEATURE_XML_RPC, + FEATURE_SFTP, MAX_FEATURE }; diff --git a/src/FtpInitiateConnectionCommand.cc b/src/FtpInitiateConnectionCommand.cc index 52fc07a4..ae3e20e1 100644 --- a/src/FtpInitiateConnectionCommand.cc +++ b/src/FtpInitiateConnectionCommand.cc @@ -60,6 +60,10 @@ #include "FtpNegotiationConnectChain.h" #include "FtpTunnelRequestConnectChain.h" #include "HttpRequestConnectChain.h" +#ifdef HAVE_LIBSSH2 +# include "SftpNegotiationConnectChain.h" +# include "SftpNegotiationCommand.h" +#endif // HAVE_LIBSSH2 namespace aria2 { @@ -82,6 +86,7 @@ std::unique_ptr FtpInitiateConnectionCommand::createNextCommandProxied std::shared_ptr pooledSocket; std::string proxyMethod = resolveProxyMethod(getRequest()->getProtocol()); + // sftp always use tunnel mode if(proxyMethod == V_GET) { pooledSocket = getDownloadEngine()->popPooledSocket (getRequest()->getHost(), getRequest()->getPort(), @@ -127,18 +132,33 @@ std::unique_ptr FtpInitiateConnectionCommand::createNextCommandProxied setConnectedAddrInfo(getRequest(), hostname, pooledSocket); if(proxyMethod == V_TUNNEL) { +#ifdef HAVE_LIBSSH2 + if (getRequest()->getProtocol() == "sftp") { + return make_unique + (getCuid(), + getRequest(), + getFileEntry(), + getRequestGroup(), + getDownloadEngine(), + pooledSocket, + SftpNegotiationCommand::SEQ_SFTP_OPEN); + } +#endif // HAVE_LIBSSH2 + // options contains "baseWorkingDir" return make_unique (getCuid(), - getRequest(), - getFileEntry(), - getRequestGroup(), - getDownloadEngine(), - pooledSocket, - FtpNegotiationCommand::SEQ_SEND_CWD_PREP, - options); + getRequest(), + getFileEntry(), + getRequestGroup(), + getDownloadEngine(), + pooledSocket, + FtpNegotiationCommand::SEQ_SEND_CWD_PREP, + options); } + assert(getRequest()->getProtocol() == "ftp"); + if(proxyMethod != V_GET) { assert(0); return nullptr; @@ -180,30 +200,51 @@ std::unique_ptr FtpInitiateConnectionCommand::createNextCommandPlain getSocket()->establishConnection(addr, port); getRequest()->setConnectedAddrInfo(hostname, addr, port); auto c = make_unique(getCuid(), - getRequest(), - nullptr, - getFileEntry(), - getRequestGroup(), - getDownloadEngine(), - getSocket()); + getRequest(), + nullptr, + getFileEntry(), + getRequestGroup(), + getDownloadEngine(), + getSocket()); - c->setControlChain(std::make_shared()); + if(getRequest()->getProtocol() == "sftp") { +#ifdef HAVE_LIBSSH2 + c->setControlChain(std::make_shared()); +#else // !HAVE_LIBSSH2 + assert(0); +#endif // !HAVE_LIBSSH2 + } else { + c->setControlChain(std::make_shared()); + } setupBackupConnection(hostname, addr, port, c.get()); return std::move(c); } - // options contains "baseWorkingDir" - auto command = make_unique - (getCuid(), - getRequest(), - getFileEntry(), - getRequestGroup(), - getDownloadEngine(), - pooledSocket, - FtpNegotiationCommand::SEQ_SEND_CWD_PREP, - options); setConnectedAddrInfo(getRequest(), hostname, pooledSocket); - return std::move(command); + +#ifdef HAVE_LIBSSH2 + if (getRequest()->getProtocol() == "sftp") { + return make_unique + (getCuid(), + getRequest(), + getFileEntry(), + getRequestGroup(), + getDownloadEngine(), + pooledSocket, + SftpNegotiationCommand::SEQ_SFTP_OPEN); + } +#endif // HAVE_LIBSSH2 + + // options contains "baseWorkingDir" + return make_unique + (getCuid(), + getRequest(), + getFileEntry(), + getRequestGroup(), + getDownloadEngine(), + pooledSocket, + FtpNegotiationCommand::SEQ_SEND_CWD_PREP, + options); } std::unique_ptr FtpInitiateConnectionCommand::createNextCommand diff --git a/src/FtpTunnelResponseCommand.cc b/src/FtpTunnelResponseCommand.cc index ed626f0d..057a5e89 100644 --- a/src/FtpTunnelResponseCommand.cc +++ b/src/FtpTunnelResponseCommand.cc @@ -40,6 +40,9 @@ #include "Segment.h" #include "SocketCore.h" #include "SocketRecvBuffer.h" +#ifdef HAVE_LIBSSH2 +# include "SftpNegotiationCommand.h" +#endif // HAVE_LIBSSH2 namespace aria2 { @@ -59,6 +62,15 @@ FtpTunnelResponseCommand::~FtpTunnelResponseCommand() {} std::unique_ptr FtpTunnelResponseCommand::getNextCommand() { +#ifdef HAVE_LIBSSH2 + if (getRequest()->getProtocol() == "sftp") { + return make_unique + (getCuid(), getRequest(), getFileEntry(), + getRequestGroup(), getDownloadEngine(), + getSocket()); + } +#endif // HAVE_LIBSSH2 + return make_unique (getCuid(), getRequest(), getFileEntry(), getRequestGroup(), getDownloadEngine(), diff --git a/src/InitiateConnectionCommandFactory.cc b/src/InitiateConnectionCommandFactory.cc index 2856b244..b4615d7a 100644 --- a/src/InitiateConnectionCommandFactory.cc +++ b/src/InitiateConnectionCommandFactory.cc @@ -71,10 +71,14 @@ InitiateConnectionCommandFactory::createInitiateConnectionCommand return make_unique(cuid, req, fileEntry, requestGroup, e); - } else if(req->getProtocol() == "ftp") { + } else if(req->getProtocol() == "ftp" +#ifdef HAVE_LIBSSH2 + || req->getProtocol() == "sftp" +#endif // HAVE_LIBSSH2 + ) { if(req->getFile().empty()) { throw DL_ABORT_EX - (fmt("FTP URI %s doesn't contain file path.", + (fmt("FTP/SFTP URI %s doesn't contain file path.", req->getUri().c_str())); } return make_unique(cuid, req, fileEntry, diff --git a/src/Makefile.am b/src/Makefile.am index 236a055e..bfa99857 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -427,6 +427,14 @@ SRCS += \ Sqlite3CookieParserImpl.cc Sqlite3CookieParserImpl.h endif # HAVE_SQLITE3 +if HAVE_LIBSSH2 +SRCS += SSHSession.cc SSHSession.h \ + SftpNegotiationCommand.cc SftpNegotiationCommand.h \ + SftpNegotiationConnectChain.h \ + SftpDownloadCommand.cc SftpDownloadCommand.h \ + SftpFinishDownloadCommand.cc SftpFinishDownloadCommand.h +endif # HAVE_LIBSSH2 + if ENABLE_ASYNC_DNS SRCS += \ AsyncNameResolver.cc AsyncNameResolver.h\ diff --git a/src/MetalinkParserController.cc b/src/MetalinkParserController.cc index 91ce7053..fec63aa0 100644 --- a/src/MetalinkParserController.cc +++ b/src/MetalinkParserController.cc @@ -193,7 +193,7 @@ void MetalinkParserController::setTypeOfResource(std::string type) if(!tResource_) { return; } - if(type == "ftp") { + if(type == "ftp" || type == "sftp") { tResource_->type = MetalinkResource::TYPE_FTP; } else if(type == "http") { tResource_->type = MetalinkResource::TYPE_HTTP; diff --git a/src/Platform.cc b/src/Platform.cc index d93c48eb..8520440d 100644 --- a/src/Platform.cc +++ b/src/Platform.cc @@ -56,6 +56,10 @@ # include #endif // ENABLE_ASYNC_DNS +#ifdef HAVE_LIBSSH2 +# include +#endif // HAVE_LIBSSH2 + #include "a2netcompat.h" #include "DlAbortEx.h" #include "message.h" @@ -149,6 +153,15 @@ bool Platform::setUp() } #endif // CARES_HAVE_ARES_LIBRARY_INIT +#ifdef HAVE_LIBSSH2 + { + auto rv = libssh2_init(0); + if (rv != 0) { + throw DL_ABORT_EX(fmt("libssh2_init() failed, code: %d", rv)); + } + } +#endif // HAVE_LIBSSH2 + #ifdef HAVE_WINSOCK2_H WSADATA wsaData; memset(reinterpret_cast(&wsaData), 0, sizeof(wsaData)); @@ -181,6 +194,10 @@ bool Platform::tearDown() ares_library_cleanup(); #endif // CARES_HAVE_ARES_LIBRARY_CLEANUP +#ifdef HAVE_LIBSSH2 + libssh2_exit(); +#endif // HAVE_LIBSSH2 + #ifdef HAVE_WINSOCK2_H WSACleanup(); #endif // HAVE_WINSOCK2_H diff --git a/src/SSHSession.cc b/src/SSHSession.cc new file mode 100644 index 00000000..1b939de6 --- /dev/null +++ b/src/SSHSession.cc @@ -0,0 +1,240 @@ +/* */ +#include "SSHSession.h" + +#include + +namespace aria2 { + +SSHSession::SSHSession() + : ssh2_(nullptr), + sftp_(nullptr), + sftph_(nullptr), + fd_(-1) +{} + +SSHSession::~SSHSession() +{ + closeConnection(); +} + +int SSHSession::closeConnection() +{ + if (sftph_) { + // TODO this could return LIBSSH2_ERROR_EAGAIN + libssh2_sftp_close(sftph_); + sftph_ = nullptr; + } + if (sftp_) { + // TODO this could return LIBSSH2_ERROR_EAGAIN + libssh2_sftp_shutdown(sftp_); + sftp_ = nullptr; + } + if (ssh2_) { + // TODO this could return LIBSSH2_ERROR_EAGAIN + libssh2_session_disconnect(ssh2_, "bye"); + libssh2_session_free(ssh2_); + ssh2_ = nullptr; + } + return SSH_ERR_OK; +} + +int SSHSession::gracefulShutdown() +{ + if (sftph_) { + auto rv = libssh2_sftp_close(sftph_); + if (rv == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (rv != 0) { + return SSH_ERR_ERROR; + } + sftph_ = nullptr; + } + if (sftp_) { + auto rv = libssh2_sftp_shutdown(sftp_); + if (rv == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (rv != 0) { + return SSH_ERR_ERROR; + } + sftp_ = nullptr; + } + if (ssh2_) { + auto rv = libssh2_session_disconnect(ssh2_, "bye"); + if (rv == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (rv != 0) { + return SSH_ERR_ERROR; + } + libssh2_session_free(ssh2_); + ssh2_ = nullptr; + } + return SSH_ERR_OK; +} + +int SSHSession::sftpClose() +{ + if (!sftph_) { + return SSH_ERR_OK; + } + + auto rv = libssh2_sftp_close(sftph_); + if (rv == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (rv != 0) { + return SSH_ERR_ERROR; + } + sftph_ = nullptr; + return SSH_ERR_OK; +} + +int SSHSession::init(sock_t sockfd) +{ + ssh2_ = libssh2_session_init(); + if (!ssh2_) { + return SSH_ERR_ERROR; + } + libssh2_session_set_blocking(ssh2_, 0); + fd_ = sockfd; + return SSH_ERR_OK; +} + + +int SSHSession::checkDirection() +{ + auto dir = libssh2_session_block_directions(ssh2_); + if (dir & LIBSSH2_SESSION_BLOCK_OUTBOUND) { + return SSH_WANT_WRITE; + } + return SSH_WANT_READ; +} + +ssize_t SSHSession::writeData(const void* data, size_t len) +{ + // net implemented yet + assert(0); +} + +ssize_t SSHSession::readData(void* data, size_t len) +{ + auto nread = libssh2_sftp_read(sftph_, static_cast(data), len); + if (nread == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (nread < 0) { + return SSH_ERR_ERROR; + } + return nread; +} + +int SSHSession::handshake() +{ + auto rv = libssh2_session_handshake(ssh2_, fd_); + if (rv == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (rv != 0) { + return SSH_ERR_ERROR; + } + // TODO we have to validate server's fingerprint + return SSH_ERR_OK; +} + +int SSHSession::authPassword(const std::string& user, + const std::string& password) +{ + auto rv = libssh2_userauth_password(ssh2_, user.c_str(), password.c_str()); + if (rv == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (rv != 0) { + return SSH_ERR_ERROR; + } + return SSH_ERR_OK; +} + +int SSHSession::sftpOpen(const std::string& path) +{ + if (!sftp_) { + sftp_ = libssh2_sftp_init(ssh2_); + if (!sftp_) { + if (libssh2_session_last_errno(ssh2_) == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + return SSH_ERR_ERROR; + } + } + if (!sftph_) { + sftph_ = libssh2_sftp_open(sftp_, path.c_str(), LIBSSH2_FXF_READ, 0); + if (!sftph_) { + if (libssh2_session_last_errno(ssh2_) == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + return SSH_ERR_ERROR; + } + } + return SSH_ERR_OK; +} + +int SSHSession::sftpStat(int64_t& totalLength, time_t& mtime) +{ + LIBSSH2_SFTP_ATTRIBUTES attrs; + auto rv = libssh2_sftp_fstat_ex(sftph_, &attrs, 0); + if (rv == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (rv != 0) { + return SSH_ERR_ERROR; + } + totalLength = attrs.filesize; + mtime = attrs.mtime; + return SSH_ERR_OK; +} + +std::string SSHSession::getLastErrorString() +{ + if (!ssh2_) { + return "SSH session has not been initialized yet"; + } + char* errmsg; + libssh2_session_last_error(ssh2_, &errmsg, nullptr, 0); + return errmsg; +} + +} // namespace aria2 diff --git a/src/SSHSession.h b/src/SSHSession.h new file mode 100644 index 00000000..31941e5d --- /dev/null +++ b/src/SSHSession.h @@ -0,0 +1,137 @@ +/* */ +#ifndef SSH_SESSION_H +#define SSH_SESSION_H + +#include "common.h" +#include "a2netcompat.h" + +#include + +#include +#include + +namespace aria2 { + +enum SSHDirection { + SSH_WANT_READ = 1, + SSH_WANT_WRITE +}; + +enum SSHErrorCode { + SSH_ERR_OK = 0, + SSH_ERR_ERROR = -1, + SSH_ERR_WOULDBLOCK = -2 +}; + +class SSHSession { +public: + SSHSession(); + + // MUST deallocate all resources + ~SSHSession(); + + SSHSession(const SSHSession&) = delete; + SSHSession& operator=(const SSHSession&) = delete; + + // Initializes SSH session. The |sockfd| is the underlying + // transport socket. This function returns SSH_ERR_OK if it + // succeeds, or SSH_ERR_ERROR. + int init(sock_t sockfd); + + // Closes the SSH session. Don't close underlying transport + // socket. This function returns SSH_ERR_OK if it succeeds, or + // SSH_ERR_ERROR. + int closeConnection(); + + int gracefulShutdown(); + + // Returns SSH_WANT_READ if SSH session needs more data from remote + // endpoint to proceed, or SSH_WANT_WRITE if SSH session needs to + // write more data to proceed. If SSH session needs neither read nor + // write data at the moment, SSH_WANT_READ must be returned. + int checkDirection(); + + // Sends |data| with length |len|. This function returns the number + // of bytes sent if it succeeds, or SSH_ERR_WOULDBLOCK if the + // underlying transport blocks, or SSH_ERR_ERROR. + ssize_t writeData(const void* data, size_t len); + + // Receives data into |data| with length |len|. This function + // returns the number of bytes received if it succeeds, or + // SSH_ERR_WOULDBLOCK if the underlying transport blocks, or + // SSH_ERR_ERROR. + ssize_t readData(void* data, size_t len); + + // Performs handshake. This function returns SSH_ERR_OK + // if it succeeds, or SSH_ERR_WOULDBLOCK if the underlying transport + // blocks, or SSH_ERR_ERROR. + int handshake(); + + // Performs authentication using username and password. This + // function returns SSH_ERR_OK if it succeeds, or SSH_ERR_WOULDBLOCK + // if the underlying transport blocks, or SSH_ERR_ERROR. + int authPassword(const std::string& user, const std::string& password); + + // Starts SFTP session and opens remote file |path|. This function + // returns SSH_ERR_OK if it succeeds, or SSH_ERR_WOULDBLOCK if the + // underlying transport blocks, or SSH_ERR_ERROR. + int sftpOpen(const std::string& path); + + // Closes remote file opened by sftpOpen(). This function returns + // SSH_ERR_OK if it succeeds, or SSH_ERR_WOULDBLOCK if the + // underlying transport blocks, or SSH_ERR_ERROR. + int sftpClose(); + + // Gets total length and modified time of opened file by sftpOpen(). + // On success, total length and modified time are assigned to + // |totalLength| and |mtime|. This function returns SSH_ERR_OK if + // it succeeds, or SSH_ERR_WOULDBLOCK if the underlying transport + // blocks, or SSH_ERR_ERROR. + int sftpStat(int64_t& totalLength, time_t& mtime); + + // Returns last error string + std::string getLastErrorString(); + +private: + LIBSSH2_SESSION* ssh2_; + LIBSSH2_SFTP* sftp_; + LIBSSH2_SFTP_HANDLE* sftph_; + sock_t fd_; +}; + +} + +#endif // SSH_SESSION_H diff --git a/src/SftpDownloadCommand.cc b/src/SftpDownloadCommand.cc new file mode 100644 index 00000000..84bfa913 --- /dev/null +++ b/src/SftpDownloadCommand.cc @@ -0,0 +1,99 @@ +/* */ +#include "SftpDownloadCommand.h" +#include "Request.h" +#include "SocketCore.h" +#include "Segment.h" +#include "DownloadEngine.h" +#include "RequestGroup.h" +#include "Option.h" +#include "FileEntry.h" +#include "SocketRecvBuffer.h" +#include "AuthConfig.h" +#include "SftpFinishDownloadCommand.h" + +namespace aria2 { + +SftpDownloadCommand::SftpDownloadCommand +(cuid_t cuid, + const std::shared_ptr& req, + const std::shared_ptr& fileEntry, + RequestGroup* requestGroup, + DownloadEngine* e, + const std::shared_ptr& socket, + std::unique_ptr authConfig) + : DownloadCommand(cuid, req, fileEntry, requestGroup, e, socket, + std::make_shared(socket)), + authConfig_(std::move(authConfig)) +{ + setWriteCheckSocket(getSocket()); +} + +SftpDownloadCommand::~SftpDownloadCommand() {} + +bool SftpDownloadCommand::prepareForNextSegment() +{ + if(getOption()->getAsBool(PREF_FTP_REUSE_CONNECTION) && + getFileEntry()->gtoloff(getSegments().front()->getPositionToWrite()) == + getFileEntry()->getLength()) { + + auto c = make_unique + (getCuid(), getRequest(), getFileEntry(), getRequestGroup(), + getDownloadEngine(), getSocket()); + + c->setStatus(Command::STATUS_ONESHOT_REALTIME); + getDownloadEngine()->setNoWait(true); + getDownloadEngine()->addCommand(std::move(c)); + + if(getRequestGroup()->downloadFinished()) { + // To run checksum checking, we had to call following function here. + DownloadCommand::prepareForNextSegment(); + } + return true; + } + + return DownloadCommand::prepareForNextSegment(); +} + +int64_t SftpDownloadCommand::getRequestEndOffset() const +{ + return getFileEntry()->getLength(); +} + +bool SftpDownloadCommand::shouldEnableWriteCheck() { + return getSocket()->wantWrite() || !getSocket()->wantRead(); +} + +} // namespace aria2 diff --git a/src/SftpDownloadCommand.h b/src/SftpDownloadCommand.h new file mode 100644 index 00000000..cbb6372f --- /dev/null +++ b/src/SftpDownloadCommand.h @@ -0,0 +1,66 @@ +/* */ +#ifndef D_SFTP_DOWNLOAD_COMMAND_H +#define D_SFTP_DOWNLOAD_COMMAND_H + +#include "DownloadCommand.h" + +namespace aria2 { + +class AuthConfig; + +class SftpDownloadCommand : public DownloadCommand { +private: + std::unique_ptr authConfig_; + +protected: + virtual bool prepareForNextSegment() CXX11_OVERRIDE; + virtual int64_t getRequestEndOffset() const CXX11_OVERRIDE; + virtual bool shouldEnableWriteCheck() CXX11_OVERRIDE; + +public: + SftpDownloadCommand(cuid_t cuid, + const std::shared_ptr& req, + const std::shared_ptr& fileEntry, + RequestGroup* requestGroup, + DownloadEngine* e, + const std::shared_ptr& socket, + std::unique_ptr authConfig); + virtual ~SftpDownloadCommand(); +}; + +} // namespace aria2 + +#endif // D_SFTP_DOWNLOAD_COMMAND_H diff --git a/src/SftpFinishDownloadCommand.cc b/src/SftpFinishDownloadCommand.cc new file mode 100644 index 00000000..399f2755 --- /dev/null +++ b/src/SftpFinishDownloadCommand.cc @@ -0,0 +1,119 @@ +/* */ +#include "SftpFinishDownloadCommand.h" + +#include "Request.h" +#include "DownloadEngine.h" +#include "prefs.h" +#include "Option.h" +#include "message.h" +#include "fmt.h" +#include "DlAbortEx.h" +#include "SocketCore.h" +#include "RequestGroup.h" +#include "Logger.h" +#include "LogFactory.h" +#include "wallclock.h" +#include "AuthConfigFactory.h" +#include "AuthConfig.h" + +namespace aria2 { + +SftpFinishDownloadCommand::SftpFinishDownloadCommand +(cuid_t cuid, + const std::shared_ptr& req, + const std::shared_ptr& fileEntry, + RequestGroup* requestGroup, + DownloadEngine* e, + const std::shared_ptr& socket) + : AbstractCommand(cuid, req, fileEntry, requestGroup, e, socket) +{ + disableReadCheckSocket(); + setWriteCheckSocket(getSocket()); +} + +SftpFinishDownloadCommand::~SftpFinishDownloadCommand() {} + +// overrides AbstractCommand::execute(). +// AbstractCommand::_segments is empty. +bool SftpFinishDownloadCommand::execute() +{ + if(getRequestGroup()->isHaltRequested()) { + return true; + } + try { + if(readEventEnabled() || writeEventEnabled() || hupEventEnabled()) { + getCheckPoint() = global::wallclock(); + + if (!getSocket()->sshSFTPClose()) { + setWriteCheckSocketIf(getSocket(), getSocket()->wantWrite()); + setReadCheckSocketIf(getSocket(), getSocket()->wantRead()); + addCommandSelf(); + return false; + } + + auto authConfig = + getDownloadEngine()->getAuthConfigFactory()->createAuthConfig + (getRequest(), getRequestGroup()->getOption().get()); + + getDownloadEngine()->poolSocket + (getRequest(), authConfig->getUser(), createProxyRequest(), + getSocket(), ""); + } else if(getCheckPoint().difference(global::wallclock()) >= getTimeout()) { + A2_LOG_INFO(fmt("CUID#%" PRId64 + " - Timeout before receiving transfer complete.", + getCuid())); + } else { + addCommandSelf(); + return false; + } + } catch(RecoverableException& e) { + A2_LOG_INFO_EX(fmt("CUID#%" PRId64 + " - Exception was thrown, but download was" + " finished, so we can ignore the exception.", + getCuid()), + e); + } + if(getRequestGroup()->downloadFinished()) { + return true; + } else { + return prepareForRetry(0); + } +} + +// This function never be called. +bool SftpFinishDownloadCommand::executeInternal() { return true; } + +} // namespace aria2 diff --git a/src/SftpFinishDownloadCommand.h b/src/SftpFinishDownloadCommand.h new file mode 100644 index 00000000..fbdd86f4 --- /dev/null +++ b/src/SftpFinishDownloadCommand.h @@ -0,0 +1,59 @@ +/* */ +#ifndef D_SFTP_FINISH_DOWNLOAD_COMMAND_H +#define D_SFTP_FINISH_DOWNLOAD_COMMAND_H + +#include "AbstractCommand.h" + +namespace aria2 { + +class SftpFinishDownloadCommand : public AbstractCommand { +protected: + virtual bool execute() CXX11_OVERRIDE; + + virtual bool executeInternal() CXX11_OVERRIDE; +public: + SftpFinishDownloadCommand(cuid_t cuid, + const std::shared_ptr& req, + const std::shared_ptr& fileEntry, + RequestGroup* requestGroup, + DownloadEngine* e, + const std::shared_ptr& socket); + virtual ~SftpFinishDownloadCommand(); +}; + +} // namespace aria2 + +#endif // D_SFTP_FINISH_DOWNLOAD_COMMAND_H diff --git a/src/SftpNegotiationCommand.cc b/src/SftpNegotiationCommand.cc new file mode 100644 index 00000000..5a477025 --- /dev/null +++ b/src/SftpNegotiationCommand.cc @@ -0,0 +1,320 @@ +/* */ +#include "SftpNegotiationCommand.h" + +#include +#include + +#include "Request.h" +#include "DownloadEngine.h" +#include "RequestGroup.h" +#include "PieceStorage.h" +#include "FileEntry.h" +#include "message.h" +#include "util.h" +#include "Option.h" +#include "Logger.h" +#include "LogFactory.h" +#include "Segment.h" +#include "DownloadContext.h" +#include "DefaultBtProgressInfoFile.h" +#include "RequestGroupMan.h" +#include "SocketCore.h" +#include "fmt.h" +#include "DiskAdaptor.h" +#include "SegmentMan.h" +#include "AuthConfigFactory.h" +#include "AuthConfig.h" +#include "a2functional.h" +#include "URISelector.h" +#include "CheckIntegrityEntry.h" +#include "NullProgressInfoFile.h" +#include "ChecksumCheckIntegrityEntry.h" +#include "SftpDownloadCommand.h" + +namespace aria2 { + +SftpNegotiationCommand::SftpNegotiationCommand +(cuid_t cuid, + const std::shared_ptr& req, + const std::shared_ptr& fileEntry, + RequestGroup* requestGroup, + DownloadEngine* e, + const std::shared_ptr& socket, + Seq seq) + : AbstractCommand(cuid, req, fileEntry, requestGroup, e, socket), + sequence_(seq), + authConfig_(e->getAuthConfigFactory()->createAuthConfig + (req, requestGroup->getOption().get())) + +{ + path_ = getPath(); + disableReadCheckSocket(); + setWriteCheckSocket(getSocket()); +} + +SftpNegotiationCommand::~SftpNegotiationCommand() {} + +bool SftpNegotiationCommand::executeInternal() { + disableWriteCheckSocket(); + for (;;) { + switch(sequence_) { + case SEQ_HANDSHAKE: + setReadCheckSocket(getSocket()); + if (!getSocket()->sshHandshake()) { + goto again; + } + A2_LOG_DEBUG(fmt("CUID#%" PRId64 " - SSH handshake success", getCuid())); + sequence_ = SEQ_AUTH_PASSWORD; + break; + case SEQ_AUTH_PASSWORD: + if (!getSocket()->sshAuthPassword(authConfig_->getUser(), + authConfig_->getPassword())) { + goto again; + } + A2_LOG_DEBUG(fmt("CUID#%" PRId64 " - SSH authentication success", + getCuid())); + sequence_ = SEQ_SFTP_OPEN; + break; + case SEQ_SFTP_OPEN: { + if (!getSocket()->sshSFTPOpen(path_)) { + goto again; + } + A2_LOG_DEBUG(fmt("CUID#%" PRId64 " - SFTP file %s opened", getCuid(), + path_.c_str())); + sequence_ = SEQ_SFTP_STAT; + break; + } + case SEQ_SFTP_STAT: { + int64_t totalLength; + time_t mtime; + if (!getSocket()->sshSFTPStat(totalLength, mtime, path_)) { + goto again; + } + Time t(mtime); + A2_LOG_INFO(fmt("CUID#%" PRId64 " - SFTP File %s, size=%" PRId64 + ", mtime=%s", + getCuid(), path_.c_str(), totalLength, + t.toHTTPDate().c_str())); + if (!getPieceStorage()) { + getRequestGroup()->updateLastModifiedTime(Time(mtime)); + onFileSizeDetermined(totalLength); + } else { + getRequestGroup()->validateTotalLength(getFileEntry()->getLength(), + totalLength); + sequence_ = SEQ_NEGOTIATION_COMPLETED; + } + break; + } + case SEQ_FILE_PREPARATION: + sequence_ = SEQ_NEGOTIATION_COMPLETED; + disableReadCheckSocket(); + disableWriteCheckSocket(); + return false; + case SEQ_NEGOTIATION_COMPLETED: { + auto command = make_unique + (getCuid(), getRequest(), getFileEntry(), getRequestGroup(), + getDownloadEngine(), getSocket(), std::move(authConfig_)); + command->setStartupIdleTime + (getOption()->getAsInt(PREF_STARTUP_IDLE_TIME)); + command->setLowestDownloadSpeedLimit + (getOption()->getAsInt(PREF_LOWEST_SPEED_LIMIT)); + command->setStatus(Command::STATUS_ONESHOT_REALTIME); + + getDownloadEngine()->setNoWait(true); + + if(getFileEntry()->isUniqueProtocol()) { + getFileEntry()->removeURIWhoseHostnameIs(getRequest()->getHost()); + } + getRequestGroup()->getURISelector()->tuneDownloadCommand + (getFileEntry()->getRemainingUris(), command.get()); + getDownloadEngine()->addCommand(std::move(command)); + return true; + } + case SEQ_DOWNLOAD_ALREADY_COMPLETED: + case SEQ_HEAD_OK: + case SEQ_EXIT: + return true; + }; + } + again: + addCommandSelf(); + if (getSocket()->wantWrite()) { + setWriteCheckSocket(getSocket()); + } + return false; +} + +void SftpNegotiationCommand::onFileSizeDetermined(int64_t totalLength) +{ + getFileEntry()->setLength(totalLength); + if(getFileEntry()->getPath().empty()) { + auto suffixPath = util::createSafePath + (util::percentDecode(std::begin(getRequest()->getFile()), + std::end(getRequest()->getFile()))); + + getFileEntry()->setPath + (util::applyDir(getOption()->get(PREF_DIR), suffixPath)); + getFileEntry()->setSuffixPath(suffixPath); + } + getRequestGroup()->preDownloadProcessing(); + + if(totalLength == 0) { + sequence_ = SEQ_NEGOTIATION_COMPLETED; + + if(getOption()->getAsBool(PREF_DRY_RUN)) { + getRequestGroup()->initPieceStorage(); + onDryRunFileFound(); + return; + } + + if(getDownloadContext()->knowsTotalLength() && + getRequestGroup()->downloadFinishedByFileLength()) { + // TODO Known issue: if .aria2 file exists, it will not be + // deleted on successful verification, because .aria2 file is + // not loaded. See also + // HttpResponseCommand::handleOtherEncoding() + getRequestGroup()->initPieceStorage(); + if(getDownloadContext()->isChecksumVerificationNeeded()) { + A2_LOG_DEBUG("Zero length file exists. Verify checksum."); + auto entry = make_unique + (getRequestGroup()); + entry->initValidator(); + getPieceStorage()->getDiskAdaptor()->openExistingFile(); + getDownloadEngine()->getCheckIntegrityMan()->pushEntry + (std::move(entry)); + sequence_ = SEQ_EXIT; + } + else { + getPieceStorage()->markAllPiecesDone(); + getDownloadContext()->setChecksumVerified(true); + sequence_ = SEQ_DOWNLOAD_ALREADY_COMPLETED; + A2_LOG_NOTICE + (fmt(MSG_DOWNLOAD_ALREADY_COMPLETED, + GroupId::toHex(getRequestGroup()->getGID()).c_str(), + getRequestGroup()->getFirstFilePath().c_str())); + } + poolConnection(); + return; + } + + getRequestGroup()->adjustFilename + (std::make_shared()); + getRequestGroup()->initPieceStorage(); + getPieceStorage()->getDiskAdaptor()->initAndOpenFile(); + + if(getDownloadContext()->knowsTotalLength()) { + A2_LOG_DEBUG("File length becomes zero and it means download completed."); + // TODO Known issue: if .aria2 file exists, it will not be + // deleted on successful verification, because .aria2 file is + // not loaded. See also + // HttpResponseCommand::handleOtherEncoding() + if(getDownloadContext()->isChecksumVerificationNeeded()) { + A2_LOG_DEBUG("Verify checksum for zero-length file"); + auto entry = make_unique + (getRequestGroup()); + entry->initValidator(); + getDownloadEngine()->getCheckIntegrityMan()->pushEntry + (std::move(entry)); + sequence_ = SEQ_EXIT; + } else + { + sequence_ = SEQ_DOWNLOAD_ALREADY_COMPLETED; + getPieceStorage()->markAllPiecesDone(); + } + poolConnection(); + return; + } + // We have to make sure that command that has Request object must + // have segment after PieceStorage is initialized. See + // AbstractCommand::execute() + getSegmentMan()->getSegmentWithIndex(getCuid(), 0); + return; + } else { + auto progressInfoFile = std::make_shared + (getDownloadContext(), nullptr, getOption().get()); + getRequestGroup()->adjustFilename(progressInfoFile); + getRequestGroup()->initPieceStorage(); + + if(getOption()->getAsBool(PREF_DRY_RUN)) { + onDryRunFileFound(); + return; + } + + auto checkIntegrityEntry = getRequestGroup()->createCheckIntegrityEntry(); + if(!checkIntegrityEntry) { + sequence_ = SEQ_DOWNLOAD_ALREADY_COMPLETED; + poolConnection(); + return; + } + checkIntegrityEntry->pushNextCommand(std::unique_ptr(this)); + // We have to make sure that command that has Request object must + // have segment after PieceStorage is initialized. See + // AbstractCommand::execute() + getSegmentMan()->getSegmentWithIndex(getCuid(), 0); + + prepareForNextAction(std::move(checkIntegrityEntry)); + + disableReadCheckSocket(); + + sequence_ = SEQ_FILE_PREPARATION; + } +} + +void SftpNegotiationCommand::poolConnection() const +{ + if(getOption()->getAsBool(PREF_FTP_REUSE_CONNECTION)) { + // TODO we don't need options. Probably, we need to pool socket + // using scheme, port and auth info as key + getDownloadEngine()->poolSocket(getRequest(), authConfig_->getUser(), + createProxyRequest(), getSocket(), ""); + } +} + +void SftpNegotiationCommand::onDryRunFileFound() +{ + getPieceStorage()->markAllPiecesDone(); + getDownloadContext()->setChecksumVerified(true); + poolConnection(); + sequence_ = SEQ_HEAD_OK; +} + +std::string SftpNegotiationCommand::getPath() const { + auto &req = getRequest(); + auto path = req->getDir() + req->getFile(); + return util::percentDecode(std::begin(path), std::end(path)); +} + +} // namespace aria2 diff --git a/src/SftpNegotiationCommand.h b/src/SftpNegotiationCommand.h new file mode 100644 index 00000000..e6b77ab6 --- /dev/null +++ b/src/SftpNegotiationCommand.h @@ -0,0 +1,87 @@ +/* */ +#ifndef D_SFTP_NEGOTIATION_COMMAND_H +#define D_SFTP_NEGOTIATION_COMMAND_H + +#include "AbstractCommand.h" + +namespace aria2 { + +class SocketCore; +class AuthConfig; + +class SftpNegotiationCommand : public AbstractCommand { +public: + enum Seq { + SEQ_HANDSHAKE, + SEQ_AUTH_PASSWORD, + SEQ_SFTP_OPEN, + SEQ_SFTP_STAT, + SEQ_NEGOTIATION_COMPLETED, + SEQ_DOWNLOAD_ALREADY_COMPLETED, + SEQ_HEAD_OK, + SEQ_FILE_PREPARATION, + SEQ_EXIT, + }; + +private: + void onFileSizeDetermined(int64_t totalLength); + void poolConnection() const; + void onDryRunFileFound(); + std::string getPath() const; + + std::shared_ptr socket_; + Seq sequence_; + std::unique_ptr authConfig_; + // remote file path + std::string path_; + +protected: + virtual bool executeInternal() CXX11_OVERRIDE; + +public: + SftpNegotiationCommand(cuid_t cuid, + const std::shared_ptr& req, + const std::shared_ptr& fileEntry, + RequestGroup* requestGroup, + DownloadEngine* e, + const std::shared_ptr& s, + Seq seq = SEQ_HANDSHAKE); + virtual ~SftpNegotiationCommand(); +}; + +} // namespace aria2 + +#endif // D_SFTP_NEGOTIATION_COMMAND_H diff --git a/src/SftpNegotiationConnectChain.h b/src/SftpNegotiationConnectChain.h new file mode 100644 index 00000000..e78081ce --- /dev/null +++ b/src/SftpNegotiationConnectChain.h @@ -0,0 +1,66 @@ +/* */ +#ifndef SFTP_NEGOTIATION_CONNECT_CHAIN_H +#define SFTP_NEGOTIATION_CONNECT_CHAIN_H + +#include "ControlChain.h" +#include "ConnectCommand.h" +#include "DownloadEngine.h" +#include "SftpNegotiationCommand.h" + +namespace aria2 { + +struct SftpNegotiationConnectChain : public ControlChain { + SftpNegotiationConnectChain() {} + virtual ~SftpNegotiationConnectChain() {} + virtual int run(ConnectCommand* t, DownloadEngine* e) CXX11_OVERRIDE + { + auto c = make_unique + (t->getCuid(), + t->getRequest(), + t->getFileEntry(), + t->getRequestGroup(), + t->getDownloadEngine(), + t->getSocket()); + c->setStatus(Command::STATUS_ONESHOT_REALTIME); + e->setNoWait(true); + e->addCommand(std::move(c)); + return 0; + } +}; + +} // namespace aria2 + +#endif // SFTP_NEGOTIATION_CONNECT_CHAIN_H diff --git a/src/SocketCore.cc b/src/SocketCore.cc index b42c7ecd..05ee28bf 100644 --- a/src/SocketCore.cc +++ b/src/SocketCore.cc @@ -45,6 +45,7 @@ #include #include +#include #include #include "message.h" @@ -60,6 +61,9 @@ # include "TLSContext.h" # include "TLSSession.h" #endif // ENABLE_SSL +#ifdef HAVE_LIBSSH2 +# include "SSHSession.h" +#endif // HAVE_LIBSSH2 namespace aria2 { @@ -608,6 +612,14 @@ void SocketCore::closeConnection() tlsSession_.reset(); } #endif // ENABLE_SSL + +#ifdef HAVE_LIBSSH2 + if(sshSession_) { + sshSession_->closeConnection(); + sshSession_.reset(); + } +#endif // HAVE_LIBSSH2 + if(sockfd_ != (sock_t) -1) { shutdown(sockfd_, SHUT_WR); CLOSE(sockfd_); @@ -796,7 +808,23 @@ void SocketCore::readData(void* data, size_t& len) wantRead_ = false; wantWrite_ = false; - if(!secure_) { + if(sshSession_) { +#ifdef HAVE_LIBSSH2 + ret = sshSession_->readData(data, len); + if(ret < 0) { + if(ret != SSH_ERR_WOULDBLOCK) { + throw DL_RETRY_EX(fmt(EX_SOCKET_RECV, + sshSession_->getLastErrorString().c_str())); + } + if(sshSession_->checkDirection() == SSH_WANT_READ) { + wantRead_ = true; + } else { + wantWrite_ = true; + } + ret = 0; + } +#endif // HAVE_LIBSSH2 + } else if(!secure_) { // Cast for Windows recv() while((ret = recv(sockfd_, reinterpret_cast(data), len, 0)) == -1 && SOCKET_ERRNO == A2_EINTR); @@ -957,6 +985,137 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname) #endif // ENABLE_SSL +#ifdef HAVE_LIBSSH2 + +bool SocketCore::sshHandshake() +{ + wantRead_ = false; + wantWrite_ = false; + + if (!sshSession_) { + sshSession_ = make_unique(); + if (sshSession_->init(sockfd_) == SSH_ERR_ERROR) { + throw DL_ABORT_EX("Could not create SSH session"); + } + } + auto rv = sshSession_->handshake(); + if (rv == SSH_ERR_WOULDBLOCK) { + sshCheckDirection(); + return false; + } + if (rv == SSH_ERR_ERROR) { + throw DL_ABORT_EX(fmt("SSH handshake failure: %s", + sshSession_->getLastErrorString().c_str())); + } + return true; +} + +bool SocketCore::sshAuthPassword(const std::string& user, + const std::string& password) +{ + assert(sshSession_); + + wantRead_ = false; + wantWrite_ = false; + + auto rv = sshSession_->authPassword(user, password); + if (rv == SSH_ERR_WOULDBLOCK) { + sshCheckDirection(); + return false; + } + if (rv == SSH_ERR_ERROR) { + throw DL_ABORT_EX(fmt("SSH authentication failure: %s", + sshSession_->getLastErrorString().c_str())); + } + return true; +} + +bool SocketCore::sshSFTPOpen(const std::string& path) +{ + assert(sshSession_); + + wantRead_ = false; + wantWrite_ = false; + + auto rv = sshSession_->sftpOpen(path); + if (rv == SSH_ERR_WOULDBLOCK) { + sshCheckDirection(); + return false; + } + if (rv == SSH_ERR_ERROR) { + throw DL_ABORT_EX(fmt("SSH opening SFTP path %s failed: %s", + path.c_str(), + sshSession_->getLastErrorString().c_str())); + } + return true; +} + +bool SocketCore::sshSFTPClose() +{ + assert(sshSession_); + + wantRead_ = false; + wantWrite_ = false; + + auto rv = sshSession_->sftpClose(); + if (rv == SSH_ERR_WOULDBLOCK) { + sshCheckDirection(); + return false; + } + if (rv == SSH_ERR_ERROR) { + throw DL_ABORT_EX(fmt("SSH closing SFTP failed: %s", + sshSession_->getLastErrorString().c_str())); + } + return true; +} + +bool SocketCore::sshSFTPStat(int64_t& totalLength, time_t& mtime, + const std::string& path) +{ + assert(sshSession_); + + wantRead_ = false; + wantWrite_ = false; + + auto rv = sshSession_->sftpStat(totalLength, mtime); + if (rv == SSH_ERR_WOULDBLOCK) { + sshCheckDirection(); + return false; + } + if (rv == SSH_ERR_ERROR) { + throw DL_ABORT_EX(fmt("SSH stat SFTP path %s filed: %s", + path.c_str(), + sshSession_->getLastErrorString().c_str())); + } + return true; +} + +bool SocketCore::sshGracefulShutdown() +{ + assert(sshSession_); + auto rv = sshSession_->gracefulShutdown(); + if (rv == SSH_ERR_WOULDBLOCK) { + sshCheckDirection(); + return false; + } + if (rv == SSH_ERR_ERROR) { + throw DL_ABORT_EX(fmt("SSH graceful shutdown failed: %s", + sshSession_->getLastErrorString().c_str())); + } + return true; +} + +void SocketCore::sshCheckDirection() +{ + if (sshSession_->checkDirection() == SSH_WANT_READ) { + wantRead_ = true; + } else { + wantWrite_ = true; + } +} + +#endif // HAVE_LIBSSH2 + ssize_t SocketCore::writeData(const void* data, size_t len, const std::string& host, uint16_t port) { diff --git a/src/SocketCore.h b/src/SocketCore.h index 35f0f83f..92bc67e3 100644 --- a/src/SocketCore.h +++ b/src/SocketCore.h @@ -55,6 +55,10 @@ class TLSContext; class TLSSession; #endif // ENABLE_SSL +#ifdef HAVE_LIBSSH2 +class SSHSession; +#endif // HAVE_LIBSSH2 + class SocketCore { friend bool operator==(const SocketCore& s1, const SocketCore& s2); friend bool operator!=(const SocketCore& s1, const SocketCore& s2); @@ -95,6 +99,12 @@ private: bool tlsHandshake(TLSContext* tlsctx, const std::string& hostname); #endif // ENABLE_SSL +#ifdef HAVE_LIBSSH2 + std::unique_ptr sshSession_; + + void sshCheckDirection(); +#endif // HAVE_LIBSSH2 + void init(); void bind(const struct sockaddr* addr, socklen_t addrlen); @@ -290,6 +300,22 @@ public: bool tlsConnect(const std::string& hostname); #endif // ENABLE_SSL +#ifdef HAVE_LIBSSH2 + // Performs SSH handshake + bool sshHandshake(); + // Performs SSH authentication using username and password. + bool sshAuthPassword(const std::string& user, const std::string& password); + // Starts sftp session and open remote file |path|. + bool sshSFTPOpen(const std::string& path); + // Closes sftp remote file gracefully + bool sshSFTPClose(); + // Gets total length and modified time for remote file currently + // opened. |path| is used for logging. + bool sshSFTPStat(int64_t& totalLength, time_t& mtime, + const std::string& path); + bool sshGracefulShutdown(); +#endif // HAVE_LIBSSH2 + bool operator==(const SocketCore& s) { return sockfd_ == s.sockfd_; } diff --git a/test/FeatureConfigTest.cc b/test/FeatureConfigTest.cc index 3b8b8460..65de79f9 100644 --- a/test/FeatureConfigTest.cc +++ b/test/FeatureConfigTest.cc @@ -30,6 +30,7 @@ void FeatureConfigTest::testGetDefaultPort() { CPPUNIT_ASSERT_EQUAL((uint16_t)80, getDefaultPort("http")); CPPUNIT_ASSERT_EQUAL((uint16_t)443, getDefaultPort("https")); CPPUNIT_ASSERT_EQUAL((uint16_t)21, getDefaultPort("ftp")); + CPPUNIT_ASSERT_EQUAL((uint16_t)22, getDefaultPort("sftp")); } void FeatureConfigTest::testStrSupportedFeature() { @@ -40,6 +41,13 @@ void FeatureConfigTest::testStrSupportedFeature() { CPPUNIT_ASSERT(!https); #endif // ENABLE_SSL CPPUNIT_ASSERT(!strSupportedFeature(MAX_FEATURE)); + + auto sftp = strSupportedFeature(FEATURE_SFTP); +#ifdef HAVE_LIBSSH2 + CPPUNIT_ASSERT(sftp); +#else // !HAVE_LIBSSH2 + CPPUNIT_ASSERT(!sftp); +#endif // !HAVE_LIBSSH2 } void FeatureConfigTest::testFeatureSummary() { @@ -75,6 +83,9 @@ void FeatureConfigTest::testFeatureSummary() { "XML-RPC", #endif // ENABLE_XML_RPC +#ifdef HAVE_LIBSSH2 + "SFTP", +#endif // HAVE_LIBSSH2 }; std::string featuresString = strjoin(std::begin(features),