Eliminated SocketCore::peekData from MSEHandshake.

pull/1/head
Tatsuhiro Tsujikawa 2011-01-08 17:32:16 +09:00
parent c48db2cdf3
commit ce2d401dce
9 changed files with 523 additions and 438 deletions

View File

@ -56,6 +56,7 @@
#include "bittorrent_helper.h"
#include "util.h"
#include "fmt.h"
#include "array_fun.h"
namespace aria2 {
@ -89,82 +90,108 @@ InitiatorMSEHandshakeCommand::~InitiatorMSEHandshakeCommand()
}
bool InitiatorMSEHandshakeCommand::executeInternal() {
switch(sequence_) {
case INITIATOR_SEND_KEY: {
if(!getSocket()->isWritable(0)) {
if(mseHandshake_->getWantRead()) {
mseHandshake_->read();
}
bool done = false;
while(!done) {
switch(sequence_) {
case INITIATOR_SEND_KEY: {
if(!getSocket()->isWritable(0)) {
getDownloadEngine()->addCommand(this);
return false;
}
setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT));
mseHandshake_->initEncryptionFacility(true);
mseHandshake_->sendPublicKey();
sequence_ = INITIATOR_SEND_KEY_PENDING;
break;
}
disableWriteCheckSocket();
setReadCheckSocket(getSocket());
//socket->setBlockingMode();
setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT));
mseHandshake_->initEncryptionFacility(true);
if(mseHandshake_->sendPublicKey()) {
sequence_ = INITIATOR_WAIT_KEY;
} else {
setWriteCheckSocket(getSocket());
sequence_ = INITIATOR_SEND_KEY_PENDING;
case INITIATOR_SEND_KEY_PENDING:
if(mseHandshake_->send()) {
sequence_ = INITIATOR_WAIT_KEY;
} else {
done = true;
}
break;
case INITIATOR_WAIT_KEY: {
if(mseHandshake_->receivePublicKey()) {
mseHandshake_->initCipher
(bittorrent::getInfoHash(requestGroup_->getDownloadContext()));;
mseHandshake_->sendInitiatorStep2();
sequence_ = INITIATOR_SEND_STEP2_PENDING;
} else {
done = true;
}
break;
}
break;
}
case INITIATOR_SEND_KEY_PENDING:
if(mseHandshake_->sendPublicKey()) {
disableWriteCheckSocket();
sequence_ = INITIATOR_WAIT_KEY;
}
break;
case INITIATOR_WAIT_KEY: {
if(mseHandshake_->receivePublicKey()) {
mseHandshake_->initCipher
(bittorrent::getInfoHash(requestGroup_->getDownloadContext()));;
if(mseHandshake_->sendInitiatorStep2()) {
case INITIATOR_SEND_STEP2_PENDING:
if(mseHandshake_->send()) {
sequence_ = INITIATOR_FIND_VC_MARKER;
} else {
setWriteCheckSocket(getSocket());
sequence_ = INITIATOR_SEND_STEP2_PENDING;
done = true;
}
}
break;
}
case INITIATOR_SEND_STEP2_PENDING:
if(mseHandshake_->sendInitiatorStep2()) {
disableWriteCheckSocket();
sequence_ = INITIATOR_FIND_VC_MARKER;
}
break;
case INITIATOR_FIND_VC_MARKER: {
if(mseHandshake_->findInitiatorVCMarker()) {
sequence_ = INITIATOR_RECEIVE_PAD_D_LENGTH;
}
break;
}
case INITIATOR_RECEIVE_PAD_D_LENGTH: {
if(mseHandshake_->receiveInitiatorCryptoSelectAndPadDLength()) {
sequence_ = INITIATOR_RECEIVE_PAD_D;
}
break;
}
case INITIATOR_RECEIVE_PAD_D: {
if(mseHandshake_->receivePad()) {
SharedHandle<PeerConnection> peerConnection
(new PeerConnection(getCuid(), getPeer(), getSocket()));
if(mseHandshake_->getNegotiatedCryptoType() == MSEHandshake::CRYPTO_ARC4){
peerConnection->enableEncryption(mseHandshake_->getEncryptor(),
mseHandshake_->getDecryptor());
break;
case INITIATOR_FIND_VC_MARKER: {
if(mseHandshake_->findInitiatorVCMarker()) {
sequence_ = INITIATOR_RECEIVE_PAD_D_LENGTH;
} else {
done = true;
}
PeerInteractionCommand* c =
new PeerInteractionCommand
(getCuid(), requestGroup_, getPeer(), getDownloadEngine(), btRuntime_,
pieceStorage_,
peerStorage_,
getSocket(),
PeerInteractionCommand::INITIATOR_SEND_HANDSHAKE,
peerConnection);
getDownloadEngine()->addCommand(c);
return true;
break;
}
case INITIATOR_RECEIVE_PAD_D_LENGTH: {
if(mseHandshake_->receiveInitiatorCryptoSelectAndPadDLength()) {
sequence_ = INITIATOR_RECEIVE_PAD_D;
} else {
done = true;
}
break;
}
case INITIATOR_RECEIVE_PAD_D: {
if(mseHandshake_->receivePad()) {
SharedHandle<PeerConnection> peerConnection
(new PeerConnection(getCuid(), getPeer(), getSocket()));
if(mseHandshake_->getNegotiatedCryptoType() ==
MSEHandshake::CRYPTO_ARC4){
peerConnection->enableEncryption(mseHandshake_->getEncryptor(),
mseHandshake_->getDecryptor());
size_t buflen = mseHandshake_->getBufferLength();
array_ptr<unsigned char> buffer(new unsigned char[buflen]);
mseHandshake_->getDecryptor()->decrypt(buffer, buflen,
mseHandshake_->getBuffer(),
buflen);
peerConnection->presetBuffer(buffer, buflen);
} else {
peerConnection->presetBuffer(mseHandshake_->getBuffer(),
mseHandshake_->getBufferLength());
}
PeerInteractionCommand* c =
new PeerInteractionCommand
(getCuid(), requestGroup_, getPeer(), getDownloadEngine(), btRuntime_,
pieceStorage_,
peerStorage_,
getSocket(),
PeerInteractionCommand::INITIATOR_SEND_HANDSHAKE,
peerConnection);
getDownloadEngine()->addCommand(c);
return true;
} else {
done = true;
}
break;
}
}
break;
}
if(mseHandshake_->getWantRead()) {
setReadCheckSocket(getSocket());
} else {
disableReadCheckSocket();
}
if(mseHandshake_->getWantWrite()) {
setWriteCheckSocket(getSocket());
} else {
disableWriteCheckSocket();
}
getDownloadEngine()->addCommand(this);
return false;

View File

@ -71,6 +71,7 @@ MSEHandshake::MSEHandshake
const Option* op)
: cuid_(cuid),
socket_(socket),
wantRead_(false),
option_(op),
rbufLength_(0),
socketBuffer_(socket),
@ -92,16 +93,8 @@ MSEHandshake::~MSEHandshake()
MSEHandshake::HANDSHAKE_TYPE MSEHandshake::identifyHandshakeType()
{
if(!socket_->isReadable(0)) {
return HANDSHAKE_NOT_YET;
}
size_t r = 20-rbufLength_;
socket_->readData(rbuf_+rbufLength_, r);
if(r == 0 && !socket_->wantRead() && !socket_->wantWrite()) {
throw DL_ABORT_EX(EX_EOF_FROM_PEER);
}
rbufLength_ += r;
if(rbufLength_ < 20) {
wantRead_ = true;
return HANDSHAKE_NOT_YET;
}
if(rbuf_[0] == BtHandshakeMessage::PSTR_LENGTH &&
@ -126,35 +119,59 @@ void MSEHandshake::initEncryptionFacility(bool initiator)
initiator_ = initiator;
}
bool MSEHandshake::sendPublicKey()
void MSEHandshake::sendPublicKey()
{
if(socketBuffer_.sendBufferIsEmpty()) {
A2_LOG_DEBUG(fmt("CUID#%lld - Sending public key.",
cuid_));
unsigned char buffer[KEY_LENGTH+MAX_PAD_LENGTH];
dh_->getPublicKey(buffer, KEY_LENGTH);
A2_LOG_DEBUG(fmt("CUID#%lld - Sending public key.",
cuid_));
unsigned char buffer[KEY_LENGTH+MAX_PAD_LENGTH];
dh_->getPublicKey(buffer, KEY_LENGTH);
size_t padLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
dh_->generateNonce(buffer+KEY_LENGTH, padLength);
socketBuffer_.pushStr(std::string(&buffer[0],
&buffer[KEY_LENGTH+padLength]));
size_t padLength =
SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
dh_->generateNonce(buffer+KEY_LENGTH, padLength);
socketBuffer_.pushStr(std::string(&buffer[0],
&buffer[KEY_LENGTH+padLength]));
}
void MSEHandshake::read()
{
if(rbufLength_ >= MAX_BUFFER_LENGTH) {
assert(!wantRead_);
return;
}
size_t len = MAX_BUFFER_LENGTH-rbufLength_;
socket_->readData(rbuf_+rbufLength_, len);
if(len == 0 && !socket_->wantRead() && !socket_->wantWrite()) {
// TODO Should we set graceful in peer?
throw DL_ABORT_EX(EX_EOF_FROM_PEER);
}
rbufLength_ += len;
wantRead_ = false;
}
bool MSEHandshake::send()
{
socketBuffer_.send();
return socketBuffer_.sendBufferIsEmpty();
}
void MSEHandshake::shiftBuffer(size_t offset)
{
memmove(rbuf_, rbuf_+offset, rbufLength_-offset);
rbufLength_ -= offset;
}
bool MSEHandshake::receivePublicKey()
{
size_t r = KEY_LENGTH-rbufLength_;
if(r > receiveNBytes(r)) {
if(rbufLength_ < KEY_LENGTH) {
wantRead_ = true;
return false;
}
A2_LOG_DEBUG(fmt("CUID#%lld - public key received.",
cuid_));
A2_LOG_DEBUG(fmt("CUID#%lld - public key received.", cuid_));
// TODO handle exception. in catch, resbufLength = 0;
dh_->computeSecret(secret_, sizeof(secret_), rbuf_, rbufLength_);
// reset rbufLength_
rbufLength_ = 0;
dh_->computeSecret(secret_, sizeof(secret_), rbuf_, KEY_LENGTH);
// shift buffer
shiftBuffer(KEY_LENGTH);
return true;
}
@ -251,109 +268,83 @@ uint16_t MSEHandshake::decodeLength16(const unsigned char* buffer)
return ntohs(be);
}
bool MSEHandshake::sendInitiatorStep2()
void MSEHandshake::sendInitiatorStep2()
{
if(socketBuffer_.sendBufferIsEmpty()) {
A2_LOG_DEBUG(fmt("CUID#%lld - Sending negotiation step2.",
cuid_));
unsigned char md[20];
createReq1Hash(md);
socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)]));
A2_LOG_DEBUG(fmt("CUID#%lld - Sending negotiation step2.", cuid_));
unsigned char md[20];
createReq1Hash(md);
socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)]));
createReq23Hash(md, infoHash_);
socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)]));
createReq23Hash(md, infoHash_);
socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)]));
// buffer is filled in this order:
// VC(VC_LENGTH bytes),
// crypto_provide(CRYPTO_BITFIELD_LENGTH bytes),
// len(padC)(2bytes),
// padC(len(padC)bytes <= MAX_PAD_LENGTH),
// len(IA)(2bytes)
unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH+2];
{
// buffer is filled in this order:
// VC(VC_LENGTH bytes),
// crypto_provide(CRYPTO_BITFIELD_LENGTH bytes),
// len(padC)(2bytes),
// padC(len(padC)bytes <= MAX_PAD_LENGTH),
// len(IA)(2bytes)
unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH+2];
// VC
memcpy(buffer, VC, sizeof(VC));
// crypto_provide
unsigned char cryptoProvide[CRYPTO_BITFIELD_LENGTH];
memset(cryptoProvide, 0, sizeof(cryptoProvide));
if(option_->get(PREF_BT_MIN_CRYPTO_LEVEL) == V_PLAIN) {
cryptoProvide[3] = CRYPTO_PLAIN_TEXT;
}
cryptoProvide[3] |= CRYPTO_ARC4;
memcpy(buffer+VC_LENGTH, cryptoProvide, sizeof(cryptoProvide));
// len(padC)
uint16_t padCLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
{
uint16_t padCLengthBE = htons(padCLength);
memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padCLengthBE,
sizeof(padCLengthBE));
}
// padC
memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padCLength);
// len(IA)
// currently, IA is zero-length.
uint16_t iaLength = 0;
{
uint16_t iaLengthBE = htons(iaLength);
memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength,
&iaLengthBE,sizeof(iaLengthBE));
}
encryptAndSendData(buffer,
VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength+2);
}
// VC
memcpy(buffer, VC, sizeof(VC));
// crypto_provide
unsigned char cryptoProvide[CRYPTO_BITFIELD_LENGTH];
memset(cryptoProvide, 0, sizeof(cryptoProvide));
if(option_->get(PREF_BT_MIN_CRYPTO_LEVEL) == V_PLAIN) {
cryptoProvide[3] = CRYPTO_PLAIN_TEXT;
}
socketBuffer_.send();
return socketBuffer_.sendBufferIsEmpty();
cryptoProvide[3] |= CRYPTO_ARC4;
memcpy(buffer+VC_LENGTH, cryptoProvide, sizeof(cryptoProvide));
// len(padC)
uint16_t padCLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
{
uint16_t padCLengthBE = htons(padCLength);
memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padCLengthBE,
sizeof(padCLengthBE));
}
// padC
memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padCLength);
// len(IA)
// currently, IA is zero-length.
uint16_t iaLength = 0;
{
uint16_t iaLengthBE = htons(iaLength);
memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength,
&iaLengthBE,sizeof(iaLengthBE));
}
encryptAndSendData(buffer,
VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength+2);
}
// This function reads exactly until the end of VC marker is reached.
bool MSEHandshake::findInitiatorVCMarker()
{
// 616 is synchronization point of initiator
size_t r = 616-KEY_LENGTH-rbufLength_;
if(!socket_->isReadable(0)) {
return false;
}
socket_->peekData(rbuf_+rbufLength_, r);
if(r == 0) {
if(socket_->wantRead() || socket_->wantWrite()) {
// find vc
std::string buf(&rbuf_[0], &rbuf_[rbufLength_]);
std::string vc(&initiatorVCMarker_[0], &initiatorVCMarker_[VC_LENGTH]);
if((markerIndex_ = buf.find(vc)) == std::string::npos) {
if(616-KEY_LENGTH <= rbufLength_) {
throw DL_ABORT_EX("Failed to find VC marker.");
} else {
wantRead_ = true;
return false;
}
throw DL_ABORT_EX(EX_EOF_FROM_PEER);
}
// find vc
{
std::string buf(&rbuf_[0], &rbuf_[rbufLength_+r]);
std::string vc(&initiatorVCMarker_[0], &initiatorVCMarker_[VC_LENGTH]);
if((markerIndex_ = buf.find(vc)) == std::string::npos) {
if(616-KEY_LENGTH <= rbufLength_+r) {
throw DL_ABORT_EX("Failed to find VC marker.");
} else {
socket_->readData(rbuf_+rbufLength_, r);
rbufLength_ += r;
return false;
}
}
}
assert(markerIndex_+VC_LENGTH-rbufLength_ <= r);
size_t toRead = markerIndex_+VC_LENGTH-rbufLength_;
socket_->readData(rbuf_+rbufLength_, toRead);
rbufLength_ += toRead;
A2_LOG_DEBUG(fmt("CUID#%lld - VC marker found at %lu",
cuid_,
static_cast<unsigned long>(markerIndex_)));
verifyVC(rbuf_+markerIndex_);
// reset rbufLength_
rbufLength_ = 0;
// shift rbuf
shiftBuffer(markerIndex_+VC_LENGTH);
return true;
}
bool MSEHandshake::receiveInitiatorCryptoSelectAndPadDLength()
{
size_t r = CRYPTO_BITFIELD_LENGTH+2/* PadD length*/-rbufLength_;
if(r > receiveNBytes(r)) {
if(CRYPTO_BITFIELD_LENGTH+2/* PadD length*/ > rbufLength_) {
wantRead_ = true;
return false;
}
//verifyCryptoSelect
@ -382,75 +373,57 @@ bool MSEHandshake::receiveInitiatorCryptoSelectAndPadDLength()
// padD length
rbufptr += CRYPTO_BITFIELD_LENGTH;
padLength_ = verifyPadLength(rbufptr, "PadD");
// reset rbufLength_
rbufLength_ = 0;
// shift rbuf
shiftBuffer(CRYPTO_BITFIELD_LENGTH+2/* PadD length*/);
return true;
}
bool MSEHandshake::receivePad()
{
if(padLength_ > rbufLength_) {
wantRead_ = true;
return false;
}
if(padLength_ == 0) {
return true;
}
size_t r = padLength_-rbufLength_;
if(r > receiveNBytes(r)) {
return false;
}
unsigned char temp[MAX_PAD_LENGTH];
decryptor_->decrypt(temp, padLength_, rbuf_, padLength_);
// reset rbufLength_
rbufLength_ = 0;
// shift rbuf_
shiftBuffer(padLength_);
return true;
}
bool MSEHandshake::findReceiverHashMarker()
{
// 628 is synchronization limit of receiver.
size_t r = 628-KEY_LENGTH-rbufLength_;
if(!socket_->isReadable(0)) {
return false;
}
socket_->peekData(rbuf_+rbufLength_, r);
if(r == 0) {
if(socket_->wantRead() || socket_->wantWrite()) {
// find hash('req1', S), S is secret_.
std::string buf(&rbuf_[0], &rbuf_[rbufLength_]);
unsigned char md[20];
createReq1Hash(md);
std::string req1(&md[0], &md[sizeof(md)]);
if((markerIndex_ = buf.find(req1)) == std::string::npos) {
if(628-KEY_LENGTH <= rbufLength_) {
throw DL_ABORT_EX("Failed to find hash marker.");
} else {
wantRead_ = true;
return false;
}
throw DL_ABORT_EX(EX_EOF_FROM_PEER);
}
// find hash('req1', S), S is secret_.
{
std::string buf(&rbuf_[0], &rbuf_[rbufLength_+r]);
unsigned char md[20];
createReq1Hash(md);
std::string req1(&md[0], &md[sizeof(md)]);
if((markerIndex_ = buf.find(req1)) == std::string::npos) {
if(628-KEY_LENGTH <= rbufLength_+r) {
throw DL_ABORT_EX("Failed to find hash marker.");
} else {
socket_->readData(rbuf_+rbufLength_, r);
rbufLength_ += r;
return false;
}
}
}
assert(markerIndex_+20-rbufLength_ <= r);
size_t toRead = markerIndex_+20-rbufLength_;
socket_->readData(rbuf_+rbufLength_, toRead);
rbufLength_ += toRead;
A2_LOG_DEBUG(fmt("CUID#%lld - Hash marker found at %lu.",
cuid_,
static_cast<unsigned long>(markerIndex_)));
verifyReq1Hash(rbuf_+markerIndex_);
// reset rbufLength_
rbufLength_ = 0;
// shift rbuf_
shiftBuffer(markerIndex_+20);
return true;
}
bool MSEHandshake::receiveReceiverHashAndPadCLength
(const std::vector<SharedHandle<DownloadContext> >& downloadContexts)
{
size_t r = 20+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2/*PadC length*/-rbufLength_;
if(r > receiveNBytes(r)) {
if(20+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2/*PadC length*/ > rbufLength_) {
wantRead_ = true;
return false;
}
// resolve info hash
@ -505,23 +478,22 @@ bool MSEHandshake::receiveReceiverHashAndPadCLength
// decrypt PadC length
rbufptr += CRYPTO_BITFIELD_LENGTH;
padLength_ = verifyPadLength(rbufptr, "PadC");
// reset rbufLength_
rbufLength_ = 0;
// shift rbuf_
shiftBuffer(20+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2/*PadC length*/);
return true;
}
bool MSEHandshake::receiveReceiverIALength()
{
size_t r = 2-rbufLength_;
assert(r > 0);
if(r > receiveNBytes(r)) {
if(2 > rbufLength_) {
wantRead_ = true;
return false;
}
iaLength_ = decodeLength16(rbuf_);
A2_LOG_DEBUG(fmt("CUID#%lld - len(IA)=%u.",
cuid_, iaLength_));
// reset rbufLength_
rbufLength_ = 0;
// TODO limit iaLength \19...+handshake
A2_LOG_DEBUG(fmt("CUID#%lld - len(IA)=%u.", cuid_, iaLength_));
// shift rbuf_
shiftBuffer(2);
return true;
}
@ -530,48 +502,44 @@ bool MSEHandshake::receiveReceiverIA()
if(iaLength_ == 0) {
return true;
}
size_t r = iaLength_-rbufLength_;
if(r > receiveNBytes(r)) {
if(iaLength_ > rbufLength_) {
wantRead_ = true;
return false;
}
delete [] ia_;
ia_ = new unsigned char[iaLength_];
decryptor_->decrypt(ia_, iaLength_, rbuf_, iaLength_);
A2_LOG_DEBUG(fmt("CUID#%lld - IA received.", cuid_));
// reset rbufLength_
rbufLength_ = 0;
// shift rbuf_
shiftBuffer(iaLength_);
return true;
}
bool MSEHandshake::sendReceiverStep2()
void MSEHandshake::sendReceiverStep2()
{
if(socketBuffer_.sendBufferIsEmpty()) {
// buffer is filled in this order:
// VC(VC_LENGTH bytes),
// cryptoSelect(CRYPTO_BITFIELD_LENGTH bytes),
// len(padD)(2bytes),
// padD(len(padD)bytes <= MAX_PAD_LENGTH)
unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH];
// VC
memcpy(buffer, VC, sizeof(VC));
// crypto_select
unsigned char cryptoSelect[CRYPTO_BITFIELD_LENGTH];
memset(cryptoSelect, 0, sizeof(cryptoSelect));
cryptoSelect[3] = negotiatedCryptoType_;
memcpy(buffer+VC_LENGTH, cryptoSelect, sizeof(cryptoSelect));
// len(padD)
uint16_t padDLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
{
uint16_t padDLengthBE = htons(padDLength);
memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padDLengthBE,
sizeof(padDLengthBE));
}
// padD, all zeroed
memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padDLength);
encryptAndSendData(buffer, VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padDLength);
// buffer is filled in this order:
// VC(VC_LENGTH bytes),
// cryptoSelect(CRYPTO_BITFIELD_LENGTH bytes),
// len(padD)(2bytes),
// padD(len(padD)bytes <= MAX_PAD_LENGTH)
unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH];
// VC
memcpy(buffer, VC, sizeof(VC));
// crypto_select
unsigned char cryptoSelect[CRYPTO_BITFIELD_LENGTH];
memset(cryptoSelect, 0, sizeof(cryptoSelect));
cryptoSelect[3] = negotiatedCryptoType_;
memcpy(buffer+VC_LENGTH, cryptoSelect, sizeof(cryptoSelect));
// len(padD)
uint16_t padDLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
{
uint16_t padDLengthBE = htons(padDLength);
memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padDLengthBE,
sizeof(padDLengthBE));
}
socketBuffer_.send();
return socketBuffer_.sendBufferIsEmpty();
// padD, all zeroed
memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padDLength);
encryptAndSendData(buffer, VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padDLength);
}
uint16_t MSEHandshake::verifyPadLength(const unsigned char* padlenbuf, const char* padName)
@ -609,20 +577,9 @@ void MSEHandshake::verifyReq1Hash(const unsigned char* req1buf)
}
}
size_t MSEHandshake::receiveNBytes(size_t bytes)
bool MSEHandshake::getWantWrite() const
{
size_t r = bytes;
if(r > 0) {
if(!socket_->isReadable(0)) {
return 0;
}
socket_->readData(rbuf_+rbufLength_, r);
if(r == 0 && !socket_->wantRead() && !socket_->wantWrite()) {
throw DL_ABORT_EX(EX_EOF_FROM_PEER);
}
rbufLength_ += r;
}
return r;
return !socketBuffer_.sendBufferIsEmpty();
}
} // namespace aria2

