diff --git a/configure.ac b/configure.ac index 58b71dad..4461139a 100644 --- a/configure.ac +++ b/configure.ac @@ -58,6 +58,7 @@ ARIA2_ARG_WITHOUT([libcares]) ARIA2_ARG_WITHOUT([libz]) ARIA2_ARG_WITH([tcmalloc]) ARIA2_ARG_WITH([jemalloc]) +ARIA2_ARG_WITHOUT([libssh2]) ARIA2_ARG_DISABLE([ssl]) ARIA2_ARG_DISABLE([bittorrent]) @@ -298,6 +299,20 @@ if test "x$with_sqlite3" = "xyes"; then fi fi +if test "x$with_libssh2" = "xyes"; then + PKG_CHECK_MODULES([LIBSSH2], [libssh2], [have_libssh2=yes], [have_libssh2=no]) + if test "x$have_libssh2" = "xyes"; then + AC_DEFINE([HAVE_LIBSSH2], [1], [Define to 1 if you have libssh2.]) + LIBS="$LIBSSH2_LIBS $LIBS" + CPPFLAGS="$LIBSSH2_CFLAGS $CPPFLAGS" + else + AC_MSG_WARN([$LIBSSH2_PKG_ERRORS]) + if test "x$with_libssh2_requested" = "yes"; then + ARIA2_DEP_NOT_MET([libssh2]) + fi + fi +fi + case "$host" in *darwin*) have_osx="yes" @@ -613,6 +628,9 @@ AM_CONDITIONAL([HAVE_ZLIB], [test "x$have_zlib" = "xyes"]) # Set conditional for sqlite3 AM_CONDITIONAL([HAVE_SQLITE3], [test "x$have_sqlite3" = "xyes"]) +# Set conditional for libssh2 +AM_CONDITIONAL([HAVE_LIBSSH2], [test "x$have_libssh2" = "xyes"]) + AC_SEARCH_LIBS([clock_gettime], [rt]) case "$host" in @@ -1062,6 +1080,7 @@ echo "LibXML2: $have_libxml2" echo "LibExpat: $have_libexpat" echo "LibCares: $have_libcares" echo "Zlib: $have_zlib" +echo "Libssh2: $have_libssh2" echo "Epoll: $have_epoll" echo "Bittorrent: $enable_bittorrent" echo "Metalink: $enable_metalink" diff --git a/src/Makefile.am b/src/Makefile.am index 236a055e..65cf6bb5 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -427,6 +427,10 @@ SRCS += \ Sqlite3CookieParserImpl.cc Sqlite3CookieParserImpl.h endif # HAVE_SQLITE3 +if HAVE_LIBSSH2 +SRCS += SSHSession.cc SSHSession.h +endif # HAVE_LIBSSH2 + if ENABLE_ASYNC_DNS SRCS += \ AsyncNameResolver.cc AsyncNameResolver.h\ diff --git a/src/Platform.cc b/src/Platform.cc index d93c48eb..8520440d 100644 --- a/src/Platform.cc +++ b/src/Platform.cc @@ -56,6 +56,10 @@ # include #endif // ENABLE_ASYNC_DNS +#ifdef HAVE_LIBSSH2 +# include +#endif // HAVE_LIBSSH2 + #include "a2netcompat.h" #include "DlAbortEx.h" #include "message.h" @@ -149,6 +153,15 @@ bool Platform::setUp() } #endif // CARES_HAVE_ARES_LIBRARY_INIT +#ifdef HAVE_LIBSSH2 + { + auto rv = libssh2_init(0); + if (rv != 0) { + throw DL_ABORT_EX(fmt("libssh2_init() failed, code: %d", rv)); + } + } +#endif // HAVE_LIBSSH2 + #ifdef HAVE_WINSOCK2_H WSADATA wsaData; memset(reinterpret_cast(&wsaData), 0, sizeof(wsaData)); @@ -181,6 +194,10 @@ bool Platform::tearDown() ares_library_cleanup(); #endif // CARES_HAVE_ARES_LIBRARY_CLEANUP +#ifdef HAVE_LIBSSH2 + libssh2_exit(); +#endif // HAVE_LIBSSH2 + #ifdef HAVE_WINSOCK2_H WSACleanup(); #endif // HAVE_WINSOCK2_H diff --git a/src/SSHSession.cc b/src/SSHSession.cc new file mode 100644 index 00000000..d320b02b --- /dev/null +++ b/src/SSHSession.cc @@ -0,0 +1,207 @@ +/* */ +#include "SSHSession.h" + +#include + +namespace aria2 { + +SSHSession::SSHSession() + : ssh2_(nullptr), + sftp_(nullptr), + sftph_(nullptr), + fd_(-1) +{} + +SSHSession::~SSHSession() +{ + closeConnection(); +} + +int SSHSession::closeConnection() +{ + if (sftph_) { + // TODO this could return LIBSSH2_ERROR_EAGAIN + libssh2_sftp_close(sftph_); + sftph_ = nullptr; + } + if (sftp_) { + // TODO this could return LIBSSH2_ERROR_EAGAIN + libssh2_sftp_shutdown(sftp_); + sftp_ = nullptr; + } + if (ssh2_) { + // TODO this could return LIBSSH2_ERROR_EAGAIN + libssh2_session_disconnect(ssh2_, "bye"); + libssh2_session_free(ssh2_); + ssh2_ = nullptr; + } + return SSH_ERR_OK; +} + +int SSHSession::gracefulShutdown() +{ + if (sftph_) { + auto rv = libssh2_sftp_close(sftph_); + if (rv == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (rv != 0) { + return SSH_ERR_ERROR; + } + sftph_ = nullptr; + } + if (sftp_) { + auto rv = libssh2_sftp_shutdown(sftp_); + if (rv == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (rv != 0) { + return SSH_ERR_ERROR; + } + sftp_ = nullptr; + } + if (ssh2_) { + auto rv = libssh2_session_disconnect(ssh2_, "bye"); + if (rv == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (rv != 0) { + return SSH_ERR_ERROR; + } + libssh2_session_free(ssh2_); + ssh2_ = nullptr; + } + return SSH_ERR_OK; +} + +int SSHSession::init(sock_t sockfd) +{ + ssh2_ = libssh2_session_init(); + if (!ssh2_) { + return SSH_ERR_ERROR; + } + fd_ = sockfd; + return SSH_ERR_OK; +} + + +int SSHSession::checkDirection() +{ + auto dir = libssh2_session_block_directions(ssh2_); + if (dir & LIBSSH2_SESSION_BLOCK_OUTBOUND) { + return SSH_WANT_WRITE; + } + return SSH_WANT_READ; +} + +ssize_t SSHSession::writeData(const void* data, size_t len) +{ + // net implemented yet + assert(0); +} + +ssize_t SSHSession::readData(void* data, size_t len) +{ + auto nread = libssh2_sftp_read(sftph_, static_cast(data), len); + if (nread == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (nread < 0) { + return SSH_ERR_ERROR; + } + return nread; +} + +int SSHSession::handshake() +{ + auto rv = libssh2_session_handshake(ssh2_, fd_); + if (rv == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (rv != 0) { + return SSH_ERR_ERROR; + } + // TODO we have to validate server's fingerprint + return SSH_ERR_OK; +} + +int SSHSession::authPassword(const std::string& user, + const std::string& password) +{ + auto rv = libssh2_userauth_password(ssh2_, user.c_str(), password.c_str()); + if (rv == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (rv != 0) { + return SSH_ERR_ERROR; + } + return SSH_ERR_OK; +} + +int SSHSession::sftpOpen(const std::string& path) +{ + if (!sftp_) { + sftp_ = libssh2_sftp_init(ssh2_); + if (!sftp_) { + if (libssh2_session_last_errno(ssh2_) == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + return SSH_ERR_ERROR; + } + } + if (!sftph_) { + sftph_ = libssh2_sftp_open(sftp_, path.c_str(), LIBSSH2_FXF_READ, 0); + if (!sftph_) { + if (libssh2_session_last_errno(ssh2_) == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + return SSH_ERR_ERROR; + } + } + return SSH_ERR_OK; +} + +std::string SSHSession::getLastErrorString() +{ + if (!ssh2_) { + return "SSH session has not been initialized yet"; + } + char* errmsg; + libssh2_session_last_error(ssh2_, &errmsg, nullptr, 0); + return errmsg; +} + +} // namespace aria2 diff --git a/src/SSHSession.h b/src/SSHSession.h new file mode 100644 index 00000000..f7cc1cbb --- /dev/null +++ b/src/SSHSession.h @@ -0,0 +1,118 @@ +/* */ +#ifndef SSH_SESSION_H +#define SSH_SESSION_H + +#include "common.h" +#include "a2netcompat.h" + +#include + +#include +#include + +namespace aria2 { + +enum SSHDirection { + SSH_WANT_READ = 1, + SSH_WANT_WRITE +}; + +enum SSHErrorCode { + SSH_ERR_OK = 0, + SSH_ERR_ERROR = -1, + SSH_ERR_WOULDBLOCK = -2 +}; + +class SSHSession { +public: + SSHSession(); + + // MUST deallocate all resources + ~SSHSession(); + + SSHSession(const SSHSession&) = delete; + SSHSession& operator=(const SSHSession&) = delete; + + // Initializes SSH session. The |sockfd| is the underlying + // transport socket. This function returns SSH_ERR_OK if it + // succeeds, or SSH_ERR_ERROR. + int init(sock_t sockfd); + + // Closes the SSH session. Don't close underlying transport + // socket. This function returns SSH_ERR_OK if it succeeds, or + // SSH_ERR_ERROR. + int closeConnection(); + + int gracefulShutdown(); + + // Returns SSH_WANT_READ if SSH session needs more data from remote + // endpoint to proceed, or SSH_WANT_WRITE if SSH session needs to + // write more data to proceed. If SSH session needs neither read nor + // write data at the moment, SSH_WANT_READ must be returned. + int checkDirection(); + + // Sends |data| with length |len|. This function returns the number + // of bytes sent if it succeeds, or SSH_ERR_WOULDBLOCK if the + // underlying transport blocks, or SSH_ERR_ERROR. + ssize_t writeData(const void* data, size_t len); + + // Receives data into |data| with length |len|. This function + // returns the number of bytes received if it succeeds, or + // SSH_ERR_WOULDBLOCK if the underlying transport blocks, or + // SSH_ERR_ERROR. + ssize_t readData(void* data, size_t len); + + // Performs handshake. This function returns SSH_ERR_OK + // if it succeeds, or SSH_ERR_WOULDBLOCK if the underlying transport + // blocks, or SSH_ERR_ERROR. + int handshake(); + + int authPassword(const std::string& user, const std::string& password); + int sftpOpen(const std::string& path); + + // Returns last error string + std::string getLastErrorString(); + +private: + LIBSSH2_SESSION* ssh2_; + LIBSSH2_SFTP* sftp_; + LIBSSH2_SFTP_HANDLE* sftph_; + sock_t fd_; +}; + +} + +#endif // SSH_SESSION_H diff --git a/src/SocketCore.cc b/src/SocketCore.cc index b42c7ecd..5469b8bf 100644 --- a/src/SocketCore.cc +++ b/src/SocketCore.cc @@ -45,6 +45,7 @@ #include #include +#include #include #include "message.h" @@ -60,6 +61,9 @@ # include "TLSContext.h" # include "TLSSession.h" #endif // ENABLE_SSL +#ifdef HAVE_LIBSSH2 +# include "SSHSession.h" +#endif // HAVE_LIBSSH2 namespace aria2 { @@ -608,6 +612,14 @@ void SocketCore::closeConnection() tlsSession_.reset(); } #endif // ENABLE_SSL + +#ifdef HAVE_LIBSSH2 + if(sshSession_) { + sshSession_->closeConnection(); + sshSession_.reset(); + } +#endif // HAVE_LIBSSH2 + if(sockfd_ != (sock_t) -1) { shutdown(sockfd_, SHUT_WR); CLOSE(sockfd_); @@ -796,7 +808,23 @@ void SocketCore::readData(void* data, size_t& len) wantRead_ = false; wantWrite_ = false; - if(!secure_) { + if(sshSession_) { +#ifdef HAVE_LIBSSH2 + ret = sshSession_->readData(data, len); + if(ret < 0) { + if(ret != SSH_ERR_WOULDBLOCK) { + throw DL_RETRY_EX(fmt(EX_SOCKET_RECV, + sshSession_->getLastErrorString().c_str())); + } + if(sshSession_->checkDirection() == SSH_WANT_READ) { + wantRead_ = true; + } else { + wantWrite_ = true; + } + ret = 0; + } +#endif // HAVE_LIBSSH2 + } else if(!secure_) { // Cast for Windows recv() while((ret = recv(sockfd_, reinterpret_cast(data), len, 0)) == -1 && SOCKET_ERRNO == A2_EINTR); @@ -957,6 +985,97 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname) #endif // ENABLE_SSL +#ifdef HAVE_LIBSSH2 + +bool SocketCore::sshHandshake() +{ + wantRead_ = false; + wantWrite_ = false; + + if (!sshSession_) { + sshSession_ = make_unique(); + if (sshSession_->init(sockfd_) == SSH_ERR_ERROR) { + throw DL_ABORT_EX("Could not create SSH session"); + } + } + auto rv = sshSession_->handshake(); + if (rv == SSH_ERR_WOULDBLOCK) { + sshCheckDirection(); + return false; + } + if (rv == SSH_ERR_ERROR) { + throw DL_ABORT_EX(fmt("SSH handshake failure: %s", + sshSession_->getLastErrorString().c_str())); + } + return true; +} + +bool SocketCore::sshAuthPassword(const std::string& user, + const std::string& password) +{ + assert(sshSession_); + + wantRead_ = false; + wantWrite_ = false; + + auto rv = sshSession_->authPassword(user, password); + if (rv == SSH_ERR_WOULDBLOCK) { + sshCheckDirection(); + return false; + } + if (rv == SSH_ERR_ERROR) { + throw DL_ABORT_EX(fmt("SSH authentication failure: %s", + sshSession_->getLastErrorString().c_str())); + } + return true; +} + +bool SocketCore::sshSFTPOpen(const std::string& path) +{ + assert(sshSession_); + + wantRead_ = false; + wantWrite_ = false; + + auto rv = sshSession_->sftpOpen(path); + if (rv == SSH_ERR_WOULDBLOCK) { + sshCheckDirection(); + return false; + } + if (rv == SSH_ERR_ERROR) { + throw DL_ABORT_EX(fmt("SSH opening SFTP path %s failed: %s", + path.c_str(), + sshSession_->getLastErrorString().c_str())); + } + return true; +} + +bool SocketCore::sshGracefulShutdown() +{ + assert(sshSession_); + auto rv = sshSession_->gracefulShutdown(); + if (rv == SSH_ERR_WOULDBLOCK) { + sshCheckDirection(); + return false; + } + if (rv == SSH_ERR_ERROR) { + throw DL_ABORT_EX(fmt("SSH graceful shutdown failed: %s", + sshSession_->getLastErrorString().c_str())); + } + return true; +} + +void SocketCore::sshCheckDirection() +{ + if (sshSession_->checkDirection() == SSH_WANT_READ) { + wantRead_ = true; + } else { + wantWrite_ = true; + } +} + +#endif // HAVE_LIBSSH2 + ssize_t SocketCore::writeData(const void* data, size_t len, const std::string& host, uint16_t port) { diff --git a/src/SocketCore.h b/src/SocketCore.h index 35f0f83f..2cebae41 100644 --- a/src/SocketCore.h +++ b/src/SocketCore.h @@ -55,6 +55,10 @@ class TLSContext; class TLSSession; #endif // ENABLE_SSL +#ifdef HAVE_LIBSSH2 +class SSHSession; +#endif // HAVE_LIBSSH2 + class SocketCore { friend bool operator==(const SocketCore& s1, const SocketCore& s2); friend bool operator!=(const SocketCore& s1, const SocketCore& s2); @@ -95,6 +99,12 @@ private: bool tlsHandshake(TLSContext* tlsctx, const std::string& hostname); #endif // ENABLE_SSL +#ifdef HAVE_LIBSSH2 + std::unique_ptr sshSession_; + + void sshCheckDirection(); +#endif // HAVE_LIBSSH2 + void init(); void bind(const struct sockaddr* addr, socklen_t addrlen); @@ -290,6 +300,13 @@ public: bool tlsConnect(const std::string& hostname); #endif // ENABLE_SSL +#ifdef HAVE_LIBSSH2 + bool sshHandshake(); + bool sshAuthPassword(const std::string& user, const std::string& password); + bool sshSFTPOpen(const std::string& path); + bool sshGracefulShutdown(); +#endif // HAVE_LIBSSH2 + bool operator==(const SocketCore& s) { return sockfd_ == s.sockfd_; }