Eliminated SocketCore::peekData from MSEHandshake.

pull/1/head
Tatsuhiro Tsujikawa 2011-01-08 17:32:16 +09:00
parent c48db2cdf3
commit ce2d401dce
9 changed files with 523 additions and 438 deletions

View File

@ -56,6 +56,7 @@
#include "bittorrent_helper.h" #include "bittorrent_helper.h"
#include "util.h" #include "util.h"
#include "fmt.h" #include "fmt.h"
#include "array_fun.h"
namespace aria2 { namespace aria2 {
@ -89,82 +90,108 @@ InitiatorMSEHandshakeCommand::~InitiatorMSEHandshakeCommand()
} }
bool InitiatorMSEHandshakeCommand::executeInternal() { bool InitiatorMSEHandshakeCommand::executeInternal() {
switch(sequence_) { if(mseHandshake_->getWantRead()) {
case INITIATOR_SEND_KEY: { mseHandshake_->read();
if(!getSocket()->isWritable(0)) { }
bool done = false;
while(!done) {
switch(sequence_) {
case INITIATOR_SEND_KEY: {
if(!getSocket()->isWritable(0)) {
getDownloadEngine()->addCommand(this);
return false;
}
setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT));
mseHandshake_->initEncryptionFacility(true);
mseHandshake_->sendPublicKey();
sequence_ = INITIATOR_SEND_KEY_PENDING;
break; break;
} }
disableWriteCheckSocket(); case INITIATOR_SEND_KEY_PENDING:
setReadCheckSocket(getSocket()); if(mseHandshake_->send()) {
//socket->setBlockingMode(); sequence_ = INITIATOR_WAIT_KEY;
setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT)); } else {
mseHandshake_->initEncryptionFacility(true); done = true;
if(mseHandshake_->sendPublicKey()) { }
sequence_ = INITIATOR_WAIT_KEY; break;
} else { case INITIATOR_WAIT_KEY: {
setWriteCheckSocket(getSocket()); if(mseHandshake_->receivePublicKey()) {
sequence_ = INITIATOR_SEND_KEY_PENDING; mseHandshake_->initCipher
(bittorrent::getInfoHash(requestGroup_->getDownloadContext()));;
mseHandshake_->sendInitiatorStep2();
sequence_ = INITIATOR_SEND_STEP2_PENDING;
} else {
done = true;
}
break;
} }
break; case INITIATOR_SEND_STEP2_PENDING:
} if(mseHandshake_->send()) {
case INITIATOR_SEND_KEY_PENDING:
if(mseHandshake_->sendPublicKey()) {
disableWriteCheckSocket();
sequence_ = INITIATOR_WAIT_KEY;
}
break;
case INITIATOR_WAIT_KEY: {
if(mseHandshake_->receivePublicKey()) {
mseHandshake_->initCipher
(bittorrent::getInfoHash(requestGroup_->getDownloadContext()));;
if(mseHandshake_->sendInitiatorStep2()) {
sequence_ = INITIATOR_FIND_VC_MARKER; sequence_ = INITIATOR_FIND_VC_MARKER;
} else { } else {
setWriteCheckSocket(getSocket()); done = true;
sequence_ = INITIATOR_SEND_STEP2_PENDING;
} }
} break;
break; case INITIATOR_FIND_VC_MARKER: {
} if(mseHandshake_->findInitiatorVCMarker()) {
case INITIATOR_SEND_STEP2_PENDING: sequence_ = INITIATOR_RECEIVE_PAD_D_LENGTH;
if(mseHandshake_->sendInitiatorStep2()) { } else {
disableWriteCheckSocket(); done = true;
sequence_ = INITIATOR_FIND_VC_MARKER;
}
break;
case INITIATOR_FIND_VC_MARKER: {
if(mseHandshake_->findInitiatorVCMarker()) {
sequence_ = INITIATOR_RECEIVE_PAD_D_LENGTH;
}
break;
}
case INITIATOR_RECEIVE_PAD_D_LENGTH: {
if(mseHandshake_->receiveInitiatorCryptoSelectAndPadDLength()) {
sequence_ = INITIATOR_RECEIVE_PAD_D;
}
break;
}
case INITIATOR_RECEIVE_PAD_D: {
if(mseHandshake_->receivePad()) {
SharedHandle<PeerConnection> peerConnection
(new PeerConnection(getCuid(), getPeer(), getSocket()));
if(mseHandshake_->getNegotiatedCryptoType() == MSEHandshake::CRYPTO_ARC4){
peerConnection->enableEncryption(mseHandshake_->getEncryptor(),
mseHandshake_->getDecryptor());
} }
PeerInteractionCommand* c = break;
new PeerInteractionCommand }
(getCuid(), requestGroup_, getPeer(), getDownloadEngine(), btRuntime_, case INITIATOR_RECEIVE_PAD_D_LENGTH: {
pieceStorage_, if(mseHandshake_->receiveInitiatorCryptoSelectAndPadDLength()) {
peerStorage_, sequence_ = INITIATOR_RECEIVE_PAD_D;
getSocket(), } else {
PeerInteractionCommand::INITIATOR_SEND_HANDSHAKE, done = true;
peerConnection); }
getDownloadEngine()->addCommand(c); break;
return true; }
case INITIATOR_RECEIVE_PAD_D: {
if(mseHandshake_->receivePad()) {
SharedHandle<PeerConnection> peerConnection
(new PeerConnection(getCuid(), getPeer(), getSocket()));
if(mseHandshake_->getNegotiatedCryptoType() ==
MSEHandshake::CRYPTO_ARC4){
peerConnection->enableEncryption(mseHandshake_->getEncryptor(),
mseHandshake_->getDecryptor());
size_t buflen = mseHandshake_->getBufferLength();
array_ptr<unsigned char> buffer(new unsigned char[buflen]);
mseHandshake_->getDecryptor()->decrypt(buffer, buflen,
mseHandshake_->getBuffer(),
buflen);
peerConnection->presetBuffer(buffer, buflen);
} else {
peerConnection->presetBuffer(mseHandshake_->getBuffer(),
mseHandshake_->getBufferLength());
}
PeerInteractionCommand* c =
new PeerInteractionCommand
(getCuid(), requestGroup_, getPeer(), getDownloadEngine(), btRuntime_,
pieceStorage_,
peerStorage_,
getSocket(),
PeerInteractionCommand::INITIATOR_SEND_HANDSHAKE,
peerConnection);
getDownloadEngine()->addCommand(c);
return true;
} else {
done = true;
}
break;
}
} }
break;
} }
if(mseHandshake_->getWantRead()) {
setReadCheckSocket(getSocket());
} else {
disableReadCheckSocket();
}
if(mseHandshake_->getWantWrite()) {
setWriteCheckSocket(getSocket());
} else {
disableWriteCheckSocket();
} }
getDownloadEngine()->addCommand(this); getDownloadEngine()->addCommand(this);
return false; return false;

