/* */ #include "PeerConnection.h" #include "message.h" #include "DlAbortEx.h" #include "LogFactory.h" #include "Logger.h" #include "BtHandshakeMessage.h" #include "Socket.h" #include "a2netcompat.h" #include "ARC4Encryptor.h" #include "ARC4Decryptor.h" #include "StringFormat.h" #include #include #include namespace aria2 { PeerConnection::PeerConnection(int32_t cuid, const SocketHandle& socket, const Option* op) :cuid(cuid), socket(socket), option(op), logger(LogFactory::getInstance()), resbufLength(0), currentPayloadLength(0), lenbufLength(0), _encryptionEnabled(false) {} PeerConnection::~PeerConnection() {} ssize_t PeerConnection::sendMessage(const unsigned char* data, size_t dataLength) { if(socket->isWritable(0)) { // TODO fix this sendData(data, dataLength, _encryptionEnabled); return dataLength; } else { return 0; } } bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) { if(resbufLength == 0 && 4 > lenbufLength) { if(!socket->isReadable(0)) { return false; } // read payload size, 32bit unsigned integer size_t remaining = 4-lenbufLength; size_t temp = remaining; readData(lenbuf+lenbufLength, remaining, _encryptionEnabled); if(remaining == 0) { // we got EOF logger->debug("CUID#%d - In PeerConnection::receiveMessage(), remain=%zu", cuid, temp); throw DlAbortEx(EX_EOF_FROM_PEER); } lenbufLength += remaining; if(4 > lenbufLength) { // still 4-lenbufLength bytes to go return false; } uint32_t payloadLength = ntohl(*(reinterpret_cast(lenbuf))); if(payloadLength > MAX_PAYLOAD_LEN) { throw DlAbortEx(StringFormat(EX_TOO_LONG_PAYLOAD, payloadLength).str()); } currentPayloadLength = payloadLength; } if(!socket->isReadable(0)) { return false; } // we have currentPayloadLen-resbufLen bytes to read size_t remaining = currentPayloadLength-resbufLength; size_t temp = remaining; if(remaining > 0) { readData(resbuf+resbufLength, remaining, _encryptionEnabled); if(remaining == 0) { // we got EOF logger->debug("CUID#%d - In PeerConnection::receiveMessage(), payloadlen=%zu, remaining=%zu", cuid, currentPayloadLength, temp); throw DlAbortEx(EX_EOF_FROM_PEER); } resbufLength += remaining; if(currentPayloadLength > resbufLength) { return false; } } // we got whole payload. resbufLength = 0; lenbufLength = 0; memcpy(data, resbuf, currentPayloadLength); dataLength = currentPayloadLength; return true; } bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength, bool peek) { size_t remaining = BtHandshakeMessage::MESSAGE_LENGTH-resbufLength; if(remaining > 0 && !socket->isReadable(0)) { dataLength = 0; return false; } bool retval = true; if(remaining > 0) { size_t temp = remaining; readData(resbuf+resbufLength, remaining, _encryptionEnabled); if(remaining == 0) { // we got EOF logger->debug("CUID#%d - In PeerConnection::receiveHandshake(), remain=%zu", cuid, temp); throw DlAbortEx(EX_EOF_FROM_PEER); } resbufLength += remaining; if(BtHandshakeMessage::MESSAGE_LENGTH > resbufLength) { retval = false; } } size_t writeLength = std::min(resbufLength, dataLength); memcpy(data, resbuf, writeLength); dataLength = writeLength; if(retval && !peek) { resbufLength = 0; } return retval; } void PeerConnection::readData(unsigned char* data, size_t& length, bool encryption) { if(encryption) { unsigned char temp[MAX_PAYLOAD_LEN]; assert(MAX_PAYLOAD_LEN >= length); socket->readData(temp, length); _decryptor->decrypt(data, length, temp, length); } else { socket->readData(data, length); } } void PeerConnection::sendData(const unsigned char* data, size_t length, bool encryption) { if(encryption) { unsigned char temp[4096]; const unsigned char* dptr = data; size_t r = length; while(r > 0) { size_t s = std::min(r, sizeof(temp)); _encryptor->encrypt(temp, s, dptr, s); socket->writeData(temp, s); dptr += s; r -= s; } } else { socket->writeData(data, length); } } void PeerConnection::enableEncryption(const SharedHandle& encryptor, const SharedHandle& decryptor) { _encryptor = encryptor; _decryptor = decryptor; _encryptionEnabled = true; } void PeerConnection::presetBuffer(const unsigned char* data, size_t length) { size_t nwrite = std::min((size_t)MAX_PAYLOAD_LEN, length); memcpy(resbuf, data, nwrite); resbufLength = length; } } // namespace aria2