diff --git a/src/BtPieceMessage.cc b/src/BtPieceMessage.cc index 55eba7cd..338bebe3 100644 --- a/src/BtPieceMessage.cc +++ b/src/BtPieceMessage.cc @@ -73,13 +73,10 @@ BtPieceMessage::BtPieceMessage } BtPieceMessage::~BtPieceMessage() -{ - delete [] data_; -} +{} -void BtPieceMessage::setRawMessage(unsigned char* data) +void BtPieceMessage::setMsgPayload(const unsigned char* data) { - delete [] data_; data_ = data; } diff --git a/src/BtPieceMessage.h b/src/BtPieceMessage.h index 5b3c320a..f9756f9b 100644 --- a/src/BtPieceMessage.h +++ b/src/BtPieceMessage.h @@ -51,7 +51,7 @@ private: size_t index_; int32_t begin_; int32_t blockLength_; - unsigned char* data_; + const unsigned char* data_; SharedHandle downloadContext_; SharedHandle peerStorage_; @@ -87,10 +87,9 @@ public: int32_t getBlockLength() const { return blockLength_; } - // Stores raw message data. After this function call, this object - // has ownership of data. Caller must not be free or alter data. - // Member block is pointed to block starting position in data. - void setRawMessage(unsigned char* data); + // Sets message payload data. Caller must not change or free data + // before doReceivedAction(). + void setMsgPayload(const unsigned char* data); void setBlockLength(int32_t blockLength) { blockLength_ = blockLength; } diff --git a/src/DefaultBtMessageReceiver.cc b/src/DefaultBtMessageReceiver.cc index 8d746d92..9832d5ac 100644 --- a/src/DefaultBtMessageReceiver.cc +++ b/src/DefaultBtMessageReceiver.cc @@ -121,12 +121,13 @@ BtMessageHandle DefaultBtMessageReceiver::receiveMessage() { return SharedHandle(); } BtMessageHandle msg = - messageFactory_->createBtMessage(peerConnection_->getBuffer(), dataLength); + messageFactory_->createBtMessage(peerConnection_->getMsgPayloadBuffer(), + dataLength); msg->validate(); if(msg->getId() == BtPieceMessage::ID) { SharedHandle piecemsg = static_pointer_cast(msg); - piecemsg->setRawMessage(peerConnection_->detachBuffer()); + piecemsg->setMsgPayload(peerConnection_->getMsgPayloadBuffer()); } return msg; } diff --git a/src/PeerConnection.cc b/src/PeerConnection.cc index da715929..4dd9a3fc 100644 --- a/src/PeerConnection.cc +++ b/src/PeerConnection.cc @@ -52,16 +52,29 @@ namespace aria2 { +namespace { +enum { + // Before reading first byte of message length + BT_MSG_PREV_READ_LENGTH, + // Reading 4 bytes message length + BT_MSG_READ_LENGTH, + // Reading message payload following message length + BT_MSG_READ_PAYLOAD +}; +} // namespace + PeerConnection::PeerConnection (cuid_t cuid, const SharedHandle& peer, const SocketHandle& socket) : cuid_(cuid), peer_(peer), socket_(socket), - maxPayloadLength_(MAX_PAYLOAD_LEN), - resbuf_(new unsigned char[maxPayloadLength_]), + msgState_(BT_MSG_PREV_READ_LENGTH), + bufferCapacity_(MAX_BUFFER_CAPACITY), + resbuf_(new unsigned char[bufferCapacity_]), resbufLength_(0), currentPayloadLength_(0), - lenbufLength_(0), + resbufOffset_(0), + msgOffset_(0), socketBuffer_(socket), encryptionEnabled_(false), prevPeek_(false) @@ -80,71 +93,98 @@ void PeerConnection::pushBytes(unsigned char* data, size_t len) socketBuffer_.pushBytes(data, len); } -bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) { - if(resbufLength_ == 0 && 4 > lenbufLength_) { - // read payload size, 32bit unsigned integer - size_t remaining = 4-lenbufLength_; - size_t temp = remaining; - readData(lenbuf_+lenbufLength_, remaining, encryptionEnabled_); - if(remaining == 0) { - if(socket_->wantRead() || socket_->wantWrite()) { - return false; +bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) +{ + while(1) { + bool done = false; + size_t i; + for(i = resbufOffset_; i < resbufLength_ && !done; ++i) { + unsigned char c = resbuf_[i]; + switch(msgState_) { + case(BT_MSG_PREV_READ_LENGTH): + msgOffset_ = i; + currentPayloadLength_ = 0; + msgState_ = BT_MSG_READ_LENGTH; + // Fall through + case(BT_MSG_READ_LENGTH): + currentPayloadLength_ <<= 8; + currentPayloadLength_ += c; + // The message length is uint32_t + if(i - msgOffset_ == 3) { + if(currentPayloadLength_ + 4 > bufferCapacity_) { + throw DL_ABORT_EX(fmt(EX_TOO_LONG_PAYLOAD, currentPayloadLength_)); + } + if(currentPayloadLength_ == 0) { + // Length == 0 means keep-alive message. + done = true; + msgState_ = BT_MSG_PREV_READ_LENGTH; + } else { + msgState_ = BT_MSG_READ_PAYLOAD; + } + } + break; + case(BT_MSG_READ_PAYLOAD): + // We chosen the bufferCapacity_ so that whole message, + // including 4 bytes length and payload, in it. So here we + // just make sure that it happens. + if(resbufLength_ - msgOffset_ >= 4 + currentPayloadLength_) { + i = msgOffset_ + 4 + currentPayloadLength_ - 1; + done = true; + msgState_ = BT_MSG_PREV_READ_LENGTH; + } else { + // We need another read. + i = resbufLength_-1; + } + break; } - // we got EOF - A2_LOG_DEBUG(fmt("CUID#%lld - In PeerConnection::receiveMessage()," - " remain=%lu", - cuid_, - static_cast(temp))); - peer_->setDisconnectedGracefully(true); - throw DL_ABORT_EX(EX_EOF_FROM_PEER); } - lenbufLength_ += remaining; - if(4 > lenbufLength_) { - // still 4-lenbufLength_ bytes to go - return false; - } - uint32_t payloadLength; - memcpy(&payloadLength, lenbuf_, sizeof(payloadLength)); - payloadLength = ntohl(payloadLength); - if(payloadLength > maxPayloadLength_) { - throw DL_ABORT_EX(fmt(EX_TOO_LONG_PAYLOAD, payloadLength)); - } - currentPayloadLength_ = payloadLength; - } - if(!socket_->isReadable(0)) { - return false; - } - // we have currentPayloadLen-resbufLen bytes to read - size_t remaining = currentPayloadLength_-resbufLength_; - size_t temp = remaining; - if(remaining > 0) { - readData(resbuf_+resbufLength_, remaining, encryptionEnabled_); - if(remaining == 0) { - if(socket_->wantRead() || socket_->wantWrite()) { - return false; + resbufOffset_ = i; + if(done) { + if(data) { + memcpy(data, resbuf_ + msgOffset_ + 4, currentPayloadLength_); + } + dataLength = currentPayloadLength_; + return true; + } else { + assert(resbufOffset_ == resbufLength_); + if(resbufLength_ != 0) { + if(msgOffset_ == 0 && resbufLength_ == currentPayloadLength_ + 4) { + // All bytes in buffer have been processed, so clear it + // away. + resbufLength_ = 0; + resbufOffset_ = 0; + msgOffset_ = 0; + } else { + // Shift buffer so that resbuf_[msgOffset_] moves to + // rebuf_[0]. + memmove(resbuf_, resbuf_ + msgOffset_, resbufLength_ - msgOffset_); + resbufLength_ -= msgOffset_; + resbufOffset_ = resbufLength_; + msgOffset_ = 0; + } + } + size_t nread; + // To reduce the amount of copy involved in buffer shift, large + // payload will be read exactly. + if(currentPayloadLength_ > 4096) { + nread = currentPayloadLength_ + 4 - resbufLength_; + } else { + nread = bufferCapacity_ - resbufLength_; + } + readData(resbuf_+resbufLength_, nread, encryptionEnabled_); + if(nread == 0) { + if(socket_->wantRead() || socket_->wantWrite()) { + break; + } else { + peer_->setDisconnectedGracefully(true); + throw DL_ABORT_EX(EX_EOF_FROM_PEER); + } + } else { + resbufLength_ += nread; } - // we got EOF - A2_LOG_DEBUG(fmt("CUID#%lld - In PeerConnection::receiveMessage()," - " payloadlen=%lu, remaining=%lu", - cuid_, - static_cast(currentPayloadLength_), - static_cast(temp))); - peer_->setDisconnectedGracefully(true); - throw DL_ABORT_EX(EX_EOF_FROM_PEER); - } - resbufLength_ += remaining; - if(currentPayloadLength_ > resbufLength_) { - return false; } } - // we got whole payload. - resbufLength_ = 0; - lenbufLength_ = 0; - if(data) { - memcpy(data, resbuf_, currentPayloadLength_); - } - dataLength = currentPayloadLength_; - return true; + return false; } bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength, @@ -202,7 +242,7 @@ void PeerConnection::enableEncryption void PeerConnection::presetBuffer(const unsigned char* data, size_t length) { - size_t nwrite = std::min(maxPayloadLength_, length); + size_t nwrite = std::min(bufferCapacity_, length); memcpy(resbuf_, data, nwrite); resbufLength_ = length; } @@ -219,18 +259,16 @@ ssize_t PeerConnection::sendPendingData() return writtenLength; } -unsigned char* PeerConnection::detachBuffer() +const unsigned char* PeerConnection::getMsgPayloadBuffer() const { - unsigned char* detachbuf = resbuf_; - resbuf_ = new unsigned char[maxPayloadLength_]; - return detachbuf; + return resbuf_ + msgOffset_ + 4; } void PeerConnection::reserveBuffer(size_t minSize) { - if(maxPayloadLength_ < minSize) { - maxPayloadLength_ = minSize; - unsigned char *buf = new unsigned char[maxPayloadLength_]; + if(bufferCapacity_ < minSize) { + bufferCapacity_ = minSize; + unsigned char *buf = new unsigned char[bufferCapacity_]; memcpy(buf, resbuf_, resbufLength_); delete [] resbuf_; resbuf_ = buf; diff --git a/src/PeerConnection.h b/src/PeerConnection.h index 02383023..b4889e0e 100644 --- a/src/PeerConnection.h +++ b/src/PeerConnection.h @@ -49,9 +49,10 @@ class Peer; class SocketCore; class ARC4Encryptor; -// The maximum length of payload. Messages beyond that length are +// The maximum length of buffer. If the message length (including 4 +// bytes length and payload length) is larger than this value, it is // dropped. -#define MAX_PAYLOAD_LEN (16*1024+128) +#define MAX_BUFFER_CAPACITY (16*1024+128) class PeerConnection { private: @@ -59,13 +60,19 @@ private: SharedHandle peer_; SharedHandle socket_; - // Maximum payload length - size_t maxPayloadLength_; + int msgState_; + // The capacity of the buffer resbuf_ + size_t bufferCapacity_; + // The internal buffer of incoming handshakes and messages unsigned char* resbuf_; + // The number of bytes written in resbuf_ size_t resbufLength_; - size_t currentPayloadLength_; - unsigned char lenbuf_[4]; - size_t lenbufLength_; + // The length of message (not handshake) currently receiving + uint32_t currentPayloadLength_; + // The number of bytes processed in resbuf_ + size_t resbufOffset_; + // The offset in resbuf_ where the 4 bytes message length begins + size_t msgOffset_; SocketBuffer socketBuffer_; @@ -123,15 +130,17 @@ public: return resbufLength_; } - unsigned char* detachBuffer(); + // Returns the pointer to the message in wire format. This method + // must be called after receiveMessage() returned true. + const unsigned char* getMsgPayloadBuffer() const; // Reserves buffer at least minSize. Reallocate memory if current // buffer length < minSize void reserveBuffer(size_t minSize); - size_t getMaxPayloadLength() + size_t getBufferCapacity() { - return maxPayloadLength_; + return bufferCapacity_; } }; diff --git a/src/SocketRecvBuffer.h b/src/SocketRecvBuffer.h index 8f8e50f5..693dd2c0 100644 --- a/src/SocketRecvBuffer.h +++ b/src/SocketRecvBuffer.h @@ -79,6 +79,8 @@ public: { return bufLen_ == 0; } + + void pushBuffer(const unsigned char* data, size_t len); private: SharedHandle socket_; size_t capacity_; diff --git a/test/PeerConnectionTest.cc b/test/PeerConnectionTest.cc index d435f3b4..c2a4f5a2 100644 --- a/test/PeerConnectionTest.cc +++ b/test/PeerConnectionTest.cc @@ -23,13 +23,13 @@ CPPUNIT_TEST_SUITE_REGISTRATION(PeerConnectionTest); void PeerConnectionTest::testReserveBuffer() { PeerConnection con(1, SharedHandle(), SharedHandle()); con.presetBuffer((unsigned char*)"foo", 3); - CPPUNIT_ASSERT_EQUAL((size_t)MAX_PAYLOAD_LEN, con.getMaxPayloadLength()); + CPPUNIT_ASSERT_EQUAL((size_t)MAX_BUFFER_CAPACITY, con.getBufferCapacity()); CPPUNIT_ASSERT_EQUAL((size_t)3, con.getBufferLength()); size_t newLength = 32*1024; con.reserveBuffer(newLength); - CPPUNIT_ASSERT_EQUAL(newLength, con.getMaxPayloadLength()); + CPPUNIT_ASSERT_EQUAL(newLength, con.getBufferCapacity()); CPPUNIT_ASSERT_EQUAL((size_t)3, con.getBufferLength()); CPPUNIT_ASSERT(memcmp("foo", con.getBuffer(), 3) == 0); }