Merge pull request #772 from aria2/refactor-wintls-write

WinTLS: Rewrite writeData
pull/786/head
Tatsuhiro Tsujikawa 2016-11-23 22:57:58 +09:00 committed by GitHub
commit 9df50804d4
2 changed files with 182 additions and 125 deletions

View File

@ -86,16 +86,6 @@ static const ULONG kReqAFlags =
ASC_REQ_SEQUENCE_DETECT | ASC_REQ_REPLAY_DETECT | ASC_REQ_CONFIDENTIALITY |
ASC_REQ_EXTENDED_ERROR | ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_STREAM;
class TLSBuffer : public ::SecBuffer {
public:
explicit TLSBuffer(ULONG type, ULONG size, void* data)
{
cbBuffer = size;
BufferType = type;
pvBuffer = data;
}
};
class TLSBufferDesc : public ::SecBufferDesc {
public:
explicit TLSBufferDesc(SecBuffer* arr, ULONG buffers)
@ -142,7 +132,8 @@ WinTLSSession::WinTLSSession(WinTLSContext* ctx)
cred_(ctx->getCredHandle()),
writeBuffered_(0),
state_(st_constructed),
status_(SEC_E_OK)
status_(SEC_E_OK),
recordBytesSent_(0)
{
memset(&handle_, 0, sizeof(handle_));
}
@ -213,7 +204,8 @@ int WinTLSSession::closeConnection()
status_ = ::AcceptSecurityContext(cred_, &handle_, nullptr, kReqAFlags, 0,
&handle_, &desc, &flags, nullptr);
}
if (status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) {
if ((status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) &&
getLeftTLSRecordSize() == 0) {
size_t len = ctx.cbBuffer;
ssize_t rv = writeData(ctx.pvBuffer, ctx.cbBuffer);
::FreeContextBuffer(ctx.pvBuffer);
@ -228,17 +220,6 @@ int WinTLSSession::closeConnection()
}
}
// Send remaining data.
while (writeBuf_.size()) {
int rv = writeData(nullptr, 0);
if (rv == 0) {
break;
}
if (rv < 0) {
return rv;
}
}
A2_LOG_DEBUG("WinTLS: Closed Connection");
state_ = st_closed;
return TLS_ERR_OK;
@ -255,12 +236,82 @@ int WinTLSSession::checkDirection()
if (readBuf_.size() || decBuf_.size()) {
return TLS_WANT_READ;
}
if (writeBuf_.size()) {
if (getLeftTLSRecordSize() || writeBuf_.size()) {
return TLS_WANT_WRITE;
}
return TLS_WANT_READ;
}
namespace {
// Fills |iov| of length |len| to send remaining data in |buffers|.
// We have already sent |offset| bytes. This function returns the
// number of |iov| filled. It assumes the array |buffers| is at least
// |len| elements.
size_t fillSendIOV(a2iovec* iov, size_t len, TLSBuffer* buffers, size_t offset)
{
size_t iovcnt = 0;
for (size_t i = 0; i < len; ++i) {
if (offset < buffers[i].cbBuffer) {
iov[iovcnt].A2IOVEC_BASE =
static_cast<char*>(buffers[i].pvBuffer) + offset;
iov[iovcnt].A2IOVEC_LEN = buffers[i].cbBuffer - offset;
++iovcnt;
offset = 0;
}
else {
offset -= buffers[i].cbBuffer;
}
}
return iovcnt;
}
} // namespace
size_t WinTLSSession::getLeftTLSRecordSize() const
{
return sendRecordBuffers_[0].cbBuffer + sendRecordBuffers_[1].cbBuffer +
sendRecordBuffers_[2].cbBuffer - recordBytesSent_;
}
int WinTLSSession::sendTLSRecord()
{
A2_LOG_DEBUG(fmt("WinTLS: TLS record %" PRIu64 " bytes left",
static_cast<uint64_t>(getLeftTLSRecordSize())));
while (getLeftTLSRecordSize()) {
std::array<a2iovec, 3> iov;
auto iovcnt = fillSendIOV(iov.data(), iov.size(), sendRecordBuffers_.data(),
recordBytesSent_);
DWORD nwrite;
auto rv =
WSASend(sockfd_, iov.data(), iovcnt, &nwrite, 0, nullptr, nullptr);
if (rv != 0) {
auto errnum = ::WSAGetLastError();
if (errnum == WSAEINTR) {
continue;
}
if (errnum == WSAEWOULDBLOCK) {
return TLS_ERR_WOULDBLOCK;
}
A2_LOG_ERROR("WinTLS: Connection error while writing");
status_ = SEC_E_INCOMPLETE_MESSAGE;
state_ = st_error;
return TLS_ERR_ERROR;
}
recordBytesSent_ += nwrite;
}
recordBytesSent_ = 0;
sendRecordBuffers_[0].cbBuffer = 0;
sendRecordBuffers_[1].cbBuffer = 0;
sendRecordBuffers_[2].cbBuffer = 0;
return 0;
}
ssize_t WinTLSSession::writeData(const void* data, size_t len)
{
if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
@ -281,45 +332,15 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
}
A2_LOG_DEBUG(fmt("WinTLS: Write request: %" PRIu64 " buffered: %" PRIu64,
(uint64_t)len, (uint64_t)writeBuf_.size()));
(uint64_t)len, (uint64_t)recordBytesSent_));
// Write remaining buffered data, if any.
while (writeBuf_.size()) {
auto written = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
errno = ::WSAGetLastError();
if (written < 0 && errno == WSAEINTR) {
continue;
}
if (written < 0 && errno == WSAEWOULDBLOCK) {
return TLS_ERR_WOULDBLOCK;
}
if (written == 0) {
return written;
}
if (written < 0) {
status_ = SEC_E_INVALID_HANDLE;
state_ = st_error;
return TLS_ERR_ERROR;
}
writeBuf_.eat(written);
auto rv = sendTLSRecord();
if (rv != 0) {
return rv;
}
if (len == 0) {
return 0;
}
if (!streamSizes_) {
streamSizes_.reset(new SecPkgContext_StreamSizes());
status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
streamSizes_.get());
if (status_ != SEC_E_OK || !streamSizes_->cbMaximumMessage) {
state_ = st_error;
return TLS_ERR_ERROR;
}
}
size_t process = len;
auto bytes = reinterpret_cast<const char*>(data);
auto left = len;
auto bytes = static_cast<const char*>(data);
if (writeBuffered_) {
// There was buffered data, hence we need to "remove" that data from the
// incoming buffer to avoid writing it again
@ -331,10 +352,10 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
}
// just advance the buffer by writeBuffered_ bytes
bytes += writeBuffered_;
process -= writeBuffered_;
left -= writeBuffered_;
writeBuffered_ = 0;
}
if (!process) {
if (!left) {
// The buffer contained the full remainder. At this point, the buffer has
// been written, so the request is done in its entirety;
return len;
@ -342,23 +363,25 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
// Buffered data was already written ;)
// If there was no buffered data, this will be len - len = 0.
len = len - process;
while (process) {
len -= left;
while (left) {
// Set up an outgoing message, according to streamSizes_
writeBuffered_ = std::min(process, (size_t)streamSizes_->cbMaximumMessage);
size_t dl =
streamSizes_->cbHeader + writeBuffered_ + streamSizes_->cbTrailer;
auto buf = make_unique<char[]>(dl);
TLSBuffer buffers[] = {
TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_->cbHeader, buf.get()),
writeBuffered_ =
std::min(left, static_cast<size_t>(streamSizes_.cbMaximumMessage));
sendRecordBuffers_ = {
TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_.cbHeader,
sendBuffer_.data()),
TLSBuffer(SECBUFFER_DATA, writeBuffered_,
buf.get() + streamSizes_->cbHeader),
TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_->cbTrailer,
buf.get() + streamSizes_->cbHeader + writeBuffered_),
sendBuffer_.data() + streamSizes_.cbHeader),
TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_.cbTrailer,
sendBuffer_.data() + streamSizes_.cbHeader + writeBuffered_),
TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
};
TLSBufferDesc desc(buffers, 4);
memcpy(buffers[1].pvBuffer, bytes, writeBuffered_);
TLSBufferDesc desc(sendRecordBuffers_.data(), sendRecordBuffers_.size());
std::copy_n(bytes, writeBuffered_,
static_cast<char*>(sendRecordBuffers_[1].pvBuffer));
status_ = ::EncryptMessage(&handle_, 0, &desc, 0);
if (status_ != SEC_E_OK) {
A2_LOG_ERROR(fmt("WinTLS: Failed to encrypt a message! %s",
@ -367,61 +390,32 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
return TLS_ERR_ERROR;
}
// EncryptMessage may have truncated the buffers.
// Should rarely happen, if ever, except for the trailer.
dl = buffers[0].cbBuffer;
if (dl < streamSizes_->cbHeader) {
// Move message.
memmove(buf.get() + dl, buffers[1].pvBuffer, buffers[1].cbBuffer);
}
dl += buffers[1].cbBuffer;
if (dl < streamSizes_->cbHeader + writeBuffered_) {
// Move trailer.
memmove(buf.get() + dl, buffers[2].pvBuffer, buffers[2].cbBuffer);
}
dl += buffers[2].cbBuffer;
A2_LOG_DEBUG(fmt("WinTLS: Write TLS record header: %" PRIu64
" body: %" PRIu64 " trailer: %" PRIu64,
static_cast<uint64_t>(sendRecordBuffers_[0].cbBuffer),
static_cast<uint64_t>(sendRecordBuffers_[1].cbBuffer),
static_cast<uint64_t>(sendRecordBuffers_[2].cbBuffer)));
// Write (or buffer) the message.
char* p = buf.get();
while (dl) {
auto written = ::send(sockfd_, p, dl, 0);
errno = ::WSAGetLastError();
if (written < 0 && errno == WSAEINTR) {
continue;
auto rv = sendTLSRecord();
if (rv == TLS_ERR_WOULDBLOCK) {
if (len == 0) {
return TLS_ERR_WOULDBLOCK;
}
if (written < 0 && errno == WSAEWOULDBLOCK) {
// Buffer the rest of the message...
writeBuf_.write(p, dl);
// and return...
return len;
}
if (written == 0) {
A2_LOG_ERROR("WinTLS: Connection closed while writing");
status_ = SEC_E_INCOMPLETE_MESSAGE;
state_ = st_error;
return TLS_ERR_ERROR;
}
if (written < 0) {
A2_LOG_ERROR("WinTLS: Connection error while writing");
status_ = SEC_E_INCOMPLETE_MESSAGE;
state_ = st_error;
return TLS_ERR_ERROR;
}
dl -= written;
p += written;
return len;
}
if (rv != 0) {
return rv;
}
len += writeBuffered_;
bytes += writeBuffered_;
process -= writeBuffered_;
left -= writeBuffered_;
writeBuffered_ = 0;
}
A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64 " buffered: %" PRIu64,
(uint64_t)len, (uint64_t)writeBuf_.size()));
if (!len) {
return TLS_ERR_WOULDBLOCK;
}
A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64, (uint64_t)len));
return len;
}
@ -777,6 +771,11 @@ restart:
// Fall through
case st_handshake_done:
if (obtainTLSRecordSizes() != 0) {
return TLS_ERR_ERROR;
}
ensureSendBuffer();
// All ready now :D
state_ = st_connected;
A2_LOG_INFO(
@ -833,4 +832,26 @@ std::string WinTLSSession::getLastErrorString()
size_t WinTLSSession::getRecvBufferedLength() { return decBuf_.size(); }
int WinTLSSession::obtainTLSRecordSizes()
{
status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
&streamSizes_);
if (status_ != SEC_E_OK || !streamSizes_.cbMaximumMessage) {
A2_LOG_ERROR("WinTLS: Unable to obtain stream sizes");
state_ = st_error;
return -1;
}
return 0;
}
void WinTLSSession::ensureSendBuffer()
{
auto sum = streamSizes_.cbHeader + streamSizes_.cbMaximumMessage +
streamSizes_.cbTrailer;
if (sendBuffer_.size() < sum) {
sendBuffer_.resize(sum);
}
}
} // namespace aria2

