Rewritten PeerConnection::receiveMessage()

The old implementation calls at least 2 read(2) (4bytes length and
payload) to receive the message. This change will read as many bytes
as possible in one read(2) call. BtPieceMessage::data_ is now just a
const pointer to the internal buffer of PeerConnection.
pull/22/head
Tatsuhiro Tsujikawa 2012-06-25 21:29:48 +09:00
parent aa34c077cb
commit e816c5eee4
7 changed files with 141 additions and 95 deletions

View File

@ -73,13 +73,10 @@ BtPieceMessage::BtPieceMessage
} }
BtPieceMessage::~BtPieceMessage() BtPieceMessage::~BtPieceMessage()
{ {}
delete [] data_;
}
void BtPieceMessage::setRawMessage(unsigned char* data) void BtPieceMessage::setMsgPayload(const unsigned char* data)
{ {
delete [] data_;
data_ = data; data_ = data;
} }

View File

@ -51,7 +51,7 @@ private:
size_t index_; size_t index_;
int32_t begin_; int32_t begin_;
int32_t blockLength_; int32_t blockLength_;
unsigned char* data_; const unsigned char* data_;
SharedHandle<DownloadContext> downloadContext_; SharedHandle<DownloadContext> downloadContext_;
SharedHandle<PeerStorage> peerStorage_; SharedHandle<PeerStorage> peerStorage_;
@ -87,10 +87,9 @@ public:
int32_t getBlockLength() const { return blockLength_; } int32_t getBlockLength() const { return blockLength_; }
// Stores raw message data. After this function call, this object // Sets message payload data. Caller must not change or free data
// has ownership of data. Caller must not be free or alter data. // before doReceivedAction().
// Member block is pointed to block starting position in data. void setMsgPayload(const unsigned char* data);
void setRawMessage(unsigned char* data);
void setBlockLength(int32_t blockLength) { blockLength_ = blockLength; } void setBlockLength(int32_t blockLength) { blockLength_ = blockLength; }

View File

@ -121,12 +121,13 @@ BtMessageHandle DefaultBtMessageReceiver::receiveMessage() {
return SharedHandle<BtMessage>(); return SharedHandle<BtMessage>();
} }
BtMessageHandle msg = BtMessageHandle msg =
messageFactory_->createBtMessage(peerConnection_->getBuffer(), dataLength); messageFactory_->createBtMessage(peerConnection_->getMsgPayloadBuffer(),
dataLength);
msg->validate(); msg->validate();
if(msg->getId() == BtPieceMessage::ID) { if(msg->getId() == BtPieceMessage::ID) {
SharedHandle<BtPieceMessage> piecemsg = SharedHandle<BtPieceMessage> piecemsg =
static_pointer_cast<BtPieceMessage>(msg); static_pointer_cast<BtPieceMessage>(msg);
piecemsg->setRawMessage(peerConnection_->detachBuffer()); piecemsg->setMsgPayload(peerConnection_->getMsgPayloadBuffer());
} }
return msg; return msg;
} }

View File

