mirror of https://github.com/aria2/aria2
commit
9df50804d4
|
@ -86,16 +86,6 @@ static const ULONG kReqAFlags =
|
|||
ASC_REQ_SEQUENCE_DETECT | ASC_REQ_REPLAY_DETECT | ASC_REQ_CONFIDENTIALITY |
|
||||
ASC_REQ_EXTENDED_ERROR | ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_STREAM;
|
||||
|
||||
class TLSBuffer : public ::SecBuffer {
|
||||
public:
|
||||
explicit TLSBuffer(ULONG type, ULONG size, void* data)
|
||||
{
|
||||
cbBuffer = size;
|
||||
BufferType = type;
|
||||
pvBuffer = data;
|
||||
}
|
||||
};
|
||||
|
||||
class TLSBufferDesc : public ::SecBufferDesc {
|
||||
public:
|
||||
explicit TLSBufferDesc(SecBuffer* arr, ULONG buffers)
|
||||
|
@ -142,7 +132,8 @@ WinTLSSession::WinTLSSession(WinTLSContext* ctx)
|
|||
cred_(ctx->getCredHandle()),
|
||||
writeBuffered_(0),
|
||||
state_(st_constructed),
|
||||
status_(SEC_E_OK)
|
||||
status_(SEC_E_OK),
|
||||
recordBytesSent_(0)
|
||||
{
|
||||
memset(&handle_, 0, sizeof(handle_));
|
||||
}
|
||||
|
@ -213,7 +204,8 @@ int WinTLSSession::closeConnection()
|
|||
status_ = ::AcceptSecurityContext(cred_, &handle_, nullptr, kReqAFlags, 0,
|
||||
&handle_, &desc, &flags, nullptr);
|
||||
}
|
||||
if (status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) {
|
||||
if ((status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) &&
|
||||
getLeftTLSRecordSize() == 0) {
|
||||
size_t len = ctx.cbBuffer;
|
||||
ssize_t rv = writeData(ctx.pvBuffer, ctx.cbBuffer);
|
||||
::FreeContextBuffer(ctx.pvBuffer);
|
||||
|
@ -228,17 +220,6 @@ int WinTLSSession::closeConnection()
|
|||
}
|
||||
}
|
||||
|
||||
// Send remaining data.
|
||||
while (writeBuf_.size()) {
|
||||
int rv = writeData(nullptr, 0);
|
||||
if (rv == 0) {
|
||||
break;
|
||||
}
|
||||
if (rv < 0) {
|
||||
return rv;
|
||||
}
|
||||
}
|
||||
|
||||
A2_LOG_DEBUG("WinTLS: Closed Connection");
|
||||
state_ = st_closed;
|
||||
return TLS_ERR_OK;
|
||||
|
@ -255,12 +236,82 @@ int WinTLSSession::checkDirection()
|
|||
if (readBuf_.size() || decBuf_.size()) {
|
||||
return TLS_WANT_READ;
|
||||
}
|
||||
if (writeBuf_.size()) {
|
||||
if (getLeftTLSRecordSize() || writeBuf_.size()) {
|
||||
return TLS_WANT_WRITE;
|
||||
}
|
||||
return TLS_WANT_READ;
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Fills |iov| of length |len| to send remaining data in |buffers|.
|
||||
// We have already sent |offset| bytes. This function returns the
|
||||
// number of |iov| filled. It assumes the array |buffers| is at least
|
||||
// |len| elements.
|
||||
size_t fillSendIOV(a2iovec* iov, size_t len, TLSBuffer* buffers, size_t offset)
|
||||
{
|
||||
size_t iovcnt = 0;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
if (offset < buffers[i].cbBuffer) {
|
||||
iov[iovcnt].A2IOVEC_BASE =
|
||||
static_cast<char*>(buffers[i].pvBuffer) + offset;
|
||||
iov[iovcnt].A2IOVEC_LEN = buffers[i].cbBuffer - offset;
|
||||
++iovcnt;
|
||||
offset = 0;
|
||||
}
|
||||
else {
|
||||
offset -= buffers[i].cbBuffer;
|
||||
}
|
||||
}
|
||||
return iovcnt;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
size_t WinTLSSession::getLeftTLSRecordSize() const
|
||||
{
|
||||
return sendRecordBuffers_[0].cbBuffer + sendRecordBuffers_[1].cbBuffer +
|
||||
sendRecordBuffers_[2].cbBuffer - recordBytesSent_;
|
||||
}
|
||||
|
||||
int WinTLSSession::sendTLSRecord()
|
||||
{
|
||||
A2_LOG_DEBUG(fmt("WinTLS: TLS record %" PRIu64 " bytes left",
|
||||
static_cast<uint64_t>(getLeftTLSRecordSize())));
|
||||
|
||||
while (getLeftTLSRecordSize()) {
|
||||
std::array<a2iovec, 3> iov;
|
||||
auto iovcnt = fillSendIOV(iov.data(), iov.size(), sendRecordBuffers_.data(),
|
||||
recordBytesSent_);
|
||||
|
||||
DWORD nwrite;
|
||||
auto rv =
|
||||
WSASend(sockfd_, iov.data(), iovcnt, &nwrite, 0, nullptr, nullptr);
|
||||
if (rv != 0) {
|
||||
auto errnum = ::WSAGetLastError();
|
||||
if (errnum == WSAEINTR) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (errnum == WSAEWOULDBLOCK) {
|
||||
return TLS_ERR_WOULDBLOCK;
|
||||
}
|
||||
|
||||
A2_LOG_ERROR("WinTLS: Connection error while writing");
|
||||
status_ = SEC_E_INCOMPLETE_MESSAGE;
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
|
||||
recordBytesSent_ += nwrite;
|
||||
}
|
||||
|
||||
recordBytesSent_ = 0;
|
||||
sendRecordBuffers_[0].cbBuffer = 0;
|
||||
sendRecordBuffers_[1].cbBuffer = 0;
|
||||
sendRecordBuffers_[2].cbBuffer = 0;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
ssize_t WinTLSSession::writeData(const void* data, size_t len)
|
||||
{
|
||||
if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
|
||||
|
@ -281,45 +332,15 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
|
|||
}
|
||||
|
||||
A2_LOG_DEBUG(fmt("WinTLS: Write request: %" PRIu64 " buffered: %" PRIu64,
|
||||
(uint64_t)len, (uint64_t)writeBuf_.size()));
|
||||
(uint64_t)len, (uint64_t)recordBytesSent_));
|
||||
|
||||
// Write remaining buffered data, if any.
|
||||
while (writeBuf_.size()) {
|
||||
auto written = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
|
||||
errno = ::WSAGetLastError();
|
||||
if (written < 0 && errno == WSAEINTR) {
|
||||
continue;
|
||||
}
|
||||
if (written < 0 && errno == WSAEWOULDBLOCK) {
|
||||
return TLS_ERR_WOULDBLOCK;
|
||||
}
|
||||
if (written == 0) {
|
||||
return written;
|
||||
}
|
||||
if (written < 0) {
|
||||
status_ = SEC_E_INVALID_HANDLE;
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
writeBuf_.eat(written);
|
||||
auto rv = sendTLSRecord();
|
||||
if (rv != 0) {
|
||||
return rv;
|
||||
}
|
||||
|
||||
if (len == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!streamSizes_) {
|
||||
streamSizes_.reset(new SecPkgContext_StreamSizes());
|
||||
status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
|
||||
streamSizes_.get());
|
||||
if (status_ != SEC_E_OK || !streamSizes_->cbMaximumMessage) {
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
size_t process = len;
|
||||
auto bytes = reinterpret_cast<const char*>(data);
|
||||
auto left = len;
|
||||
auto bytes = static_cast<const char*>(data);
|
||||
if (writeBuffered_) {
|
||||
// There was buffered data, hence we need to "remove" that data from the
|
||||
// incoming buffer to avoid writing it again
|
||||
|
@ -331,10 +352,10 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
|
|||
}
|
||||
// just advance the buffer by writeBuffered_ bytes
|
||||
bytes += writeBuffered_;
|
||||
process -= writeBuffered_;
|
||||
left -= writeBuffered_;
|
||||
writeBuffered_ = 0;
|
||||
}
|
||||
if (!process) {
|
||||
if (!left) {
|
||||
// The buffer contained the full remainder. At this point, the buffer has
|
||||
// been written, so the request is done in its entirety;
|
||||
return len;
|
||||
|
@ -342,23 +363,25 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
|
|||
|
||||
// Buffered data was already written ;)
|
||||
// If there was no buffered data, this will be len - len = 0.
|
||||
len = len - process;
|
||||
while (process) {
|
||||
len -= left;
|
||||
while (left) {
|
||||
// Set up an outgoing message, according to streamSizes_
|
||||
writeBuffered_ = std::min(process, (size_t)streamSizes_->cbMaximumMessage);
|
||||
size_t dl =
|
||||
streamSizes_->cbHeader + writeBuffered_ + streamSizes_->cbTrailer;
|
||||
auto buf = make_unique<char[]>(dl);
|
||||
TLSBuffer buffers[] = {
|
||||
TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_->cbHeader, buf.get()),
|
||||
writeBuffered_ =
|
||||
std::min(left, static_cast<size_t>(streamSizes_.cbMaximumMessage));
|
||||
|
||||
sendRecordBuffers_ = {
|
||||
TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_.cbHeader,
|
||||
sendBuffer_.data()),
|
||||
TLSBuffer(SECBUFFER_DATA, writeBuffered_,
|
||||
buf.get() + streamSizes_->cbHeader),
|
||||
TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_->cbTrailer,
|
||||
buf.get() + streamSizes_->cbHeader + writeBuffered_),
|
||||
sendBuffer_.data() + streamSizes_.cbHeader),
|
||||
TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_.cbTrailer,
|
||||
sendBuffer_.data() + streamSizes_.cbHeader + writeBuffered_),
|
||||
TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
|
||||
};
|
||||
TLSBufferDesc desc(buffers, 4);
|
||||
memcpy(buffers[1].pvBuffer, bytes, writeBuffered_);
|
||||
|
||||
TLSBufferDesc desc(sendRecordBuffers_.data(), sendRecordBuffers_.size());
|
||||
std::copy_n(bytes, writeBuffered_,
|
||||
static_cast<char*>(sendRecordBuffers_[1].pvBuffer));
|
||||
status_ = ::EncryptMessage(&handle_, 0, &desc, 0);
|
||||
if (status_ != SEC_E_OK) {
|
||||
A2_LOG_ERROR(fmt("WinTLS: Failed to encrypt a message! %s",
|
||||
|
@ -367,61 +390,32 @@ ssize_t WinTLSSession::writeData(const void* data, size_t len)
|
|||
return TLS_ERR_ERROR;
|
||||
}
|
||||
|
||||
// EncryptMessage may have truncated the buffers.
|
||||
// Should rarely happen, if ever, except for the trailer.
|
||||
dl = buffers[0].cbBuffer;
|
||||
if (dl < streamSizes_->cbHeader) {
|
||||
// Move message.
|
||||
memmove(buf.get() + dl, buffers[1].pvBuffer, buffers[1].cbBuffer);
|
||||
}
|
||||
dl += buffers[1].cbBuffer;
|
||||
if (dl < streamSizes_->cbHeader + writeBuffered_) {
|
||||
// Move trailer.
|
||||
memmove(buf.get() + dl, buffers[2].pvBuffer, buffers[2].cbBuffer);
|
||||
}
|
||||
dl += buffers[2].cbBuffer;
|
||||
A2_LOG_DEBUG(fmt("WinTLS: Write TLS record header: %" PRIu64
|
||||
" body: %" PRIu64 " trailer: %" PRIu64,
|
||||
static_cast<uint64_t>(sendRecordBuffers_[0].cbBuffer),
|
||||
static_cast<uint64_t>(sendRecordBuffers_[1].cbBuffer),
|
||||
static_cast<uint64_t>(sendRecordBuffers_[2].cbBuffer)));
|
||||
|
||||
// Write (or buffer) the message.
|
||||
char* p = buf.get();
|
||||
while (dl) {
|
||||
auto written = ::send(sockfd_, p, dl, 0);
|
||||
errno = ::WSAGetLastError();
|
||||
if (written < 0 && errno == WSAEINTR) {
|
||||
continue;
|
||||
auto rv = sendTLSRecord();
|
||||
if (rv == TLS_ERR_WOULDBLOCK) {
|
||||
if (len == 0) {
|
||||
return TLS_ERR_WOULDBLOCK;
|
||||
}
|
||||
if (written < 0 && errno == WSAEWOULDBLOCK) {
|
||||
// Buffer the rest of the message...
|
||||
writeBuf_.write(p, dl);
|
||||
// and return...
|
||||
return len;
|
||||
}
|
||||
if (written == 0) {
|
||||
A2_LOG_ERROR("WinTLS: Connection closed while writing");
|
||||
status_ = SEC_E_INCOMPLETE_MESSAGE;
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
if (written < 0) {
|
||||
A2_LOG_ERROR("WinTLS: Connection error while writing");
|
||||
status_ = SEC_E_INCOMPLETE_MESSAGE;
|
||||
state_ = st_error;
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
dl -= written;
|
||||
p += written;
|
||||
return len;
|
||||
}
|
||||
|
||||
if (rv != 0) {
|
||||
return rv;
|
||||
}
|
||||
|
||||
len += writeBuffered_;
|
||||
bytes += writeBuffered_;
|
||||
process -= writeBuffered_;
|
||||
left -= writeBuffered_;
|
||||
writeBuffered_ = 0;
|
||||
}
|
||||
|
||||
A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64 " buffered: %" PRIu64,
|
||||
(uint64_t)len, (uint64_t)writeBuf_.size()));
|
||||
if (!len) {
|
||||
return TLS_ERR_WOULDBLOCK;
|
||||
}
|
||||
A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64, (uint64_t)len));
|
||||
|
||||
return len;
|
||||
}
|
||||
|
||||
|
@ -777,6 +771,11 @@ restart:
|
|||
// Fall through
|
||||
|
||||
case st_handshake_done:
|
||||
if (obtainTLSRecordSizes() != 0) {
|
||||
return TLS_ERR_ERROR;
|
||||
}
|
||||
ensureSendBuffer();
|
||||
|
||||
// All ready now :D
|
||||
state_ = st_connected;
|
||||
A2_LOG_INFO(
|
||||
|
@ -833,4 +832,26 @@ std::string WinTLSSession::getLastErrorString()
|
|||
|
||||
size_t WinTLSSession::getRecvBufferedLength() { return decBuf_.size(); }
|
||||
|
||||
int WinTLSSession::obtainTLSRecordSizes()
|
||||
{
|
||||
status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
|
||||
&streamSizes_);
|
||||
if (status_ != SEC_E_OK || !streamSizes_.cbMaximumMessage) {
|
||||
A2_LOG_ERROR("WinTLS: Unable to obtain stream sizes");
|
||||
state_ = st_error;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void WinTLSSession::ensureSendBuffer()
|
||||
{
|
||||
auto sum = streamSizes_.cbHeader + streamSizes_.cbMaximumMessage +
|
||||
streamSizes_.cbTrailer;
|
||||
if (sendBuffer_.size() < sum) {
|
||||
sendBuffer_.resize(sum);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace aria2
|
||||
|
|
|
@ -100,6 +100,18 @@ public:
|
|||
};
|
||||
} // namespace wintls
|
||||
|
||||
class TLSBuffer : public ::SecBuffer {
|
||||
public:
|
||||
TLSBuffer() : ::SecBuffer{}{}
|
||||
|
||||
explicit TLSBuffer(ULONG type, ULONG size, void* data)
|
||||
{
|
||||
cbBuffer = size;
|
||||
BufferType = type;
|
||||
pvBuffer = data;
|
||||
}
|
||||
};
|
||||
|
||||
class WinTLSSession : public TLSSession {
|
||||
enum state_t {
|
||||
st_constructed,
|
||||
|
@ -172,16 +184,31 @@ public:
|
|||
virtual size_t getRecvBufferedLength() CXX11_OVERRIDE;
|
||||
|
||||
private:
|
||||
// Obtains TLS record size limits. This function returns 0 if it
|
||||
// succeeds, or -1. status_ and state_ are updated according to the
|
||||
// result.
|
||||
int obtainTLSRecordSizes();
|
||||
// Ensures the buffer size so that maximum TLS record can be sent.
|
||||
void ensureSendBuffer();
|
||||
// Sends TLS record specified in sendRecordBuffers_. It uses
|
||||
// recordBytesSent_ to track down how many bytes have been sent.
|
||||
// This function returns 0 if it succeeds, or negative error codes.
|
||||
int sendTLSRecord();
|
||||
// Returns the number of bytes in the remaining TLS record size.
|
||||
size_t getLeftTLSRecordSize() const;
|
||||
|
||||
std::string hostname_;
|
||||
sock_t sockfd_;
|
||||
TLSSessionSide side_;
|
||||
CredHandle* cred_;
|
||||
CtxtHandle handle_;
|
||||
|
||||
// Buffer for already encrypted writes
|
||||
// Buffer for already encrypted writes. This is only used in
|
||||
// handshake.
|
||||
wintls::Buffer writeBuf_;
|
||||
// While the writeBuf_ holds encrypted messages, writeBuffered_ has the
|
||||
// corresponding size of unencrypted data used to produce the messages.
|
||||
// While the sendRecordBuffers_ holds encrypted messages,
|
||||
// writeBuffered_ has the corresponding size of unencrypted data
|
||||
// used to produce the messages.
|
||||
size_t writeBuffered_;
|
||||
// Buffer for still encrypted reads
|
||||
wintls::Buffer readBuf_;
|
||||
|
@ -191,7 +218,16 @@ private:
|
|||
state_t state_;
|
||||
|
||||
SECURITY_STATUS status_;
|
||||
std::unique_ptr<SecPkgContext_StreamSizes> streamSizes_;
|
||||
// The number of maximum size for TLS record header, body, and
|
||||
// trailer.
|
||||
SecPkgContext_StreamSizes streamSizes_;
|
||||
// Underlying buffer for outgoing TLS record.
|
||||
std::vector<unsigned char> sendBuffer_;
|
||||
// How many bytes has been sent for current TLS record held in
|
||||
// sendRecordBuffers_.
|
||||
size_t recordBytesSent_;
|
||||
// This holds current outgoing TLS record.
|
||||
std::array<TLSBuffer, 4> sendRecordBuffers_;
|
||||
};
|
||||
|
||||
} // namespace aria2
|
||||
|
|
Loading…
Reference in New Issue