2008-09-27 Tatsuhiro Tsujikawa <tujikawa at rednoah dot com>

Fixed the bug that HTTPS download fails.
	* src/AbstractCommand.cc
	* src/AbstractCommand.h
	* src/DownloadCommand.cc
	* src/FtpConnection.cc
	* src/HttpConnection.cc
	* src/HttpRequestCommand.cc
	* src/HttpResponseCommand.cc
	* src/HttpSkipResponseCommand.cc
	* src/MSEHandshake.cc
	* src/PeerConnection.cc
	* src/SocketCore.cc
	* src/SocketCore.h

	Fixed the bug that aria2 doesn't download whole content body and 
cannot
	reuse connection if chunked transfer encoding and gzip content 
encoding
	are set.
	* src/DownloadCommand.cc
	* src/HttpSkipResponseCommand.cc
pull/1/head
Tatsuhiro Tsujikawa 2008-09-27 16:06:34 +00:00
parent 080fcd5fb8
commit e9e215dc1f
13 changed files with 355 additions and 94 deletions

View File

@ -1,3 +1,25 @@
2008-09-27 Tatsuhiro Tsujikawa <tujikawa at rednoah dot com>
Fixed the bug that HTTPS download fails.
* src/AbstractCommand.cc
* src/AbstractCommand.h
* src/DownloadCommand.cc
* src/FtpConnection.cc
* src/HttpConnection.cc
* src/HttpRequestCommand.cc
* src/HttpResponseCommand.cc
* src/HttpSkipResponseCommand.cc
* src/MSEHandshake.cc
* src/PeerConnection.cc
* src/SocketCore.cc
* src/SocketCore.h
Fixed the bug that aria2 doesn't download whole content body and cannot
reuse connection if chunked transfer encoding and gzip content encoding
are set.
* src/DownloadCommand.cc
* src/HttpSkipResponseCommand.cc
2008-09-27 Tatsuhiro Tsujikawa <tujikawa at rednoah dot com>
Updated man page.

View File