@ -52,16 +52,29 @@
namespace aria2 { namespace aria2 {
namespace {
enum {
// Before reading first byte of message length
BT_MSG_PREV_READ_LENGTH,
// Reading 4 bytes message length
BT_MSG_READ_LENGTH,
// Reading message payload following message length
BT_MSG_READ_PAYLOAD
};
} // namespace
PeerConnection::PeerConnection PeerConnection::PeerConnection
(cuid_t cuid, const SharedHandle<Peer>& peer, const SocketHandle& socket) (cuid_t cuid, const SharedHandle<Peer>& peer, const SocketHandle& socket)
: cuid_(cuid), : cuid_(cuid),
peer_(peer), peer_(peer),
socket_(socket), socket_(socket),
maxPayloadLength_(MAX_PAYLOAD_LEN), msgState_(BT_MSG_PREV_READ_LENGTH),
resbuf_(new unsigned char[maxPayloadLength_]), bufferCapacity_(MAX_BUFFER_CAPACITY),
resbuf_(new unsigned char[bufferCapacity_]),
resbufLength_(0), resbufLength_(0),
currentPayloadLength_(0), currentPayloadLength_(0),
lenbufLength_(0), resbufOffset_(0),
msgOffset_(0),
socketBuffer_(socket), socketBuffer_(socket),
encryptionEnabled_(false), encryptionEnabled_(false),
prevPeek_(false) prevPeek_(false)
@ -80,71 +93,98 @@ void PeerConnection::pushBytes(unsigned char* data, size_t len)
socketBuffer_.pushBytes(data, len); socketBuffer_.pushBytes(data, len);
} }
bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) { bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength)
if(resbufLength_ == 0 && 4 > lenbufLength_) { {
// read payload size, 32bit unsigned integer while(1) {
size_t remaining = 4-lenbufLength_; bool done = false;
size_t temp = remaining; size_t i;
readData(lenbuf_+lenbufLength_, remaining, encryptionEnabled_); for(i = resbufOffset_; i < resbufLength_ && !done; ++i) {
if(remaining == 0) { unsigned char c = resbuf_[i];
if(socket_->wantRead() || socket_->wantWrite()) { switch(msgState_) {
return false; case(BT_MSG_PREV_READ_LENGTH):
msgOffset_ = i;
currentPayloadLength_ = 0;
msgState_ = BT_MSG_READ_LENGTH;
// Fall through
case(BT_MSG_READ_LENGTH):
currentPayloadLength_ <<= 8;
currentPayloadLength_ += c;
// The message length is uint32_t
if(i - msgOffset_ == 3) {
if(currentPayloadLength_ + 4 > bufferCapacity_) {
throw DL_ABORT_EX(fmt(EX_TOO_LONG_PAYLOAD, currentPayloadLength_));
}
if(currentPayloadLength_ == 0) {
// Length == 0 means keep-alive message.
done = true;
msgState_ = BT_MSG_PREV_READ_LENGTH;
} else {
msgState_ = BT_MSG_READ_PAYLOAD;
}
}
break;
case(BT_MSG_READ_PAYLOAD):
// We chosen the bufferCapacity_ so that whole message,
// including 4 bytes length and payload, in it. So here we
// just make sure that it happens.
if(resbufLength_ - msgOffset_ >= 4 + currentPayloadLength_) {
i = msgOffset_ + 4 + currentPayloadLength_ - 1;
done = true;
msgState_ = BT_MSG_PREV_READ_LENGTH;
} else {
// We need another read.
i = resbufLength_-1;
}
break;
} }
// we got EOF
A2_LOG_DEBUG(fmt("CUID#%lld - In PeerConnection::receiveMessage(),"
" remain=%lu",
cuid_,
static_cast<unsigned long>(temp)));
peer_->setDisconnectedGracefully(true);
throw DL_ABORT_EX(EX_EOF_FROM_PEER);
} }
lenbufLength_ += remaining; resbufOffset_ = i;
if(4 > lenbufLength_) { if(done) {
// still 4-lenbufLength_ bytes to go if(data) {
return false; memcpy(data, resbuf_ + msgOffset_ + 4, currentPayloadLength_);
} }
uint32_t payloadLength; dataLength = currentPayloadLength_;
memcpy(&payloadLength, lenbuf_, sizeof(payloadLength)); return true;
payloadLength = ntohl(payloadLength); } else {
if(payloadLength > maxPayloadLength_) { assert(resbufOffset_ == resbufLength_);
throw DL_ABORT_EX(fmt(EX_TOO_LONG_PAYLOAD, payloadLength)); if(resbufLength_ != 0) {
} if(msgOffset_ == 0 && resbufLength_ == currentPayloadLength_ + 4) {
currentPayloadLength_ = payloadLength; // All bytes in buffer have been processed, so clear it
} // away.
if(!socket_->isReadable(0)) { resbufLength_ = 0;
return false; resbufOffset_ = 0;
} msgOffset_ = 0;
// we have currentPayloadLen-resbufLen bytes to read } else {
size_t remaining = currentPayloadLength_-resbufLength_; // Shift buffer so that resbuf_[msgOffset_] moves to
size_t temp = remaining; // rebuf_[0].
if(remaining > 0) { memmove(resbuf_, resbuf_ + msgOffset_, resbufLength_ - msgOffset_);
readData(resbuf_+resbufLength_, remaining, encryptionEnabled_); resbufLength_ -= msgOffset_;
if(remaining == 0) { resbufOffset_ = resbufLength_;
if(socket_->wantRead() || socket_->wantWrite()) { msgOffset_ = 0;
return false; }
}
size_t nread;
// To reduce the amount of copy involved in buffer shift, large
// payload will be read exactly.
if(currentPayloadLength_ > 4096) {
nread = currentPayloadLength_ + 4 - resbufLength_;
} else {
nread = bufferCapacity_ - resbufLength_;
}
readData(resbuf_+resbufLength_, nread, encryptionEnabled_);
if(nread == 0) {
if(socket_->wantRead() || socket_->wantWrite()) {
break;
} else {
peer_->setDisconnectedGracefully(true);
throw DL_ABORT_EX(EX_EOF_FROM_PEER);
}
} else {
resbufLength_ += nread;
} }
// we got EOF
A2_LOG_DEBUG(fmt("CUID#%lld - In PeerConnection::receiveMessage(),"
" payloadlen=%lu, remaining=%lu",
cuid_,
static_cast<unsigned long>(currentPayloadLength_),
static_cast<unsigned long>(temp)));
peer_->setDisconnectedGracefully(true);
throw DL_ABORT_EX(EX_EOF_FROM_PEER);
}
resbufLength_ += remaining;
if(currentPayloadLength_ > resbufLength_) {
return false;
} }
} }
// we got whole payload. return false;
resbufLength_ = 0;
lenbufLength_ = 0;
if(data) {
memcpy(data, resbuf_, currentPayloadLength_);
}
dataLength = currentPayloadLength_;
return true;
} }
bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength, bool PeerConnection::receiveHandshake(unsigned char* data, size_t& dataLength,
@ -202,7 +242,7 @@ void PeerConnection::enableEncryption
void PeerConnection::presetBuffer(const unsigned char* data, size_t length) void PeerConnection::presetBuffer(const unsigned char* data, size_t length)
{ {
size_t nwrite = std::min(maxPayloadLength_, length); size_t nwrite = std::min(bufferCapacity_, length);
memcpy(resbuf_, data, nwrite); memcpy(resbuf_, data, nwrite);
resbufLength_ = length; resbufLength_ = length;
} }
@ -219,18 +259,16 @@ ssize_t PeerConnection::sendPendingData()
return writtenLength; return writtenLength;
} }
unsigned char* PeerConnection::detachBuffer() const unsigned char* PeerConnection::getMsgPayloadBuffer() const
{ {
unsigned char* detachbuf = resbuf_; return resbuf_ + msgOffset_ + 4;
resbuf_ = new unsigned char[maxPayloadLength_];
return detachbuf;
} }
void PeerConnection::reserveBuffer(size_t minSize) void PeerConnection::reserveBuffer(size_t minSize)
{ {
if(maxPayloadLength_ < minSize) { if(bufferCapacity_ < minSize) {
maxPayloadLength_ = minSize; bufferCapacity_ = minSize;
unsigned char *buf = new unsigned char[maxPayloadLength_]; unsigned char *buf = new unsigned char[bufferCapacity_];
memcpy(buf, resbuf_, resbufLength_); memcpy(buf, resbuf_, resbufLength_);
delete [] resbuf_; delete [] resbuf_;
resbuf_ = buf; resbuf_ = buf;

View File

@ -49,9 +49,10 @@ class Peer;
class SocketCore; class SocketCore;
class ARC4Encryptor; class ARC4Encryptor;
// The maximum length of payload. Messages beyond that length are // The maximum length of buffer. If the message length (including 4
// bytes length and payload length) is larger than this value, it is
// dropped. // dropped.
#define MAX_PAYLOAD_LEN (16*1024+128) #define MAX_BUFFER_CAPACITY (16*1024+128)
class PeerConnection { class PeerConnection {
private: private:
@ -59,13 +60,19 @@ private:
SharedHandle<Peer> peer_; SharedHandle<Peer> peer_;
SharedHandle<SocketCore> socket_; SharedHandle<SocketCore> socket_;
// Maximum payload length int msgState_;
size_t maxPayloadLength_; // The capacity of the buffer resbuf_
size_t bufferCapacity_;
// The internal buffer of incoming handshakes and messages
unsigned char* resbuf_; unsigned char* resbuf_;
// The number of bytes written in resbuf_
size_t resbufLength_; size_t resbufLength_;
size_t currentPayloadLength_; // The length of message (not handshake) currently receiving
unsigned char lenbuf_[4]; uint32_t currentPayloadLength_;
size_t lenbufLength_; // The number of bytes processed in resbuf_
size_t resbufOffset_;
// The offset in resbuf_ where the 4 bytes message length begins
size_t msgOffset_;
SocketBuffer socketBuffer_; SocketBuffer socketBuffer_;
@ -123,15 +130,17 @@ public:
return resbufLength_; return resbufLength_;
} }
unsigned char* detachBuffer(); // Returns the pointer to the message in wire format. This method
// must be called after receiveMessage() returned true.
const unsigned char* getMsgPayloadBuffer() const;
// Reserves buffer at least minSize. Reallocate memory if current // Reserves buffer at least minSize. Reallocate memory if current
// buffer length < minSize // buffer length < minSize
void reserveBuffer(size_t minSize); void reserveBuffer(size_t minSize);
size_t getMaxPayloadLength() size_t getBufferCapacity()
{ {
return maxPayloadLength_; return bufferCapacity_;
} }
}; };

View File

@ -79,6 +79,8 @@ public:
{ {
return bufLen_ == 0; return bufLen_ == 0;
} }
void pushBuffer(const unsigned char* data, size_t len);
private: private:
SharedHandle<SocketCore> socket_; SharedHandle<SocketCore> socket_;
size_t capacity_; size_t capacity_;

View File

@ -23,13 +23,13 @@ CPPUNIT_TEST_SUITE_REGISTRATION(PeerConnectionTest);
void PeerConnectionTest::testReserveBuffer() { void PeerConnectionTest::testReserveBuffer() {
PeerConnection con(1, SharedHandle<Peer>(), SharedHandle<SocketCore>()); PeerConnection con(1, SharedHandle<Peer>(), SharedHandle<SocketCore>());
con.presetBuffer((unsigned char*)"foo", 3); con.presetBuffer((unsigned char*)"foo", 3);
CPPUNIT_ASSERT_EQUAL((size_t)MAX_PAYLOAD_LEN, con.getMaxPayloadLength()); CPPUNIT_ASSERT_EQUAL((size_t)MAX_BUFFER_CAPACITY, con.getBufferCapacity());
CPPUNIT_ASSERT_EQUAL((size_t)3, con.getBufferLength()); CPPUNIT_ASSERT_EQUAL((size_t)3, con.getBufferLength());
size_t newLength = 32*1024; size_t newLength = 32*1024;
con.reserveBuffer(newLength); con.reserveBuffer(newLength);
CPPUNIT_ASSERT_EQUAL(newLength, con.getMaxPayloadLength()); CPPUNIT_ASSERT_EQUAL(newLength, con.getBufferCapacity());
CPPUNIT_ASSERT_EQUAL((size_t)3, con.getBufferLength()); CPPUNIT_ASSERT_EQUAL((size_t)3, con.getBufferLength());
CPPUNIT_ASSERT(memcmp("foo", con.getBuffer(), 3) == 0); CPPUNIT_ASSERT(memcmp("foo", con.getBuffer(), 3) == 0);
} }