mirror of https://github.com/aria2/aria2
Reserve PeerConnection's buffer capacity according to number of pieces.
If the number of pieces gets bigger, the length of Bitfield message payload exceeds the initial buffer capacity of PeerConnection, which is MAX_PAYLOAD_LEN. We expand buffer as necessary so that PeerConnection can receive the Bitfield message.pull/12/head
parent
7489bbe1a7
commit
9b7e4219d9
|
@ -104,8 +104,8 @@ unsigned char* BtBitfieldMessage::createMessage() {
|
|||
/**
|
||||
* len --- 1+bitfieldLength, 4bytes
|
||||
* id --- 5, 1byte
|
||||
* bitfield --- bitfield, len bytes
|
||||
* total: 5+len bytes
|
||||
* bitfield --- bitfield, bitfieldLength bytes
|
||||
* total: 5+bitfieldLength bytes
|
||||
*/
|
||||
const size_t msgLength = 5+bitfieldLength_;
|
||||
unsigned char* msg = new unsigned char[msgLength];
|
||||
|
|
|
@ -57,7 +57,8 @@ PeerConnection::PeerConnection
|
|||
: cuid_(cuid),
|
||||
peer_(peer),
|
||||
socket_(socket),
|
||||
resbuf_(new unsigned char[MAX_PAYLOAD_LEN]),
|
||||
maxPayloadLength_(MAX_PAYLOAD_LEN),
|
||||
resbuf_(new unsigned char[maxPayloadLength_]),
|
||||
resbufLength_(0),
|
||||
currentPayloadLength_(0),
|
||||
lenbufLength_(0),
|
||||
|
@ -105,7 +106,7 @@ bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) {
|
|||
uint32_t payloadLength;
|
||||
memcpy(&payloadLength, lenbuf_, sizeof(payloadLength));
|
||||
payloadLength = ntohl(payloadLength);
|
||||
if(payloadLength > MAX_PAYLOAD_LEN) {
|
||||
if(payloadLength > maxPayloadLength_) {
|
||||
throw DL_ABORT_EX(fmt(EX_TOO_LONG_PAYLOAD, payloadLength));
|
||||
}
|
||||
currentPayloadLength_ = payloadLength;
|
||||
|
@ -201,7 +202,7 @@ void PeerConnection::enableEncryption
|
|||
|
||||
void PeerConnection::presetBuffer(const unsigned char* data, size_t length)
|
||||
{
|
||||
size_t nwrite = std::min((size_t)MAX_PAYLOAD_LEN, length);
|
||||
size_t nwrite = std::min(maxPayloadLength_, length);
|
||||
memcpy(resbuf_, data, nwrite);
|
||||
resbufLength_ = length;
|
||||
}
|
||||
|
@ -221,8 +222,19 @@ ssize_t PeerConnection::sendPendingData()
|
|||
unsigned char* PeerConnection::detachBuffer()
|
||||
{
|
||||
unsigned char* detachbuf = resbuf_;
|
||||
resbuf_ = new unsigned char[MAX_PAYLOAD_LEN];
|
||||
resbuf_ = new unsigned char[maxPayloadLength_];
|
||||
return detachbuf;
|
||||
}
|
||||
|
||||
void PeerConnection::reserveBuffer(size_t minSize)
|
||||
{
|
||||
if(maxPayloadLength_ < minSize) {
|
||||
maxPayloadLength_ = minSize;
|
||||
unsigned char *buf = new unsigned char[maxPayloadLength_];
|
||||
memcpy(buf, resbuf_, resbufLength_);
|
||||
delete [] resbuf_;
|
||||
resbuf_ = buf;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace aria2
|
||||
|
|
|
@ -59,6 +59,8 @@ private:
|
|||
SharedHandle<Peer> peer_;
|
||||
SharedHandle<SocketCore> socket_;
|
||||
|
||||
// Maximum payload length
|
||||
size_t maxPayloadLength_;
|
||||
unsigned char* resbuf_;
|
||||
size_t resbufLength_;
|
||||
size_t currentPayloadLength_;
|
||||
|
@ -122,6 +124,15 @@ public:
|
|||
}
|
||||
|
||||
unsigned char* detachBuffer();
|
||||
|
||||
// Reserves buffer at least minSize. Reallocate memory if current
|
||||
// buffer length < minSize
|
||||
void reserveBuffer(size_t minSize);
|
||||
|
||||
size_t getMaxPayloadLength()
|
||||
{
|
||||
return maxPayloadLength_;
|
||||
}
|
||||
};
|
||||
|
||||
typedef SharedHandle<PeerConnection> PeerConnectionHandle;
|
||||
|
|
|
@ -170,6 +170,14 @@ PeerInteractionCommand::PeerInteractionCommand
|
|||
getDownloadEngine()->setNoWait(true);
|
||||
}
|
||||
}
|
||||
// If the number of pieces gets bigger, the length of Bitfield
|
||||
// message payload exceeds the initial buffer capacity of
|
||||
// PeerConnection, which is MAX_PAYLOAD_LEN. We expand buffer as
|
||||
// necessary so that PeerConnection can receive the Bitfield
|
||||
// message.
|
||||
size_t bitfieldPayloadSize =
|
||||
1+(requestGroup_->getDownloadContext()->getNumPieces()+7)/8;
|
||||
peerConnection->reserveBuffer(bitfieldPayloadSize);
|
||||
|
||||
SharedHandle<DefaultBtMessageDispatcher> dispatcher
|
||||
(new DefaultBtMessageDispatcher());
|
||||
|
|
|
@ -203,7 +203,8 @@ aria2c_SOURCES += BtAllowedFastMessageTest.cc\
|
|||
extension_message_test_helper.h\
|
||||
LpdMessageDispatcherTest.cc\
|
||||
LpdMessageReceiverTest.cc\
|
||||
Bencode2Test.cc
|
||||
Bencode2Test.cc\
|
||||
PeerConnectionTest.cc
|
||||
endif # ENABLE_BITTORRENT
|
||||
|
||||
if ENABLE_METALINK
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
#include "PeerConnection.h"
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include <cppunit/extensions/HelperMacros.h>
|
||||
|
||||
#include "Peer.h"
|
||||
#include "SocketCore.h"
|
||||
|
||||
namespace aria2 {
|
||||
|
||||
class PeerConnectionTest:public CppUnit::TestFixture {
|
||||
|
||||
CPPUNIT_TEST_SUITE(PeerConnectionTest);
|
||||
CPPUNIT_TEST(testReserveBuffer);
|
||||
CPPUNIT_TEST_SUITE_END();
|
||||
public:
|
||||
void testReserveBuffer();
|
||||
};
|
||||
|
||||
CPPUNIT_TEST_SUITE_REGISTRATION(PeerConnectionTest);
|
||||
|
||||
void PeerConnectionTest::testReserveBuffer() {
|
||||
PeerConnection con(1, SharedHandle<Peer>(), SharedHandle<SocketCore>());
|
||||
con.presetBuffer((unsigned char*)"foo", 3);
|
||||
CPPUNIT_ASSERT_EQUAL((size_t)MAX_PAYLOAD_LEN, con.getMaxPayloadLength());
|
||||
CPPUNIT_ASSERT_EQUAL((size_t)3, con.getBufferLength());
|
||||
|
||||
size_t newLength = 32*1024;
|
||||
con.reserveBuffer(newLength);
|
||||
|
||||
CPPUNIT_ASSERT_EQUAL(newLength, con.getMaxPayloadLength());
|
||||
CPPUNIT_ASSERT_EQUAL((size_t)3, con.getBufferLength());
|
||||
CPPUNIT_ASSERT(memcmp("foo", con.getBuffer(), 3) == 0);
|
||||
}
|
||||
|
||||
} // namespace aria2
|
Loading…
Reference in New Issue