View File

@ -71,6 +71,7 @@ MSEHandshake::MSEHandshake
const Option* op) const Option* op)
: cuid_(cuid), : cuid_(cuid),
socket_(socket), socket_(socket),
wantRead_(false),
option_(op), option_(op),
rbufLength_(0), rbufLength_(0),
socketBuffer_(socket), socketBuffer_(socket),
@ -92,16 +93,8 @@ MSEHandshake::~MSEHandshake()
MSEHandshake::HANDSHAKE_TYPE MSEHandshake::identifyHandshakeType() MSEHandshake::HANDSHAKE_TYPE MSEHandshake::identifyHandshakeType()
{ {
if(!socket_->isReadable(0)) {
return HANDSHAKE_NOT_YET;
}
size_t r = 20-rbufLength_;
socket_->readData(rbuf_+rbufLength_, r);
if(r == 0 && !socket_->wantRead() && !socket_->wantWrite()) {
throw DL_ABORT_EX(EX_EOF_FROM_PEER);
}
rbufLength_ += r;
if(rbufLength_ < 20) { if(rbufLength_ < 20) {
wantRead_ = true;
return HANDSHAKE_NOT_YET; return HANDSHAKE_NOT_YET;
} }
if(rbuf_[0] == BtHandshakeMessage::PSTR_LENGTH && if(rbuf_[0] == BtHandshakeMessage::PSTR_LENGTH &&
@ -126,35 +119,59 @@ void MSEHandshake::initEncryptionFacility(bool initiator)
initiator_ = initiator; initiator_ = initiator;
} }
bool MSEHandshake::sendPublicKey() void MSEHandshake::sendPublicKey()
{ {
if(socketBuffer_.sendBufferIsEmpty()) { A2_LOG_DEBUG(fmt("CUID#%lld - Sending public key.",
A2_LOG_DEBUG(fmt("CUID#%lld - Sending public key.", cuid_));
cuid_)); unsigned char buffer[KEY_LENGTH+MAX_PAD_LENGTH];
unsigned char buffer[KEY_LENGTH+MAX_PAD_LENGTH]; dh_->getPublicKey(buffer, KEY_LENGTH);
dh_->getPublicKey(buffer, KEY_LENGTH);
size_t padLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1); size_t padLength =
dh_->generateNonce(buffer+KEY_LENGTH, padLength); SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
socketBuffer_.pushStr(std::string(&buffer[0], dh_->generateNonce(buffer+KEY_LENGTH, padLength);
&buffer[KEY_LENGTH+padLength])); socketBuffer_.pushStr(std::string(&buffer[0],
&buffer[KEY_LENGTH+padLength]));
}
void MSEHandshake::read()
{
if(rbufLength_ >= MAX_BUFFER_LENGTH) {
assert(!wantRead_);
return;
} }
size_t len = MAX_BUFFER_LENGTH-rbufLength_;
socket_->readData(rbuf_+rbufLength_, len);
if(len == 0 && !socket_->wantRead() && !socket_->wantWrite()) {
// TODO Should we set graceful in peer?
throw DL_ABORT_EX(EX_EOF_FROM_PEER);
}
rbufLength_ += len;
wantRead_ = false;
}
bool MSEHandshake::send()
{
socketBuffer_.send(); socketBuffer_.send();
return socketBuffer_.sendBufferIsEmpty(); return socketBuffer_.sendBufferIsEmpty();
} }
void MSEHandshake::shiftBuffer(size_t offset)
{
memmove(rbuf_, rbuf_+offset, rbufLength_-offset);
rbufLength_ -= offset;
}
bool MSEHandshake::receivePublicKey() bool MSEHandshake::receivePublicKey()
{ {
size_t r = KEY_LENGTH-rbufLength_; if(rbufLength_ < KEY_LENGTH) {
if(r > receiveNBytes(r)) { wantRead_ = true;
return false; return false;
} }
A2_LOG_DEBUG(fmt("CUID#%lld - public key received.", A2_LOG_DEBUG(fmt("CUID#%lld - public key received.", cuid_));
cuid_));
// TODO handle exception. in catch, resbufLength = 0; // TODO handle exception. in catch, resbufLength = 0;
dh_->computeSecret(secret_, sizeof(secret_), rbuf_, rbufLength_); dh_->computeSecret(secret_, sizeof(secret_), rbuf_, KEY_LENGTH);
// reset rbufLength_ // shift buffer
rbufLength_ = 0; shiftBuffer(KEY_LENGTH);
return true; return true;
} }
@ -251,109 +268,83 @@ uint16_t MSEHandshake::decodeLength16(const unsigned char* buffer)
return ntohs(be); return ntohs(be);
} }
bool MSEHandshake::sendInitiatorStep2() void MSEHandshake::sendInitiatorStep2()
{ {
if(socketBuffer_.sendBufferIsEmpty()) { A2_LOG_DEBUG(fmt("CUID#%lld - Sending negotiation step2.", cuid_));
A2_LOG_DEBUG(fmt("CUID#%lld - Sending negotiation step2.", unsigned char md[20];
cuid_)); createReq1Hash(md);
unsigned char md[20]; socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)]));
createReq1Hash(md); createReq23Hash(md, infoHash_);
socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)])); socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)]));
createReq23Hash(md, infoHash_); // buffer is filled in this order:
socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)])); // VC(VC_LENGTH bytes),
// crypto_provide(CRYPTO_BITFIELD_LENGTH bytes),
// len(padC)(2bytes),
// padC(len(padC)bytes <= MAX_PAD_LENGTH),
// len(IA)(2bytes)
unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH+2];
{ // VC
// buffer is filled in this order: memcpy(buffer, VC, sizeof(VC));
// VC(VC_LENGTH bytes), // crypto_provide
// crypto_provide(CRYPTO_BITFIELD_LENGTH bytes), unsigned char cryptoProvide[CRYPTO_BITFIELD_LENGTH];
// len(padC)(2bytes), memset(cryptoProvide, 0, sizeof(cryptoProvide));
// padC(len(padC)bytes <= MAX_PAD_LENGTH), if(option_->get(PREF_BT_MIN_CRYPTO_LEVEL) == V_PLAIN) {
// len(IA)(2bytes) cryptoProvide[3] = CRYPTO_PLAIN_TEXT;
unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH+2];
// VC
memcpy(buffer, VC, sizeof(VC));
// crypto_provide
unsigned char cryptoProvide[CRYPTO_BITFIELD_LENGTH];
memset(cryptoProvide, 0, sizeof(cryptoProvide));
if(option_->get(PREF_BT_MIN_CRYPTO_LEVEL) == V_PLAIN) {
cryptoProvide[3] = CRYPTO_PLAIN_TEXT;
}
cryptoProvide[3] |= CRYPTO_ARC4;
memcpy(buffer+VC_LENGTH, cryptoProvide, sizeof(cryptoProvide));
// len(padC)
uint16_t padCLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
{
uint16_t padCLengthBE = htons(padCLength);
memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padCLengthBE,
sizeof(padCLengthBE));
}
// padC
memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padCLength);
// len(IA)
// currently, IA is zero-length.
uint16_t iaLength = 0;
{
uint16_t iaLengthBE = htons(iaLength);
memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength,
&iaLengthBE,sizeof(iaLengthBE));
}
encryptAndSendData(buffer,
VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength+2);
}
} }
socketBuffer_.send(); cryptoProvide[3] |= CRYPTO_ARC4;
return socketBuffer_.sendBufferIsEmpty(); memcpy(buffer+VC_LENGTH, cryptoProvide, sizeof(cryptoProvide));
// len(padC)
uint16_t padCLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
{
uint16_t padCLengthBE = htons(padCLength);
memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padCLengthBE,
sizeof(padCLengthBE));
}
// padC
memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padCLength);
// len(IA)
// currently, IA is zero-length.
uint16_t iaLength = 0;
{
uint16_t iaLengthBE = htons(iaLength);
memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength,
&iaLengthBE,sizeof(iaLengthBE));
}
encryptAndSendData(buffer,
VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength+2);
} }
// This function reads exactly until the end of VC marker is reached. // This function reads exactly until the end of VC marker is reached.
bool MSEHandshake::findInitiatorVCMarker() bool MSEHandshake::findInitiatorVCMarker()
{ {
// 616 is synchronization point of initiator // 616 is synchronization point of initiator
size_t r = 616-KEY_LENGTH-rbufLength_; // find vc
if(!socket_->isReadable(0)) { std::string buf(&rbuf_[0], &rbuf_[rbufLength_]);
return false; std::string vc(&initiatorVCMarker_[0], &initiatorVCMarker_[VC_LENGTH]);
} if((markerIndex_ = buf.find(vc)) == std::string::npos) {
socket_->peekData(rbuf_+rbufLength_, r); if(616-KEY_LENGTH <= rbufLength_) {
if(r == 0) { throw DL_ABORT_EX("Failed to find VC marker.");
if(socket_->wantRead() || socket_->wantWrite()) { } else {
wantRead_ = true;
return false; return false;
} }
throw DL_ABORT_EX(EX_EOF_FROM_PEER);
} }
// find vc
{
std::string buf(&rbuf_[0], &rbuf_[rbufLength_+r]);
std::string vc(&initiatorVCMarker_[0], &initiatorVCMarker_[VC_LENGTH]);
if((markerIndex_ = buf.find(vc)) == std::string::npos) {
if(616-KEY_LENGTH <= rbufLength_+r) {
throw DL_ABORT_EX("Failed to find VC marker.");
} else {
socket_->readData(rbuf_+rbufLength_, r);
rbufLength_ += r;
return false;
}
}
}
assert(markerIndex_+VC_LENGTH-rbufLength_ <= r);
size_t toRead = markerIndex_+VC_LENGTH-rbufLength_;
socket_->readData(rbuf_+rbufLength_, toRead);
rbufLength_ += toRead;
A2_LOG_DEBUG(fmt("CUID#%lld - VC marker found at %lu", A2_LOG_DEBUG(fmt("CUID#%lld - VC marker found at %lu",
cuid_, cuid_,
static_cast<unsigned long>(markerIndex_))); static_cast<unsigned long>(markerIndex_)));
verifyVC(rbuf_+markerIndex_); verifyVC(rbuf_+markerIndex_);
// reset rbufLength_ // shift rbuf
rbufLength_ = 0; shiftBuffer(markerIndex_+VC_LENGTH);
return true; return true;
} }
bool MSEHandshake::receiveInitiatorCryptoSelectAndPadDLength() bool MSEHandshake::receiveInitiatorCryptoSelectAndPadDLength()
{ {
size_t r = CRYPTO_BITFIELD_LENGTH+2/* PadD length*/-rbufLength_; if(CRYPTO_BITFIELD_LENGTH+2/* PadD length*/ > rbufLength_) {
if(r > receiveNBytes(r)) { wantRead_ = true;
return false; return false;
} }
//verifyCryptoSelect //verifyCryptoSelect
@ -382,75 +373,57 @@ bool MSEHandshake::receiveInitiatorCryptoSelectAndPadDLength()
// padD length // padD length
rbufptr += CRYPTO_BITFIELD_LENGTH; rbufptr += CRYPTO_BITFIELD_LENGTH;
padLength_ = verifyPadLength(rbufptr, "PadD"); padLength_ = verifyPadLength(rbufptr, "PadD");
// reset rbufLength_ // shift rbuf
rbufLength_ = 0; shiftBuffer(CRYPTO_BITFIELD_LENGTH+2/* PadD length*/);
return true; return true;
} }
bool MSEHandshake::receivePad() bool MSEHandshake::receivePad()
{ {
if(padLength_ > rbufLength_) {
wantRead_ = true;
return false;
}
if(padLength_ == 0) { if(padLength_ == 0) {
return true; return true;
} }
size_t r = padLength_-rbufLength_;
if(r > receiveNBytes(r)) {
return false;
}
unsigned char temp[MAX_PAD_LENGTH]; unsigned char temp[MAX_PAD_LENGTH];
decryptor_->decrypt(temp, padLength_, rbuf_, padLength_); decryptor_->decrypt(temp, padLength_, rbuf_, padLength_);
// reset rbufLength_ // shift rbuf_
rbufLength_ = 0; shiftBuffer(padLength_);
return true; return true;
} }
bool MSEHandshake::findReceiverHashMarker() bool MSEHandshake::findReceiverHashMarker()
{ {
// 628 is synchronization limit of receiver. // 628 is synchronization limit of receiver.
size_t r = 628-KEY_LENGTH-rbufLength_; // find hash('req1', S), S is secret_.
if(!socket_->isReadable(0)) { std::string buf(&rbuf_[0], &rbuf_[rbufLength_]);
return false; unsigned char md[20];
} createReq1Hash(md);
socket_->peekData(rbuf_+rbufLength_, r); std::string req1(&md[0], &md[sizeof(md)]);
if(r == 0) { if((markerIndex_ = buf.find(req1)) == std::string::npos) {
if(socket_->wantRead() || socket_->wantWrite()) { if(628-KEY_LENGTH <= rbufLength_) {
throw DL_ABORT_EX("Failed to find hash marker.");
} else {
wantRead_ = true;
return false; return false;
} }
throw DL_ABORT_EX(EX_EOF_FROM_PEER);
} }
// find hash('req1', S), S is secret_.
{
std::string buf(&rbuf_[0], &rbuf_[rbufLength_+r]);
unsigned char md[20];
createReq1Hash(md);
std::string req1(&md[0], &md[sizeof(md)]);
if((markerIndex_ = buf.find(req1)) == std::string::npos) {
if(628-KEY_LENGTH <= rbufLength_+r) {
throw DL_ABORT_EX("Failed to find hash marker.");
} else {
socket_->readData(rbuf_+rbufLength_, r);
rbufLength_ += r;
return false;
}
}
}
assert(markerIndex_+20-rbufLength_ <= r);
size_t toRead = markerIndex_+20-rbufLength_;
socket_->readData(rbuf_+rbufLength_, toRead);
rbufLength_ += toRead;
A2_LOG_DEBUG(fmt("CUID#%lld - Hash marker found at %lu.", A2_LOG_DEBUG(fmt("CUID#%lld - Hash marker found at %lu.",
cuid_, cuid_,
static_cast<unsigned long>(markerIndex_))); static_cast<unsigned long>(markerIndex_)));
verifyReq1Hash(rbuf_+markerIndex_); verifyReq1Hash(rbuf_+markerIndex_);
// reset rbufLength_ // shift rbuf_
rbufLength_ = 0; shiftBuffer(markerIndex_+20);
return true; return true;
} }
bool MSEHandshake::receiveReceiverHashAndPadCLength bool MSEHandshake::receiveReceiverHashAndPadCLength
(const std::vector<SharedHandle<DownloadContext> >& downloadContexts) (const std::vector<SharedHandle<DownloadContext> >& downloadContexts)
{ {
size_t r = 20+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2/*PadC length*/-rbufLength_; if(20+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2/*PadC length*/ > rbufLength_) {
if(r > receiveNBytes(r)) { wantRead_ = true;
return false; return false;
} }
// resolve info hash // resolve info hash
@ -505,23 +478,22 @@ bool MSEHandshake::receiveReceiverHashAndPadCLength
// decrypt PadC length // decrypt PadC length
rbufptr += CRYPTO_BITFIELD_LENGTH; rbufptr += CRYPTO_BITFIELD_LENGTH;
padLength_ = verifyPadLength(rbufptr, "PadC"); padLength_ = verifyPadLength(rbufptr, "PadC");
// reset rbufLength_ // shift rbuf_
rbufLength_ = 0; shiftBuffer(20+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2/*PadC length*/);
return true; return true;
} }
bool MSEHandshake::receiveReceiverIALength() bool MSEHandshake::receiveReceiverIALength()
{ {
size_t r = 2-rbufLength_; if(2 > rbufLength_) {
assert(r > 0); wantRead_ = true;
if(r > receiveNBytes(r)) {
return false; return false;
} }
iaLength_ = decodeLength16(rbuf_); iaLength_ = decodeLength16(rbuf_);
A2_LOG_DEBUG(fmt("CUID#%lld - len(IA)=%u.", // TODO limit iaLength \19...+handshake
cuid_, iaLength_)); A2_LOG_DEBUG(fmt("CUID#%lld - len(IA)=%u.", cuid_, iaLength_));
// reset rbufLength_ // shift rbuf_
rbufLength_ = 0; shiftBuffer(2);
return true; return true;
} }
@ -530,48 +502,44 @@ bool MSEHandshake::receiveReceiverIA()
if(iaLength_ == 0) { if(iaLength_ == 0) {
return true; return true;
} }
size_t r = iaLength_-rbufLength_; if(iaLength_ > rbufLength_) {
if(r > receiveNBytes(r)) { wantRead_ = true;
return false; return false;
} }
delete [] ia_; delete [] ia_;
ia_ = new unsigned char[iaLength_]; ia_ = new unsigned char[iaLength_];
decryptor_->decrypt(ia_, iaLength_, rbuf_, iaLength_); decryptor_->decrypt(ia_, iaLength_, rbuf_, iaLength_);
A2_LOG_DEBUG(fmt("CUID#%lld - IA received.", cuid_)); A2_LOG_DEBUG(fmt("CUID#%lld - IA received.", cuid_));
// reset rbufLength_ // shift rbuf_
rbufLength_ = 0; shiftBuffer(iaLength_);
return true; return true;
} }
bool MSEHandshake::sendReceiverStep2() void MSEHandshake::sendReceiverStep2()
{ {
if(socketBuffer_.sendBufferIsEmpty()) { // buffer is filled in this order:
// buffer is filled in this order: // VC(VC_LENGTH bytes),
// VC(VC_LENGTH bytes), // cryptoSelect(CRYPTO_BITFIELD_LENGTH bytes),
// cryptoSelect(CRYPTO_BITFIELD_LENGTH bytes), // len(padD)(2bytes),
// len(padD)(2bytes), // padD(len(padD)bytes <= MAX_PAD_LENGTH)
// padD(len(padD)bytes <= MAX_PAD_LENGTH) unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH];
unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH]; // VC
// VC memcpy(buffer, VC, sizeof(VC));
memcpy(buffer, VC, sizeof(VC)); // crypto_select
// crypto_select unsigned char cryptoSelect[CRYPTO_BITFIELD_LENGTH];
unsigned char cryptoSelect[CRYPTO_BITFIELD_LENGTH]; memset(cryptoSelect, 0, sizeof(cryptoSelect));
memset(cryptoSelect, 0, sizeof(cryptoSelect)); cryptoSelect[3] = negotiatedCryptoType_;
cryptoSelect[3] = negotiatedCryptoType_; memcpy(buffer+VC_LENGTH, cryptoSelect, sizeof(cryptoSelect));
memcpy(buffer+VC_LENGTH, cryptoSelect, sizeof(cryptoSelect)); // len(padD)
// len(padD) uint16_t padDLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1);
uint16_t padDLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1); {
{ uint16_t padDLengthBE = htons(padDLength);
uint16_t padDLengthBE = htons(padDLength); memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padDLengthBE,
memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padDLengthBE, sizeof(padDLengthBE));
sizeof(padDLengthBE));
}
// padD, all zeroed
memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padDLength);
encryptAndSendData(buffer, VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padDLength);
} }
socketBuffer_.send(); // padD, all zeroed
return socketBuffer_.sendBufferIsEmpty(); memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padDLength);
encryptAndSendData(buffer, VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padDLength);
} }
uint16_t MSEHandshake::verifyPadLength(const unsigned char* padlenbuf, const char* padName) uint16_t MSEHandshake::verifyPadLength(const unsigned char* padlenbuf, const char* padName)
@ -609,20 +577,9 @@ void MSEHandshake::verifyReq1Hash(const unsigned char* req1buf)
} }
} }
size_t MSEHandshake::receiveNBytes(size_t bytes) bool MSEHandshake::getWantWrite() const
{ {
size_t r = bytes; return !socketBuffer_.sendBufferIsEmpty();
if(r > 0) {
if(!socket_->isReadable(0)) {
return 0;
}
socket_->readData(rbuf_+rbufLength_, r);
if(r == 0 && !socket_->wantRead() && !socket_->wantWrite()) {
throw DL_ABORT_EX(EX_EOF_FROM_PEER);
}
rbufLength_ += r;
}
return r;
} }
} // namespace aria2 } // namespace aria2

