diff --git a/src/WinTLSSession.cc b/src/WinTLSSession.cc index 72fff2e1..78a919a7 100644 --- a/src/WinTLSSession.cc +++ b/src/WinTLSSession.cc @@ -86,16 +86,6 @@ static const ULONG kReqAFlags = ASC_REQ_SEQUENCE_DETECT | ASC_REQ_REPLAY_DETECT | ASC_REQ_CONFIDENTIALITY | ASC_REQ_EXTENDED_ERROR | ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_STREAM; -class TLSBuffer : public ::SecBuffer { -public: - explicit TLSBuffer(ULONG type, ULONG size, void* data) - { - cbBuffer = size; - BufferType = type; - pvBuffer = data; - } -}; - class TLSBufferDesc : public ::SecBufferDesc { public: explicit TLSBufferDesc(SecBuffer* arr, ULONG buffers) @@ -142,7 +132,8 @@ WinTLSSession::WinTLSSession(WinTLSContext* ctx) cred_(ctx->getCredHandle()), writeBuffered_(0), state_(st_constructed), - status_(SEC_E_OK) + status_(SEC_E_OK), + recordBytesSent_(0) { memset(&handle_, 0, sizeof(handle_)); } @@ -213,7 +204,8 @@ int WinTLSSession::closeConnection() status_ = ::AcceptSecurityContext(cred_, &handle_, nullptr, kReqAFlags, 0, &handle_, &desc, &flags, nullptr); } - if (status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) { + if ((status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) && + getLeftTLSRecordSize() == 0) { size_t len = ctx.cbBuffer; ssize_t rv = writeData(ctx.pvBuffer, ctx.cbBuffer); ::FreeContextBuffer(ctx.pvBuffer); @@ -228,17 +220,6 @@ int WinTLSSession::closeConnection() } } - // Send remaining data. - while (writeBuf_.size()) { - int rv = writeData(nullptr, 0); - if (rv == 0) { - break; - } - if (rv < 0) { - return rv; - } - } - A2_LOG_DEBUG("WinTLS: Closed Connection"); state_ = st_closed; return TLS_ERR_OK; @@ -255,12 +236,82 @@ int WinTLSSession::checkDirection() if (readBuf_.size() || decBuf_.size()) { return TLS_WANT_READ; } - if (writeBuf_.size()) { + if (getLeftTLSRecordSize() || writeBuf_.size()) { return TLS_WANT_WRITE; } return TLS_WANT_READ; } +namespace { +// Fills |iov| of length |len| to send remaining data in |buffers|. +// We have already sent |offset| bytes. This function returns the +// number of |iov| filled. It assumes the array |buffers| is at least +// |len| elements. +size_t fillSendIOV(a2iovec* iov, size_t len, TLSBuffer* buffers, size_t offset) +{ + size_t iovcnt = 0; + for (size_t i = 0; i < len; ++i) { + if (offset < buffers[i].cbBuffer) { + iov[iovcnt].A2IOVEC_BASE = + static_cast(buffers[i].pvBuffer) + offset; + iov[iovcnt].A2IOVEC_LEN = buffers[i].cbBuffer - offset; + ++iovcnt; + offset = 0; + } + else { + offset -= buffers[i].cbBuffer; + } + } + return iovcnt; +} +} // namespace + +size_t WinTLSSession::getLeftTLSRecordSize() const +{ + return sendRecordBuffers_[0].cbBuffer + sendRecordBuffers_[1].cbBuffer + + sendRecordBuffers_[2].cbBuffer - recordBytesSent_; +} + +int WinTLSSession::sendTLSRecord() +{ + A2_LOG_DEBUG(fmt("WinTLS: TLS record %" PRIu64 " bytes left", + static_cast(getLeftTLSRecordSize()))); + + while (getLeftTLSRecordSize()) { + std::array iov; + auto iovcnt = fillSendIOV(iov.data(), iov.size(), sendRecordBuffers_.data(), + recordBytesSent_); + + DWORD nwrite; + auto rv = + WSASend(sockfd_, iov.data(), iovcnt, &nwrite, 0, nullptr, nullptr); + if (rv != 0) { + auto errnum = ::WSAGetLastError(); + if (errnum == WSAEINTR) { + continue; + } + + if (errnum == WSAEWOULDBLOCK) { + return TLS_ERR_WOULDBLOCK; + } + + A2_LOG_ERROR("WinTLS: Connection error while writing"); + status_ = SEC_E_INCOMPLETE_MESSAGE; + state_ = st_error; + return TLS_ERR_ERROR; + } + + recordBytesSent_ += nwrite; + } + + recordBytesSent_ = 0; + sendRecordBuffers_[0].cbBuffer = 0; + sendRecordBuffers_[1].cbBuffer = 0; + sendRecordBuffers_[2].cbBuffer = 0; + + return 0; +} + ssize_t WinTLSSession::writeData(const void* data, size_t len) { if (state_ == st_handshake_write || state_ == st_handshake_write_last || @@ -281,45 +332,15 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len) } A2_LOG_DEBUG(fmt("WinTLS: Write request: %" PRIu64 " buffered: %" PRIu64, - (uint64_t)len, (uint64_t)writeBuf_.size())); + (uint64_t)len, (uint64_t)recordBytesSent_)); - // Write remaining buffered data, if any. - while (writeBuf_.size()) { - auto written = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0); - errno = ::WSAGetLastError(); - if (written < 0 && errno == WSAEINTR) { - continue; - } - if (written < 0 && errno == WSAEWOULDBLOCK) { - return TLS_ERR_WOULDBLOCK; - } - if (written == 0) { - return written; - } - if (written < 0) { - status_ = SEC_E_INVALID_HANDLE; - state_ = st_error; - return TLS_ERR_ERROR; - } - writeBuf_.eat(written); + auto rv = sendTLSRecord(); + if (rv != 0) { + return rv; } - if (len == 0) { - return 0; - } - - if (!streamSizes_) { - streamSizes_.reset(new SecPkgContext_StreamSizes()); - status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES, - streamSizes_.get()); - if (status_ != SEC_E_OK || !streamSizes_->cbMaximumMessage) { - state_ = st_error; - return TLS_ERR_ERROR; - } - } - - size_t process = len; - auto bytes = reinterpret_cast(data); + auto left = len; + auto bytes = static_cast(data); if (writeBuffered_) { // There was buffered data, hence we need to "remove" that data from the // incoming buffer to avoid writing it again @@ -331,10 +352,10 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len) } // just advance the buffer by writeBuffered_ bytes bytes += writeBuffered_; - process -= writeBuffered_; + left -= writeBuffered_; writeBuffered_ = 0; } - if (!process) { + if (!left) { // The buffer contained the full remainder. At this point, the buffer has // been written, so the request is done in its entirety; return len; @@ -342,23 +363,25 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len) // Buffered data was already written ;) // If there was no buffered data, this will be len - len = 0. - len = len - process; - while (process) { + len -= left; + while (left) { // Set up an outgoing message, according to streamSizes_ - writeBuffered_ = std::min(process, (size_t)streamSizes_->cbMaximumMessage); - size_t dl = - streamSizes_->cbHeader + writeBuffered_ + streamSizes_->cbTrailer; - auto buf = make_unique(dl); - TLSBuffer buffers[] = { - TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_->cbHeader, buf.get()), + writeBuffered_ = + std::min(left, static_cast(streamSizes_.cbMaximumMessage)); + + sendRecordBuffers_ = { + TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_.cbHeader, + sendBuffer_.data()), TLSBuffer(SECBUFFER_DATA, writeBuffered_, - buf.get() + streamSizes_->cbHeader), - TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_->cbTrailer, - buf.get() + streamSizes_->cbHeader + writeBuffered_), + sendBuffer_.data() + streamSizes_.cbHeader), + TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_.cbTrailer, + sendBuffer_.data() + streamSizes_.cbHeader + writeBuffered_), TLSBuffer(SECBUFFER_EMPTY, 0, nullptr), }; - TLSBufferDesc desc(buffers, 4); - memcpy(buffers[1].pvBuffer, bytes, writeBuffered_); + + TLSBufferDesc desc(sendRecordBuffers_.data(), sendRecordBuffers_.size()); + std::copy_n(bytes, writeBuffered_, + static_cast(sendRecordBuffers_[1].pvBuffer)); status_ = ::EncryptMessage(&handle_, 0, &desc, 0); if (status_ != SEC_E_OK) { A2_LOG_ERROR(fmt("WinTLS: Failed to encrypt a message! %s", @@ -367,61 +390,32 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len) return TLS_ERR_ERROR; } - // EncryptMessage may have truncated the buffers. - // Should rarely happen, if ever, except for the trailer. - dl = buffers[0].cbBuffer; - if (dl < streamSizes_->cbHeader) { - // Move message. - memmove(buf.get() + dl, buffers[1].pvBuffer, buffers[1].cbBuffer); - } - dl += buffers[1].cbBuffer; - if (dl < streamSizes_->cbHeader + writeBuffered_) { - // Move trailer. - memmove(buf.get() + dl, buffers[2].pvBuffer, buffers[2].cbBuffer); - } - dl += buffers[2].cbBuffer; + A2_LOG_DEBUG(fmt("WinTLS: Write TLS record header: %" PRIu64 + " body: %" PRIu64 " trailer: %" PRIu64, + static_cast(sendRecordBuffers_[0].cbBuffer), + static_cast(sendRecordBuffers_[1].cbBuffer), + static_cast(sendRecordBuffers_[2].cbBuffer))); - // Write (or buffer) the message. - char* p = buf.get(); - while (dl) { - auto written = ::send(sockfd_, p, dl, 0); - errno = ::WSAGetLastError(); - if (written < 0 && errno == WSAEINTR) { - continue; + auto rv = sendTLSRecord(); + if (rv == TLS_ERR_WOULDBLOCK) { + if (len == 0) { + return TLS_ERR_WOULDBLOCK; } - if (written < 0 && errno == WSAEWOULDBLOCK) { - // Buffer the rest of the message... - writeBuf_.write(p, dl); - // and return... - return len; - } - if (written == 0) { - A2_LOG_ERROR("WinTLS: Connection closed while writing"); - status_ = SEC_E_INCOMPLETE_MESSAGE; - state_ = st_error; - return TLS_ERR_ERROR; - } - if (written < 0) { - A2_LOG_ERROR("WinTLS: Connection error while writing"); - status_ = SEC_E_INCOMPLETE_MESSAGE; - state_ = st_error; - return TLS_ERR_ERROR; - } - dl -= written; - p += written; + return len; + } + + if (rv != 0) { + return rv; } len += writeBuffered_; bytes += writeBuffered_; - process -= writeBuffered_; + left -= writeBuffered_; writeBuffered_ = 0; } - A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64 " buffered: %" PRIu64, - (uint64_t)len, (uint64_t)writeBuf_.size())); - if (!len) { - return TLS_ERR_WOULDBLOCK; - } + A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64, (uint64_t)len)); + return len; } @@ -777,6 +771,11 @@ restart: // Fall through case st_handshake_done: + if (obtainTLSRecordSizes() != 0) { + return TLS_ERR_ERROR; + } + ensureSendBuffer(); + // All ready now :D state_ = st_connected; A2_LOG_INFO( @@ -833,4 +832,26 @@ std::string WinTLSSession::getLastErrorString() size_t WinTLSSession::getRecvBufferedLength() { return decBuf_.size(); } +int WinTLSSession::obtainTLSRecordSizes() +{ + status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES, + &streamSizes_); + if (status_ != SEC_E_OK || !streamSizes_.cbMaximumMessage) { + A2_LOG_ERROR("WinTLS: Unable to obtain stream sizes"); + state_ = st_error; + return -1; + } + + return 0; +} + +void WinTLSSession::ensureSendBuffer() +{ + auto sum = streamSizes_.cbHeader + streamSizes_.cbMaximumMessage + + streamSizes_.cbTrailer; + if (sendBuffer_.size() < sum) { + sendBuffer_.resize(sum); + } +} + } // namespace aria2 diff --git a/src/WinTLSSession.h b/src/WinTLSSession.h index 2db9cb21..9095bed2 100644 --- a/src/WinTLSSession.h +++ b/src/WinTLSSession.h @@ -100,6 +100,18 @@ public: }; } // namespace wintls +class TLSBuffer : public ::SecBuffer { +public: + TLSBuffer() : ::SecBuffer{}{} + + explicit TLSBuffer(ULONG type, ULONG size, void* data) + { + cbBuffer = size; + BufferType = type; + pvBuffer = data; + } +}; + class WinTLSSession : public TLSSession { enum state_t { st_constructed, @@ -172,16 +184,31 @@ public: virtual size_t getRecvBufferedLength() CXX11_OVERRIDE; private: + // Obtains TLS record size limits. This function returns 0 if it + // succeeds, or -1. status_ and state_ are updated according to the + // result. + int obtainTLSRecordSizes(); + // Ensures the buffer size so that maximum TLS record can be sent. + void ensureSendBuffer(); + // Sends TLS record specified in sendRecordBuffers_. It uses + // recordBytesSent_ to track down how many bytes have been sent. + // This function returns 0 if it succeeds, or negative error codes. + int sendTLSRecord(); + // Returns the number of bytes in the remaining TLS record size. + size_t getLeftTLSRecordSize() const; + std::string hostname_; sock_t sockfd_; TLSSessionSide side_; CredHandle* cred_; CtxtHandle handle_; - // Buffer for already encrypted writes + // Buffer for already encrypted writes. This is only used in + // handshake. wintls::Buffer writeBuf_; - // While the writeBuf_ holds encrypted messages, writeBuffered_ has the - // corresponding size of unencrypted data used to produce the messages. + // While the sendRecordBuffers_ holds encrypted messages, + // writeBuffered_ has the corresponding size of unencrypted data + // used to produce the messages. size_t writeBuffered_; // Buffer for still encrypted reads wintls::Buffer readBuf_; @@ -191,7 +218,16 @@ private: state_t state_; SECURITY_STATUS status_; - std::unique_ptr streamSizes_; + // The number of maximum size for TLS record header, body, and + // trailer. + SecPkgContext_StreamSizes streamSizes_; + // Underlying buffer for outgoing TLS record. + std::vector sendBuffer_; + // How many bytes has been sent for current TLS record held in + // sendRecordBuffers_. + size_t recordBytesSent_; + // This holds current outgoing TLS record. + std::array sendRecordBuffers_; }; } // namespace aria2