View File

@ -83,6 +83,7 @@ private:
cuid_t cuid_;
SharedHandle<SocketCore> socket_;
bool wantRead_;
const Option* option_;
unsigned char rbuf_[MAX_BUFFER_LENGTH];
@ -130,8 +131,7 @@ private:
void verifyReq1Hash(const unsigned char* req1buf);
size_t receiveNBytes(size_t bytes);
void shiftBuffer(size_t offset);
public:
MSEHandshake(cuid_t cuid, const SharedHandle<SocketCore>& socket,
const Option* op);
@ -142,13 +142,33 @@ public:
void initEncryptionFacility(bool initiator);
bool sendPublicKey();
// Reads data from Socket. If EOF is reached, throws
// RecoverableException.
void read();
// Sends pending data in the send buffer. Returns true if all data
// is sent. Otherwise returns false.
bool send();
bool getWantRead() const
{
return wantRead_;
}
void setWantRead(bool wantRead)
{
wantRead_ = wantRead;
}
bool getWantWrite() const;
void sendPublicKey();
bool receivePublicKey();
void initCipher(const unsigned char* infoHash);
bool sendInitiatorStep2();
void sendInitiatorStep2();
bool findInitiatorVCMarker();
@ -165,7 +185,7 @@ public:
bool receiveReceiverIA();
bool sendReceiverStep2();
void sendReceiverStep2();
// returns plain text IA
const unsigned char* getIA() const
@ -207,7 +227,6 @@ public:
{
return rbufLength_;
}
};
} // namespace aria2

