/* */ #include "PeerConnection.h" #include #include #include #include "message.h" #include "DlAbortEx.h" #include "LogFactory.h" #include "Logger.h" #include "BtHandshakeMessage.h" #include "Socket.h" #include "a2netcompat.h" #include "ARC4Encryptor.h" #include "ARC4Decryptor.h" #include "StringFormat.h" #include "util.h" namespace aria2 { PeerConnection::PeerConnection(cuid_t cuid, const SocketHandle& socket) :cuid_(cuid), socket_(socket), logger_(LogFactory::getInstance()), resbuf_(new unsigned char[MAX_PAYLOAD_LEN]), resbufLength_(0), currentPayloadLength_(0), lenbufLength_(0), socketBuffer_(socket), encryptionEnabled_(false), prevPeek_(false) {} PeerConnection::~PeerConnection() { delete [] resbuf_; } void PeerConnection::pushStr(const std::string& data) { if(encryptionEnabled_) { const size_t len = data.size(); unsigned char* chunk = new unsigned char[len]; try { encryptor_->encrypt (chunk, len, reinterpret_cast(data.data()), len); } catch(RecoverableException& e) { delete [] chunk; throw; } socketBuffer_.pushBytes(chunk, len); } else { socketBuffer_.pushStr(data); } } void PeerConnection::pushBytes(unsigned char* data, size_t len) { if(encryptionEnabled_) { unsigned char* chunk = new unsigned char[len]; try { encryptor_->encrypt(chunk, len, data, len); } catch(RecoverableException& e) { delete [] data; delete [] chunk; throw; } delete [] data; socketBuffer_.pushBytes(chunk, len); } else { socketBuffer_.pushBytes(data, len); } } bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) { if(resbufLength_ == 0 && 4 > lenbufLength_) { if(!socket_->isReadable(0)) { return false; } // read payload size, 32bit unsigned integer size_t remaining = 4-lenbufLength_; size_t temp = remaining; readData(lenbuf_+lenbufLength_, remaining, encryptionEnabled_); if(remaining == 0) { if(socket_->wantRead() || socket_->wantWrite()) { return false; } // we got EOF if(logger_->debug()) { logger_->debug("CUID#%s - In PeerConnection::receiveMessage()," " remain=%lu", util::itos(cuid_).c_str(), static_cast(temp)); } throw DL_ABORT_EX(EX_EOF_FROM_PEER); } lenbufLength_ += remaining; if(4 > lenbufLength_) { // still 4-lenbufLength_ bytes to go return false; } uint32_t payloadLength; memcpy(&payloadLength, lenbuf_, sizeof(payloadLength)); payloadLength = ntohl(payloadLength); if(payloadLength > MAX_PAYLOAD_LEN) { throw DL_ABORT_EX(StringFormat(EX_TOO_LONG_PAYLOAD, payloadLength).str()); } currentPayloadLength_ = payloadLength; } if(!socket_->isReadable(0)) { return false; } // we have currentPayloadLen-resbufLen bytes to read size_t remaining = currentPayloadLength_-resbufLength_; size_t temp = remaining; if(remaining > 0) { readData(resbuf_+resbufLength_, remaining, encryptionEnabled_); if(remaining == 0) { if(socket_->wantRead() || socket_->wantWrite()) { return false; } // we got EOF if(logger_->debug()) { logger_->debug("CUID#%s - In PeerConnection::receiveMessage()," " payloadlen=%lu, remaining=%lu", util::itos(cuid_).c_str(), static_cast(currentPayloadLength_), static_cast(temp)); } throw DL_ABORT_EX(EX_EOF_FROM_PEER); } resbufLength_ += remaining; if(currentPayloadLength_ > resbufLength_) { return false; } } // we got whole payload. resbufLength_ = 0; lenbufLength_ = 0; if(data) { memcpy(data, resbuf_, currentPayloadLength_); } dataLength = currentPayloadLength_; return true; } bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength, bool peek) { assert(BtHandshakeMessage::MESSAGE_LENGTH >= resbufLength_); bool retval = true; if(prevPeek_ && !peek && resbufLength_) { // We have data in previous peek. // There is a chance that socket is readable because of EOF, for example, // official bttrack shutdowns socket after sending first 48 bytes of // handshake in its NAT checking. // So if there are data in resbuf_, return it without checking socket // status. prevPeek_ = false; retval = BtHandshakeMessage::MESSAGE_LENGTH <= resbufLength_; } else { prevPeek_ = peek; size_t remaining = BtHandshakeMessage::MESSAGE_LENGTH-resbufLength_; if(remaining > 0 && !socket_->isReadable(0)) { dataLength = 0; return false; } if(remaining > 0) { size_t temp = remaining; readData(resbuf_+resbufLength_, remaining, encryptionEnabled_); if(remaining == 0) { if(socket_->wantRead() || socket_->wantWrite()) { return false; } // we got EOF if(logger_->debug()) { logger_->debug ("CUID#%s - In PeerConnection::receiveHandshake(), remain=%lu", util::itos(cuid_).c_str(), static_cast(temp)); } throw DL_ABORT_EX(EX_EOF_FROM_PEER); } resbufLength_ += remaining; if(BtHandshakeMessage::MESSAGE_LENGTH > resbufLength_) { retval = false; } } } size_t writeLength = std::min(resbufLength_, dataLength); memcpy(data, resbuf_, writeLength); dataLength = writeLength; if(retval && !peek) { resbufLength_ = 0; } return retval; } void PeerConnection::readData (unsigned char* data, size_t& length, bool encryption) { if(encryption) { unsigned char temp[MAX_PAYLOAD_LEN]; assert(MAX_PAYLOAD_LEN >= length); socket_->readData(temp, length); decryptor_->decrypt(data, length, temp, length); } else { socket_->readData(data, length); } } void PeerConnection::enableEncryption (const SharedHandle& encryptor, const SharedHandle& decryptor) { encryptor_ = encryptor; decryptor_ = decryptor; encryptionEnabled_ = true; } void PeerConnection::presetBuffer(const unsigned char* data, size_t length) { size_t nwrite = std::min((size_t)MAX_PAYLOAD_LEN, length); memcpy(resbuf_, data, nwrite); resbufLength_ = length; } bool PeerConnection::sendBufferIsEmpty() const { return socketBuffer_.sendBufferIsEmpty(); } ssize_t PeerConnection::sendPendingData() { ssize_t writtenLength = socketBuffer_.send(); if(logger_->debug()) { logger_->debug("sent %d byte(s).", writtenLength); } return writtenLength; } } // namespace aria2