@ -245,6 +245,16 @@ void AbstractCommand::setReadCheckSocket(const SocketHandle& socket) {
}
}
void AbstractCommand::setReadCheckSocketIf
(const SharedHandle<SocketCore>& socket, bool pred)
{
if(pred) {
setReadCheckSocket(socket);
} else {
disableReadCheckSocket();
}
}
void AbstractCommand::disableWriteCheckSocket() {
if(checkSocketIsWritable) {
e->deleteSocketForWriteCheck(writeCheckTarget, this);
@ -271,6 +281,16 @@ void AbstractCommand::setWriteCheckSocket(const SocketHandle& socket) {
}
}
void AbstractCommand::setWriteCheckSocketIf
(const SharedHandle<SocketCore>& socket, bool pred)
{
if(pred) {
setWriteCheckSocket(socket);
} else {
disableWriteCheckSocket();
}
}
static bool isProxyGETRequest(const std::string& protocol, const Option* option)
{
return

View File

@ -84,6 +84,17 @@ protected:
void disableReadCheckSocket();
void disableWriteCheckSocket();
/**
* If pred == true, calls setReadCheckSocket(socket). Otherwise, calls
* disableReadCheckSocket().
*/
void setReadCheckSocketIf(const SharedHandle<SocketCore>& socket, bool pred);
/**
* If pred == true, calls setWriteCheckSocket(socket). Otherwise, calls
* disableWriteCheckSocket().
*/
void setWriteCheckSocketIf(const SharedHandle<SocketCore>& socket, bool pred);
void setTimeout(time_t timeout) { this->timeout = timeout; }
void prepareForNextAction(Command* nextCommand = 0);

View File

@ -159,24 +159,38 @@ bool DownloadCommand::executeInternal() {
peerStat->updateDownloadLength(bufSize);
if(_requestGroup->getTotalLength() != 0 && bufSize == 0) {
if(_requestGroup->getTotalLength() != 0 && bufSize == 0 &&
!socket->wantRead() && !socket->wantWrite()) {
throw DlRetryEx(EX_GOT_EOF);
}
if((!_transferEncodingDecoder.isNull() &&
_transferEncodingDecoder->finished())
|| (_transferEncodingDecoder.isNull() && segment->complete())
|| (!_contentEncodingDecoder.isNull() &&
_contentEncodingDecoder->finished())
|| bufSize == 0) {
logger->info(MSG_SEGMENT_DOWNLOAD_COMPLETED, cuid);
if(!_contentEncodingDecoder.isNull() &&
!_contentEncodingDecoder->finished()) {
logger->warn("CUID#%d - Transfer was completed, but inflate operation"
" have not finished. Maybe the file is broken in the server"
" side.", cuid);
bool segmentComplete = false;
if(_transferEncodingDecoder.isNull() && _contentEncodingDecoder.isNull()) {
if(segment->complete()) {
segmentComplete = true;
} else if(segment->getLength() == 0 && bufSize == 0 &&
!socket->wantRead() && !socket->wantWrite()) {
segmentComplete = true;
}
} else if(!_transferEncodingDecoder.isNull() &&
!_contentEncodingDecoder.isNull()) {
if(_transferEncodingDecoder->finished() &&
_contentEncodingDecoder->finished()) {
segmentComplete = true;
}
} else if(!_transferEncodingDecoder.isNull() &&
_contentEncodingDecoder.isNull()) {
if(_transferEncodingDecoder->finished()) {
segmentComplete = true;
}
} else if(_transferEncodingDecoder.isNull() &&
!_contentEncodingDecoder.isNull()) {
if(_contentEncodingDecoder->finished()) {
segmentComplete = true;
}
}
if(segmentComplete) {
logger->info(MSG_SEGMENT_DOWNLOAD_COMPLETED, cuid);
#ifdef ENABLE_MESSAGE_DIGEST
{
@ -211,6 +225,7 @@ bool DownloadCommand::executeInternal() {
return prepareForNextSegment();
} else {
checkLowestDownloadSpeed();
setWriteCheckSocketIf(socket, socket->wantWrite());
e->commands.push_back(this);
return false;
}

View File

@ -285,6 +285,9 @@ bool FtpConnection::bulkReceiveResponse(std::pair<unsigned int, std::string>& re
size_t size = sizeof(buf);
socket->readData(buf, size);
if(size == 0) {
if(socket->wantRead() || socket->wantWrite()) {
return false;
}
throw DlRetryEx(EX_GOT_EOF);
}
if(strbuf.size()+size > MAX_RECV_BUFFER) {

View File

@ -126,7 +126,11 @@ HttpResponseHandle HttpConnection::receiveResponse()
size_t size = sizeof(buf);
socket->peekData(buf, size);
if(size == 0) {
throw DlRetryEx(EX_INVALID_RESPONSE);
if(socket->wantRead() || socket->wantWrite()) {
return SharedHandle<HttpResponse>();
} else {
throw DlRetryEx(EX_INVALID_RESPONSE);
}
}
proc->update(buf, size);
if(!proc->eoh()) {

View File

@ -99,11 +99,17 @@ createHttpRequest(const SharedHandle<Request>& req,
bool HttpRequestCommand::executeInternal() {
//socket->setBlockingMode();
if(req->getProtocol() == Request::PROTO_HTTPS) {
socket->prepareSecureConnection();
if(!socket->initiateSecureConnection()) {
setReadCheckSocketIf(socket, socket->wantRead());
setWriteCheckSocketIf(socket, socket->wantWrite());
e->commands.push_back(this);
return false;
}
}
if(_httpConnection->sendBufferIsEmpty()) {
checkIfConnectionEstablished(socket);
if(req->getProtocol() == Request::PROTO_HTTPS) {
socket->initiateSecureConnection();
}
if(_segments.empty()) {
HttpRequestHandle httpRequest
@ -134,7 +140,8 @@ bool HttpRequestCommand::executeInternal() {
e->commands.push_back(command);
return true;
} else {
setWriteCheckSocket(socket);
setReadCheckSocketIf(socket, socket->wantRead());
setWriteCheckSocketIf(socket, socket->wantWrite());
e->commands.push_back(this);
return false;
}

View File

@ -83,6 +83,9 @@ bool HttpResponseCommand::executeInternal()
HttpResponseHandle httpResponse = httpConnection->receiveResponse();
if(httpResponse.isNull()) {
// The server has not responded to our request yet.
// For socket->wantRead() == true, setReadCheckSocket(socket) is already
// done in the constructor.
setWriteCheckSocketIf(socket, socket->wantWrite());
e->commands.push_back(this);
return false;
}

View File

@ -96,7 +96,8 @@ bool HttpSkipResponseCommand::executeInternal()
// The return value is safely ignored here.
_transferEncodingDecoder->decode(buf, bufSize);
}
if(_totalLength != 0 && bufSize == 0) {
if(_totalLength != 0 && bufSize == 0 &&
!socket->wantRead() && !socket->wantWrite()) {
throw DlRetryEx(EX_GOT_EOF);
}
} catch(RecoverableException& e) {
@ -104,15 +105,19 @@ bool HttpSkipResponseCommand::executeInternal()
return processResponse();
}
if(bufSize == 0) {
// Since this method is called by DownloadEngine only when the socket is
// readable, bufSize == 0 means server shutdown the connection.
// So socket cannot be reused in this case.
return prepareForRetry(0);
} else if((!_transferEncodingDecoder.isNull() &&
_transferEncodingDecoder->finished())
|| (_transferEncodingDecoder.isNull() &&
_totalLength == _receivedBytes)) {
bool finished = false;
if(_transferEncodingDecoder.isNull()) {
if(bufSize == 0) {
if(!socket->wantRead() && !socket->wantWrite()) {
return processResponse();
}
} else {
finished = (_totalLength == _receivedBytes);
}
} else {
finished = _transferEncodingDecoder->finished();
}
if(finished) {
if(!e->option->getAsBool(PREF_HTTP_PROXY_ENABLED) &&
req->supportsPersistentConnection()) {
std::pair<std::string, uint16_t> peerInfo;
@ -121,6 +126,7 @@ bool HttpSkipResponseCommand::executeInternal()
}
return processResponse();
} else {
setWriteCheckSocketIf(socket, socket->wantWrite());
e->commands.push_back(this);
return false;
}

View File

@ -93,7 +93,7 @@ MSEHandshake::HANDSHAKE_TYPE MSEHandshake::identifyHandshakeType()
}
size_t r = 20-_rbufLength;
_socket->readData(_rbuf+_rbufLength, r);
if(r == 0) {
if(r == 0 && !_socket->wantRead() && !_socket->wantWrite()) {
throw DlAbortEx(EX_EOF_FROM_PEER);
}
_rbufLength += r;
@ -301,6 +301,9 @@ bool MSEHandshake::findInitiatorVCMarker()
}
_socket->peekData(_rbuf+_rbufLength, r);
if(r == 0) {
if(_socket->wantRead() || _socket->wantWrite()) {
return false;
}
throw DlAbortEx(EX_EOF_FROM_PEER);
}
// find vc
@ -388,6 +391,9 @@ bool MSEHandshake::findReceiverHashMarker()
}
_socket->peekData(_rbuf+_rbufLength, r);
if(r == 0) {
if(_socket->wantRead() || _socket->wantWrite()) {
return false;
}
throw DlAbortEx(EX_EOF_FROM_PEER);
}
// find hash('req1', S), S is _secret.
@ -575,7 +581,7 @@ size_t MSEHandshake::receiveNBytes(size_t bytes)
return 0;
}
_socket->readData(_rbuf+_rbufLength, r);
if(r == 0) {
if(r == 0 && !_socket->wantRead() && !_socket->wantWrite()) {
throw DlAbortEx(EX_EOF_FROM_PEER);
}
_rbufLength += r;

View File

@ -86,6 +86,9 @@ bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) {
size_t temp = remaining;
readData(lenbuf+lenbufLength, remaining, _encryptionEnabled);
if(remaining == 0) {
if(socket->wantRead() || socket->wantWrite()) {
return false;
}
// we got EOF
logger->debug("CUID#%d - In PeerConnection::receiveMessage(), remain=%zu",
cuid, temp);
@ -111,6 +114,9 @@ bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) {
if(remaining > 0) {
readData(resbuf+resbufLength, remaining, _encryptionEnabled);
if(remaining == 0) {
if(socket->wantRead() || socket->wantWrite()) {
return false;
}
// we got EOF
logger->debug("CUID#%d - In PeerConnection::receiveMessage(), payloadlen=%zu, remaining=%zu",
cuid, currentPayloadLength, temp);
@ -154,6 +160,9 @@ bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength,
size_t temp = remaining;
readData(resbuf+resbufLength, remaining, _encryptionEnabled);
if(remaining == 0) {
if(socket->wantRead() || socket->wantWrite()) {
return false;
}
// we got EOF
logger->debug("CUID#%d - In PeerConnection::receiveHandshake(), remain=%zu",
cuid, temp);

View File

@ -59,7 +59,7 @@
#else
# define CLOSE(X) while(close(X) == -1 && errno == EINTR)
#endif // __MINGW32__
#include "LogFactory.h"
namespace aria2 {
SocketCore::SocketCore(int sockType):_sockType(sockType), sockfd(-1) {
@ -80,7 +80,11 @@ void SocketCore::init()
#endif // HAVE_EPOLL
blocking = true;
secure = false;
secure = 0;
_wantRead = false;
_wantWrite = false;
#ifdef HAVE_LIBSSL
// for SSL
sslCtx = NULL;
@ -440,42 +444,74 @@ bool SocketCore::isReadable(time_t timeout)
}
#ifdef HAVE_LIBSSL
int SocketCore::sslHandleEAGAIN(int ret)
{
int error = SSL_get_error(ssl, ret);
if(error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) {
ret = 0;
if(error == SSL_ERROR_WANT_READ) {
_wantRead = true;
} else {
_wantWrite = true;
}
}
return ret;
}
#endif // HAVE_LIBSSL
#ifdef HAVE_LIBGNUTLS
void SocketCore::gnutlsRecordCheckDirection()
{
int direction = gnutls_record_get_direction(sslSession);
if(direction == 0) {
_wantRead = true;
} else { // if(direction == 1) {
_wantWrite = true;
}
}
#endif // HAVE_LIBGNUTLS
ssize_t SocketCore::writeData(const char* data, size_t len)
{
ssize_t ret = 0;
_wantRead = false;
_wantWrite = false;
if(!secure) {
while((ret = send(sockfd, data, len, 0)) == -1 && errno == EINTR);
if(ret == -1 && errno == EAGAIN) {
ret = 0;
}
if(ret == -1) {
throw DlRetryEx(StringFormat(EX_SOCKET_SEND, errorMsg()).str());
if(errno == EAGAIN) {
_wantWrite = true;
ret = 0;
} else {
throw DlRetryEx(StringFormat(EX_SOCKET_SEND, errorMsg()).str());
}
}
} else {
#ifdef HAVE_LIBSSL
// for SSL
// TODO handling len == 0 case required
ret = SSL_write(ssl, data, len);
if(ret < 0) {
switch(SSL_get_error(ssl, ret)) {
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
ret = 0;
}
if(ret == 0) {
throw DlRetryEx
(StringFormat
(EX_SOCKET_SEND, ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
}
if(ret <= 0) {
throw DlRetryEx(StringFormat(EX_SOCKET_SEND,
ERR_error_string(ERR_get_error(), 0)).str());
if(ret < 0) {
ret = sslHandleEAGAIN(ret);
}
if(ret < 0) {
throw DlRetryEx
(StringFormat
(EX_SOCKET_SEND, ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
}
#endif // HAVE_LIBSSL
#ifdef HAVE_LIBGNUTLS
while((ret = gnutls_record_send(sslSession, data, len)) ==
GNUTLS_E_INTERRUPTED);
if(ret == GNUTLS_E_AGAIN) {
gnutlsRecordCheckDirection();
ret = 0;
}
if(ret < 0) {
} else if(ret < 0) {
throw DlRetryEx(StringFormat(EX_SOCKET_SEND, gnutls_strerror(ret)).str());
}
#endif // HAVE_LIBGNUTLS
@ -487,24 +523,45 @@ ssize_t SocketCore::writeData(const char* data, size_t len)
void SocketCore::readData(char* data, size_t& len)
{
ssize_t ret = 0;
_wantRead = false;
_wantWrite = false;
if(!secure) {
while((ret = recv(sockfd, data, len, 0)) == -1 && errno == EINTR);
if(ret == -1) {
throw DlRetryEx(StringFormat(EX_SOCKET_RECV, errorMsg()).str());
if(errno == EAGAIN) {
_wantRead = true;
ret = 0;
} else {
throw DlRetryEx(StringFormat(EX_SOCKET_RECV, errorMsg()).str());
}
}
} else {
#ifdef HAVE_LIBSSL
// for SSL
// TODO handling len == 0 case required
if ((ret = SSL_read(ssl, data, len)) <= 0) {
ret = SSL_read(ssl, data, len);
if(ret == 0) {
throw DlRetryEx
(StringFormat(EX_SOCKET_RECV,
ERR_error_string(ERR_get_error(), 0)).str());
(StringFormat
(EX_SOCKET_RECV, ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
}
if(ret < 0) {
ret = sslHandleEAGAIN(ret);
}
if(ret < 0) {
throw DlRetryEx
(StringFormat
(EX_SOCKET_RECV, ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
}
#endif // HAVE_LIBSSL
#ifdef HAVE_LIBGNUTLS
if ((ret = gnutlsRecv(data, len)) < 0) {
ret = gnutlsRecv(data, len);
if(ret == GNUTLS_E_AGAIN) {
gnutlsRecordCheckDirection();
ret = 0;
} else if(ret < 0) {
throw DlRetryEx
(StringFormat(EX_SOCKET_RECV, gnutls_strerror(ret)).str());
}
@ -517,24 +574,45 @@ void SocketCore::readData(char* data, size_t& len)
void SocketCore::peekData(char* data, size_t& len)
{
ssize_t ret = 0;
_wantRead = false;
_wantWrite = false;
if(!secure) {
while((ret = recv(sockfd, data, len, MSG_PEEK)) == -1 && errno == EINTR);
if(ret == -1) {
throw DlRetryEx(StringFormat(EX_SOCKET_PEEK, errorMsg()).str());
if(errno == EAGAIN) {
_wantRead = true;
ret = 0;
} else {
throw DlRetryEx(StringFormat(EX_SOCKET_PEEK, errorMsg()).str());
}
}
} else {
#ifdef HAVE_LIBSSL
// for SSL
// TODO handling len == 0 case required
if ((ret = SSL_peek(ssl, data, len)) < 0) {
ret = SSL_peek(ssl, data, len);
LogFactory::getInstance()->debug("len = %d", ret);
if(ret == 0) {
throw DlRetryEx
(StringFormat(EX_SOCKET_PEEK,
ERR_error_string(ERR_get_error(), 0)).str());
ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
}
if(ret < 0) {
ret = sslHandleEAGAIN(ret);
}
if(ret < 0) {
throw DlRetryEx
(StringFormat(EX_SOCKET_PEEK,
ERR_error_string(SSL_get_error(ssl, ret), 0)).str());
}
#endif // HAVE_LIBSSL
#ifdef HAVE_LIBGNUTLS
if ((ret = gnutlsPeek(data, len)) < 0) {
ret = gnutlsPeek(data, len);
if(ret == GNUTLS_E_AGAIN) {
gnutlsRecordCheckDirection();
ret = 0;
} else if(ret < 0) {
throw DlRetryEx(StringFormat(EX_SOCKET_PEEK,
gnutls_strerror(ret)).str());
}
@ -577,13 +655,27 @@ void SocketCore::addPeekData(char* data, size_t len)
peekBufLength += len;
}
static ssize_t GNUTLS_RECORD_RECV_NO_INTERRUPT
(gnutls_session_t sslSession, char* data, size_t len)
{
int ret;
while((ret = gnutls_record_recv(sslSession, data, len)) ==
GNUTLS_E_INTERRUPTED);
if(ret < 0 && ret != GNUTLS_E_AGAIN) {
throw DlRetryEx
(StringFormat(EX_SOCKET_RECV, gnutls_strerror(ret)).str());
}
return ret;
}
ssize_t SocketCore::gnutlsRecv(char* data, size_t len)
{
size_t plen = shiftPeekData(data, len);
if(plen < len) {
ssize_t ret = gnutls_record_recv(sslSession, data+plen, len-plen);
if(ret < 0) {
throw DlRetryEx(StringFormat(EX_SOCKET_RECV, gnutls_strerror(ret)).str());
ssize_t ret = GNUTLS_RECORD_RECV_NO_INTERRUPT
(sslSession, data+plen, len-plen);
if(ret == GNUTLS_E_AGAIN) {
return GNUTLS_E_AGAIN;
}
return plen+ret;
} else {
@ -598,9 +690,10 @@ ssize_t SocketCore::gnutlsPeek(char* data, size_t len)
return len;
} else {
memcpy(data, peekBuf, peekBufLength);
ssize_t ret = gnutls_record_recv(sslSession, data+peekBufLength, len-peekBufLength);
if(ret < 0) {
throw DlRetryEx(StringFormat(EX_SOCKET_PEEK, gnutls_strerror(ret)).str());
ssize_t ret = GNUTLS_RECORD_RECV_NO_INTERRUPT
(sslSession, data+peekBufLength, len-peekBufLength);
if(ret == GNUTLS_E_AGAIN) {
return GNUTLS_E_AGAIN;
}
addPeekData(data+peekBufLength, ret);
return peekBufLength;
@ -608,11 +701,11 @@ ssize_t SocketCore::gnutlsPeek(char* data, size_t len)
}
#endif // HAVE_LIBGNUTLS
void SocketCore::initiateSecureConnection()
void SocketCore::prepareSecureConnection()
{
if(!secure) {
#ifdef HAVE_LIBSSL
// for SSL
if(!secure) {
sslCtx = SSL_CTX_new(SSLv23_client_method());
if(sslCtx == NULL) {
throw DlAbortEx
@ -631,7 +724,31 @@ void SocketCore::initiateSecureConnection()
(StringFormat(EX_SSL_INIT_FAILURE,
ERR_error_string(ERR_get_error(), 0)).str());
}
// TODO handling return value == 0 case required
#endif // HAVE_LIBSSL
#ifdef HAVE_LIBGNUTLS
const int cert_type_priority[3] = { GNUTLS_CRT_X509,
GNUTLS_CRT_OPENPGP, 0
};
// while we do not support X509 certificate, most web servers require
// X509 stuff.
gnutls_certificate_allocate_credentials (&sslXcred);
gnutls_init(&sslSession, GNUTLS_CLIENT);
gnutls_set_default_priority(sslSession);
gnutls_kx_set_priority(sslSession, cert_type_priority);
// put the x509 credentials to the current session
gnutls_credentials_set(sslSession, GNUTLS_CRD_CERTIFICATE, sslXcred);
gnutls_transport_set_ptr(sslSession, (gnutls_transport_ptr_t)sockfd);
#endif // HAVE_LIBGNUTLS
secure = 1;
}
}
bool SocketCore::initiateSecureConnection()
{
if(secure == 1) {
_wantRead = false;
_wantWrite = false;
#ifdef HAVE_LIBSSL
int e = SSL_connect(ssl);
if (e <= 0) {
@ -641,7 +758,11 @@ void SocketCore::initiateSecureConnection()
break;
case SSL_ERROR_WANT_READ:
_wantRead = true;
return false;
case SSL_ERROR_WANT_WRITE:
_wantWrite = true;
return false;
case SSL_ERROR_WANT_X509_LOOKUP:
case SSL_ERROR_ZERO_RETURN:
if (blocking) {
@ -661,32 +782,24 @@ void SocketCore::initiateSecureConnection()
(StringFormat(EX_SSL_UNKNOWN_ERROR, ssl_error).str());
}
}
}
#endif // HAVE_LIBSSL
#ifdef HAVE_LIBGNUTLS
if(!secure) {
const int cert_type_priority[3] = { GNUTLS_CRT_X509,
GNUTLS_CRT_OPENPGP, 0
};
// while we do not support X509 certificate, most web servers require
// X509 stuff.
gnutls_certificate_allocate_credentials (&sslXcred);
gnutls_init(&sslSession, GNUTLS_CLIENT);
gnutls_set_default_priority(sslSession);
gnutls_kx_set_priority(sslSession, cert_type_priority);
// put the x509 credentials to the current session
gnutls_credentials_set(sslSession, GNUTLS_CRD_CERTIFICATE, sslXcred);
gnutls_transport_set_ptr(sslSession, (gnutls_transport_ptr_t)sockfd);
int ret = gnutls_handshake(sslSession);
if(ret < 0) {
if(ret == GNUTLS_E_AGAIN) {
gnutlsRecordCheckDirection();
return false;
} else if(ret < 0) {
throw DlAbortEx
(StringFormat(EX_SSL_INIT_FAILURE, gnutls_strerror(ret)).str());
} else {
peekBuf = new char[peekBufMax];
}
peekBuf = new char[peekBufMax];
}
#endif // HAVE_LIBGNUTLS
secure = true;
secure = 2;
return true;
} else {
return true;
}
}
/* static */ int SocketCore::error()
@ -783,4 +896,14 @@ std::string SocketCore::getSocketError() const
}
}
bool SocketCore::wantRead() const
{
return _wantRead;
}
bool SocketCore::wantWrite() const
{
return _wantWrite;
}
} // namespace aria2

View File

@ -77,11 +77,17 @@ private:
#endif // HAVE_EPOLL
bool blocking;
bool secure;
int secure;
bool _wantRead;
bool _wantWrite;
#ifdef HAVE_LIBSSL
// for SSL
SSL_CTX* sslCtx;
SSL* ssl;
int sslHandleEAGAIN(int ret);
#endif // HAVE_LIBSSL
#ifdef HAVE_LIBGNUTLS
gnutls_session_t sslSession;
@ -94,6 +100,8 @@ private:
void addPeekData(char* data, size_t len);
ssize_t gnutlsRecv(char* data, size_t len);
ssize_t gnutlsPeek(char* data, size_t len);
void gnutlsRecordCheckDirection();
#endif // HAVE_LIBGNUTLS
void init();
@ -105,6 +113,7 @@ private:
#endif // HAVE_EPOLL
SocketCore(sock_t sockfd, int sockType);
static int error();
static const char *errorMsg();
static const char *errorMsg(const int err);
@ -189,10 +198,14 @@ public:
bool isReadable(time_t timeout);
/**
* Writes characters into this socket. data is a pointer pointing the first
* Writes data into this socket. data is a pointer pointing the first
* byte of the data and len is the length of data.
* This method internally calls isWritable(). The parmeter timeout is used
* for this method call.
* If the underlying socket is in blocking mode, this method may block until
* all data is sent.
* If the underlying socket is in non-blocking mode, this method may return
* even if all data is sent. The size of written data is returned. If
* underlying socket gets EAGAIN, _wantRead or _wantWrite is set accordingly.
* This method sets _wantRead and _wantWrite to false before do anything else.
* @param data data to write
* @param len length of data
*/
@ -220,8 +233,12 @@ public:
* byte of the data, which must be allocated before this method is called.
* len is the size of the allocated memory. When this method returns
* successfully, len is replaced by the size of the read data.
* This method internally calls isReadable(). The parameter timeout is used
* for this method call.
* If the underlying socket is in blocking mode, this method may block until
* at least 1byte is received.
* If the underlying socket is in non-blocking mode, this method may return
* even if no single byte is received. If the underlying socket gets EAGAIN,
* _wantRead or _wantWrite is set accordingly.
* This method sets _wantRead and _wantWrite to false before do anything else.
* @param data holder to store data.
* @param len the maximum size data can store. This method assigns
* the number of bytes read to len.
@ -265,7 +282,9 @@ public:
* If the system has not OpenSSL, then this method do nothing.
* connection must be established before calling this method.
*/
void initiateSecureConnection();
bool initiateSecureConnection();
void prepareSecureConnection();
bool operator==(const SocketCore& s) {
return sockfd == s.sockfd;
@ -280,6 +299,19 @@ public:
}
std::string getSocketError() const;
/**
* Returns true if the underlying socket gets EAGAIN in the previous
* readData() or writeData() and the socket needs more incoming data to
* continue the operation.
*/
bool wantRead() const;
/**
* Returns true if the underlying socket gets EAGAIN in the previous
* readData() or writeData() and the socket needs to write more data.
*/
bool wantWrite() const;
};
} // namespace aria2