/* */ #include "PeerConnection.h" #include #include #include #include "message.h" #include "DlAbortEx.h" #include "LogFactory.h" #include "Logger.h" #include "BtHandshakeMessage.h" #include "Socket.h" #include "a2netcompat.h" #include "ARC4Encryptor.h" #include "ARC4Decryptor.h" #include "StringFormat.h" namespace aria2 { PeerConnection::PeerConnection(int32_t cuid, const SocketHandle& socket) :cuid(cuid), socket(socket), logger(LogFactory::getInstance()), resbuf(new unsigned char[MAX_PAYLOAD_LEN]), resbufLength(0), currentPayloadLength(0), lenbufLength(0), _socketBuffer(socket), _encryptionEnabled(false), _prevPeek(false) {} PeerConnection::~PeerConnection() { delete [] resbuf; } void PeerConnection::pushStr(const std::string& data) { if(_encryptionEnabled) { const size_t len = data.size(); unsigned char* chunk = new unsigned char[len]; try { _encryptor->encrypt (chunk, len, reinterpret_cast(data.data()), len); } catch(RecoverableException& e) { delete [] chunk; throw; } _socketBuffer.pushBytes(chunk, len); } else { _socketBuffer.pushStr(data); } } void PeerConnection::pushBytes(unsigned char* data, size_t len) { if(_encryptionEnabled) { unsigned char* chunk = new unsigned char[len]; try { _encryptor->encrypt(chunk, len, data, len); } catch(RecoverableException& e) { delete [] data; delete [] chunk; throw; } delete [] data; _socketBuffer.pushBytes(chunk, len); } else { _socketBuffer.pushBytes(data, len); } } bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) { if(resbufLength == 0 && 4 > lenbufLength) { if(!socket->isReadable(0)) { return false; } // read payload size, 32bit unsigned integer size_t remaining = 4-lenbufLength; size_t temp = remaining; readData(lenbuf+lenbufLength, remaining, _encryptionEnabled); if(remaining == 0) { if(socket->wantRead() || socket->wantWrite()) { return false; } // we got EOF if(logger->debug()) { logger->debug("CUID#%d - In PeerConnection::receiveMessage()," " remain=%lu", cuid, static_cast(temp)); } throw DL_ABORT_EX(EX_EOF_FROM_PEER); } lenbufLength += remaining; if(4 > lenbufLength) { // still 4-lenbufLength bytes to go return false; } uint32_t payloadLength; memcpy(&payloadLength, lenbuf, sizeof(payloadLength)); payloadLength = ntohl(payloadLength); if(payloadLength > MAX_PAYLOAD_LEN) { throw DL_ABORT_EX(StringFormat(EX_TOO_LONG_PAYLOAD, payloadLength).str()); } currentPayloadLength = payloadLength; } if(!socket->isReadable(0)) { return false; } // we have currentPayloadLen-resbufLen bytes to read size_t remaining = currentPayloadLength-resbufLength; size_t temp = remaining; if(remaining > 0) { readData(resbuf+resbufLength, remaining, _encryptionEnabled); if(remaining == 0) { if(socket->wantRead() || socket->wantWrite()) { return false; } // we got EOF if(logger->debug()) { logger->debug("CUID#%d - In PeerConnection::receiveMessage()," " payloadlen=%lu, remaining=%lu", cuid, static_cast(currentPayloadLength), static_cast(temp)); } throw DL_ABORT_EX(EX_EOF_FROM_PEER); } resbufLength += remaining; if(currentPayloadLength > resbufLength) { return false; } } // we got whole payload. resbufLength = 0; lenbufLength = 0; if(data) { memcpy(data, resbuf, currentPayloadLength); } dataLength = currentPayloadLength; return true; } bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength, bool peek) { assert(BtHandshakeMessage::MESSAGE_LENGTH >= resbufLength); bool retval = true; if(_prevPeek && !peek && resbufLength) { // We have data in previous peek. // There is a chance that socket is readable because of EOF, for example, // official bttrack shutdowns socket after sending first 48 bytes of // handshake in its NAT checking. // So if there are data in resbuf, return it without checking socket // status. _prevPeek = false; retval = BtHandshakeMessage::MESSAGE_LENGTH <= resbufLength; } else { _prevPeek = peek; size_t remaining = BtHandshakeMessage::MESSAGE_LENGTH-resbufLength; if(remaining > 0 && !socket->isReadable(0)) { dataLength = 0; return false; } if(remaining > 0) { size_t temp = remaining; readData(resbuf+resbufLength, remaining, _encryptionEnabled); if(remaining == 0) { if(socket->wantRead() || socket->wantWrite()) { return false; } // we got EOF if(logger->debug()) { logger->debug ("CUID#%d - In PeerConnection::receiveHandshake(), remain=%lu", cuid, static_cast(temp)); } throw DL_ABORT_EX(EX_EOF_FROM_PEER); } resbufLength += remaining; if(BtHandshakeMessage::MESSAGE_LENGTH > resbufLength) { retval = false; } } } size_t writeLength = std::min(resbufLength, dataLength); memcpy(data, resbuf, writeLength); dataLength = writeLength; if(retval && !peek) { resbufLength = 0; } return retval; } void PeerConnection::readData(unsigned char* data, size_t& length, bool encryption) { if(encryption) { unsigned char temp[MAX_PAYLOAD_LEN]; assert(MAX_PAYLOAD_LEN >= length); socket->readData(temp, length); _decryptor->decrypt(data, length, temp, length); } else { socket->readData(data, length); } } void PeerConnection::enableEncryption(const SharedHandle& encryptor, const SharedHandle& decryptor) { _encryptor = encryptor; _decryptor = decryptor; _encryptionEnabled = true; } void PeerConnection::presetBuffer(const unsigned char* data, size_t length) { size_t nwrite = std::min((size_t)MAX_PAYLOAD_LEN, length); memcpy(resbuf, data, nwrite); resbufLength = length; } bool PeerConnection::sendBufferIsEmpty() const { return _socketBuffer.sendBufferIsEmpty(); } ssize_t PeerConnection::sendPendingData() { ssize_t writtenLength = _socketBuffer.send(); if(logger->debug()) { logger->debug("sent %d byte(s).", writtenLength); } return writtenLength; } } // namespace aria2