/* */ #include "PeerConnection.h" #include "message.h" #include "DlAbortEx.h" #include "LogFactory.h" #include "Logger.h" #include "BtHandshakeMessage.h" #include "Socket.h" #include "a2netcompat.h" #include "ARC4Encryptor.h" #include "ARC4Decryptor.h" #include "StringFormat.h" #include #include #include namespace aria2 { PeerConnection::PeerConnection(int32_t cuid, const SocketHandle& socket, const Option* op) :cuid(cuid), socket(socket), option(op), logger(LogFactory::getInstance()), resbufLength(0), currentPayloadLength(0), lenbufLength(0), _socketBuffer(socket), _encryptionEnabled(false), _prevPeek(false) {} PeerConnection::~PeerConnection() {} ssize_t PeerConnection::sendMessage(const unsigned char* data, size_t dataLength) { ssize_t writtenLength = sendData(data, dataLength, _encryptionEnabled); logger->debug("sent %d byte(s).", writtenLength); return writtenLength; } bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) { if(resbufLength == 0 && 4 > lenbufLength) { if(!socket->isReadable(0)) { return false; } // read payload size, 32bit unsigned integer size_t remaining = 4-lenbufLength; size_t temp = remaining; readData(lenbuf+lenbufLength, remaining, _encryptionEnabled); if(remaining == 0) { if(socket->wantRead() || socket->wantWrite()) { return false; } // we got EOF logger->debug("CUID#%d - In PeerConnection::receiveMessage(), remain=%lu", cuid, static_cast(temp)); throw DlAbortEx(EX_EOF_FROM_PEER); } lenbufLength += remaining; if(4 > lenbufLength) { // still 4-lenbufLength bytes to go return false; } uint32_t payloadLength = ntohl(*(reinterpret_cast(lenbuf))); if(payloadLength > MAX_PAYLOAD_LEN) { throw DlAbortEx(StringFormat(EX_TOO_LONG_PAYLOAD, payloadLength).str()); } currentPayloadLength = payloadLength; } if(!socket->isReadable(0)) { return false; } // we have currentPayloadLen-resbufLen bytes to read size_t remaining = currentPayloadLength-resbufLength; size_t temp = remaining; if(remaining > 0) { readData(resbuf+resbufLength, remaining, _encryptionEnabled); if(remaining == 0) { if(socket->wantRead() || socket->wantWrite()) { return false; } // we got EOF logger->debug("CUID#%d - In PeerConnection::receiveMessage()," " payloadlen=%lu, remaining=%lu", cuid, static_cast(currentPayloadLength), static_cast(temp)); throw DlAbortEx(EX_EOF_FROM_PEER); } resbufLength += remaining; if(currentPayloadLength > resbufLength) { return false; } } // we got whole payload. resbufLength = 0; lenbufLength = 0; memcpy(data, resbuf, currentPayloadLength); dataLength = currentPayloadLength; return true; } bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength, bool peek) { assert(BtHandshakeMessage::MESSAGE_LENGTH >= resbufLength); bool retval = true; if(_prevPeek && !peek && resbufLength) { // We have data in previous peek. // There is a chance that socket is readable because of EOF, for example, // official bttrack shutdowns socket after sending first 48 bytes of // handshake in its NAT checking. // So if there are data in resbuf, return it without checking socket // status. _prevPeek = false; retval = BtHandshakeMessage::MESSAGE_LENGTH <= resbufLength; } else { _prevPeek = peek; size_t remaining = BtHandshakeMessage::MESSAGE_LENGTH-resbufLength; if(remaining > 0 && !socket->isReadable(0)) { dataLength = 0; return false; } if(remaining > 0) { size_t temp = remaining; readData(resbuf+resbufLength, remaining, _encryptionEnabled); if(remaining == 0) { if(socket->wantRead() || socket->wantWrite()) { return false; } // we got EOF logger->debug ("CUID#%d - In PeerConnection::receiveHandshake(), remain=%lu", cuid, static_cast(temp)); throw DlAbortEx(EX_EOF_FROM_PEER); } resbufLength += remaining; if(BtHandshakeMessage::MESSAGE_LENGTH > resbufLength) { retval = false; } } } size_t writeLength = std::min(resbufLength, dataLength); memcpy(data, resbuf, writeLength); dataLength = writeLength; if(retval && !peek) { resbufLength = 0; } return retval; } void PeerConnection::readData(unsigned char* data, size_t& length, bool encryption) { if(encryption) { unsigned char temp[MAX_PAYLOAD_LEN]; assert(MAX_PAYLOAD_LEN >= length); socket->readData(temp, length); _decryptor->decrypt(data, length, temp, length); } else { socket->readData(data, length); } } ssize_t PeerConnection::sendData(const unsigned char* data, size_t length, bool encryption) { if(encryption) { unsigned char temp[4096]; const unsigned char* dptr = data; size_t r = length; while(r > 0) { size_t s = std::min(r, sizeof(temp)); _encryptor->encrypt(temp, s, dptr, s); _socketBuffer.feedSendBuffer(std::string(&temp[0], &temp[s])); dptr += s; r -= s; } } else { _socketBuffer.feedSendBuffer(std::string(&data[0], &data[length])); } return _socketBuffer.send(); } void PeerConnection::enableEncryption(const SharedHandle& encryptor, const SharedHandle& decryptor) { _encryptor = encryptor; _decryptor = decryptor; _encryptionEnabled = true; } void PeerConnection::presetBuffer(const unsigned char* data, size_t length) { size_t nwrite = std::min((size_t)MAX_PAYLOAD_LEN, length); memcpy(resbuf, data, nwrite); resbufLength = length; } bool PeerConnection::sendBufferIsEmpty() const { return _socketBuffer.sendBufferIsEmpty(); } ssize_t PeerConnection::sendPendingData() { ssize_t writtenLength = _socketBuffer.send(); logger->debug("sent %d byte(s).", writtenLength); return writtenLength; } } // namespace aria2