diff --git a/src/BtBitfieldMessage.cc b/src/BtBitfieldMessage.cc index 6992d281..6ba1d97e 100644 --- a/src/BtBitfieldMessage.cc +++ b/src/BtBitfieldMessage.cc @@ -104,8 +104,8 @@ unsigned char* BtBitfieldMessage::createMessage() { /** * len --- 1+bitfieldLength, 4bytes * id --- 5, 1byte - * bitfield --- bitfield, len bytes - * total: 5+len bytes + * bitfield --- bitfield, bitfieldLength bytes + * total: 5+bitfieldLength bytes */ const size_t msgLength = 5+bitfieldLength_; unsigned char* msg = new unsigned char[msgLength]; diff --git a/src/PeerConnection.cc b/src/PeerConnection.cc index 2af5383b..da715929 100644 --- a/src/PeerConnection.cc +++ b/src/PeerConnection.cc @@ -57,7 +57,8 @@ PeerConnection::PeerConnection : cuid_(cuid), peer_(peer), socket_(socket), - resbuf_(new unsigned char[MAX_PAYLOAD_LEN]), + maxPayloadLength_(MAX_PAYLOAD_LEN), + resbuf_(new unsigned char[maxPayloadLength_]), resbufLength_(0), currentPayloadLength_(0), lenbufLength_(0), @@ -105,7 +106,7 @@ bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) { uint32_t payloadLength; memcpy(&payloadLength, lenbuf_, sizeof(payloadLength)); payloadLength = ntohl(payloadLength); - if(payloadLength > MAX_PAYLOAD_LEN) { + if(payloadLength > maxPayloadLength_) { throw DL_ABORT_EX(fmt(EX_TOO_LONG_PAYLOAD, payloadLength)); } currentPayloadLength_ = payloadLength; @@ -201,7 +202,7 @@ void PeerConnection::enableEncryption void PeerConnection::presetBuffer(const unsigned char* data, size_t length) { - size_t nwrite = std::min((size_t)MAX_PAYLOAD_LEN, length); + size_t nwrite = std::min(maxPayloadLength_, length); memcpy(resbuf_, data, nwrite); resbufLength_ = length; } @@ -221,8 +222,19 @@ ssize_t PeerConnection::sendPendingData() unsigned char* PeerConnection::detachBuffer() { unsigned char* detachbuf = resbuf_; - resbuf_ = new unsigned char[MAX_PAYLOAD_LEN]; + resbuf_ = new unsigned char[maxPayloadLength_]; return detachbuf; } +void PeerConnection::reserveBuffer(size_t minSize) +{ + if(maxPayloadLength_ < minSize) { + maxPayloadLength_ = minSize; + unsigned char *buf = new unsigned char[maxPayloadLength_]; + memcpy(buf, resbuf_, resbufLength_); + delete [] resbuf_; + resbuf_ = buf; + } +} + } // namespace aria2 diff --git a/src/PeerConnection.h b/src/PeerConnection.h index 3da56eb4..02383023 100644 --- a/src/PeerConnection.h +++ b/src/PeerConnection.h @@ -59,6 +59,8 @@ private: SharedHandle peer_; SharedHandle socket_; + // Maximum payload length + size_t maxPayloadLength_; unsigned char* resbuf_; size_t resbufLength_; size_t currentPayloadLength_; @@ -122,6 +124,15 @@ public: } unsigned char* detachBuffer(); + + // Reserves buffer at least minSize. Reallocate memory if current + // buffer length < minSize + void reserveBuffer(size_t minSize); + + size_t getMaxPayloadLength() + { + return maxPayloadLength_; + } }; typedef SharedHandle PeerConnectionHandle; diff --git a/src/PeerInteractionCommand.cc b/src/PeerInteractionCommand.cc index 6c176999..592c20d9 100644 --- a/src/PeerInteractionCommand.cc +++ b/src/PeerInteractionCommand.cc @@ -170,6 +170,14 @@ PeerInteractionCommand::PeerInteractionCommand getDownloadEngine()->setNoWait(true); } } + // If the number of pieces gets bigger, the length of Bitfield + // message payload exceeds the initial buffer capacity of + // PeerConnection, which is MAX_PAYLOAD_LEN. We expand buffer as + // necessary so that PeerConnection can receive the Bitfield + // message. + size_t bitfieldPayloadSize = + 1+(requestGroup_->getDownloadContext()->getNumPieces()+7)/8; + peerConnection->reserveBuffer(bitfieldPayloadSize); SharedHandle dispatcher (new DefaultBtMessageDispatcher()); diff --git a/test/Makefile.am b/test/Makefile.am index 127bd2e2..5e328c8c 100644 --- a/test/Makefile.am +++ b/test/Makefile.am @@ -203,7 +203,8 @@ aria2c_SOURCES += BtAllowedFastMessageTest.cc\ extension_message_test_helper.h\ LpdMessageDispatcherTest.cc\ LpdMessageReceiverTest.cc\ - Bencode2Test.cc + Bencode2Test.cc\ + PeerConnectionTest.cc endif # ENABLE_BITTORRENT if ENABLE_METALINK diff --git a/test/PeerConnectionTest.cc b/test/PeerConnectionTest.cc new file mode 100644 index 00000000..d435f3b4 --- /dev/null +++ b/test/PeerConnectionTest.cc @@ -0,0 +1,37 @@ +#include "PeerConnection.h" + +#include + +#include + +#include "Peer.h" +#include "SocketCore.h" + +namespace aria2 { + +class PeerConnectionTest:public CppUnit::TestFixture { + + CPPUNIT_TEST_SUITE(PeerConnectionTest); + CPPUNIT_TEST(testReserveBuffer); + CPPUNIT_TEST_SUITE_END(); +public: + void testReserveBuffer(); +}; + +CPPUNIT_TEST_SUITE_REGISTRATION(PeerConnectionTest); + +void PeerConnectionTest::testReserveBuffer() { + PeerConnection con(1, SharedHandle(), SharedHandle()); + con.presetBuffer((unsigned char*)"foo", 3); + CPPUNIT_ASSERT_EQUAL((size_t)MAX_PAYLOAD_LEN, con.getMaxPayloadLength()); + CPPUNIT_ASSERT_EQUAL((size_t)3, con.getBufferLength()); + + size_t newLength = 32*1024; + con.reserveBuffer(newLength); + + CPPUNIT_ASSERT_EQUAL(newLength, con.getMaxPayloadLength()); + CPPUNIT_ASSERT_EQUAL((size_t)3, con.getBufferLength()); + CPPUNIT_ASSERT(memcmp("foo", con.getBuffer(), 3) == 0); +} + +} // namespace aria2