View File

@ -110,9 +110,6 @@ void PeerConnection::pushBytes(unsigned char* data, size_t len)
bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) {
if(resbufLength_ == 0 && 4 > lenbufLength_) {
if(!socket_->isReadable(0)) {
return false;
}
// read payload size, 32bit unsigned integer
size_t remaining = 4-lenbufLength_;
size_t temp = remaining;
@ -182,7 +179,7 @@ bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength,
bool peek) {
assert(BtHandshakeMessage::MESSAGE_LENGTH >= resbufLength_);
bool retval = true;
if(prevPeek_ && !peek && resbufLength_) {
if(prevPeek_ && resbufLength_) {
// We have data in previous peek.
// There is a chance that socket is readable because of EOF, for example,
// official bttrack shutdowns socket after sending first 48 bytes of
@ -194,17 +191,10 @@ bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength,
} else {
prevPeek_ = peek;
size_t remaining = BtHandshakeMessage::MESSAGE_LENGTH-resbufLength_;
if(remaining > 0 && !socket_->isReadable(0)) {
dataLength = 0;
return false;
}
if(remaining > 0) {
size_t temp = remaining;
readData(resbuf_+resbufLength_, remaining, encryptionEnabled_);
if(remaining == 0) {
if(socket_->wantRead() || socket_->wantWrite()) {
return false;
}
if(remaining == 0 && !socket_->wantRead() && !socket_->wantWrite()) {
// we got EOF
A2_LOG_DEBUG
(fmt("CUID#%lld - In PeerConnection::receiveHandshake(), remain=%lu",
@ -256,6 +246,9 @@ void PeerConnection::presetBuffer(const unsigned char* data, size_t length)
size_t nwrite = std::min((size_t)MAX_PAYLOAD_LEN, length);
memcpy(resbuf_, data, nwrite);
resbufLength_ = length;
if(resbufLength_ > 0) {
prevPeek_ = true;
}
}
bool PeerConnection::sendBufferIsEmpty() const

View File

@ -117,6 +117,11 @@ public:
return resbuf_;
}
size_t getBufferLength() const
{
return resbufLength_;
}
unsigned char* detachBuffer();
};

View File

@ -163,6 +163,11 @@ PeerInteractionCommand::PeerInteractionCommand
peerConnection.reset(new PeerConnection(cuid, getPeer(), getSocket()));
} else {
peerConnection = passedPeerConnection;
if(sequence_ == RECEIVER_WAIT_HANDSHAKE &&
peerConnection->getBufferLength() > 0) {
setStatus(Command::STATUS_ONESHOT_REALTIME);
getDownloadEngine()->setNoWait(true);
}
}
SharedHandle<DefaultBtMessageDispatcher> dispatcher
@ -274,71 +279,80 @@ PeerInteractionCommand::~PeerInteractionCommand() {
bool PeerInteractionCommand::executeInternal() {
setNoCheck(false);
switch(sequence_) {
case INITIATOR_SEND_HANDSHAKE:
if(!getSocket()->isWritable(0)) {
break;
}
disableWriteCheckSocket();
setReadCheckSocket(getSocket());
//socket->setBlockingMode();
setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT));
btInteractive_->initiateHandshake();
sequence_ = INITIATOR_WAIT_HANDSHAKE;
break;
case INITIATOR_WAIT_HANDSHAKE: {
if(btInteractive_->countPendingMessage() > 0) {
btInteractive_->sendPendingMessage();
if(btInteractive_->countPendingMessage() > 0) {
bool done = false;
while(!done) {
switch(sequence_) {
case INITIATOR_SEND_HANDSHAKE:
if(!getSocket()->isWritable(0)) {
done = true;
break;
}
}
BtMessageHandle handshakeMessage = btInteractive_->receiveHandshake();
if(!handshakeMessage) {
disableWriteCheckSocket();
setReadCheckSocket(getSocket());
//socket->setBlockingMode();
setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT));
btInteractive_->initiateHandshake();
sequence_ = INITIATOR_WAIT_HANDSHAKE;
break;
}
btInteractive_->doPostHandshakeProcessing();
sequence_ = WIRED;
break;
}
case RECEIVER_WAIT_HANDSHAKE: {
BtMessageHandle handshakeMessage =btInteractive_->receiveAndSendHandshake();
if(!handshakeMessage) {
break;
}
btInteractive_->doPostHandshakeProcessing();
sequence_ = WIRED;
break;
}
case WIRED:
// See the comment for writable check below.
disableWriteCheckSocket();
btInteractive_->doInteractionProcessing();
if(btInteractive_->countReceivedMessageInIteration() > 0) {
updateKeepAlive();
}
if((getPeer()->amInterested() && !getPeer()->peerChoking()) ||
btInteractive_->countOutstandingRequest() ||
(getPeer()->peerInterested() && !getPeer()->amChoking())) {
// Writable check to avoid slow seeding
if(btInteractive_->isSendingMessageInProgress()) {
setWriteCheckSocket(getSocket());
case INITIATOR_WAIT_HANDSHAKE: {
if(btInteractive_->countPendingMessage() > 0) {
btInteractive_->sendPendingMessage();
if(btInteractive_->countPendingMessage() > 0) {
done = true;
break;
}
}
BtMessageHandle handshakeMessage = btInteractive_->receiveHandshake();
if(!handshakeMessage) {
done = true;
break;
}
btInteractive_->doPostHandshakeProcessing();
sequence_ = WIRED;
break;
}
case RECEIVER_WAIT_HANDSHAKE: {
BtMessageHandle handshakeMessage =
btInteractive_->receiveAndSendHandshake();
if(!handshakeMessage) {
done = true;
break;
}
btInteractive_->doPostHandshakeProcessing();
sequence_ = WIRED;
break;
}
case WIRED:
// See the comment for writable check below.
disableWriteCheckSocket();
if(getDownloadEngine()->getRequestGroupMan()->
doesOverallDownloadSpeedExceed() ||
requestGroup_->doesDownloadSpeedExceed()) {
disableReadCheckSocket();
setNoCheck(true);
btInteractive_->doInteractionProcessing();
if(btInteractive_->countReceivedMessageInIteration() > 0) {
updateKeepAlive();
}
if((getPeer()->amInterested() && !getPeer()->peerChoking()) ||
btInteractive_->countOutstandingRequest() ||
(getPeer()->peerInterested() && !getPeer()->amChoking())) {
// Writable check to avoid slow seeding
if(btInteractive_->isSendingMessageInProgress()) {
setWriteCheckSocket(getSocket());
}
if(getDownloadEngine()->getRequestGroupMan()->
doesOverallDownloadSpeedExceed() ||
requestGroup_->doesDownloadSpeedExceed()) {
disableReadCheckSocket();
setNoCheck(true);
} else {
setReadCheckSocket(getSocket());
}
} else {
setReadCheckSocket(getSocket());
disableReadCheckSocket();
}
} else {
disableReadCheckSocket();
done = true;
break;
}
break;
}
if(btInteractive_->countPendingMessage() > 0) {
setNoCheck(true);

View File

@ -67,7 +67,12 @@ PeerReceiveHandshakeCommand::PeerReceiveHandshakeCommand
: PeerAbstractCommand(cuid, peer, e, s),
peerConnection_(peerConnection)
{
if(!peerConnection_) {
if(peerConnection_) {
if(peerConnection_->getBufferLength() > 0) {
setStatus(Command::STATUS_ONESHOT_REALTIME);
getDownloadEngine()->setNoWait(true);
}
} else {
peerConnection_.reset(new PeerConnection(cuid, getPeer(), getSocket()));
}
}

View File

@ -50,6 +50,7 @@
#include "RequestGroupMan.h"
#include "BtRegistry.h"
#include "DownloadContext.h"
#include "array_fun.h"
namespace aria2 {
@ -64,6 +65,7 @@ ReceiverMSEHandshakeCommand::ReceiverMSEHandshakeCommand
mseHandshake_(new MSEHandshake(cuid, s, e->getOption()))
{
setTimeout(e->getOption()->getAsInt(PREF_PEER_CONNECTION_TIMEOUT));
mseHandshake_->setWantRead(true);
}
ReceiverMSEHandshakeCommand::~ReceiverMSEHandshakeCommand()
@ -79,102 +81,125 @@ bool ReceiverMSEHandshakeCommand::exitBeforeExecute()
bool ReceiverMSEHandshakeCommand::executeInternal()
{
switch(sequence_) {
case RECEIVER_IDENTIFY_HANDSHAKE: {
MSEHandshake::HANDSHAKE_TYPE type = mseHandshake_->identifyHandshakeType();
switch(type) {
case MSEHandshake::HANDSHAKE_NOT_YET:
break;
case MSEHandshake::HANDSHAKE_ENCRYPTED:
mseHandshake_->initEncryptionFacility(false);
sequence_ = RECEIVER_WAIT_KEY;
break;
case MSEHandshake::HANDSHAKE_LEGACY: {
if(getDownloadEngine()->getOption()->getAsBool(PREF_BT_REQUIRE_CRYPTO)) {
throw DL_ABORT_EX
("The legacy BitTorrent handshake is not acceptable by the"
" preference.");
}
SharedHandle<PeerConnection> peerConnection
(new PeerConnection(getCuid(), getPeer(), getSocket()));
peerConnection->presetBuffer(mseHandshake_->getBuffer(),
mseHandshake_->getBufferLength());
Command* c = new PeerReceiveHandshakeCommand(getCuid(),
getPeer(),
getDownloadEngine(),
getSocket(),
peerConnection);
getDownloadEngine()->addCommand(c);
return true;
}
default:
throw DL_ABORT_EX("Not supported handshake type.");
}
break;
if(mseHandshake_->getWantRead()) {
mseHandshake_->read();
}
case RECEIVER_WAIT_KEY: {
if(mseHandshake_->receivePublicKey()) {
if(mseHandshake_->sendPublicKey()) {
bool done = false;
while(!done) {
switch(sequence_) {
case RECEIVER_IDENTIFY_HANDSHAKE: {
MSEHandshake::HANDSHAKE_TYPE type =
mseHandshake_->identifyHandshakeType();
switch(type) {
case MSEHandshake::HANDSHAKE_NOT_YET:
done = true;
break;
case MSEHandshake::HANDSHAKE_ENCRYPTED:
mseHandshake_->initEncryptionFacility(false);
sequence_ = RECEIVER_WAIT_KEY;
break;
case MSEHandshake::HANDSHAKE_LEGACY: {
if(getDownloadEngine()->getOption()->getAsBool(PREF_BT_REQUIRE_CRYPTO)){
throw DL_ABORT_EX
("The legacy BitTorrent handshake is not acceptable by the"
" preference.");
}
SharedHandle<PeerConnection> peerConnection
(new PeerConnection(getCuid(), getPeer(), getSocket()));
peerConnection->presetBuffer(mseHandshake_->getBuffer(),
mseHandshake_->getBufferLength());
Command* c = new PeerReceiveHandshakeCommand(getCuid(),
getPeer(),
getDownloadEngine(),
getSocket(),
peerConnection);
getDownloadEngine()->addCommand(c);
return true;
}
default:
throw DL_ABORT_EX("Not supported handshake type.");
}
break;
}
case RECEIVER_WAIT_KEY: {
if(mseHandshake_->receivePublicKey()) {
mseHandshake_->sendPublicKey();
sequence_ = RECEIVER_SEND_KEY_PENDING;
} else {
done = true;
}
break;
}
case RECEIVER_SEND_KEY_PENDING:
if(mseHandshake_->send()) {
sequence_ = RECEIVER_FIND_HASH_MARKER;
} else {
setWriteCheckSocket(getSocket());
sequence_ = RECEIVER_SEND_KEY_PENDING;
done = true;
}
break;
case RECEIVER_FIND_HASH_MARKER: {
if(mseHandshake_->findReceiverHashMarker()) {
sequence_ = RECEIVER_RECEIVE_PAD_C_LENGTH;
} else {
done = true;
}
break;
}
break;
}
case RECEIVER_SEND_KEY_PENDING:
if(mseHandshake_->sendPublicKey()) {
disableWriteCheckSocket();
sequence_ = RECEIVER_FIND_HASH_MARKER;
case RECEIVER_RECEIVE_PAD_C_LENGTH: {
std::vector<SharedHandle<DownloadContext> > downloadContexts;
getDownloadEngine()->getBtRegistry()->getAllDownloadContext
(std::back_inserter(downloadContexts));
if(mseHandshake_->receiveReceiverHashAndPadCLength(downloadContexts)) {
sequence_ = RECEIVER_RECEIVE_PAD_C;
} else {
done = true;
}
break;
}
break;
case RECEIVER_FIND_HASH_MARKER: {
if(mseHandshake_->findReceiverHashMarker()) {
sequence_ = RECEIVER_RECEIVE_PAD_C_LENGTH;
case RECEIVER_RECEIVE_PAD_C: {
if(mseHandshake_->receivePad()) {
sequence_ = RECEIVER_RECEIVE_IA_LENGTH;
} else {
done = true;
}
break;
}
break;
}
case RECEIVER_RECEIVE_PAD_C_LENGTH: {
std::vector<SharedHandle<DownloadContext> > downloadContexts;
getDownloadEngine()->getBtRegistry()->getAllDownloadContext
(std::back_inserter(downloadContexts));
if(mseHandshake_->receiveReceiverHashAndPadCLength(downloadContexts)) {
sequence_ = RECEIVER_RECEIVE_PAD_C;
case RECEIVER_RECEIVE_IA_LENGTH: {
if(mseHandshake_->receiveReceiverIALength()) {
sequence_ = RECEIVER_RECEIVE_IA;
} else {
done = true;
}
break;
}
break;
}
case RECEIVER_RECEIVE_PAD_C: {
if(mseHandshake_->receivePad()) {
sequence_ = RECEIVER_RECEIVE_IA_LENGTH;
case RECEIVER_RECEIVE_IA: {
if(mseHandshake_->receiveReceiverIA()) {
mseHandshake_->sendReceiverStep2();
sequence_ = RECEIVER_SEND_STEP2_PENDING;
} else {
done = true;
}
break;
}
break;
}
case RECEIVER_RECEIVE_IA_LENGTH: {
if(mseHandshake_->receiveReceiverIALength()) {
sequence_ = RECEIVER_RECEIVE_IA;
}
break;
}
case RECEIVER_RECEIVE_IA: {
if(mseHandshake_->receiveReceiverIA()) {
if(mseHandshake_->sendReceiverStep2()) {
case RECEIVER_SEND_STEP2_PENDING:
if(mseHandshake_->send()) {
createCommand();
return true;
} else {
setWriteCheckSocket(getSocket());
sequence_ = RECEIVER_SEND_STEP2_PENDING;
done = true;
}
break;
}
break;
}
case RECEIVER_SEND_STEP2_PENDING:
if(mseHandshake_->sendReceiverStep2()) {
disableWriteCheckSocket();
createCommand();
return true;
}
break;
if(mseHandshake_->getWantRead()) {
setReadCheckSocket(getSocket());
} else {
disableReadCheckSocket();
}
if(mseHandshake_->getWantWrite()) {
setWriteCheckSocket(getSocket());
} else {
disableWriteCheckSocket();
}
getDownloadEngine()->addCommand(this);
return false;
@ -188,10 +213,19 @@ void ReceiverMSEHandshakeCommand::createCommand()
peerConnection->enableEncryption(mseHandshake_->getEncryptor(),
mseHandshake_->getDecryptor());
}
if(mseHandshake_->getIALength() > 0) {
peerConnection->presetBuffer(mseHandshake_->getIA(),
mseHandshake_->getIALength());
size_t buflen = mseHandshake_->getIALength()+mseHandshake_->getBufferLength();
array_ptr<unsigned char> buffer(new unsigned char[buflen]);
memcpy(buffer, mseHandshake_->getIA(), mseHandshake_->getIALength());
if(mseHandshake_->getNegotiatedCryptoType() == MSEHandshake::CRYPTO_ARC4) {
mseHandshake_->getDecryptor()->decrypt(buffer+mseHandshake_->getIALength(),
mseHandshake_->getBufferLength(),
mseHandshake_->getBuffer(),
mseHandshake_->getBufferLength());
} else {
memcpy(buffer+mseHandshake_->getIALength(),
mseHandshake_->getBuffer(), mseHandshake_->getBufferLength());
}
peerConnection->presetBuffer(buffer, buflen);
// TODO add mseHandshake_->getInfoHash() to PeerReceiveHandshakeCommand
// as a hint. If this info hash and one in BitTorrent Handshake does not
// match, then drop connection.

View File

@ -70,26 +70,57 @@ createSocketPair()
void MSEHandshakeTest::doHandshake(const SharedHandle<MSEHandshake>& initiator, const SharedHandle<MSEHandshake>& receiver)
{
initiator->sendPublicKey();
while(!receiver->receivePublicKey());
while(initiator->getWantWrite()) {
initiator->send();
}
while(!receiver->receivePublicKey()) {
receiver->read();
}
receiver->sendPublicKey();
while(receiver->getWantWrite()) {
receiver->send();
}
while(!initiator->receivePublicKey());
while(!initiator->receivePublicKey()) {
initiator->read();
}
initiator->initCipher(bittorrent::getInfoHash(dctx_));
initiator->sendInitiatorStep2();
while(initiator->getWantWrite()) {
initiator->send();
}
while(!receiver->findReceiverHashMarker());
while(!receiver->findReceiverHashMarker()) {
receiver->read();
}
std::vector<SharedHandle<DownloadContext> > contexts;
contexts.push_back(dctx_);
while(!receiver->receiveReceiverHashAndPadCLength(contexts));
while(!receiver->receivePad());
while(!receiver->receiveReceiverIALength());
while(!receiver->receiveReceiverIA());
while(!receiver->receiveReceiverHashAndPadCLength(contexts)) {
receiver->read();
}
while(!receiver->receivePad()) {
receiver->read();
}
while(!receiver->receiveReceiverIALength()) {
receiver->read();
}
while(!receiver->receiveReceiverIA()) {
receiver->read();
}
receiver->sendReceiverStep2();
while(receiver->getWantWrite()) {
receiver->send();
}
while(!initiator->findInitiatorVCMarker());
while(!initiator->receiveInitiatorCryptoSelectAndPadDLength());
while(!initiator->receivePad());
while(!initiator->findInitiatorVCMarker()) {
initiator->read();
}
while(!initiator->receiveInitiatorCryptoSelectAndPadDLength()) {
initiator->read();
}
while(!initiator->receivePad()) {
initiator->read();
}
}
namespace {