diff --git a/src/AbstractCommand.cc b/src/AbstractCommand.cc index 2df0141b..ba1297bc 100644 --- a/src/AbstractCommand.cc +++ b/src/AbstractCommand.cc @@ -677,7 +677,7 @@ std::string getProxyUri(const std::string& protocol, const Option* option) option); } - if (protocol == "ftp") { + if (protocol == "ftp" || protocol == "sftp") { return getProxyOptionFor (PREF_FTP_PROXY, PREF_FTP_PROXY_USER, PREF_FTP_PROXY_PASSWD, option); } @@ -883,7 +883,8 @@ bool AbstractCommand::checkIfConnectionEstablished const std::string& AbstractCommand::resolveProxyMethod(const std::string& protocol) const { - if (getOption()->get(PREF_PROXY_METHOD) == V_TUNNEL || protocol == "https") { + if (getOption()->get(PREF_PROXY_METHOD) == V_TUNNEL || protocol == "https" || + protocol == "sftp") { return V_TUNNEL; } return V_GET; diff --git a/src/AuthConfigFactory.cc b/src/AuthConfigFactory.cc index e8d55b04..2f3c0573 100644 --- a/src/AuthConfigFactory.cc +++ b/src/AuthConfigFactory.cc @@ -87,7 +87,8 @@ AuthConfigFactory::createAuthConfig createHttpAuthResolver(op)->resolveAuthConfig(request->getHost()); } } - } else if(request->getProtocol() == "ftp") { + } else if(request->getProtocol() == "ftp" || + request->getProtocol() == "sftp") { if(!request->getUsername().empty()) { if(request->hasPassword()) { return AuthConfig::create(request->getUsername(), diff --git a/src/DownloadCommand.cc b/src/DownloadCommand.cc index 5be8d551..20337611 100644 --- a/src/DownloadCommand.cc +++ b/src/DownloadCommand.cc @@ -134,6 +134,7 @@ bool DownloadCommand::executeInternal() { || getRequestGroup()->doesDownloadSpeedExceed()) { addCommandSelf(); disableReadCheckSocket(); + disableWriteCheckSocket(); return false; } setReadCheckSocket(getSocket()); @@ -195,7 +196,8 @@ bool DownloadCommand::executeInternal() { // Note that GrowSegment::complete() always returns false. if(sinkFilterOnly_) { if(segment->complete() || - segment->getPositionToWrite() == getFileEntry()->getLastOffset()) { + (getFileEntry()->getLength() != 0 && + segment->getPositionToWrite() == getFileEntry()->getLastOffset())) { segmentPartComplete = true; } else if(segment->getLength() == 0 && eof) { segmentPartComplete = true; @@ -275,13 +277,17 @@ bool DownloadCommand::executeInternal() { return prepareForNextSegment(); } else { checkLowestDownloadSpeed(); - setWriteCheckSocketIf(getSocket(), getSocket()->wantWrite()); + setWriteCheckSocketIf(getSocket(), shouldEnableWriteCheck()); checkSocketRecvBuffer(); addCommandSelf(); return false; } } +bool DownloadCommand::shouldEnableWriteCheck() { + return getSocket()->wantWrite(); +} + void DownloadCommand::checkLowestDownloadSpeed() const { if(lowestDownloadSpeedLimit_ > 0 && diff --git a/src/DownloadCommand.h b/src/DownloadCommand.h index 3efaf5a0..485cdc74 100644 --- a/src/DownloadCommand.h +++ b/src/DownloadCommand.h @@ -76,6 +76,12 @@ protected: // This is file local offset virtual int64_t getRequestEndOffset() const = 0; + + // Returns true if socket should be monitored for writing. The + // default implementation is return the return value of + // getSocket()->wantWrite(). + virtual bool shouldEnableWriteCheck(); + public: DownloadCommand(cuid_t cuid, const std::shared_ptr& req, diff --git a/src/FeatureConfig.cc b/src/FeatureConfig.cc index 1d2d30bb..aa8cd8f2 100644 --- a/src/FeatureConfig.cc +++ b/src/FeatureConfig.cc @@ -80,6 +80,8 @@ uint16_t getDefaultPort(const std::string& protocol) return 443; } else if(protocol == "ftp") { return 21; + } else if(protocol == "sftp") { + return 22; } else { return 0; } diff --git a/src/FtpInitiateConnectionCommand.cc b/src/FtpInitiateConnectionCommand.cc index 52fc07a4..3950fed0 100644 --- a/src/FtpInitiateConnectionCommand.cc +++ b/src/FtpInitiateConnectionCommand.cc @@ -60,6 +60,10 @@ #include "FtpNegotiationConnectChain.h" #include "FtpTunnelRequestConnectChain.h" #include "HttpRequestConnectChain.h" +#ifdef HAVE_LIBSSH2 +# include "SftpNegotiationConnectChain.h" +# include "SftpNegotiationCommand.h" +#endif // HAVE_LIBSSH2 namespace aria2 { @@ -82,6 +86,7 @@ std::unique_ptr FtpInitiateConnectionCommand::createNextCommandProxied std::shared_ptr pooledSocket; std::string proxyMethod = resolveProxyMethod(getRequest()->getProtocol()); + // sftp always use tunnel mode if(proxyMethod == V_GET) { pooledSocket = getDownloadEngine()->popPooledSocket (getRequest()->getHost(), getRequest()->getPort(), @@ -180,30 +185,45 @@ std::unique_ptr FtpInitiateConnectionCommand::createNextCommandPlain getSocket()->establishConnection(addr, port); getRequest()->setConnectedAddrInfo(hostname, addr, port); auto c = make_unique(getCuid(), - getRequest(), - nullptr, - getFileEntry(), - getRequestGroup(), - getDownloadEngine(), - getSocket()); + getRequest(), + nullptr, + getFileEntry(), + getRequestGroup(), + getDownloadEngine(), + getSocket()); - c->setControlChain(std::make_shared()); + if(getRequest()->getProtocol() == "sftp") { + c->setControlChain(std::make_shared()); + } else { + c->setControlChain(std::make_shared()); + } setupBackupConnection(hostname, addr, port, c.get()); return std::move(c); } - // options contains "baseWorkingDir" - auto command = make_unique - (getCuid(), - getRequest(), - getFileEntry(), - getRequestGroup(), - getDownloadEngine(), - pooledSocket, - FtpNegotiationCommand::SEQ_SEND_CWD_PREP, - options); setConnectedAddrInfo(getRequest(), hostname, pooledSocket); - return std::move(command); + + if (getRequest()->getProtocol() == "sftp") { + return make_unique + (getCuid(), + getRequest(), + getFileEntry(), + getRequestGroup(), + getDownloadEngine(), + pooledSocket, + SftpNegotiationCommand::SEQ_SFTP_OPEN); + } + + // options contains "baseWorkingDir" + return make_unique + (getCuid(), + getRequest(), + getFileEntry(), + getRequestGroup(), + getDownloadEngine(), + pooledSocket, + FtpNegotiationCommand::SEQ_SEND_CWD_PREP, + options); } std::unique_ptr FtpInitiateConnectionCommand::createNextCommand diff --git a/src/InitiateConnectionCommandFactory.cc b/src/InitiateConnectionCommandFactory.cc index 2856b244..b4615d7a 100644 --- a/src/InitiateConnectionCommandFactory.cc +++ b/src/InitiateConnectionCommandFactory.cc @@ -71,10 +71,14 @@ InitiateConnectionCommandFactory::createInitiateConnectionCommand return make_unique(cuid, req, fileEntry, requestGroup, e); - } else if(req->getProtocol() == "ftp") { + } else if(req->getProtocol() == "ftp" +#ifdef HAVE_LIBSSH2 + || req->getProtocol() == "sftp" +#endif // HAVE_LIBSSH2 + ) { if(req->getFile().empty()) { throw DL_ABORT_EX - (fmt("FTP URI %s doesn't contain file path.", + (fmt("FTP/SFTP URI %s doesn't contain file path.", req->getUri().c_str())); } return make_unique(cuid, req, fileEntry, diff --git a/src/Makefile.am b/src/Makefile.am index 65cf6bb5..bfa99857 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -428,7 +428,11 @@ SRCS += \ endif # HAVE_SQLITE3 if HAVE_LIBSSH2 -SRCS += SSHSession.cc SSHSession.h +SRCS += SSHSession.cc SSHSession.h \ + SftpNegotiationCommand.cc SftpNegotiationCommand.h \ + SftpNegotiationConnectChain.h \ + SftpDownloadCommand.cc SftpDownloadCommand.h \ + SftpFinishDownloadCommand.cc SftpFinishDownloadCommand.h endif # HAVE_LIBSSH2 if ENABLE_ASYNC_DNS diff --git a/src/MetalinkParserController.cc b/src/MetalinkParserController.cc index 91ce7053..fec63aa0 100644 --- a/src/MetalinkParserController.cc +++ b/src/MetalinkParserController.cc @@ -193,7 +193,7 @@ void MetalinkParserController::setTypeOfResource(std::string type) if(!tResource_) { return; } - if(type == "ftp") { + if(type == "ftp" || type == "sftp") { tResource_->type = MetalinkResource::TYPE_FTP; } else if(type == "http") { tResource_->type = MetalinkResource::TYPE_HTTP; diff --git a/src/SSHSession.cc b/src/SSHSession.cc index d320b02b..1b939de6 100644 --- a/src/SSHSession.cc +++ b/src/SSHSession.cc @@ -107,12 +107,30 @@ int SSHSession::gracefulShutdown() return SSH_ERR_OK; } +int SSHSession::sftpClose() +{ + if (!sftph_) { + return SSH_ERR_OK; + } + + auto rv = libssh2_sftp_close(sftph_); + if (rv == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (rv != 0) { + return SSH_ERR_ERROR; + } + sftph_ = nullptr; + return SSH_ERR_OK; +} + int SSHSession::init(sock_t sockfd) { ssh2_ = libssh2_session_init(); if (!ssh2_) { return SSH_ERR_ERROR; } + libssh2_session_set_blocking(ssh2_, 0); fd_ = sockfd; return SSH_ERR_OK; } @@ -194,6 +212,21 @@ int SSHSession::sftpOpen(const std::string& path) return SSH_ERR_OK; } +int SSHSession::sftpStat(int64_t& totalLength, time_t& mtime) +{ + LIBSSH2_SFTP_ATTRIBUTES attrs; + auto rv = libssh2_sftp_fstat_ex(sftph_, &attrs, 0); + if (rv == LIBSSH2_ERROR_EAGAIN) { + return SSH_ERR_WOULDBLOCK; + } + if (rv != 0) { + return SSH_ERR_ERROR; + } + totalLength = attrs.filesize; + mtime = attrs.mtime; + return SSH_ERR_OK; +} + std::string SSHSession::getLastErrorString() { if (!ssh2_) { diff --git a/src/SSHSession.h b/src/SSHSession.h index f7cc1cbb..31941e5d 100644 --- a/src/SSHSession.h +++ b/src/SSHSession.h @@ -100,9 +100,28 @@ public: // blocks, or SSH_ERR_ERROR. int handshake(); + // Performs authentication using username and password. This + // function returns SSH_ERR_OK if it succeeds, or SSH_ERR_WOULDBLOCK + // if the underlying transport blocks, or SSH_ERR_ERROR. int authPassword(const std::string& user, const std::string& password); + + // Starts SFTP session and opens remote file |path|. This function + // returns SSH_ERR_OK if it succeeds, or SSH_ERR_WOULDBLOCK if the + // underlying transport blocks, or SSH_ERR_ERROR. int sftpOpen(const std::string& path); + // Closes remote file opened by sftpOpen(). This function returns + // SSH_ERR_OK if it succeeds, or SSH_ERR_WOULDBLOCK if the + // underlying transport blocks, or SSH_ERR_ERROR. + int sftpClose(); + + // Gets total length and modified time of opened file by sftpOpen(). + // On success, total length and modified time are assigned to + // |totalLength| and |mtime|. This function returns SSH_ERR_OK if + // it succeeds, or SSH_ERR_WOULDBLOCK if the underlying transport + // blocks, or SSH_ERR_ERROR. + int sftpStat(int64_t& totalLength, time_t& mtime); + // Returns last error string std::string getLastErrorString(); diff --git a/src/SftpDownloadCommand.cc b/src/SftpDownloadCommand.cc new file mode 100644 index 00000000..84bfa913 --- /dev/null +++ b/src/SftpDownloadCommand.cc @@ -0,0 +1,99 @@ +/* */ +#include "SftpDownloadCommand.h" +#include "Request.h" +#include "SocketCore.h" +#include "Segment.h" +#include "DownloadEngine.h" +#include "RequestGroup.h" +#include "Option.h" +#include "FileEntry.h" +#include "SocketRecvBuffer.h" +#include "AuthConfig.h" +#include "SftpFinishDownloadCommand.h" + +namespace aria2 { + +SftpDownloadCommand::SftpDownloadCommand +(cuid_t cuid, + const std::shared_ptr& req, + const std::shared_ptr& fileEntry, + RequestGroup* requestGroup, + DownloadEngine* e, + const std::shared_ptr& socket, + std::unique_ptr authConfig) + : DownloadCommand(cuid, req, fileEntry, requestGroup, e, socket, + std::make_shared(socket)), + authConfig_(std::move(authConfig)) +{ + setWriteCheckSocket(getSocket()); +} + +SftpDownloadCommand::~SftpDownloadCommand() {} + +bool SftpDownloadCommand::prepareForNextSegment() +{ + if(getOption()->getAsBool(PREF_FTP_REUSE_CONNECTION) && + getFileEntry()->gtoloff(getSegments().front()->getPositionToWrite()) == + getFileEntry()->getLength()) { + + auto c = make_unique + (getCuid(), getRequest(), getFileEntry(), getRequestGroup(), + getDownloadEngine(), getSocket()); + + c->setStatus(Command::STATUS_ONESHOT_REALTIME); + getDownloadEngine()->setNoWait(true); + getDownloadEngine()->addCommand(std::move(c)); + + if(getRequestGroup()->downloadFinished()) { + // To run checksum checking, we had to call following function here. + DownloadCommand::prepareForNextSegment(); + } + return true; + } + + return DownloadCommand::prepareForNextSegment(); +} + +int64_t SftpDownloadCommand::getRequestEndOffset() const +{ + return getFileEntry()->getLength(); +} + +bool SftpDownloadCommand::shouldEnableWriteCheck() { + return getSocket()->wantWrite() || !getSocket()->wantRead(); +} + +} // namespace aria2 diff --git a/src/SftpDownloadCommand.h b/src/SftpDownloadCommand.h new file mode 100644 index 00000000..cbb6372f --- /dev/null +++ b/src/SftpDownloadCommand.h @@ -0,0 +1,66 @@ +/* */ +#ifndef D_SFTP_DOWNLOAD_COMMAND_H +#define D_SFTP_DOWNLOAD_COMMAND_H + +#include "DownloadCommand.h" + +namespace aria2 { + +class AuthConfig; + +class SftpDownloadCommand : public DownloadCommand { +private: + std::unique_ptr authConfig_; + +protected: + virtual bool prepareForNextSegment() CXX11_OVERRIDE; + virtual int64_t getRequestEndOffset() const CXX11_OVERRIDE; + virtual bool shouldEnableWriteCheck() CXX11_OVERRIDE; + +public: + SftpDownloadCommand(cuid_t cuid, + const std::shared_ptr& req, + const std::shared_ptr& fileEntry, + RequestGroup* requestGroup, + DownloadEngine* e, + const std::shared_ptr& socket, + std::unique_ptr authConfig); + virtual ~SftpDownloadCommand(); +}; + +} // namespace aria2 + +#endif // D_SFTP_DOWNLOAD_COMMAND_H diff --git a/src/SftpFinishDownloadCommand.cc b/src/SftpFinishDownloadCommand.cc new file mode 100644 index 00000000..399f2755 --- /dev/null +++ b/src/SftpFinishDownloadCommand.cc @@ -0,0 +1,119 @@ +/* */ +#include "SftpFinishDownloadCommand.h" + +#include "Request.h" +#include "DownloadEngine.h" +#include "prefs.h" +#include "Option.h" +#include "message.h" +#include "fmt.h" +#include "DlAbortEx.h" +#include "SocketCore.h" +#include "RequestGroup.h" +#include "Logger.h" +#include "LogFactory.h" +#include "wallclock.h" +#include "AuthConfigFactory.h" +#include "AuthConfig.h" + +namespace aria2 { + +SftpFinishDownloadCommand::SftpFinishDownloadCommand +(cuid_t cuid, + const std::shared_ptr& req, + const std::shared_ptr& fileEntry, + RequestGroup* requestGroup, + DownloadEngine* e, + const std::shared_ptr& socket) + : AbstractCommand(cuid, req, fileEntry, requestGroup, e, socket) +{ + disableReadCheckSocket(); + setWriteCheckSocket(getSocket()); +} + +SftpFinishDownloadCommand::~SftpFinishDownloadCommand() {} + +// overrides AbstractCommand::execute(). +// AbstractCommand::_segments is empty. +bool SftpFinishDownloadCommand::execute() +{ + if(getRequestGroup()->isHaltRequested()) { + return true; + } + try { + if(readEventEnabled() || writeEventEnabled() || hupEventEnabled()) { + getCheckPoint() = global::wallclock(); + + if (!getSocket()->sshSFTPClose()) { + setWriteCheckSocketIf(getSocket(), getSocket()->wantWrite()); + setReadCheckSocketIf(getSocket(), getSocket()->wantRead()); + addCommandSelf(); + return false; + } + + auto authConfig = + getDownloadEngine()->getAuthConfigFactory()->createAuthConfig + (getRequest(), getRequestGroup()->getOption().get()); + + getDownloadEngine()->poolSocket + (getRequest(), authConfig->getUser(), createProxyRequest(), + getSocket(), ""); + } else if(getCheckPoint().difference(global::wallclock()) >= getTimeout()) { + A2_LOG_INFO(fmt("CUID#%" PRId64 + " - Timeout before receiving transfer complete.", + getCuid())); + } else { + addCommandSelf(); + return false; + } + } catch(RecoverableException& e) { + A2_LOG_INFO_EX(fmt("CUID#%" PRId64 + " - Exception was thrown, but download was" + " finished, so we can ignore the exception.", + getCuid()), + e); + } + if(getRequestGroup()->downloadFinished()) { + return true; + } else { + return prepareForRetry(0); + } +} + +// This function never be called. +bool SftpFinishDownloadCommand::executeInternal() { return true; } + +} // namespace aria2 diff --git a/src/SftpFinishDownloadCommand.h b/src/SftpFinishDownloadCommand.h new file mode 100644 index 00000000..fbdd86f4 --- /dev/null +++ b/src/SftpFinishDownloadCommand.h @@ -0,0 +1,59 @@ +/* */ +#ifndef D_SFTP_FINISH_DOWNLOAD_COMMAND_H +#define D_SFTP_FINISH_DOWNLOAD_COMMAND_H + +#include "AbstractCommand.h" + +namespace aria2 { + +class SftpFinishDownloadCommand : public AbstractCommand { +protected: + virtual bool execute() CXX11_OVERRIDE; + + virtual bool executeInternal() CXX11_OVERRIDE; +public: + SftpFinishDownloadCommand(cuid_t cuid, + const std::shared_ptr& req, + const std::shared_ptr& fileEntry, + RequestGroup* requestGroup, + DownloadEngine* e, + const std::shared_ptr& socket); + virtual ~SftpFinishDownloadCommand(); +}; + +} // namespace aria2 + +#endif // D_SFTP_FINISH_DOWNLOAD_COMMAND_H diff --git a/src/SftpNegotiationCommand.cc b/src/SftpNegotiationCommand.cc new file mode 100644 index 00000000..5a477025 --- /dev/null +++ b/src/SftpNegotiationCommand.cc @@ -0,0 +1,320 @@ +/* */ +#include "SftpNegotiationCommand.h" + +#include +#include + +#include "Request.h" +#include "DownloadEngine.h" +#include "RequestGroup.h" +#include "PieceStorage.h" +#include "FileEntry.h" +#include "message.h" +#include "util.h" +#include "Option.h" +#include "Logger.h" +#include "LogFactory.h" +#include "Segment.h" +#include "DownloadContext.h" +#include "DefaultBtProgressInfoFile.h" +#include "RequestGroupMan.h" +#include "SocketCore.h" +#include "fmt.h" +#include "DiskAdaptor.h" +#include "SegmentMan.h" +#include "AuthConfigFactory.h" +#include "AuthConfig.h" +#include "a2functional.h" +#include "URISelector.h" +#include "CheckIntegrityEntry.h" +#include "NullProgressInfoFile.h" +#include "ChecksumCheckIntegrityEntry.h" +#include "SftpDownloadCommand.h" + +namespace aria2 { + +SftpNegotiationCommand::SftpNegotiationCommand +(cuid_t cuid, + const std::shared_ptr& req, + const std::shared_ptr& fileEntry, + RequestGroup* requestGroup, + DownloadEngine* e, + const std::shared_ptr& socket, + Seq seq) + : AbstractCommand(cuid, req, fileEntry, requestGroup, e, socket), + sequence_(seq), + authConfig_(e->getAuthConfigFactory()->createAuthConfig + (req, requestGroup->getOption().get())) + +{ + path_ = getPath(); + disableReadCheckSocket(); + setWriteCheckSocket(getSocket()); +} + +SftpNegotiationCommand::~SftpNegotiationCommand() {} + +bool SftpNegotiationCommand::executeInternal() { + disableWriteCheckSocket(); + for (;;) { + switch(sequence_) { + case SEQ_HANDSHAKE: + setReadCheckSocket(getSocket()); + if (!getSocket()->sshHandshake()) { + goto again; + } + A2_LOG_DEBUG(fmt("CUID#%" PRId64 " - SSH handshake success", getCuid())); + sequence_ = SEQ_AUTH_PASSWORD; + break; + case SEQ_AUTH_PASSWORD: + if (!getSocket()->sshAuthPassword(authConfig_->getUser(), + authConfig_->getPassword())) { + goto again; + } + A2_LOG_DEBUG(fmt("CUID#%" PRId64 " - SSH authentication success", + getCuid())); + sequence_ = SEQ_SFTP_OPEN; + break; + case SEQ_SFTP_OPEN: { + if (!getSocket()->sshSFTPOpen(path_)) { + goto again; + } + A2_LOG_DEBUG(fmt("CUID#%" PRId64 " - SFTP file %s opened", getCuid(), + path_.c_str())); + sequence_ = SEQ_SFTP_STAT; + break; + } + case SEQ_SFTP_STAT: { + int64_t totalLength; + time_t mtime; + if (!getSocket()->sshSFTPStat(totalLength, mtime, path_)) { + goto again; + } + Time t(mtime); + A2_LOG_INFO(fmt("CUID#%" PRId64 " - SFTP File %s, size=%" PRId64 + ", mtime=%s", + getCuid(), path_.c_str(), totalLength, + t.toHTTPDate().c_str())); + if (!getPieceStorage()) { + getRequestGroup()->updateLastModifiedTime(Time(mtime)); + onFileSizeDetermined(totalLength); + } else { + getRequestGroup()->validateTotalLength(getFileEntry()->getLength(), + totalLength); + sequence_ = SEQ_NEGOTIATION_COMPLETED; + } + break; + } + case SEQ_FILE_PREPARATION: + sequence_ = SEQ_NEGOTIATION_COMPLETED; + disableReadCheckSocket(); + disableWriteCheckSocket(); + return false; + case SEQ_NEGOTIATION_COMPLETED: { + auto command = make_unique + (getCuid(), getRequest(), getFileEntry(), getRequestGroup(), + getDownloadEngine(), getSocket(), std::move(authConfig_)); + command->setStartupIdleTime + (getOption()->getAsInt(PREF_STARTUP_IDLE_TIME)); + command->setLowestDownloadSpeedLimit + (getOption()->getAsInt(PREF_LOWEST_SPEED_LIMIT)); + command->setStatus(Command::STATUS_ONESHOT_REALTIME); + + getDownloadEngine()->setNoWait(true); + + if(getFileEntry()->isUniqueProtocol()) { + getFileEntry()->removeURIWhoseHostnameIs(getRequest()->getHost()); + } + getRequestGroup()->getURISelector()->tuneDownloadCommand + (getFileEntry()->getRemainingUris(), command.get()); + getDownloadEngine()->addCommand(std::move(command)); + return true; + } + case SEQ_DOWNLOAD_ALREADY_COMPLETED: + case SEQ_HEAD_OK: + case SEQ_EXIT: + return true; + }; + } + again: + addCommandSelf(); + if (getSocket()->wantWrite()) { + setWriteCheckSocket(getSocket()); + } + return false; +} + +void SftpNegotiationCommand::onFileSizeDetermined(int64_t totalLength) +{ + getFileEntry()->setLength(totalLength); + if(getFileEntry()->getPath().empty()) { + auto suffixPath = util::createSafePath + (util::percentDecode(std::begin(getRequest()->getFile()), + std::end(getRequest()->getFile()))); + + getFileEntry()->setPath + (util::applyDir(getOption()->get(PREF_DIR), suffixPath)); + getFileEntry()->setSuffixPath(suffixPath); + } + getRequestGroup()->preDownloadProcessing(); + + if(totalLength == 0) { + sequence_ = SEQ_NEGOTIATION_COMPLETED; + + if(getOption()->getAsBool(PREF_DRY_RUN)) { + getRequestGroup()->initPieceStorage(); + onDryRunFileFound(); + return; + } + + if(getDownloadContext()->knowsTotalLength() && + getRequestGroup()->downloadFinishedByFileLength()) { + // TODO Known issue: if .aria2 file exists, it will not be + // deleted on successful verification, because .aria2 file is + // not loaded. See also + // HttpResponseCommand::handleOtherEncoding() + getRequestGroup()->initPieceStorage(); + if(getDownloadContext()->isChecksumVerificationNeeded()) { + A2_LOG_DEBUG("Zero length file exists. Verify checksum."); + auto entry = make_unique + (getRequestGroup()); + entry->initValidator(); + getPieceStorage()->getDiskAdaptor()->openExistingFile(); + getDownloadEngine()->getCheckIntegrityMan()->pushEntry + (std::move(entry)); + sequence_ = SEQ_EXIT; + } + else { + getPieceStorage()->markAllPiecesDone(); + getDownloadContext()->setChecksumVerified(true); + sequence_ = SEQ_DOWNLOAD_ALREADY_COMPLETED; + A2_LOG_NOTICE + (fmt(MSG_DOWNLOAD_ALREADY_COMPLETED, + GroupId::toHex(getRequestGroup()->getGID()).c_str(), + getRequestGroup()->getFirstFilePath().c_str())); + } + poolConnection(); + return; + } + + getRequestGroup()->adjustFilename + (std::make_shared()); + getRequestGroup()->initPieceStorage(); + getPieceStorage()->getDiskAdaptor()->initAndOpenFile(); + + if(getDownloadContext()->knowsTotalLength()) { + A2_LOG_DEBUG("File length becomes zero and it means download completed."); + // TODO Known issue: if .aria2 file exists, it will not be + // deleted on successful verification, because .aria2 file is + // not loaded. See also + // HttpResponseCommand::handleOtherEncoding() + if(getDownloadContext()->isChecksumVerificationNeeded()) { + A2_LOG_DEBUG("Verify checksum for zero-length file"); + auto entry = make_unique + (getRequestGroup()); + entry->initValidator(); + getDownloadEngine()->getCheckIntegrityMan()->pushEntry + (std::move(entry)); + sequence_ = SEQ_EXIT; + } else + { + sequence_ = SEQ_DOWNLOAD_ALREADY_COMPLETED; + getPieceStorage()->markAllPiecesDone(); + } + poolConnection(); + return; + } + // We have to make sure that command that has Request object must + // have segment after PieceStorage is initialized. See + // AbstractCommand::execute() + getSegmentMan()->getSegmentWithIndex(getCuid(), 0); + return; + } else { + auto progressInfoFile = std::make_shared + (getDownloadContext(), nullptr, getOption().get()); + getRequestGroup()->adjustFilename(progressInfoFile); + getRequestGroup()->initPieceStorage(); + + if(getOption()->getAsBool(PREF_DRY_RUN)) { + onDryRunFileFound(); + return; + } + + auto checkIntegrityEntry = getRequestGroup()->createCheckIntegrityEntry(); + if(!checkIntegrityEntry) { + sequence_ = SEQ_DOWNLOAD_ALREADY_COMPLETED; + poolConnection(); + return; + } + checkIntegrityEntry->pushNextCommand(std::unique_ptr(this)); + // We have to make sure that command that has Request object must + // have segment after PieceStorage is initialized. See + // AbstractCommand::execute() + getSegmentMan()->getSegmentWithIndex(getCuid(), 0); + + prepareForNextAction(std::move(checkIntegrityEntry)); + + disableReadCheckSocket(); + + sequence_ = SEQ_FILE_PREPARATION; + } +} + +void SftpNegotiationCommand::poolConnection() const +{ + if(getOption()->getAsBool(PREF_FTP_REUSE_CONNECTION)) { + // TODO we don't need options. Probably, we need to pool socket + // using scheme, port and auth info as key + getDownloadEngine()->poolSocket(getRequest(), authConfig_->getUser(), + createProxyRequest(), getSocket(), ""); + } +} + +void SftpNegotiationCommand::onDryRunFileFound() +{ + getPieceStorage()->markAllPiecesDone(); + getDownloadContext()->setChecksumVerified(true); + poolConnection(); + sequence_ = SEQ_HEAD_OK; +} + +std::string SftpNegotiationCommand::getPath() const { + auto &req = getRequest(); + auto path = req->getDir() + req->getFile(); + return util::percentDecode(std::begin(path), std::end(path)); +} + +} // namespace aria2 diff --git a/src/SftpNegotiationCommand.h b/src/SftpNegotiationCommand.h new file mode 100644 index 00000000..e6b77ab6 --- /dev/null +++ b/src/SftpNegotiationCommand.h @@ -0,0 +1,87 @@ +/* */ +#ifndef D_SFTP_NEGOTIATION_COMMAND_H +#define D_SFTP_NEGOTIATION_COMMAND_H + +#include "AbstractCommand.h" + +namespace aria2 { + +class SocketCore; +class AuthConfig; + +class SftpNegotiationCommand : public AbstractCommand { +public: + enum Seq { + SEQ_HANDSHAKE, + SEQ_AUTH_PASSWORD, + SEQ_SFTP_OPEN, + SEQ_SFTP_STAT, + SEQ_NEGOTIATION_COMPLETED, + SEQ_DOWNLOAD_ALREADY_COMPLETED, + SEQ_HEAD_OK, + SEQ_FILE_PREPARATION, + SEQ_EXIT, + }; + +private: + void onFileSizeDetermined(int64_t totalLength); + void poolConnection() const; + void onDryRunFileFound(); + std::string getPath() const; + + std::shared_ptr socket_; + Seq sequence_; + std::unique_ptr authConfig_; + // remote file path + std::string path_; + +protected: + virtual bool executeInternal() CXX11_OVERRIDE; + +public: + SftpNegotiationCommand(cuid_t cuid, + const std::shared_ptr& req, + const std::shared_ptr& fileEntry, + RequestGroup* requestGroup, + DownloadEngine* e, + const std::shared_ptr& s, + Seq seq = SEQ_HANDSHAKE); + virtual ~SftpNegotiationCommand(); +}; + +} // namespace aria2 + +#endif // D_SFTP_NEGOTIATION_COMMAND_H diff --git a/src/SftpNegotiationConnectChain.h b/src/SftpNegotiationConnectChain.h new file mode 100644 index 00000000..e78081ce --- /dev/null +++ b/src/SftpNegotiationConnectChain.h @@ -0,0 +1,66 @@ +/* */ +#ifndef SFTP_NEGOTIATION_CONNECT_CHAIN_H +#define SFTP_NEGOTIATION_CONNECT_CHAIN_H + +#include "ControlChain.h" +#include "ConnectCommand.h" +#include "DownloadEngine.h" +#include "SftpNegotiationCommand.h" + +namespace aria2 { + +struct SftpNegotiationConnectChain : public ControlChain { + SftpNegotiationConnectChain() {} + virtual ~SftpNegotiationConnectChain() {} + virtual int run(ConnectCommand* t, DownloadEngine* e) CXX11_OVERRIDE + { + auto c = make_unique + (t->getCuid(), + t->getRequest(), + t->getFileEntry(), + t->getRequestGroup(), + t->getDownloadEngine(), + t->getSocket()); + c->setStatus(Command::STATUS_ONESHOT_REALTIME); + e->setNoWait(true); + e->addCommand(std::move(c)); + return 0; + } +}; + +} // namespace aria2 + +#endif // SFTP_NEGOTIATION_CONNECT_CHAIN_H diff --git a/src/SocketCore.cc b/src/SocketCore.cc index 5469b8bf..05ee28bf 100644 --- a/src/SocketCore.cc +++ b/src/SocketCore.cc @@ -1050,6 +1050,46 @@ bool SocketCore::sshSFTPOpen(const std::string& path) return true; } +bool SocketCore::sshSFTPClose() +{ + assert(sshSession_); + + wantRead_ = false; + wantWrite_ = false; + + auto rv = sshSession_->sftpClose(); + if (rv == SSH_ERR_WOULDBLOCK) { + sshCheckDirection(); + return false; + } + if (rv == SSH_ERR_ERROR) { + throw DL_ABORT_EX(fmt("SSH closing SFTP failed: %s", + sshSession_->getLastErrorString().c_str())); + } + return true; +} + +bool SocketCore::sshSFTPStat(int64_t& totalLength, time_t& mtime, + const std::string& path) +{ + assert(sshSession_); + + wantRead_ = false; + wantWrite_ = false; + + auto rv = sshSession_->sftpStat(totalLength, mtime); + if (rv == SSH_ERR_WOULDBLOCK) { + sshCheckDirection(); + return false; + } + if (rv == SSH_ERR_ERROR) { + throw DL_ABORT_EX(fmt("SSH stat SFTP path %s filed: %s", + path.c_str(), + sshSession_->getLastErrorString().c_str())); + } + return true; +} + bool SocketCore::sshGracefulShutdown() { assert(sshSession_); diff --git a/src/SocketCore.h b/src/SocketCore.h index 2cebae41..92bc67e3 100644 --- a/src/SocketCore.h +++ b/src/SocketCore.h @@ -301,9 +301,18 @@ public: #endif // ENABLE_SSL #ifdef HAVE_LIBSSH2 + // Performs SSH handshake bool sshHandshake(); + // Performs SSH authentication using username and password. bool sshAuthPassword(const std::string& user, const std::string& password); + // Starts sftp session and open remote file |path|. bool sshSFTPOpen(const std::string& path); + // Closes sftp remote file gracefully + bool sshSFTPClose(); + // Gets total length and modified time for remote file currently + // opened. |path| is used for logging. + bool sshSFTPStat(int64_t& totalLength, time_t& mtime, + const std::string& path); bool sshGracefulShutdown(); #endif // HAVE_LIBSSH2