View File

@ -100,6 +100,18 @@ public:
};
} // namespace wintls
class TLSBuffer : public ::SecBuffer {
public:
TLSBuffer() : ::SecBuffer{}{}
explicit TLSBuffer(ULONG type, ULONG size, void* data)
{
cbBuffer = size;
BufferType = type;
pvBuffer = data;
}
};
class WinTLSSession : public TLSSession {
enum state_t {
st_constructed,
@ -172,16 +184,31 @@ public:
virtual size_t getRecvBufferedLength() CXX11_OVERRIDE;
private:
// Obtains TLS record size limits. This function returns 0 if it
// succeeds, or -1. status_ and state_ are updated according to the
// result.
int obtainTLSRecordSizes();
// Ensures the buffer size so that maximum TLS record can be sent.
void ensureSendBuffer();
// Sends TLS record specified in sendRecordBuffers_. It uses
// recordBytesSent_ to track down how many bytes have been sent.
// This function returns 0 if it succeeds, or negative error codes.
int sendTLSRecord();
// Returns the number of bytes in the remaining TLS record size.
size_t getLeftTLSRecordSize() const;
std::string hostname_;
sock_t sockfd_;
TLSSessionSide side_;
CredHandle* cred_;
CtxtHandle handle_;
// Buffer for already encrypted writes
// Buffer for already encrypted writes. This is only used in
// handshake.
wintls::Buffer writeBuf_;
// While the writeBuf_ holds encrypted messages, writeBuffered_ has the
// corresponding size of unencrypted data used to produce the messages.
// While the sendRecordBuffers_ holds encrypted messages,
// writeBuffered_ has the corresponding size of unencrypted data
// used to produce the messages.
size_t writeBuffered_;
// Buffer for still encrypted reads
wintls::Buffer readBuf_;
@ -191,7 +218,16 @@ private:
state_t state_;
SECURITY_STATUS status_;
std::unique_ptr<SecPkgContext_StreamSizes> streamSizes_;
// The number of maximum size for TLS record header, body, and
// trailer.
SecPkgContext_StreamSizes streamSizes_;
// Underlying buffer for outgoing TLS record.
std::vector<unsigned char> sendBuffer_;
// How many bytes has been sent for current TLS record held in
// sendRecordBuffers_.
size_t recordBytesSent_;
// This holds current outgoing TLS record.
std::array<TLSBuffer, 4> sendRecordBuffers_;
};
} // namespace aria2