From ce2d401dce35b94c15120a4cac441a4614658e3d Mon Sep 17 00:00:00 2001 From: Tatsuhiro Tsujikawa Date: Sat, 8 Jan 2011 17:32:16 +0900 Subject: [PATCH] Eliminated SocketCore::peekData from MSEHandshake. --- src/InitiatorMSEHandshakeCommand.cc | 161 +++++++------ src/MSEHandshake.cc | 357 ++++++++++++---------------- src/MSEHandshake.h | 31 ++- src/PeerConnection.cc | 17 +- src/PeerConnection.h | 5 + src/PeerInteractionCommand.cc | 128 +++++----- src/PeerReceiveHandshakeCommand.cc | 7 +- src/ReceiverMSEHandshakeCommand.cc | 202 +++++++++------- test/MSEHandshakeTest.cc | 53 ++++- 9 files changed, 523 insertions(+), 438 deletions(-) diff --git a/src/InitiatorMSEHandshakeCommand.cc b/src/InitiatorMSEHandshakeCommand.cc index b5ec29a2..c8795685 100644 --- a/src/InitiatorMSEHandshakeCommand.cc +++ b/src/InitiatorMSEHandshakeCommand.cc @@ -56,6 +56,7 @@ #include "bittorrent_helper.h" #include "util.h" #include "fmt.h" +#include "array_fun.h" namespace aria2 { @@ -89,82 +90,108 @@ InitiatorMSEHandshakeCommand::~InitiatorMSEHandshakeCommand() } bool InitiatorMSEHandshakeCommand::executeInternal() { - switch(sequence_) { - case INITIATOR_SEND_KEY: { - if(!getSocket()->isWritable(0)) { + if(mseHandshake_->getWantRead()) { + mseHandshake_->read(); + } + bool done = false; + while(!done) { + switch(sequence_) { + case INITIATOR_SEND_KEY: { + if(!getSocket()->isWritable(0)) { + getDownloadEngine()->addCommand(this); + return false; + } + setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT)); + mseHandshake_->initEncryptionFacility(true); + mseHandshake_->sendPublicKey(); + sequence_ = INITIATOR_SEND_KEY_PENDING; break; } - disableWriteCheckSocket(); - setReadCheckSocket(getSocket()); - //socket->setBlockingMode(); - setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT)); - mseHandshake_->initEncryptionFacility(true); - if(mseHandshake_->sendPublicKey()) { - sequence_ = INITIATOR_WAIT_KEY; - } else { - setWriteCheckSocket(getSocket()); - sequence_ = INITIATOR_SEND_KEY_PENDING; + case INITIATOR_SEND_KEY_PENDING: + if(mseHandshake_->send()) { + sequence_ = INITIATOR_WAIT_KEY; + } else { + done = true; + } + break; + case INITIATOR_WAIT_KEY: { + if(mseHandshake_->receivePublicKey()) { + mseHandshake_->initCipher + (bittorrent::getInfoHash(requestGroup_->getDownloadContext()));; + mseHandshake_->sendInitiatorStep2(); + sequence_ = INITIATOR_SEND_STEP2_PENDING; + } else { + done = true; + } + break; } - break; - } - case INITIATOR_SEND_KEY_PENDING: - if(mseHandshake_->sendPublicKey()) { - disableWriteCheckSocket(); - sequence_ = INITIATOR_WAIT_KEY; - } - break; - case INITIATOR_WAIT_KEY: { - if(mseHandshake_->receivePublicKey()) { - mseHandshake_->initCipher - (bittorrent::getInfoHash(requestGroup_->getDownloadContext()));; - if(mseHandshake_->sendInitiatorStep2()) { + case INITIATOR_SEND_STEP2_PENDING: + if(mseHandshake_->send()) { sequence_ = INITIATOR_FIND_VC_MARKER; } else { - setWriteCheckSocket(getSocket()); - sequence_ = INITIATOR_SEND_STEP2_PENDING; + done = true; } - } - break; - } - case INITIATOR_SEND_STEP2_PENDING: - if(mseHandshake_->sendInitiatorStep2()) { - disableWriteCheckSocket(); - sequence_ = INITIATOR_FIND_VC_MARKER; - } - break; - case INITIATOR_FIND_VC_MARKER: { - if(mseHandshake_->findInitiatorVCMarker()) { - sequence_ = INITIATOR_RECEIVE_PAD_D_LENGTH; - } - break; - } - case INITIATOR_RECEIVE_PAD_D_LENGTH: { - if(mseHandshake_->receiveInitiatorCryptoSelectAndPadDLength()) { - sequence_ = INITIATOR_RECEIVE_PAD_D; - } - break; - } - case INITIATOR_RECEIVE_PAD_D: { - if(mseHandshake_->receivePad()) { - SharedHandle peerConnection - (new PeerConnection(getCuid(), getPeer(), getSocket())); - if(mseHandshake_->getNegotiatedCryptoType() == MSEHandshake::CRYPTO_ARC4){ - peerConnection->enableEncryption(mseHandshake_->getEncryptor(), - mseHandshake_->getDecryptor()); + break; + case INITIATOR_FIND_VC_MARKER: { + if(mseHandshake_->findInitiatorVCMarker()) { + sequence_ = INITIATOR_RECEIVE_PAD_D_LENGTH; + } else { + done = true; } - PeerInteractionCommand* c = - new PeerInteractionCommand - (getCuid(), requestGroup_, getPeer(), getDownloadEngine(), btRuntime_, - pieceStorage_, - peerStorage_, - getSocket(), - PeerInteractionCommand::INITIATOR_SEND_HANDSHAKE, - peerConnection); - getDownloadEngine()->addCommand(c); - return true; + break; + } + case INITIATOR_RECEIVE_PAD_D_LENGTH: { + if(mseHandshake_->receiveInitiatorCryptoSelectAndPadDLength()) { + sequence_ = INITIATOR_RECEIVE_PAD_D; + } else { + done = true; + } + break; + } + case INITIATOR_RECEIVE_PAD_D: { + if(mseHandshake_->receivePad()) { + SharedHandle peerConnection + (new PeerConnection(getCuid(), getPeer(), getSocket())); + if(mseHandshake_->getNegotiatedCryptoType() == + MSEHandshake::CRYPTO_ARC4){ + peerConnection->enableEncryption(mseHandshake_->getEncryptor(), + mseHandshake_->getDecryptor()); + size_t buflen = mseHandshake_->getBufferLength(); + array_ptr buffer(new unsigned char[buflen]); + mseHandshake_->getDecryptor()->decrypt(buffer, buflen, + mseHandshake_->getBuffer(), + buflen); + peerConnection->presetBuffer(buffer, buflen); + } else { + peerConnection->presetBuffer(mseHandshake_->getBuffer(), + mseHandshake_->getBufferLength()); + } + PeerInteractionCommand* c = + new PeerInteractionCommand + (getCuid(), requestGroup_, getPeer(), getDownloadEngine(), btRuntime_, + pieceStorage_, + peerStorage_, + getSocket(), + PeerInteractionCommand::INITIATOR_SEND_HANDSHAKE, + peerConnection); + getDownloadEngine()->addCommand(c); + return true; + } else { + done = true; + } + break; + } } - break; } + if(mseHandshake_->getWantRead()) { + setReadCheckSocket(getSocket()); + } else { + disableReadCheckSocket(); + } + if(mseHandshake_->getWantWrite()) { + setWriteCheckSocket(getSocket()); + } else { + disableWriteCheckSocket(); } getDownloadEngine()->addCommand(this); return false; diff --git a/src/MSEHandshake.cc b/src/MSEHandshake.cc index b82a9a10..d073381a 100644 --- a/src/MSEHandshake.cc +++ b/src/MSEHandshake.cc @@ -71,6 +71,7 @@ MSEHandshake::MSEHandshake const Option* op) : cuid_(cuid), socket_(socket), + wantRead_(false), option_(op), rbufLength_(0), socketBuffer_(socket), @@ -92,16 +93,8 @@ MSEHandshake::~MSEHandshake() MSEHandshake::HANDSHAKE_TYPE MSEHandshake::identifyHandshakeType() { - if(!socket_->isReadable(0)) { - return HANDSHAKE_NOT_YET; - } - size_t r = 20-rbufLength_; - socket_->readData(rbuf_+rbufLength_, r); - if(r == 0 && !socket_->wantRead() && !socket_->wantWrite()) { - throw DL_ABORT_EX(EX_EOF_FROM_PEER); - } - rbufLength_ += r; if(rbufLength_ < 20) { + wantRead_ = true; return HANDSHAKE_NOT_YET; } if(rbuf_[0] == BtHandshakeMessage::PSTR_LENGTH && @@ -126,35 +119,59 @@ void MSEHandshake::initEncryptionFacility(bool initiator) initiator_ = initiator; } -bool MSEHandshake::sendPublicKey() +void MSEHandshake::sendPublicKey() { - if(socketBuffer_.sendBufferIsEmpty()) { - A2_LOG_DEBUG(fmt("CUID#%lld - Sending public key.", - cuid_)); - unsigned char buffer[KEY_LENGTH+MAX_PAD_LENGTH]; - dh_->getPublicKey(buffer, KEY_LENGTH); + A2_LOG_DEBUG(fmt("CUID#%lld - Sending public key.", + cuid_)); + unsigned char buffer[KEY_LENGTH+MAX_PAD_LENGTH]; + dh_->getPublicKey(buffer, KEY_LENGTH); - size_t padLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1); - dh_->generateNonce(buffer+KEY_LENGTH, padLength); - socketBuffer_.pushStr(std::string(&buffer[0], - &buffer[KEY_LENGTH+padLength])); + size_t padLength = + SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1); + dh_->generateNonce(buffer+KEY_LENGTH, padLength); + socketBuffer_.pushStr(std::string(&buffer[0], + &buffer[KEY_LENGTH+padLength])); +} + +void MSEHandshake::read() +{ + if(rbufLength_ >= MAX_BUFFER_LENGTH) { + assert(!wantRead_); + return; } + size_t len = MAX_BUFFER_LENGTH-rbufLength_; + socket_->readData(rbuf_+rbufLength_, len); + if(len == 0 && !socket_->wantRead() && !socket_->wantWrite()) { + // TODO Should we set graceful in peer? + throw DL_ABORT_EX(EX_EOF_FROM_PEER); + } + rbufLength_ += len; + wantRead_ = false; +} + +bool MSEHandshake::send() +{ socketBuffer_.send(); return socketBuffer_.sendBufferIsEmpty(); } +void MSEHandshake::shiftBuffer(size_t offset) +{ + memmove(rbuf_, rbuf_+offset, rbufLength_-offset); + rbufLength_ -= offset; +} + bool MSEHandshake::receivePublicKey() { - size_t r = KEY_LENGTH-rbufLength_; - if(r > receiveNBytes(r)) { + if(rbufLength_ < KEY_LENGTH) { + wantRead_ = true; return false; } - A2_LOG_DEBUG(fmt("CUID#%lld - public key received.", - cuid_)); + A2_LOG_DEBUG(fmt("CUID#%lld - public key received.", cuid_)); // TODO handle exception. in catch, resbufLength = 0; - dh_->computeSecret(secret_, sizeof(secret_), rbuf_, rbufLength_); - // reset rbufLength_ - rbufLength_ = 0; + dh_->computeSecret(secret_, sizeof(secret_), rbuf_, KEY_LENGTH); + // shift buffer + shiftBuffer(KEY_LENGTH); return true; } @@ -251,109 +268,83 @@ uint16_t MSEHandshake::decodeLength16(const unsigned char* buffer) return ntohs(be); } -bool MSEHandshake::sendInitiatorStep2() +void MSEHandshake::sendInitiatorStep2() { - if(socketBuffer_.sendBufferIsEmpty()) { - A2_LOG_DEBUG(fmt("CUID#%lld - Sending negotiation step2.", - cuid_)); - unsigned char md[20]; - createReq1Hash(md); - socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)])); + A2_LOG_DEBUG(fmt("CUID#%lld - Sending negotiation step2.", cuid_)); + unsigned char md[20]; + createReq1Hash(md); + socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)])); + createReq23Hash(md, infoHash_); + socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)])); - createReq23Hash(md, infoHash_); - socketBuffer_.pushStr(std::string(&md[0], &md[sizeof(md)])); + // buffer is filled in this order: + // VC(VC_LENGTH bytes), + // crypto_provide(CRYPTO_BITFIELD_LENGTH bytes), + // len(padC)(2bytes), + // padC(len(padC)bytes <= MAX_PAD_LENGTH), + // len(IA)(2bytes) + unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH+2]; - { - // buffer is filled in this order: - // VC(VC_LENGTH bytes), - // crypto_provide(CRYPTO_BITFIELD_LENGTH bytes), - // len(padC)(2bytes), - // padC(len(padC)bytes <= MAX_PAD_LENGTH), - // len(IA)(2bytes) - unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH+2]; - - // VC - memcpy(buffer, VC, sizeof(VC)); - // crypto_provide - unsigned char cryptoProvide[CRYPTO_BITFIELD_LENGTH]; - memset(cryptoProvide, 0, sizeof(cryptoProvide)); - if(option_->get(PREF_BT_MIN_CRYPTO_LEVEL) == V_PLAIN) { - cryptoProvide[3] = CRYPTO_PLAIN_TEXT; - } - cryptoProvide[3] |= CRYPTO_ARC4; - memcpy(buffer+VC_LENGTH, cryptoProvide, sizeof(cryptoProvide)); - - // len(padC) - uint16_t padCLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1); - { - uint16_t padCLengthBE = htons(padCLength); - memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padCLengthBE, - sizeof(padCLengthBE)); - } - // padC - memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padCLength); - // len(IA) - // currently, IA is zero-length. - uint16_t iaLength = 0; - { - uint16_t iaLengthBE = htons(iaLength); - memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength, - &iaLengthBE,sizeof(iaLengthBE)); - } - encryptAndSendData(buffer, - VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength+2); - } + // VC + memcpy(buffer, VC, sizeof(VC)); + // crypto_provide + unsigned char cryptoProvide[CRYPTO_BITFIELD_LENGTH]; + memset(cryptoProvide, 0, sizeof(cryptoProvide)); + if(option_->get(PREF_BT_MIN_CRYPTO_LEVEL) == V_PLAIN) { + cryptoProvide[3] = CRYPTO_PLAIN_TEXT; } - socketBuffer_.send(); - return socketBuffer_.sendBufferIsEmpty(); + cryptoProvide[3] |= CRYPTO_ARC4; + memcpy(buffer+VC_LENGTH, cryptoProvide, sizeof(cryptoProvide)); + + // len(padC) + uint16_t padCLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1); + { + uint16_t padCLengthBE = htons(padCLength); + memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padCLengthBE, + sizeof(padCLengthBE)); + } + // padC + memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padCLength); + // len(IA) + // currently, IA is zero-length. + uint16_t iaLength = 0; + { + uint16_t iaLengthBE = htons(iaLength); + memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength, + &iaLengthBE,sizeof(iaLengthBE)); + } + encryptAndSendData(buffer, + VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padCLength+2); } // This function reads exactly until the end of VC marker is reached. bool MSEHandshake::findInitiatorVCMarker() { // 616 is synchronization point of initiator - size_t r = 616-KEY_LENGTH-rbufLength_; - if(!socket_->isReadable(0)) { - return false; - } - socket_->peekData(rbuf_+rbufLength_, r); - if(r == 0) { - if(socket_->wantRead() || socket_->wantWrite()) { + // find vc + std::string buf(&rbuf_[0], &rbuf_[rbufLength_]); + std::string vc(&initiatorVCMarker_[0], &initiatorVCMarker_[VC_LENGTH]); + if((markerIndex_ = buf.find(vc)) == std::string::npos) { + if(616-KEY_LENGTH <= rbufLength_) { + throw DL_ABORT_EX("Failed to find VC marker."); + } else { + wantRead_ = true; return false; } - throw DL_ABORT_EX(EX_EOF_FROM_PEER); } - // find vc - { - std::string buf(&rbuf_[0], &rbuf_[rbufLength_+r]); - std::string vc(&initiatorVCMarker_[0], &initiatorVCMarker_[VC_LENGTH]); - if((markerIndex_ = buf.find(vc)) == std::string::npos) { - if(616-KEY_LENGTH <= rbufLength_+r) { - throw DL_ABORT_EX("Failed to find VC marker."); - } else { - socket_->readData(rbuf_+rbufLength_, r); - rbufLength_ += r; - return false; - } - } - } - assert(markerIndex_+VC_LENGTH-rbufLength_ <= r); - size_t toRead = markerIndex_+VC_LENGTH-rbufLength_; - socket_->readData(rbuf_+rbufLength_, toRead); - rbufLength_ += toRead; A2_LOG_DEBUG(fmt("CUID#%lld - VC marker found at %lu", cuid_, static_cast(markerIndex_))); verifyVC(rbuf_+markerIndex_); - // reset rbufLength_ - rbufLength_ = 0; + // shift rbuf + shiftBuffer(markerIndex_+VC_LENGTH); return true; } bool MSEHandshake::receiveInitiatorCryptoSelectAndPadDLength() { - size_t r = CRYPTO_BITFIELD_LENGTH+2/* PadD length*/-rbufLength_; - if(r > receiveNBytes(r)) { + if(CRYPTO_BITFIELD_LENGTH+2/* PadD length*/ > rbufLength_) { + wantRead_ = true; return false; } //verifyCryptoSelect @@ -382,75 +373,57 @@ bool MSEHandshake::receiveInitiatorCryptoSelectAndPadDLength() // padD length rbufptr += CRYPTO_BITFIELD_LENGTH; padLength_ = verifyPadLength(rbufptr, "PadD"); - // reset rbufLength_ - rbufLength_ = 0; + // shift rbuf + shiftBuffer(CRYPTO_BITFIELD_LENGTH+2/* PadD length*/); return true; } bool MSEHandshake::receivePad() { + if(padLength_ > rbufLength_) { + wantRead_ = true; + return false; + } if(padLength_ == 0) { return true; } - size_t r = padLength_-rbufLength_; - if(r > receiveNBytes(r)) { - return false; - } unsigned char temp[MAX_PAD_LENGTH]; decryptor_->decrypt(temp, padLength_, rbuf_, padLength_); - // reset rbufLength_ - rbufLength_ = 0; + // shift rbuf_ + shiftBuffer(padLength_); return true; } bool MSEHandshake::findReceiverHashMarker() { // 628 is synchronization limit of receiver. - size_t r = 628-KEY_LENGTH-rbufLength_; - if(!socket_->isReadable(0)) { - return false; - } - socket_->peekData(rbuf_+rbufLength_, r); - if(r == 0) { - if(socket_->wantRead() || socket_->wantWrite()) { + // find hash('req1', S), S is secret_. + std::string buf(&rbuf_[0], &rbuf_[rbufLength_]); + unsigned char md[20]; + createReq1Hash(md); + std::string req1(&md[0], &md[sizeof(md)]); + if((markerIndex_ = buf.find(req1)) == std::string::npos) { + if(628-KEY_LENGTH <= rbufLength_) { + throw DL_ABORT_EX("Failed to find hash marker."); + } else { + wantRead_ = true; return false; } - throw DL_ABORT_EX(EX_EOF_FROM_PEER); } - // find hash('req1', S), S is secret_. - { - std::string buf(&rbuf_[0], &rbuf_[rbufLength_+r]); - unsigned char md[20]; - createReq1Hash(md); - std::string req1(&md[0], &md[sizeof(md)]); - if((markerIndex_ = buf.find(req1)) == std::string::npos) { - if(628-KEY_LENGTH <= rbufLength_+r) { - throw DL_ABORT_EX("Failed to find hash marker."); - } else { - socket_->readData(rbuf_+rbufLength_, r); - rbufLength_ += r; - return false; - } - } - } - assert(markerIndex_+20-rbufLength_ <= r); - size_t toRead = markerIndex_+20-rbufLength_; - socket_->readData(rbuf_+rbufLength_, toRead); - rbufLength_ += toRead; A2_LOG_DEBUG(fmt("CUID#%lld - Hash marker found at %lu.", cuid_, static_cast(markerIndex_))); verifyReq1Hash(rbuf_+markerIndex_); - // reset rbufLength_ - rbufLength_ = 0; + // shift rbuf_ + shiftBuffer(markerIndex_+20); return true; } bool MSEHandshake::receiveReceiverHashAndPadCLength (const std::vector >& downloadContexts) { - size_t r = 20+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2/*PadC length*/-rbufLength_; - if(r > receiveNBytes(r)) { + if(20+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2/*PadC length*/ > rbufLength_) { + wantRead_ = true; return false; } // resolve info hash @@ -505,23 +478,22 @@ bool MSEHandshake::receiveReceiverHashAndPadCLength // decrypt PadC length rbufptr += CRYPTO_BITFIELD_LENGTH; padLength_ = verifyPadLength(rbufptr, "PadC"); - // reset rbufLength_ - rbufLength_ = 0; + // shift rbuf_ + shiftBuffer(20+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2/*PadC length*/); return true; } bool MSEHandshake::receiveReceiverIALength() { - size_t r = 2-rbufLength_; - assert(r > 0); - if(r > receiveNBytes(r)) { + if(2 > rbufLength_) { + wantRead_ = true; return false; } iaLength_ = decodeLength16(rbuf_); - A2_LOG_DEBUG(fmt("CUID#%lld - len(IA)=%u.", - cuid_, iaLength_)); - // reset rbufLength_ - rbufLength_ = 0; + // TODO limit iaLength \19...+handshake + A2_LOG_DEBUG(fmt("CUID#%lld - len(IA)=%u.", cuid_, iaLength_)); + // shift rbuf_ + shiftBuffer(2); return true; } @@ -530,48 +502,44 @@ bool MSEHandshake::receiveReceiverIA() if(iaLength_ == 0) { return true; } - size_t r = iaLength_-rbufLength_; - if(r > receiveNBytes(r)) { + if(iaLength_ > rbufLength_) { + wantRead_ = true; return false; } delete [] ia_; ia_ = new unsigned char[iaLength_]; decryptor_->decrypt(ia_, iaLength_, rbuf_, iaLength_); A2_LOG_DEBUG(fmt("CUID#%lld - IA received.", cuid_)); - // reset rbufLength_ - rbufLength_ = 0; + // shift rbuf_ + shiftBuffer(iaLength_); return true; } -bool MSEHandshake::sendReceiverStep2() +void MSEHandshake::sendReceiverStep2() { - if(socketBuffer_.sendBufferIsEmpty()) { - // buffer is filled in this order: - // VC(VC_LENGTH bytes), - // cryptoSelect(CRYPTO_BITFIELD_LENGTH bytes), - // len(padD)(2bytes), - // padD(len(padD)bytes <= MAX_PAD_LENGTH) - unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH]; - // VC - memcpy(buffer, VC, sizeof(VC)); - // crypto_select - unsigned char cryptoSelect[CRYPTO_BITFIELD_LENGTH]; - memset(cryptoSelect, 0, sizeof(cryptoSelect)); - cryptoSelect[3] = negotiatedCryptoType_; - memcpy(buffer+VC_LENGTH, cryptoSelect, sizeof(cryptoSelect)); - // len(padD) - uint16_t padDLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1); - { - uint16_t padDLengthBE = htons(padDLength); - memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padDLengthBE, - sizeof(padDLengthBE)); - } - // padD, all zeroed - memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padDLength); - encryptAndSendData(buffer, VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padDLength); + // buffer is filled in this order: + // VC(VC_LENGTH bytes), + // cryptoSelect(CRYPTO_BITFIELD_LENGTH bytes), + // len(padD)(2bytes), + // padD(len(padD)bytes <= MAX_PAD_LENGTH) + unsigned char buffer[VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+MAX_PAD_LENGTH]; + // VC + memcpy(buffer, VC, sizeof(VC)); + // crypto_select + unsigned char cryptoSelect[CRYPTO_BITFIELD_LENGTH]; + memset(cryptoSelect, 0, sizeof(cryptoSelect)); + cryptoSelect[3] = negotiatedCryptoType_; + memcpy(buffer+VC_LENGTH, cryptoSelect, sizeof(cryptoSelect)); + // len(padD) + uint16_t padDLength = SimpleRandomizer::getInstance()->getRandomNumber(MAX_PAD_LENGTH+1); + { + uint16_t padDLengthBE = htons(padDLength); + memcpy(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH, &padDLengthBE, + sizeof(padDLengthBE)); } - socketBuffer_.send(); - return socketBuffer_.sendBufferIsEmpty(); + // padD, all zeroed + memset(buffer+VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2, 0, padDLength); + encryptAndSendData(buffer, VC_LENGTH+CRYPTO_BITFIELD_LENGTH+2+padDLength); } uint16_t MSEHandshake::verifyPadLength(const unsigned char* padlenbuf, const char* padName) @@ -609,20 +577,9 @@ void MSEHandshake::verifyReq1Hash(const unsigned char* req1buf) } } -size_t MSEHandshake::receiveNBytes(size_t bytes) +bool MSEHandshake::getWantWrite() const { - size_t r = bytes; - if(r > 0) { - if(!socket_->isReadable(0)) { - return 0; - } - socket_->readData(rbuf_+rbufLength_, r); - if(r == 0 && !socket_->wantRead() && !socket_->wantWrite()) { - throw DL_ABORT_EX(EX_EOF_FROM_PEER); - } - rbufLength_ += r; - } - return r; + return !socketBuffer_.sendBufferIsEmpty(); } } // namespace aria2 diff --git a/src/MSEHandshake.h b/src/MSEHandshake.h index 4e22d407..9c8cf12f 100644 --- a/src/MSEHandshake.h +++ b/src/MSEHandshake.h @@ -83,6 +83,7 @@ private: cuid_t cuid_; SharedHandle socket_; + bool wantRead_; const Option* option_; unsigned char rbuf_[MAX_BUFFER_LENGTH]; @@ -130,8 +131,7 @@ private: void verifyReq1Hash(const unsigned char* req1buf); - size_t receiveNBytes(size_t bytes); - + void shiftBuffer(size_t offset); public: MSEHandshake(cuid_t cuid, const SharedHandle& socket, const Option* op); @@ -142,13 +142,33 @@ public: void initEncryptionFacility(bool initiator); - bool sendPublicKey(); + // Reads data from Socket. If EOF is reached, throws + // RecoverableException. + void read(); + + // Sends pending data in the send buffer. Returns true if all data + // is sent. Otherwise returns false. + bool send(); + + bool getWantRead() const + { + return wantRead_; + } + + void setWantRead(bool wantRead) + { + wantRead_ = wantRead; + } + + bool getWantWrite() const; + + void sendPublicKey(); bool receivePublicKey(); void initCipher(const unsigned char* infoHash); - bool sendInitiatorStep2(); + void sendInitiatorStep2(); bool findInitiatorVCMarker(); @@ -165,7 +185,7 @@ public: bool receiveReceiverIA(); - bool sendReceiverStep2(); + void sendReceiverStep2(); // returns plain text IA const unsigned char* getIA() const @@ -207,7 +227,6 @@ public: { return rbufLength_; } - }; } // namespace aria2 diff --git a/src/PeerConnection.cc b/src/PeerConnection.cc index 8da69e17..c6b78c90 100644 --- a/src/PeerConnection.cc +++ b/src/PeerConnection.cc @@ -110,9 +110,6 @@ void PeerConnection::pushBytes(unsigned char* data, size_t len) bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) { if(resbufLength_ == 0 && 4 > lenbufLength_) { - if(!socket_->isReadable(0)) { - return false; - } // read payload size, 32bit unsigned integer size_t remaining = 4-lenbufLength_; size_t temp = remaining; @@ -182,7 +179,7 @@ bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength, bool peek) { assert(BtHandshakeMessage::MESSAGE_LENGTH >= resbufLength_); bool retval = true; - if(prevPeek_ && !peek && resbufLength_) { + if(prevPeek_ && resbufLength_) { // We have data in previous peek. // There is a chance that socket is readable because of EOF, for example, // official bttrack shutdowns socket after sending first 48 bytes of @@ -194,17 +191,10 @@ bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength, } else { prevPeek_ = peek; size_t remaining = BtHandshakeMessage::MESSAGE_LENGTH-resbufLength_; - if(remaining > 0 && !socket_->isReadable(0)) { - dataLength = 0; - return false; - } if(remaining > 0) { size_t temp = remaining; readData(resbuf_+resbufLength_, remaining, encryptionEnabled_); - if(remaining == 0) { - if(socket_->wantRead() || socket_->wantWrite()) { - return false; - } + if(remaining == 0 && !socket_->wantRead() && !socket_->wantWrite()) { // we got EOF A2_LOG_DEBUG (fmt("CUID#%lld - In PeerConnection::receiveHandshake(), remain=%lu", @@ -256,6 +246,9 @@ void PeerConnection::presetBuffer(const unsigned char* data, size_t length) size_t nwrite = std::min((size_t)MAX_PAYLOAD_LEN, length); memcpy(resbuf_, data, nwrite); resbufLength_ = length; + if(resbufLength_ > 0) { + prevPeek_ = true; + } } bool PeerConnection::sendBufferIsEmpty() const diff --git a/src/PeerConnection.h b/src/PeerConnection.h index 0557204d..15676381 100644 --- a/src/PeerConnection.h +++ b/src/PeerConnection.h @@ -117,6 +117,11 @@ public: return resbuf_; } + size_t getBufferLength() const + { + return resbufLength_; + } + unsigned char* detachBuffer(); }; diff --git a/src/PeerInteractionCommand.cc b/src/PeerInteractionCommand.cc index dece90a8..bb8ea4bb 100644 --- a/src/PeerInteractionCommand.cc +++ b/src/PeerInteractionCommand.cc @@ -163,6 +163,11 @@ PeerInteractionCommand::PeerInteractionCommand peerConnection.reset(new PeerConnection(cuid, getPeer(), getSocket())); } else { peerConnection = passedPeerConnection; + if(sequence_ == RECEIVER_WAIT_HANDSHAKE && + peerConnection->getBufferLength() > 0) { + setStatus(Command::STATUS_ONESHOT_REALTIME); + getDownloadEngine()->setNoWait(true); + } } SharedHandle dispatcher @@ -274,71 +279,80 @@ PeerInteractionCommand::~PeerInteractionCommand() { bool PeerInteractionCommand::executeInternal() { setNoCheck(false); - switch(sequence_) { - case INITIATOR_SEND_HANDSHAKE: - if(!getSocket()->isWritable(0)) { - break; - } - disableWriteCheckSocket(); - setReadCheckSocket(getSocket()); - //socket->setBlockingMode(); - setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT)); - btInteractive_->initiateHandshake(); - sequence_ = INITIATOR_WAIT_HANDSHAKE; - break; - case INITIATOR_WAIT_HANDSHAKE: { - if(btInteractive_->countPendingMessage() > 0) { - btInteractive_->sendPendingMessage(); - if(btInteractive_->countPendingMessage() > 0) { + bool done = false; + while(!done) { + switch(sequence_) { + case INITIATOR_SEND_HANDSHAKE: + if(!getSocket()->isWritable(0)) { + done = true; break; } - } - BtMessageHandle handshakeMessage = btInteractive_->receiveHandshake(); - if(!handshakeMessage) { + disableWriteCheckSocket(); + setReadCheckSocket(getSocket()); + //socket->setBlockingMode(); + setTimeout(getOption()->getAsInt(PREF_BT_TIMEOUT)); + btInteractive_->initiateHandshake(); + sequence_ = INITIATOR_WAIT_HANDSHAKE; break; - } - btInteractive_->doPostHandshakeProcessing(); - sequence_ = WIRED; - break; - } - case RECEIVER_WAIT_HANDSHAKE: { - BtMessageHandle handshakeMessage =btInteractive_->receiveAndSendHandshake(); - if(!handshakeMessage) { - break; - } - btInteractive_->doPostHandshakeProcessing(); - sequence_ = WIRED; - break; - } - case WIRED: - // See the comment for writable check below. - disableWriteCheckSocket(); - - btInteractive_->doInteractionProcessing(); - if(btInteractive_->countReceivedMessageInIteration() > 0) { - updateKeepAlive(); - } - if((getPeer()->amInterested() && !getPeer()->peerChoking()) || - btInteractive_->countOutstandingRequest() || - (getPeer()->peerInterested() && !getPeer()->amChoking())) { - - // Writable check to avoid slow seeding - if(btInteractive_->isSendingMessageInProgress()) { - setWriteCheckSocket(getSocket()); + case INITIATOR_WAIT_HANDSHAKE: { + if(btInteractive_->countPendingMessage() > 0) { + btInteractive_->sendPendingMessage(); + if(btInteractive_->countPendingMessage() > 0) { + done = true; + break; + } } + BtMessageHandle handshakeMessage = btInteractive_->receiveHandshake(); + if(!handshakeMessage) { + done = true; + break; + } + btInteractive_->doPostHandshakeProcessing(); + sequence_ = WIRED; + break; + } + case RECEIVER_WAIT_HANDSHAKE: { + BtMessageHandle handshakeMessage = + btInteractive_->receiveAndSendHandshake(); + if(!handshakeMessage) { + done = true; + break; + } + btInteractive_->doPostHandshakeProcessing(); + sequence_ = WIRED; + break; + } + case WIRED: + // See the comment for writable check below. + disableWriteCheckSocket(); - if(getDownloadEngine()->getRequestGroupMan()-> - doesOverallDownloadSpeedExceed() || - requestGroup_->doesDownloadSpeedExceed()) { - disableReadCheckSocket(); - setNoCheck(true); + btInteractive_->doInteractionProcessing(); + if(btInteractive_->countReceivedMessageInIteration() > 0) { + updateKeepAlive(); + } + if((getPeer()->amInterested() && !getPeer()->peerChoking()) || + btInteractive_->countOutstandingRequest() || + (getPeer()->peerInterested() && !getPeer()->amChoking())) { + + // Writable check to avoid slow seeding + if(btInteractive_->isSendingMessageInProgress()) { + setWriteCheckSocket(getSocket()); + } + + if(getDownloadEngine()->getRequestGroupMan()-> + doesOverallDownloadSpeedExceed() || + requestGroup_->doesDownloadSpeedExceed()) { + disableReadCheckSocket(); + setNoCheck(true); + } else { + setReadCheckSocket(getSocket()); + } } else { - setReadCheckSocket(getSocket()); + disableReadCheckSocket(); } - } else { - disableReadCheckSocket(); + done = true; + break; } - break; } if(btInteractive_->countPendingMessage() > 0) { setNoCheck(true); diff --git a/src/PeerReceiveHandshakeCommand.cc b/src/PeerReceiveHandshakeCommand.cc index bddf352f..5af632a5 100644 --- a/src/PeerReceiveHandshakeCommand.cc +++ b/src/PeerReceiveHandshakeCommand.cc @@ -67,7 +67,12 @@ PeerReceiveHandshakeCommand::PeerReceiveHandshakeCommand : PeerAbstractCommand(cuid, peer, e, s), peerConnection_(peerConnection) { - if(!peerConnection_) { + if(peerConnection_) { + if(peerConnection_->getBufferLength() > 0) { + setStatus(Command::STATUS_ONESHOT_REALTIME); + getDownloadEngine()->setNoWait(true); + } + } else { peerConnection_.reset(new PeerConnection(cuid, getPeer(), getSocket())); } } diff --git a/src/ReceiverMSEHandshakeCommand.cc b/src/ReceiverMSEHandshakeCommand.cc index 14edea78..17748e8c 100644 --- a/src/ReceiverMSEHandshakeCommand.cc +++ b/src/ReceiverMSEHandshakeCommand.cc @@ -50,6 +50,7 @@ #include "RequestGroupMan.h" #include "BtRegistry.h" #include "DownloadContext.h" +#include "array_fun.h" namespace aria2 { @@ -64,6 +65,7 @@ ReceiverMSEHandshakeCommand::ReceiverMSEHandshakeCommand mseHandshake_(new MSEHandshake(cuid, s, e->getOption())) { setTimeout(e->getOption()->getAsInt(PREF_PEER_CONNECTION_TIMEOUT)); + mseHandshake_->setWantRead(true); } ReceiverMSEHandshakeCommand::~ReceiverMSEHandshakeCommand() @@ -79,102 +81,125 @@ bool ReceiverMSEHandshakeCommand::exitBeforeExecute() bool ReceiverMSEHandshakeCommand::executeInternal() { - switch(sequence_) { - case RECEIVER_IDENTIFY_HANDSHAKE: { - MSEHandshake::HANDSHAKE_TYPE type = mseHandshake_->identifyHandshakeType(); - switch(type) { - case MSEHandshake::HANDSHAKE_NOT_YET: - break; - case MSEHandshake::HANDSHAKE_ENCRYPTED: - mseHandshake_->initEncryptionFacility(false); - sequence_ = RECEIVER_WAIT_KEY; - break; - case MSEHandshake::HANDSHAKE_LEGACY: { - if(getDownloadEngine()->getOption()->getAsBool(PREF_BT_REQUIRE_CRYPTO)) { - throw DL_ABORT_EX - ("The legacy BitTorrent handshake is not acceptable by the" - " preference."); - } - SharedHandle peerConnection - (new PeerConnection(getCuid(), getPeer(), getSocket())); - peerConnection->presetBuffer(mseHandshake_->getBuffer(), - mseHandshake_->getBufferLength()); - Command* c = new PeerReceiveHandshakeCommand(getCuid(), - getPeer(), - getDownloadEngine(), - getSocket(), - peerConnection); - getDownloadEngine()->addCommand(c); - return true; - } - default: - throw DL_ABORT_EX("Not supported handshake type."); - } - break; + if(mseHandshake_->getWantRead()) { + mseHandshake_->read(); } - case RECEIVER_WAIT_KEY: { - if(mseHandshake_->receivePublicKey()) { - if(mseHandshake_->sendPublicKey()) { + bool done = false; + while(!done) { + switch(sequence_) { + case RECEIVER_IDENTIFY_HANDSHAKE: { + MSEHandshake::HANDSHAKE_TYPE type = + mseHandshake_->identifyHandshakeType(); + switch(type) { + case MSEHandshake::HANDSHAKE_NOT_YET: + done = true; + break; + case MSEHandshake::HANDSHAKE_ENCRYPTED: + mseHandshake_->initEncryptionFacility(false); + sequence_ = RECEIVER_WAIT_KEY; + break; + case MSEHandshake::HANDSHAKE_LEGACY: { + if(getDownloadEngine()->getOption()->getAsBool(PREF_BT_REQUIRE_CRYPTO)){ + throw DL_ABORT_EX + ("The legacy BitTorrent handshake is not acceptable by the" + " preference."); + } + SharedHandle peerConnection + (new PeerConnection(getCuid(), getPeer(), getSocket())); + peerConnection->presetBuffer(mseHandshake_->getBuffer(), + mseHandshake_->getBufferLength()); + Command* c = new PeerReceiveHandshakeCommand(getCuid(), + getPeer(), + getDownloadEngine(), + getSocket(), + peerConnection); + getDownloadEngine()->addCommand(c); + return true; + } + default: + throw DL_ABORT_EX("Not supported handshake type."); + } + break; + } + case RECEIVER_WAIT_KEY: { + if(mseHandshake_->receivePublicKey()) { + mseHandshake_->sendPublicKey(); + sequence_ = RECEIVER_SEND_KEY_PENDING; + } else { + done = true; + } + break; + } + case RECEIVER_SEND_KEY_PENDING: + if(mseHandshake_->send()) { sequence_ = RECEIVER_FIND_HASH_MARKER; } else { - setWriteCheckSocket(getSocket()); - sequence_ = RECEIVER_SEND_KEY_PENDING; + done = true; } + break; + case RECEIVER_FIND_HASH_MARKER: { + if(mseHandshake_->findReceiverHashMarker()) { + sequence_ = RECEIVER_RECEIVE_PAD_C_LENGTH; + } else { + done = true; + } + break; } - break; - } - case RECEIVER_SEND_KEY_PENDING: - if(mseHandshake_->sendPublicKey()) { - disableWriteCheckSocket(); - sequence_ = RECEIVER_FIND_HASH_MARKER; + case RECEIVER_RECEIVE_PAD_C_LENGTH: { + std::vector > downloadContexts; + getDownloadEngine()->getBtRegistry()->getAllDownloadContext + (std::back_inserter(downloadContexts)); + if(mseHandshake_->receiveReceiverHashAndPadCLength(downloadContexts)) { + sequence_ = RECEIVER_RECEIVE_PAD_C; + } else { + done = true; + } + break; } - break; - case RECEIVER_FIND_HASH_MARKER: { - if(mseHandshake_->findReceiverHashMarker()) { - sequence_ = RECEIVER_RECEIVE_PAD_C_LENGTH; + case RECEIVER_RECEIVE_PAD_C: { + if(mseHandshake_->receivePad()) { + sequence_ = RECEIVER_RECEIVE_IA_LENGTH; + } else { + done = true; + } + break; } - break; - } - case RECEIVER_RECEIVE_PAD_C_LENGTH: { - std::vector > downloadContexts; - getDownloadEngine()->getBtRegistry()->getAllDownloadContext - (std::back_inserter(downloadContexts)); - if(mseHandshake_->receiveReceiverHashAndPadCLength(downloadContexts)) { - sequence_ = RECEIVER_RECEIVE_PAD_C; + case RECEIVER_RECEIVE_IA_LENGTH: { + if(mseHandshake_->receiveReceiverIALength()) { + sequence_ = RECEIVER_RECEIVE_IA; + } else { + done = true; + } + break; } - break; - } - case RECEIVER_RECEIVE_PAD_C: { - if(mseHandshake_->receivePad()) { - sequence_ = RECEIVER_RECEIVE_IA_LENGTH; + case RECEIVER_RECEIVE_IA: { + if(mseHandshake_->receiveReceiverIA()) { + mseHandshake_->sendReceiverStep2(); + sequence_ = RECEIVER_SEND_STEP2_PENDING; + } else { + done = true; + } + break; } - break; - } - case RECEIVER_RECEIVE_IA_LENGTH: { - if(mseHandshake_->receiveReceiverIALength()) { - sequence_ = RECEIVER_RECEIVE_IA; - } - break; - } - case RECEIVER_RECEIVE_IA: { - if(mseHandshake_->receiveReceiverIA()) { - if(mseHandshake_->sendReceiverStep2()) { + case RECEIVER_SEND_STEP2_PENDING: + if(mseHandshake_->send()) { createCommand(); return true; } else { - setWriteCheckSocket(getSocket()); - sequence_ = RECEIVER_SEND_STEP2_PENDING; + done = true; } + break; } - break; } - case RECEIVER_SEND_STEP2_PENDING: - if(mseHandshake_->sendReceiverStep2()) { - disableWriteCheckSocket(); - createCommand(); - return true; - } - break; + if(mseHandshake_->getWantRead()) { + setReadCheckSocket(getSocket()); + } else { + disableReadCheckSocket(); + } + if(mseHandshake_->getWantWrite()) { + setWriteCheckSocket(getSocket()); + } else { + disableWriteCheckSocket(); } getDownloadEngine()->addCommand(this); return false; @@ -188,10 +213,19 @@ void ReceiverMSEHandshakeCommand::createCommand() peerConnection->enableEncryption(mseHandshake_->getEncryptor(), mseHandshake_->getDecryptor()); } - if(mseHandshake_->getIALength() > 0) { - peerConnection->presetBuffer(mseHandshake_->getIA(), - mseHandshake_->getIALength()); + size_t buflen = mseHandshake_->getIALength()+mseHandshake_->getBufferLength(); + array_ptr buffer(new unsigned char[buflen]); + memcpy(buffer, mseHandshake_->getIA(), mseHandshake_->getIALength()); + if(mseHandshake_->getNegotiatedCryptoType() == MSEHandshake::CRYPTO_ARC4) { + mseHandshake_->getDecryptor()->decrypt(buffer+mseHandshake_->getIALength(), + mseHandshake_->getBufferLength(), + mseHandshake_->getBuffer(), + mseHandshake_->getBufferLength()); + } else { + memcpy(buffer+mseHandshake_->getIALength(), + mseHandshake_->getBuffer(), mseHandshake_->getBufferLength()); } + peerConnection->presetBuffer(buffer, buflen); // TODO add mseHandshake_->getInfoHash() to PeerReceiveHandshakeCommand // as a hint. If this info hash and one in BitTorrent Handshake does not // match, then drop connection. diff --git a/test/MSEHandshakeTest.cc b/test/MSEHandshakeTest.cc index b523b221..2a29a7c3 100644 --- a/test/MSEHandshakeTest.cc +++ b/test/MSEHandshakeTest.cc @@ -70,26 +70,57 @@ createSocketPair() void MSEHandshakeTest::doHandshake(const SharedHandle& initiator, const SharedHandle& receiver) { initiator->sendPublicKey(); - - while(!receiver->receivePublicKey()); + while(initiator->getWantWrite()) { + initiator->send(); + } + while(!receiver->receivePublicKey()) { + receiver->read(); + } receiver->sendPublicKey(); + while(receiver->getWantWrite()) { + receiver->send(); + } - while(!initiator->receivePublicKey()); + while(!initiator->receivePublicKey()) { + initiator->read(); + } initiator->initCipher(bittorrent::getInfoHash(dctx_)); initiator->sendInitiatorStep2(); + while(initiator->getWantWrite()) { + initiator->send(); + } - while(!receiver->findReceiverHashMarker()); + while(!receiver->findReceiverHashMarker()) { + receiver->read(); + } std::vector > contexts; contexts.push_back(dctx_); - while(!receiver->receiveReceiverHashAndPadCLength(contexts)); - while(!receiver->receivePad()); - while(!receiver->receiveReceiverIALength()); - while(!receiver->receiveReceiverIA()); + while(!receiver->receiveReceiverHashAndPadCLength(contexts)) { + receiver->read(); + } + while(!receiver->receivePad()) { + receiver->read(); + } + while(!receiver->receiveReceiverIALength()) { + receiver->read(); + } + while(!receiver->receiveReceiverIA()) { + receiver->read(); + } receiver->sendReceiverStep2(); + while(receiver->getWantWrite()) { + receiver->send(); + } - while(!initiator->findInitiatorVCMarker()); - while(!initiator->receiveInitiatorCryptoSelectAndPadDLength()); - while(!initiator->receivePad()); + while(!initiator->findInitiatorVCMarker()) { + initiator->read(); + } + while(!initiator->receiveInitiatorCryptoSelectAndPadDLength()) { + initiator->read(); + } + while(!initiator->receivePad()) { + initiator->read(); + } } namespace {