View File

@ -83,6 +83,7 @@ private:
cuid_t cuid_; cuid_t cuid_;
SharedHandle<SocketCore> socket_; SharedHandle<SocketCore> socket_;
bool wantRead_;
const Option* option_; const Option* option_;
unsigned char rbuf_[MAX_BUFFER_LENGTH]; unsigned char rbuf_[MAX_BUFFER_LENGTH];
@ -130,8 +131,7 @@ private:
void verifyReq1Hash(const unsigned char* req1buf); void verifyReq1Hash(const unsigned char* req1buf);
size_t receiveNBytes(size_t bytes); void shiftBuffer(size_t offset);
public: public:
MSEHandshake(cuid_t cuid, const SharedHandle<SocketCore>& socket, MSEHandshake(cuid_t cuid, const SharedHandle<SocketCore>& socket,
const Option* op); const Option* op);
@ -142,13 +142,33 @@ public:
void initEncryptionFacility(bool initiator); void initEncryptionFacility(bool initiator);
bool sendPublicKey(); // Reads data from Socket. If EOF is reached, throws
// RecoverableException.
void read();
// Sends pending data in the send buffer. Returns true if all data
// is sent. Otherwise returns false.
bool send();
bool getWantRead() const
{
return wantRead_;
}
void setWantRead(bool wantRead)
{
wantRead_ = wantRead;
}
bool getWantWrite() const;
void sendPublicKey();
bool receivePublicKey(); bool receivePublicKey();
void initCipher(const unsigned char* infoHash); void initCipher(const unsigned char* infoHash);
bool sendInitiatorStep2(); void sendInitiatorStep2();
bool findInitiatorVCMarker(); bool findInitiatorVCMarker();
@ -165,7 +185,7 @@ public:
bool receiveReceiverIA(); bool receiveReceiverIA();
bool sendReceiverStep2(); void sendReceiverStep2();
// returns plain text IA // returns plain text IA
const unsigned char* getIA() const const unsigned char* getIA() const
@ -207,7 +227,6 @@ public:
{ {
return rbufLength_; return rbufLength_;
} }
}; };
} // namespace aria2 } // namespace aria2

