WinTLS: Rewrite writeData

We re-wrote WinTLSSession::writeData.  The major points are:

* Buffer is now preallocated once handshake is finished.  Previously,
  they are allocated each time when we send one TLS record.

* Schannel uses header, body and trailer for each secBuffer.  Now we
  send them off at once using WSASend which is windows counterpart of
  sendv.  Previously, we do memmove if some of them are truncated.

* We don't try to send application data in
  WinTLSSession::closeConnection, since semantically we need same
  application data used to create TLS record before.  Using 0 length
  data to finish sending buffered data looks like a hack.
pull/772/head
Tatsuhiro Tsujikawa 2016-11-11 22:20:29 +09:00
parent d289dc1108
commit d974c935cd
2 changed files with 182 additions and 125 deletions

View File

@ -86,16 +86,6 @@ static const ULONG kReqAFlags =
ASC_REQ_SEQUENCE_DETECT | ASC_REQ_REPLAY_DETECT | ASC_REQ_CONFIDENTIALITY |
ASC_REQ_EXTENDED_ERROR | ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_STREAM;
class TLSBuffer : public ::SecBuffer {
public:
explicit TLSBuffer(ULONG type, ULONG size, void* data)
{
cbBuffer = size;
BufferType = type;
pvBuffer = data;
}
};
class TLSBufferDesc : public ::SecBufferDesc {
public:
explicit TLSBufferDesc(SecBuffer* arr, ULONG buffers)
@ -142,7 +132,8 @@ WinTLSSession::WinTLSSession(WinTLSContext* ctx)
cred_(ctx->getCredHandle()),
writeBuffered_(0),
state_(st_constructed),
status_(SEC_E_OK)
status_(SEC_E_OK),
recordBytesSent_(0)
{
memset(&handle_, 0, sizeof(handle_));
}
@ -213,7 +204,8 @@ int WinTLSSession::closeConnection()
status_ = ::AcceptSecurityContext(cred_, &handle_, nullptr, kReqAFlags, 0,
&handle_, &desc, &flags, nullptr);
}
if (status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) {
if ((status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) &&
getLeftTLSRecordSize() == 0) {
size_t len = ctx.cbBuffer;
ssize_t rv = writeData(ctx.pvBuffer, ctx.cbBuffer);
::FreeContextBuffer(ctx.pvBuffer);
@ -228,17 +220,6 @@ int WinTLSSession::closeConnection()
}
}
// Send remaining data.
while (writeBuf_.size()) {
int rv = writeData(nullptr, 0);
if (rv == 0) {
break;
}
if (rv < 0) {
return rv;
}
}
A2_LOG_DEBUG("WinTLS: Closed Connection");
state_ = st_closed;
return TLS_ERR_OK;
@ -255,12 +236,82 @@ int WinTLSSession::checkDirection()
if (readBuf_.size() || decBuf_.size()) {
return TLS_WANT_READ;
}
if (writeBuf_.size()) {
if (getLeftTLSRecordSize() || writeBuf_.size()) {
return TLS_WANT_WRITE;
}
return TLS_WANT_READ;
}
namespace {
// Fills |iov| of length |len| to send remaining data in |buffers|.
// We have already sent |offset| bytes. This function returns the
// number of |iov| filled. It assumes the array |buffers| is at least
// |len| elements.
size_t fillSendIOV(a2iovec* iov, size_t len, TLSBuffer* buffers, size_t offset)
{
size_t iovcnt = 0;
for (size_t i = 0; i < len; ++i) {
if (offset < buffers[i].cbBuffer) {
iov[iovcnt].A2IOVEC_BASE =
static_cast<char*>(buffers[i].pvBuffer) + offset;
iov[iovcnt].A2IOVEC_LEN = buffers[i].cbBuffer - offset;
++iovcnt;
offset = 0;
}
else {
offset -= buffers[i].cbBuffer;
}
}
return iovcnt;
}
} // namespace
size_t WinTLSSession::getLeftTLSRecordSize() const
{
return sendRecordBuffers_[0].cbBuffer + sendRecordBuffers_[1].cbBuffer +
sendRecordBuffers_[2].cbBuffer - recordBytesSent_;
}
int WinTLSSession::sendTLSRecord()
{
A2_LOG_DEBUG(fmt("WinTLS: TLS record %" PRIu64 " bytes left",
static_cast<uint64_t>(getLeftTLSRecordSize())));
while (getLeftTLSRecordSize()) {
std::array<a2iovec, 3> iov;
auto iovcnt = fillSendIOV(iov.data(), iov.size(), sendRecordBuffers_.data(),
recordBytesSent_);
DWORD nwrite;
auto rv =
WSASend(sockfd_, iov.data(), iovcnt, &nwrite, 0, nullptr, nullptr);
if (rv != 0) {
auto errnum = ::WSAGetLastError();
if (errnum == WSAEINTR) {
continue;
}
if (errnum == WSAEWOULDBLOCK) {
return TLS_ERR_WOULDBLOCK;
}
A2_LOG_ERROR("WinTLS: Connection error while writing");
status_ = SEC_E_INCOMPLETE_MESSAGE;
state_ = st_error;
return TLS_ERR_ERROR;
}
recordBytesSent_ += nwrite;
}
recordBytesSent_ = 0;
sendRecordBuffers_[0].cbBuffer = 0;
sendRecordBuffers_[1].cbBuffer = 0;
sendRecordBuffers_[2].cbBuffer = 0;
return 0;
}
ssize_t WinTLSSession::writeData(const void* data, size_t len)
{
if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
@ -281,45 +332,15 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
}
A2_LOG_DEBUG(fmt("WinTLS: Write request: %" PRIu64 " buffered: %" PRIu64,
(uint64_t)len, (uint64_t)writeBuf_.size()));
(uint64_t)len, (uint64_t)recordBytesSent_));
// Write remaining buffered data, if any.
while (writeBuf_.size()) {
auto written = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
errno = ::WSAGetLastError();
if (written < 0 && errno == WSAEINTR) {
continue;
}
if (written < 0 && errno == WSAEWOULDBLOCK) {
return TLS_ERR_WOULDBLOCK;
}
if (written == 0) {
return written;
}
if (written < 0) {
status_ = SEC_E_INVALID_HANDLE;
state_ = st_error;
return TLS_ERR_ERROR;
}
writeBuf_.eat(written);
auto rv = sendTLSRecord();
if (rv != 0) {
return rv;
}
if (len == 0) {
return 0;
}
if (!streamSizes_) {
streamSizes_.reset(new SecPkgContext_StreamSizes());
status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
streamSizes_.get());
if (status_ != SEC_E_OK || !streamSizes_->cbMaximumMessage) {
state_ = st_error;
return TLS_ERR_ERROR;
}
}
size_t process = len;
auto bytes = reinterpret_cast<const char*>(data);
auto left = len;
auto bytes = static_cast<const char*>(data);
if (writeBuffered_) {
// There was buffered data, hence we need to "remove" that data from the
// incoming buffer to avoid writing it again
@ -331,10 +352,10 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
}
// just advance the buffer by writeBuffered_ bytes
bytes += writeBuffered_;
process -= writeBuffered_;
left -= writeBuffered_;
writeBuffered_ = 0;
}
if (!process) {
if (!left) {
// The buffer contained the full remainder. At this point, the buffer has
// been written, so the request is done in its entirety;
return len;
@ -342,23 +363,25 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
// Buffered data was already written ;)
// If there was no buffered data, this will be len - len = 0.
len = len - process;
while (process) {
len -= left;
while (left) {
// Set up an outgoing message, according to streamSizes_
writeBuffered_ = std::min(process, (size_t)streamSizes_->cbMaximumMessage);
size_t dl =
streamSizes_->cbHeader + writeBuffered_ + streamSizes_->cbTrailer;
auto buf = make_unique<char[]>(dl);
TLSBuffer buffers[] = {
TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_->cbHeader, buf.get()),
writeBuffered_ =
std::min(left, static_cast<size_t>(streamSizes_.cbMaximumMessage));
sendRecordBuffers_ = {
TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_.cbHeader,
sendBuffer_.data()),
TLSBuffer(SECBUFFER_DATA, writeBuffered_,
buf.get() + streamSizes_->cbHeader),
TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_->cbTrailer,
buf.get() + streamSizes_->cbHeader + writeBuffered_),
sendBuffer_.data() + streamSizes_.cbHeader),
TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_.cbTrailer,
sendBuffer_.data() + streamSizes_.cbHeader + writeBuffered_),
TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
};
TLSBufferDesc desc(buffers, 4);
memcpy(buffers[1].pvBuffer, bytes, writeBuffered_);
TLSBufferDesc desc(sendRecordBuffers_.data(), sendRecordBuffers_.size());
std::copy_n(bytes, writeBuffered_,
static_cast<char*>(sendRecordBuffers_[1].pvBuffer));
status_ = ::EncryptMessage(&handle_, 0, &desc, 0);
if (status_ != SEC_E_OK) {
A2_LOG_ERROR(fmt("WinTLS: Failed to encrypt a message! %s",
@ -367,61 +390,32 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
return TLS_ERR_ERROR;
}
// EncryptMessage may have truncated the buffers.
// Should rarely happen, if ever, except for the trailer.
dl = buffers[0].cbBuffer;
if (dl < streamSizes_->cbHeader) {
// Move message.
memmove(buf.get() + dl, buffers[1].pvBuffer, buffers[1].cbBuffer);
}
dl += buffers[1].cbBuffer;
if (dl < streamSizes_->cbHeader + writeBuffered_) {
// Move trailer.
memmove(buf.get() + dl, buffers[2].pvBuffer, buffers[2].cbBuffer);
}
dl += buffers[2].cbBuffer;
A2_LOG_DEBUG(fmt("WinTLS: Write TLS record header: %" PRIu64
" body: %" PRIu64 " trailer: %" PRIu64,
static_cast<uint64_t>(sendRecordBuffers_[0].cbBuffer),
static_cast<uint64_t>(sendRecordBuffers_[1].cbBuffer),
static_cast<uint64_t>(sendRecordBuffers_[2].cbBuffer)));
// Write (or buffer) the message.
char* p = buf.get();
while (dl) {
auto written = ::send(sockfd_, p, dl, 0);
errno = ::WSAGetLastError();
if (written < 0 && errno == WSAEINTR) {
continue;
auto rv = sendTLSRecord();
if (rv == TLS_ERR_WOULDBLOCK) {
if (len == 0) {
return TLS_ERR_WOULDBLOCK;
}
if (written < 0 && errno == WSAEWOULDBLOCK) {
// Buffer the rest of the message...
writeBuf_.write(p, dl);
// and return...
return len;
}
if (written == 0) {
A2_LOG_ERROR("WinTLS: Connection closed while writing");
status_ = SEC_E_INCOMPLETE_MESSAGE;
state_ = st_error;
return TLS_ERR_ERROR;
}
if (written < 0) {
A2_LOG_ERROR("WinTLS: Connection error while writing");
status_ = SEC_E_INCOMPLETE_MESSAGE;
state_ = st_error;
return TLS_ERR_ERROR;
}
dl -= written;
p += written;
return len;
}
if (rv != 0) {
return rv;
}
len += writeBuffered_;
bytes += writeBuffered_;
process -= writeBuffered_;
left -= writeBuffered_;
writeBuffered_ = 0;
}
A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64 " buffered: %" PRIu64,
(uint64_t)len, (uint64_t)writeBuf_.size()));
if (!len) {
return TLS_ERR_WOULDBLOCK;
}
A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64, (uint64_t)len));
return len;
}
@ -777,6 +771,11 @@ restart:
// Fall through
case st_handshake_done:
if (obtainTLSRecordSizes() != 0) {
return TLS_ERR_ERROR;
}
ensureSendBuffer();
// All ready now :D
state_ = st_connected;
A2_LOG_INFO(
@ -833,4 +832,26 @@ std::string WinTLSSession::getLastErrorString()
size_t WinTLSSession::getRecvBufferedLength() { return decBuf_.size(); }
int WinTLSSession::obtainTLSRecordSizes()
{
status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
&streamSizes_);
if (status_ != SEC_E_OK || !streamSizes_.cbMaximumMessage) {
A2_LOG_ERROR("WinTLS: Unable to obtain stream sizes");
state_ = st_error;
return -1;
}
return 0;
}
void WinTLSSession::ensureSendBuffer()
{
auto sum = streamSizes_.cbHeader + streamSizes_.cbMaximumMessage +
streamSizes_.cbTrailer;
if (sendBuffer_.size() < sum) {
sendBuffer_.resize(sum);
}
}
} // namespace aria2