View File

@ -110,9 +110,6 @@ void PeerConnection::pushBytes(unsigned char* data, size_t len)
bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) { bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) {
if(resbufLength_ == 0 && 4 > lenbufLength_) { if(resbufLength_ == 0 && 4 > lenbufLength_) {
if(!socket_->isReadable(0)) {
return false;
}
// read payload size, 32bit unsigned integer // read payload size, 32bit unsigned integer
size_t remaining = 4-lenbufLength_; size_t remaining = 4-lenbufLength_;
size_t temp = remaining; size_t temp = remaining;
@ -182,7 +179,7 @@ bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength,
bool peek) { bool peek) {
assert(BtHandshakeMessage::MESSAGE_LENGTH >= resbufLength_); assert(BtHandshakeMessage::MESSAGE_LENGTH >= resbufLength_);
bool retval = true; bool retval = true;
if(prevPeek_ && !peek && resbufLength_) { if(prevPeek_ && resbufLength_) {
// We have data in previous peek. // We have data in previous peek.
// There is a chance that socket is readable because of EOF, for example, // There is a chance that socket is readable because of EOF, for example,
// official bttrack shutdowns socket after sending first 48 bytes of // official bttrack shutdowns socket after sending first 48 bytes of
@ -194,17 +191,10 @@ bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength,
} else { } else {
prevPeek_ = peek; prevPeek_ = peek;
size_t remaining = BtHandshakeMessage::MESSAGE_LENGTH-resbufLength_; size_t remaining = BtHandshakeMessage::MESSAGE_LENGTH-resbufLength_;
if(remaining > 0 && !socket_->isReadable(0)) {
dataLength = 0;
return false;
}
if(remaining > 0) { if(remaining > 0) {
size_t temp = remaining; size_t temp = remaining;
readData(resbuf_+resbufLength_, remaining, encryptionEnabled_); readData(resbuf_+resbufLength_, remaining, encryptionEnabled_);
if(remaining == 0) { if(remaining == 0 && !socket_->wantRead() && !socket_->wantWrite()) {
if(socket_->wantRead() || socket_->wantWrite()) {
return false;
}
// we got EOF // we got EOF
A2_LOG_DEBUG A2_LOG_DEBUG
(fmt("CUID#%lld - In PeerConnection::receiveHandshake(), remain=%lu", (fmt("CUID#%lld - In PeerConnection::receiveHandshake(), remain=%lu",
@ -256,6 +246,9 @@ void PeerConnection::presetBuffer(const unsigned char* data, size_t length)
size_t nwrite = std::min((size_t)MAX_PAYLOAD_LEN, length); size_t nwrite = std::min((size_t)MAX_PAYLOAD_LEN, length);
memcpy(resbuf_, data, nwrite); memcpy(resbuf_, data, nwrite);
resbufLength_ = length; resbufLength_ = length;
if(resbufLength_ > 0) {
prevPeek_ = true;
}
} }
bool PeerConnection::sendBufferIsEmpty() const bool PeerConnection::sendBufferIsEmpty() const

View File

@ -117,6 +117,11 @@ public:
return resbuf_; return resbuf_;
} }
size_t getBufferLength() const
{
return resbufLength_;
}
unsigned char* detachBuffer(); unsigned char* detachBuffer();
}; };

View File

@ -163,6 +163,11 @@ PeerInteractionCommand::PeerInteractionCommand
peerConnection.reset(new PeerConnection(cuid, getPeer(), getSocket())); peerConnection.reset(new PeerConnection(cuid, getPeer(), getSocket()));
} else { } else {
peerConnection = passedPeerConnection; peerConnection = passedPeerConnection;
if(sequence_ == RECEIVER_WAIT_HANDSHAKE &&
peerConnection->getBufferLength() > 0) {
setStatus(Command::STATUS_ONESHOT_REALTIME);
getDownloadEngine()->setNoWait(true);
}
} }
SharedHandle<DefaultBtMessageDispatcher> dispatcher SharedHandle<DefaultBtMessageDispatcher> dispatcher
@ -274,71 +279,80 @@ PeerInteractionCommand::~PeerInteractionCommand() {
bool PeerInteractionCommand::executeInternal() { bool PeerInteractionCommand::executeInternal() {
setNoCheck(false); setNoCheck(false);
switch(sequence_) { bool done = false;
case INITIATOR_SEND_HANDSHAKE: while(!done) {
if(!getSocket()->isWritable(0)) { switch(sequence_) {
break; case INITIATOR_SEND_HANDSHAKE:
} if(!getSocket()->isWritable(0)) {
disableWriteCheckSocket(); done = true;
setReadCheckSocket(getSocket());
//socket->setBlockingMode();
setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT));
btInteractive_->initiateHandshake();
sequence_ = INITIATOR_WAIT_HANDSHAKE;
break;
case INITIATOR_WAIT_HANDSHAKE: {
if(btInteractive_->countPendingMessage() > 0) {
btInteractive_->sendPendingMessage();
if(btInteractive_->countPendingMessage() > 0) {
break; break;
} }
} disableWriteCheckSocket();
BtMessageHandle handshakeMessage = btInteractive_->receiveHandshake(); setReadCheckSocket(getSocket());
if(!handshakeMessage) { //socket->setBlockingMode();
setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT));
btInteractive_->initiateHandshake();
sequence_ = INITIATOR_WAIT_HANDSHAKE;
break; break;
} case INITIATOR_WAIT_HANDSHAKE: {
btInteractive_->doPostHandshakeProcessing(); if(btInteractive_->countPendingMessage() > 0) {
sequence_ = WIRED; btInteractive_->sendPendingMessage();
break; if(btInteractive_->countPendingMessage() > 0) {
} done = true;
case RECEIVER_WAIT_HANDSHAKE: { break;
BtMessageHandle handshakeMessage =btInteractive_->receiveAndSendHandshake(); }
if(!handshakeMessage) {
break;
}
btInteractive_->doPostHandshakeProcessing();
sequence_ = WIRED;
break;
}
case WIRED:
// See the comment for writable check below.
disableWriteCheckSocket();
btInteractive_->doInteractionProcessing();
if(btInteractive_->countReceivedMessageInIteration() > 0) {
updateKeepAlive();
}
if((getPeer()->amInterested() && !getPeer()->peerChoking()) ||
btInteractive_->countOutstandingRequest() ||
(getPeer()->peerInterested() && !getPeer()->amChoking())) {
// Writable check to avoid slow seeding
if(btInteractive_->isSendingMessageInProgress()) {
setWriteCheckSocket(getSocket());
} }
BtMessageHandle handshakeMessage = btInteractive_->receiveHandshake();
if(!handshakeMessage) {
done = true;
break;
}
btInteractive_->doPostHandshakeProcessing();
sequence_ = WIRED;
break;
}
case RECEIVER_WAIT_HANDSHAKE: {
BtMessageHandle handshakeMessage =
btInteractive_->receiveAndSendHandshake();
if(!handshakeMessage) {
done = true;
break;
}
btInteractive_->doPostHandshakeProcessing();
sequence_ = WIRED;
break;
}
case WIRED:
// See the comment for writable check below.
disableWriteCheckSocket();
if(getDownloadEngine()->getRequestGroupMan()-> btInteractive_->doInteractionProcessing();
doesOverallDownloadSpeedExceed() || if(btInteractive_->countReceivedMessageInIteration() > 0) {
requestGroup_->doesDownloadSpeedExceed()) { updateKeepAlive();
disableReadCheckSocket(); }
setNoCheck(true); if((getPeer()->amInterested() && !getPeer()->peerChoking()) ||
btInteractive_->countOutstandingRequest() ||
(getPeer()->peerInterested() && !getPeer()->amChoking())) {
// Writable check to avoid slow seeding
if(btInteractive_->isSendingMessageInProgress()) {
setWriteCheckSocket(getSocket());
}
if(getDownloadEngine()->getRequestGroupMan()->
doesOverallDownloadSpeedExceed() ||
requestGroup_->doesDownloadSpeedExceed()) {
disableReadCheckSocket();
setNoCheck(true);
} else {
setReadCheckSocket(getSocket());
}
} else { } else {
setReadCheckSocket(getSocket()); disableReadCheckSocket();
} }
} else { done = true;
disableReadCheckSocket(); break;
} }
break;
} }
if(btInteractive_->countPendingMessage() > 0) { if(btInteractive_->countPendingMessage() > 0) {
setNoCheck(true); setNoCheck(true);

View File

@ -67,7 +67,12 @@ PeerReceiveHandshakeCommand::PeerReceiveHandshakeCommand
: PeerAbstractCommand(cuid, peer, e, s), : PeerAbstractCommand(cuid, peer, e, s),
peerConnection_(peerConnection) peerConnection_(peerConnection)
{ {
if(!peerConnection_) { if(peerConnection_) {
if(peerConnection_->getBufferLength() > 0) {
setStatus(Command::STATUS_ONESHOT_REALTIME);
getDownloadEngine()->setNoWait(true);
}
} else {
peerConnection_.reset(new PeerConnection(cuid, getPeer(), getSocket())); peerConnection_.reset(new PeerConnection(cuid, getPeer(), getSocket()));
} }
} }

View File

@ -50,6 +50,7 @@
#include "RequestGroupMan.h" #include "RequestGroupMan.h"
#include "BtRegistry.h" #include "BtRegistry.h"
#include "DownloadContext.h" #include "DownloadContext.h"
#include "array_fun.h"
namespace aria2 { namespace aria2 {
@ -64,6 +65,7 @@ ReceiverMSEHandshakeCommand::ReceiverMSEHandshakeCommand
mseHandshake_(new MSEHandshake(cuid, s, e->getOption())) mseHandshake_(new MSEHandshake(cuid, s, e->getOption()))
{ {
setTimeout(e->getOption()->getAsInt(PREF_PEER_CONNECTION_TIMEOUT)); setTimeout(e->getOption()->getAsInt(PREF_PEER_CONNECTION_TIMEOUT));
mseHandshake_->setWantRead(true);
} }
ReceiverMSEHandshakeCommand::~ReceiverMSEHandshakeCommand() ReceiverMSEHandshakeCommand::~ReceiverMSEHandshakeCommand()
@ -79,102 +81,125 @@ bool ReceiverMSEHandshakeCommand::exitBeforeExecute()
bool ReceiverMSEHandshakeCommand::executeInternal() bool ReceiverMSEHandshakeCommand::executeInternal()
{ {
switch(sequence_) { if(mseHandshake_->getWantRead()) {
case RECEIVER_IDENTIFY_HANDSHAKE: { mseHandshake_->read();
MSEHandshake::HANDSHAKE_TYPE type = mseHandshake_->identifyHandshakeType();
switch(type) {
case MSEHandshake::HANDSHAKE_NOT_YET:
break;
case MSEHandshake::HANDSHAKE_ENCRYPTED:
mseHandshake_->initEncryptionFacility(false);
sequence_ = RECEIVER_WAIT_KEY;
break;
case MSEHandshake::HANDSHAKE_LEGACY: {
if(getDownloadEngine()->getOption()->getAsBool(PREF_BT_REQUIRE_CRYPTO)) {
throw DL_ABORT_EX
("The legacy BitTorrent handshake is not acceptable by the"
" preference.");
}
SharedHandle<PeerConnection> peerConnection
(new PeerConnection(getCuid(), getPeer(), getSocket()));
peerConnection->presetBuffer(mseHandshake_->getBuffer(),
mseHandshake_->getBufferLength());
Command* c = new PeerReceiveHandshakeCommand(getCuid(),
getPeer(),
getDownloadEngine(),
getSocket(),
peerConnection);
getDownloadEngine()->addCommand(c);
return true;
}
default:
throw DL_ABORT_EX("Not supported handshake type.");
}
break;
} }
case RECEIVER_WAIT_KEY: { bool done = false;
if(mseHandshake_->receivePublicKey()) { while(!done) {
if(mseHandshake_->sendPublicKey()) { switch(sequence_) {
case RECEIVER_IDENTIFY_HANDSHAKE: {
MSEHandshake::HANDSHAKE_TYPE type =
mseHandshake_->identifyHandshakeType();
switch(type) {
case MSEHandshake::HANDSHAKE_NOT_YET:
done = true;
break;
case MSEHandshake::HANDSHAKE_ENCRYPTED:
mseHandshake_->initEncryptionFacility(false);
sequence_ = RECEIVER_WAIT_KEY;
break;
case MSEHandshake::HANDSHAKE_LEGACY: {
if(getDownloadEngine()->getOption()->getAsBool(PREF_BT_REQUIRE_CRYPTO)){
throw DL_ABORT_EX
("The legacy BitTorrent handshake is not acceptable by the"
" preference.");
}
SharedHandle<PeerConnection> peerConnection
(new PeerConnection(getCuid(), getPeer(), getSocket()));
peerConnection->presetBuffer(mseHandshake_->getBuffer(),
mseHandshake_->getBufferLength());
Command* c = new PeerReceiveHandshakeCommand(getCuid(),
getPeer(),
getDownloadEngine(),
getSocket(),
peerConnection);
getDownloadEngine()->addCommand(c);
return true;
}
default:
throw DL_ABORT_EX("Not supported handshake type.");
}
break;
}
case RECEIVER_WAIT_KEY: {
if(mseHandshake_->receivePublicKey()) {
mseHandshake_->sendPublicKey();
sequence_ = RECEIVER_SEND_KEY_PENDING;
} else {
done = true;
}
break;
}
case RECEIVER_SEND_KEY_PENDING:
if(mseHandshake_->send()) {
sequence_ = RECEIVER_FIND_HASH_MARKER; sequence_ = RECEIVER_FIND_HASH_MARKER;
} else { } else {
setWriteCheckSocket(getSocket()); done = true;
sequence_ = RECEIVER_SEND_KEY_PENDING;
} }
break;
case RECEIVER_FIND_HASH_MARKER: {
if(mseHandshake_->findReceiverHashMarker()) {
sequence_ = RECEIVER_RECEIVE_PAD_C_LENGTH;
} else {
done = true;
}
break;
} }
break; case RECEIVER_RECEIVE_PAD_C_LENGTH: {
} std::vector<SharedHandle<DownloadContext> > downloadContexts;
case RECEIVER_SEND_KEY_PENDING: getDownloadEngine()->getBtRegistry()->getAllDownloadContext
if(mseHandshake_->sendPublicKey()) { (std::back_inserter(downloadContexts));
disableWriteCheckSocket(); if(mseHandshake_->receiveReceiverHashAndPadCLength(downloadContexts)) {
sequence_ = RECEIVER_FIND_HASH_MARKER; sequence_ = RECEIVER_RECEIVE_PAD_C;
} else {
done = true;
}
break;
} }
break; case RECEIVER_RECEIVE_PAD_C: {
case RECEIVER_FIND_HASH_MARKER: { if(mseHandshake_->receivePad()) {
if(mseHandshake_->findReceiverHashMarker()) { sequence_ = RECEIVER_RECEIVE_IA_LENGTH;
sequence_ = RECEIVER_RECEIVE_PAD_C_LENGTH; } else {
done = true;
}
break;
} }
break; case RECEIVER_RECEIVE_IA_LENGTH: {
} if(mseHandshake_->receiveReceiverIALength()) {
case RECEIVER_RECEIVE_PAD_C_LENGTH: { sequence_ = RECEIVER_RECEIVE_IA;
std::vector<SharedHandle<DownloadContext> > downloadContexts; } else {
getDownloadEngine()->getBtRegistry()->getAllDownloadContext done = true;
(std::back_inserter(downloadContexts)); }
if(mseHandshake_->receiveReceiverHashAndPadCLength(downloadContexts)) { break;
sequence_ = RECEIVER_RECEIVE_PAD_C;
} }
break; case RECEIVER_RECEIVE_IA: {
} if(mseHandshake_->receiveReceiverIA()) {
case RECEIVER_RECEIVE_PAD_C: { mseHandshake_->sendReceiverStep2();
if(mseHandshake_->receivePad()) { sequence_ = RECEIVER_SEND_STEP2_PENDING;
sequence_ = RECEIVER_RECEIVE_IA_LENGTH; } else {
done = true;
}
break;
} }
break; case RECEIVER_SEND_STEP2_PENDING:
} if(mseHandshake_->send()) {
case RECEIVER_RECEIVE_IA_LENGTH: {
if(mseHandshake_->receiveReceiverIALength()) {
sequence_ = RECEIVER_RECEIVE_IA;
}
break;
}
case RECEIVER_RECEIVE_IA: {
if(mseHandshake_->receiveReceiverIA()) {
if(mseHandshake_->sendReceiverStep2()) {
createCommand(); createCommand();
return true; return true;
} else { } else {
setWriteCheckSocket(getSocket()); done = true;
sequence_ = RECEIVER_SEND_STEP2_PENDING;
} }
break;
} }
break;
} }
case RECEIVER_SEND_STEP2_PENDING: if(mseHandshake_->getWantRead()) {
if(mseHandshake_->sendReceiverStep2()) { setReadCheckSocket(getSocket());
disableWriteCheckSocket(); } else {
createCommand(); disableReadCheckSocket();
return true; }
} if(mseHandshake_->getWantWrite()) {
break; setWriteCheckSocket(getSocket());
} else {
disableWriteCheckSocket();
} }
getDownloadEngine()->addCommand(this); getDownloadEngine()->addCommand(this);
return false; return false;
@ -188,10 +213,19 @@ void ReceiverMSEHandshakeCommand::createCommand()
peerConnection->enableEncryption(mseHandshake_->getEncryptor(), peerConnection->enableEncryption(mseHandshake_->getEncryptor(),
mseHandshake_->getDecryptor()); mseHandshake_->getDecryptor());
} }
if(mseHandshake_->getIALength() > 0) { size_t buflen = mseHandshake_->getIALength()+mseHandshake_->getBufferLength();
peerConnection->presetBuffer(mseHandshake_->getIA(), array_ptr<unsigned char> buffer(new unsigned char[buflen]);
mseHandshake_->getIALength()); memcpy(buffer, mseHandshake_->getIA(), mseHandshake_->getIALength());
if(mseHandshake_->getNegotiatedCryptoType() == MSEHandshake::CRYPTO_ARC4) {
mseHandshake_->getDecryptor()->decrypt(buffer+mseHandshake_->getIALength(),
mseHandshake_->getBufferLength(),
mseHandshake_->getBuffer(),
mseHandshake_->getBufferLength());
} else {
memcpy(buffer+mseHandshake_->getIALength(),
mseHandshake_->getBuffer(), mseHandshake_->getBufferLength());
} }
peerConnection->presetBuffer(buffer, buflen);
// TODO add mseHandshake_->getInfoHash() to PeerReceiveHandshakeCommand // TODO add mseHandshake_->getInfoHash() to PeerReceiveHandshakeCommand
// as a hint. If this info hash and one in BitTorrent Handshake does not // as a hint. If this info hash and one in BitTorrent Handshake does not
// match, then drop connection. // match, then drop connection.

View File

@ -70,26 +70,57 @@ createSocketPair()
void MSEHandshakeTest::doHandshake(const SharedHandle<MSEHandshake>& initiator, const SharedHandle<MSEHandshake>& receiver) void MSEHandshakeTest::doHandshake(const SharedHandle<MSEHandshake>& initiator, const SharedHandle<MSEHandshake>& receiver)
{ {
initiator->sendPublicKey(); initiator->sendPublicKey();
while(initiator->getWantWrite()) {
while(!receiver->receivePublicKey()); initiator->send();
}
while(!receiver->receivePublicKey()) {
receiver->read();
}
receiver->sendPublicKey(); receiver->sendPublicKey();
while(receiver->getWantWrite()) {
receiver->send();
}
while(!initiator->receivePublicKey()); while(!initiator->receivePublicKey()) {
initiator->read();
}
initiator->initCipher(bittorrent::getInfoHash(dctx_)); initiator->initCipher(bittorrent::getInfoHash(dctx_));
initiator->sendInitiatorStep2(); initiator->sendInitiatorStep2();
while(initiator->getWantWrite()) {
initiator->send();
}
while(!receiver->findReceiverHashMarker()); while(!receiver->findReceiverHashMarker()) {
receiver->read();
}
std::vector<SharedHandle<DownloadContext> > contexts; std::vector<SharedHandle<DownloadContext> > contexts;
contexts.push_back(dctx_); contexts.push_back(dctx_);
while(!receiver->receiveReceiverHashAndPadCLength(contexts)); while(!receiver->receiveReceiverHashAndPadCLength(contexts)) {
while(!receiver->receivePad()); receiver->read();
while(!receiver->receiveReceiverIALength()); }
while(!receiver->receiveReceiverIA()); while(!receiver->receivePad()) {
receiver->read();
}
while(!receiver->receiveReceiverIALength()) {
receiver->read();
}
while(!receiver->receiveReceiverIA()) {
receiver->read();
}
receiver->sendReceiverStep2(); receiver->sendReceiverStep2();
while(receiver->getWantWrite()) {
receiver->send();
}
while(!initiator->findInitiatorVCMarker()); while(!initiator->findInitiatorVCMarker()) {
while(!initiator->receiveInitiatorCryptoSelectAndPadDLength()); initiator->read();
while(!initiator->receivePad()); }
while(!initiator->receiveInitiatorCryptoSelectAndPadDLength()) {
initiator->read();
}
while(!initiator->receivePad()) {
initiator->read();
}
} }
namespace { namespace {