View File

@ -100,6 +100,18 @@ public:
};
} // namespace wintls
class TLSBuffer : public ::SecBuffer {
public:
TLSBuffer() : ::SecBuffer{}{}
explicit TLSBuffer(ULONG type, ULONG size, void* data)
{
cbBuffer = size;
BufferType = type;
pvBuffer = data;
}
};
class WinTLSSession : public TLSSession {
enum state_t {
st_constructed,
@ -172,16 +184,31 @@ public:
virtual size_t getRecvBufferedLength() CXX11_OVERRIDE;
private:
// Obtains TLS record size limits. This function returns 0 if it
// succeeds, or -1. status_ and state_ are updated according to the
// result.
int obtainTLSRecordSizes();
// Ensures the buffer size so that maximum TLS record can be sent.
void ensureSendBuffer();
// Sends TLS record specified in sendRecordBuffers_. It uses
// recordBytesSent_ to track down how many bytes have been sent.
// This function returns 0 if it succeeds, or negative error codes.
int sendTLSRecord();
// Returns the number of bytes in the remaining TLS record size.
size_t getLeftTLSRecordSize() const;
std::string hostname_;
sock_t sockfd_;
TLSSessionSide side_;
CredHandle* cred_;
CtxtHandle handle_;
// Buffer for already encrypted writes
// Buffer for already encrypted writes. This is only used in
// handshake.
wintls::Buffer writeBuf_;
// While the writeBuf_ holds encrypted messages, writeBuffered_ has the
// corresponding size of unencrypted data used to produce the messages.
// While the sendRecordBuffers_ holds encrypted messages,
// writeBuffered_ has the corresponding size of unencrypted data
// used to produce the messages.
size_t writeBuffered_;
// Buffer for still encrypted reads
wintls::Buffer readBuf_;
@ -191,7 +218,16 @@ private:
state_t state_;
SECURITY_STATUS status_;
std::unique_ptr<SecPkgContext_StreamSizes> streamSizes_;
// The number of maximum size for TLS record header, body, and
// trailer.
SecPkgContext_StreamSizes streamSizes_;
// Underlying buffer for outgoing TLS record.
std::vector<unsigned char> sendBuffer_;
// How many bytes has been sent for current TLS record held in
// sendRecordBuffers_.
size_t recordBytesSent_;
// This holds current outgoing TLS record.
std::array<TLSBuffer, 4> sendRecordBuffers_;
};
} // namespace aria2