Vectorized write for SocketBuffer to avoid small packet

pull/43/head
Tatsuhiro Tsujikawa 2013-01-11 14:20:34 +09:00
parent 74e570de37
commit 8ba0d58ee1
32 changed files with 388 additions and 154 deletions

View File

@ -375,6 +375,7 @@ AC_CHECK_HEADERS([argz.h \
sys/socket.h \ sys/socket.h \
sys/time.h \ sys/time.h \
sys/types.h \ sys/types.h \
sys/uio.h \
termios.h \ termios.h \
unistd.h \ unistd.h \
utime.h \ utime.h \

View File

@ -36,6 +36,7 @@
#include "DlAbortEx.h" #include "DlAbortEx.h"
#include "Peer.h" #include "Peer.h"
#include "fmt.h" #include "fmt.h"
#include "SocketBuffer.h"
namespace aria2 { namespace aria2 {
@ -62,8 +63,24 @@ void BtAllowedFastMessage::doReceivedAction() {
getPeer()->addPeerAllowedIndex(getIndex()); getPeer()->addPeerAllowedIndex(getIndex());
} }
void BtAllowedFastMessage::onSendComplete() { namespace {
getPeer()->addAmAllowedIndex(getIndex()); struct ThisProgressUpdate : public ProgressUpdate {
ThisProgressUpdate(const SharedHandle<Peer>& peer, size_t index)
: peer(peer), index(index) {}
virtual void update(size_t length, bool complete)
{
if(complete) {
peer->addAmAllowedIndex(index);
}
}
SharedHandle<Peer> peer;
size_t index;
};
} // namespace
ProgressUpdate* BtAllowedFastMessage::getProgressUpdate()
{
return new ThisProgressUpdate(getPeer(), getIndex());
} }
} // namespace aria2 } // namespace aria2

View File

@ -52,8 +52,7 @@ public:
virtual void doReceivedAction(); virtual void doReceivedAction();
virtual void onSendComplete(); virtual ProgressUpdate* getProgressUpdate();
}; };
} // namespace aria2 } // namespace aria2

View File

@ -36,6 +36,7 @@
#include "Peer.h" #include "Peer.h"
#include "BtMessageDispatcher.h" #include "BtMessageDispatcher.h"
#include "BtRequestFactory.h" #include "BtRequestFactory.h"
#include "SocketBuffer.h"
namespace aria2 { namespace aria2 {
@ -64,10 +65,26 @@ bool BtChokeMessage::sendPredicate() const
return !getPeer()->amChoking(); return !getPeer()->amChoking();
} }
void BtChokeMessage::onSendComplete() namespace {
struct ThisProgressUpdate : public ProgressUpdate {
ThisProgressUpdate(const SharedHandle<Peer>& peer,
BtMessageDispatcher* disp)
: peer(peer), disp(disp) {}
virtual void update(size_t length, bool complete)
{
if(complete) {
peer->amChoking(true);
disp->doChokingAction();
}
}
SharedHandle<Peer> peer;
BtMessageDispatcher* disp;
};
} // namespace
ProgressUpdate* BtChokeMessage::getProgressUpdate()
{ {
getPeer()->amChoking(true); return new ThisProgressUpdate(getPeer(), getBtMessageDispatcher());
getBtMessageDispatcher()->doChokingAction();
} }
} // namespace aria2 } // namespace aria2

View File

@ -53,7 +53,7 @@ public:
virtual bool sendPredicate() const; virtual bool sendPredicate() const;
virtual void onSendComplete(); virtual ProgressUpdate* getProgressUpdate();
}; };
} // namespace aria2 } // namespace aria2

View File

@ -35,6 +35,7 @@
#include "BtInterestedMessage.h" #include "BtInterestedMessage.h"
#include "Peer.h" #include "Peer.h"
#include "PeerStorage.h" #include "PeerStorage.h"
#include "SocketBuffer.h"
namespace aria2 { namespace aria2 {
@ -66,8 +67,23 @@ bool BtInterestedMessage::sendPredicate() const
return !getPeer()->amInterested(); return !getPeer()->amInterested();
} }
void BtInterestedMessage::onSendComplete() { namespace {
getPeer()->amInterested(true); struct ThisProgressUpdate : public ProgressUpdate {
ThisProgressUpdate(const SharedHandle<Peer>& peer)
: peer(peer) {}
virtual void update(size_t length, bool complete)
{
if(complete) {
peer->amInterested(true);
}
}
SharedHandle<Peer> peer;
};
} // namespace
ProgressUpdate* BtInterestedMessage::getProgressUpdate()
{
return new ThisProgressUpdate(getPeer());
} }
void BtInterestedMessage::setPeerStorage void BtInterestedMessage::setPeerStorage

View File

@ -60,7 +60,7 @@ public:
virtual bool sendPredicate() const; virtual bool sendPredicate() const;
virtual void onSendComplete(); virtual ProgressUpdate* getProgressUpdate();
void setPeerStorage(const SharedHandle<PeerStorage>& peerStorage); void setPeerStorage(const SharedHandle<PeerStorage>& peerStorage);
}; };

View File

@ -35,6 +35,7 @@
#include "BtNotInterestedMessage.h" #include "BtNotInterestedMessage.h"
#include "Peer.h" #include "Peer.h"
#include "PeerStorage.h" #include "PeerStorage.h"
#include "SocketBuffer.h"
namespace aria2 { namespace aria2 {
@ -66,8 +67,23 @@ bool BtNotInterestedMessage::sendPredicate() const
return getPeer()->amInterested(); return getPeer()->amInterested();
} }
void BtNotInterestedMessage::onSendComplete() { namespace {
getPeer()->amInterested(false); struct ThisProgressUpdate : public ProgressUpdate {
ThisProgressUpdate(const SharedHandle<Peer>& peer)
: peer(peer) {}
virtual void update(size_t length, bool complete)
{
if(complete) {
peer->amInterested(false);
}
}
SharedHandle<Peer> peer;
};
} // namespace
ProgressUpdate* BtNotInterestedMessage::getProgressUpdate()
{
return new ThisProgressUpdate(getPeer());
} }
void BtNotInterestedMessage::setPeerStorage void BtNotInterestedMessage::setPeerStorage

View File

@ -60,7 +60,7 @@ public:
virtual bool sendPredicate() const; virtual bool sendPredicate() const;
virtual void onSendComplete(); virtual ProgressUpdate* getProgressUpdate();
void setPeerStorage(const SharedHandle<PeerStorage>& peerStorage); void setPeerStorage(const SharedHandle<PeerStorage>& peerStorage);
}; };

View File

@ -71,7 +71,6 @@ BtPieceMessage::BtPieceMessage
index_(index), index_(index),
begin_(begin), begin_(begin),
blockLength_(blockLength), blockLength_(blockLength),
msgHdrLen_(0),
data_(0) data_(0)
{ {
setUploading(true); setUploading(true);
@ -180,39 +179,33 @@ size_t BtPieceMessage::getMessageHeaderLength()
return MESSAGE_HEADER_LENGTH; return MESSAGE_HEADER_LENGTH;
} }
namespace {
struct PieceSendUpdate : public ProgressUpdate {
PieceSendUpdate(const SharedHandle<Peer>& peer)
: peer(peer) {}
virtual void update(size_t length, bool complete)
{
peer->updateUploadLength(length);
}
SharedHandle<Peer> peer;
};
} // namespace
void BtPieceMessage::send() void BtPieceMessage::send()
{ {
if(isInvalidate()) { if(isInvalidate()) {
return; return;
} }
size_t writtenLength; A2_LOG_INFO(fmt(MSG_SEND_PEER_MESSAGE,
if(!isSendingInProgress()) { getCuid(),
A2_LOG_INFO(fmt(MSG_SEND_PEER_MESSAGE, getPeer()->getIPAddress().c_str(),
getCuid(), getPeer()->getPort(),
getPeer()->getIPAddress().c_str(), toString().c_str()));
getPeer()->getPort(), getPeerConnection()->pushBytes(createMessageHeader(),
toString().c_str())); getMessageHeaderLength());
unsigned char* msgHdr = createMessageHeader(); int64_t pieceDataOffset =
msgHdrLen_ = getMessageHeaderLength(); static_cast<int64_t>(index_)*downloadContext_->getPieceLength()+begin_;
A2_LOG_DEBUG(fmt("msglength = %lu bytes", pushPieceData(pieceDataOffset, blockLength_);
static_cast<unsigned long>(msgHdrLen_+blockLength_)));
getPeerConnection()->pushBytes(msgHdr, msgHdrLen_);
int64_t pieceDataOffset =
static_cast<int64_t>(index_)*downloadContext_->getPieceLength()+begin_;
pushPieceData(pieceDataOffset, blockLength_);
}
writtenLength = getPeerConnection()->sendPendingData();
// Subtract msgHdrLen_ from writtenLength to get the uploaded data
// size.
if(writtenLength > msgHdrLen_) {
writtenLength -= msgHdrLen_;
msgHdrLen_ = 0;
getPeer()->updateUploadLength(writtenLength);
downloadContext_->updateUploadLength(writtenLength);
} else {
msgHdrLen_ -= writtenLength;
}
setSendingInProgress(!getPeerConnection()->sendBufferIsEmpty());
} }
void BtPieceMessage::pushPieceData(int64_t offset, int32_t length) const void BtPieceMessage::pushPieceData(int64_t offset, int32_t length) const
@ -224,7 +217,11 @@ void BtPieceMessage::pushPieceData(int64_t offset, int32_t length) const
if(r == length) { if(r == length) {
unsigned char* dbuf = buf; unsigned char* dbuf = buf;
buf.reset(0); buf.reset(0);
getPeerConnection()->pushBytes(dbuf, length); getPeerConnection()->pushBytes(dbuf, length,
new PieceSendUpdate(getPeer()));
// To avoid upload rate overflow, we update the length here at
// once.
downloadContext_->updateUploadLength(length);
} else { } else {
throw DL_ABORT_EX(EX_DATA_READ); throw DL_ABORT_EX(EX_DATA_READ);
} }

View File

@ -48,7 +48,6 @@ private:
size_t index_; size_t index_;
int32_t begin_; int32_t begin_;
int32_t blockLength_; int32_t blockLength_;
size_t msgHdrLen_;
const unsigned char* data_; const unsigned char* data_;
SharedHandle<DownloadContext> downloadContext_; SharedHandle<DownloadContext> downloadContext_;
SharedHandle<PeerStorage> peerStorage_; SharedHandle<PeerStorage> peerStorage_;

View File

@ -34,6 +34,7 @@
/* copyright --> */ /* copyright --> */
#include "BtUnchokeMessage.h" #include "BtUnchokeMessage.h"
#include "Peer.h" #include "Peer.h"
#include "SocketBuffer.h"
namespace aria2 { namespace aria2 {
@ -60,8 +61,23 @@ bool BtUnchokeMessage::sendPredicate() const
return getPeer()->amChoking(); return getPeer()->amChoking();
} }
void BtUnchokeMessage::onSendComplete() { namespace {
getPeer()->amChoking(false); struct ThisProgressUpdate : public ProgressUpdate {
ThisProgressUpdate(const SharedHandle<Peer>& peer)
: peer(peer) {}
virtual void update(size_t length, bool complete)
{
if(complete) {
peer->amChoking(false);
}
}
SharedHandle<Peer> peer;
};
} // namespace
ProgressUpdate* BtUnchokeMessage::getProgressUpdate()
{
return new ThisProgressUpdate(getPeer());
} }
} // namespace aria2 } // namespace aria2

View File

@ -56,7 +56,7 @@ public:
virtual bool sendPredicate() const; virtual bool sendPredicate() const;
virtual void onSendComplete(); virtual ProgressUpdate* getProgressUpdate();
}; };
} // namespace aria2 } // namespace aria2

View File

@ -56,6 +56,7 @@
#include "RequestGroup.h" #include "RequestGroup.h"
#include "util.h" #include "util.h"
#include "fmt.h" #include "fmt.h"
#include "PeerConnection.h"
namespace aria2 { namespace aria2 {
@ -87,7 +88,8 @@ void DefaultBtMessageDispatcher::addMessageToQueue
} }
} }
void DefaultBtMessageDispatcher::sendMessages() { void DefaultBtMessageDispatcher::sendMessagesInternal()
{
std::vector<SharedHandle<BtMessage> > tempQueue; std::vector<SharedHandle<BtMessage> > tempQueue;
while(!messageQueue_.empty()) { while(!messageQueue_.empty()) {
SharedHandle<BtMessage> msg = messageQueue_.front(); SharedHandle<BtMessage> msg = messageQueue_.front();
@ -100,10 +102,6 @@ void DefaultBtMessageDispatcher::sendMessages() {
} }
} }
msg->send(); msg->send();
if(msg->isSendingInProgress()) {
messageQueue_.push_front(msg);
break;
}
} }
if(!tempQueue.empty()) { if(!tempQueue.empty()) {
// Insert pending message to the front, so that message is likely sent in // Insert pending message to the front, so that message is likely sent in
@ -118,6 +116,16 @@ void DefaultBtMessageDispatcher::sendMessages() {
} }
} }
void DefaultBtMessageDispatcher::sendMessages() {
// First flush any pending data in the buffer.
peerConnection_->sendPendingData();
if(!peerConnection_->sendBufferIsEmpty()) {
return;
}
sendMessagesInternal();
peerConnection_->sendPendingData();
}
// Cancel sending piece message to peer. // Cancel sending piece message to peer.
void DefaultBtMessageDispatcher::doCancelSendingPieceAction void DefaultBtMessageDispatcher::doCancelSendingPieceAction
(size_t index, int32_t begin, int32_t length) (size_t index, int32_t begin, int32_t length)

View File

@ -52,6 +52,7 @@ class BtMessageFactory;
class Peer; class Peer;
class Piece; class Piece;
class RequestGroupMan; class RequestGroupMan;
class PeerConnection;
class DefaultBtMessageDispatcher : public BtMessageDispatcher { class DefaultBtMessageDispatcher : public BtMessageDispatcher {
private: private:
@ -61,6 +62,7 @@ private:
SharedHandle<DownloadContext> downloadContext_; SharedHandle<DownloadContext> downloadContext_;
SharedHandle<PeerStorage> peerStorage_; SharedHandle<PeerStorage> peerStorage_;
SharedHandle<PieceStorage> pieceStorage_; SharedHandle<PieceStorage> pieceStorage_;
SharedHandle<PeerConnection> peerConnection_;
BtMessageFactory* messageFactory_; BtMessageFactory* messageFactory_;
SharedHandle<Peer> peer_; SharedHandle<Peer> peer_;
RequestGroupMan* requestGroupMan_; RequestGroupMan* requestGroupMan_;
@ -77,6 +79,9 @@ public:
virtual void sendMessages(); virtual void sendMessages();
// For unit tests without PeerConnection
void sendMessagesInternal();
virtual void doCancelSendingPieceAction virtual void doCancelSendingPieceAction
(size_t index, int32_t begin, int32_t length); (size_t index, int32_t begin, int32_t length);
@ -143,6 +148,11 @@ public:
{ {
requestTimeout_ = requestTimeout; requestTimeout_ = requestTimeout;
} }
void setPeerConnection(const SharedHandle<PeerConnection>& peerConnection)
{
peerConnection_ = peerConnection;
}
}; };
} // namespace aria2 } // namespace aria2

View File

@ -322,7 +322,7 @@ void DefaultPeerStorage::returnPeer(const SharedHandle<Peer>& peer)
bool DefaultPeerStorage::chokeRoundIntervalElapsed() bool DefaultPeerStorage::chokeRoundIntervalElapsed()
{ {
const time_t CHOKE_ROUND_INTERVAL = 10; const time_t CHOKE_ROUND_INTERVAL = 1;//10;
if(pieceStorage_->downloadFinished()) { if(pieceStorage_->downloadFinished()) {
return seederStateChoke_->getLastRound(). return seederStateChoke_->getLastRound().
difference(global::wallclock()) >= CHOKE_ROUND_INTERVAL; difference(global::wallclock()) >= CHOKE_ROUND_INTERVAL;

View File

@ -85,12 +85,13 @@ PeerConnection::~PeerConnection()
delete [] resbuf_; delete [] resbuf_;
} }
void PeerConnection::pushBytes(unsigned char* data, size_t len) void PeerConnection::pushBytes(unsigned char* data, size_t len,
ProgressUpdate* progressUpdate)
{ {
if(encryptionEnabled_) { if(encryptionEnabled_) {
encryptor_->encrypt(len, data, data); encryptor_->encrypt(len, data, data);
} }
socketBuffer_.pushBytes(data, len); socketBuffer_.pushBytes(data, len, progressUpdate);
} }
bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength)

View File

@ -96,7 +96,8 @@ public:
// Pushes data into send buffer. After this call, this object gets // Pushes data into send buffer. After this call, this object gets
// ownership of data, so caller must not delete or alter it. // ownership of data, so caller must not delete or alter it.
void pushBytes(unsigned char* data, size_t len); void pushBytes(unsigned char* data, size_t len,
ProgressUpdate* progressUpdate = 0);
void pushStr(const std::string& data); void pushStr(const std::string& data);

View File

@ -200,6 +200,7 @@ PeerInteractionCommand::PeerInteractionCommand
dispatcherPtr->setBtMessageFactory(factoryPtr); dispatcherPtr->setBtMessageFactory(factoryPtr);
dispatcherPtr->setRequestGroupMan dispatcherPtr->setRequestGroupMan
(getDownloadEngine()->getRequestGroupMan().get()); (getDownloadEngine()->getRequestGroupMan().get());
dispatcherPtr->setPeerConnection(peerConnection);
SharedHandle<BtMessageDispatcher> dispatcher(dispatcherPtr); SharedHandle<BtMessageDispatcher> dispatcher(dispatcherPtr);
DefaultBtMessageReceiver* receiverPtr(new DefaultBtMessageReceiver()); DefaultBtMessageReceiver* receiverPtr(new DefaultBtMessageReceiver());

View File

@ -48,29 +48,19 @@ SimpleBtMessage::SimpleBtMessage(uint8_t id, const char* name)
{} {}
void SimpleBtMessage::send() { void SimpleBtMessage::send() {
if(isInvalidate()) { if(isInvalidate() || !sendPredicate()) {
return; return;
} }
if(!sendPredicate() && !isSendingInProgress()) { A2_LOG_INFO(fmt(MSG_SEND_PEER_MESSAGE,
return; getCuid(),
} getPeer()->getIPAddress().c_str(),
if(!isSendingInProgress()) { getPeer()->getPort(),
A2_LOG_INFO(fmt(MSG_SEND_PEER_MESSAGE, toString().c_str()));
getCuid(), unsigned char* msg = createMessage();
getPeer()->getIPAddress().c_str(), size_t msgLength = getMessageLength();
getPeer()->getPort(), A2_LOG_DEBUG(fmt("msglength = %lu bytes",
toString().c_str())); static_cast<unsigned long>(msgLength)));
unsigned char* msg = createMessage(); getPeerConnection()->pushBytes(msg, msgLength, getProgressUpdate());
size_t msgLength = getMessageLength();
A2_LOG_DEBUG(fmt("msglength = %lu bytes",
static_cast<unsigned long>(msgLength)));
getPeerConnection()->pushBytes(msg, msgLength);
}
getPeerConnection()->sendPendingData();
setSendingInProgress(!getPeerConnection()->sendBufferIsEmpty());
if(!isSendingInProgress()) {
onSendComplete();
}
} }
} // namespace aria2 } // namespace aria2

View File

@ -39,6 +39,8 @@
namespace aria2 { namespace aria2 {
class ProgressUpdate;
class SimpleBtMessage : public AbstractBtMessage { class SimpleBtMessage : public AbstractBtMessage {
public: public:
SimpleBtMessage(uint8_t id, const char* name); SimpleBtMessage(uint8_t id, const char* name);
@ -49,7 +51,7 @@ public:
virtual size_t getMessageLength() = 0; virtual size_t getMessageLength() = 0;
virtual void onSendComplete() {}; virtual ProgressUpdate* getProgressUpdate() { return 0; };
virtual bool sendPredicate() const { return true; }; virtual bool sendPredicate() const { return true; };

View File

@ -41,12 +41,13 @@
#include "DlAbortEx.h" #include "DlAbortEx.h"
#include "message.h" #include "message.h"
#include "fmt.h" #include "fmt.h"
#include "LogFactory.h"
namespace aria2 { namespace aria2 {
SocketBuffer::ByteArrayBufEntry::ByteArrayBufEntry SocketBuffer::ByteArrayBufEntry::ByteArrayBufEntry
(unsigned char* bytes, size_t length) (unsigned char* bytes, size_t length, ProgressUpdate* progressUpdate)
: bytes_(bytes), length_(length) : BufEntry(progressUpdate), bytes_(bytes), length_(length)
{} {}
SocketBuffer::ByteArrayBufEntry::~ByteArrayBufEntry() SocketBuffer::ByteArrayBufEntry::~ByteArrayBufEntry()
@ -65,11 +66,22 @@ bool SocketBuffer::ByteArrayBufEntry::final(size_t offset) const
return length_ <= offset; return length_ <= offset;
} }
SocketBuffer::StringBufEntry::StringBufEntry(const std::string& s) size_t SocketBuffer::ByteArrayBufEntry::getLength() const
: str_(s) {
return length_;
}
const unsigned char* SocketBuffer::ByteArrayBufEntry::getData() const
{
return bytes_;
}
SocketBuffer::StringBufEntry::StringBufEntry(const std::string& s,
ProgressUpdate* progressUpdate)
: BufEntry(progressUpdate), str_(s)
{} {}
SocketBuffer::StringBufEntry::StringBufEntry() {} // SocketBuffer::StringBufEntry::StringBufEntry() {}
ssize_t SocketBuffer::StringBufEntry::send ssize_t SocketBuffer::StringBufEntry::send
(const SharedHandle<SocketCore>& socket, size_t offset) (const SharedHandle<SocketCore>& socket, size_t offset)
@ -82,6 +94,16 @@ bool SocketBuffer::StringBufEntry::final(size_t offset) const
return str_.size() <= offset; return str_.size() <= offset;
} }
size_t SocketBuffer::StringBufEntry::getLength() const
{
return str_.size();
}
const unsigned char* SocketBuffer::StringBufEntry::getData() const
{
return reinterpret_cast<const unsigned char*>(str_.c_str());
}
void SocketBuffer::StringBufEntry::swap(std::string& s) void SocketBuffer::StringBufEntry::swap(std::string& s)
{ {
str_.swap(s); str_.swap(s);
@ -92,38 +114,82 @@ SocketBuffer::SocketBuffer(const SharedHandle<SocketCore>& socket):
SocketBuffer::~SocketBuffer() {} SocketBuffer::~SocketBuffer() {}
void SocketBuffer::pushBytes(unsigned char* bytes, size_t len) void SocketBuffer::pushBytes(unsigned char* bytes, size_t len,
ProgressUpdate* progressUpdate)
{ {
if(len > 0) { if(len > 0) {
bufq_.push_back(SharedHandle<BufEntry>(new ByteArrayBufEntry(bytes, len))); bufq_.push_back(SharedHandle<BufEntry>
(new ByteArrayBufEntry(bytes, len, progressUpdate)));
} }
} }
void SocketBuffer::pushStr(const std::string& data) void SocketBuffer::pushStr(const std::string& data,
ProgressUpdate* progressUpdate)
{ {
if(data.size() > 0) { if(data.size() > 0) {
bufq_.push_back(SharedHandle<BufEntry>(new StringBufEntry(data))); bufq_.push_back(SharedHandle<BufEntry>
(new StringBufEntry(data, progressUpdate)));
} }
} }
ssize_t SocketBuffer::send() ssize_t SocketBuffer::send()
{ {
a2iovec iov[A2_IOV_MAX];
size_t totalslen = 0; size_t totalslen = 0;
while(!bufq_.empty()) { while(!bufq_.empty()) {
const SharedHandle<BufEntry>& buf = bufq_[0]; size_t num;
ssize_t slen = buf->send(socket_, offset_); ssize_t amount = 16*1024;
ssize_t firstlen = bufq_[0]->getLength() - offset_;
amount -= firstlen;
iov[0].A2IOVEC_BASE =
reinterpret_cast<char*>(const_cast<unsigned char*>
(bufq_[0]->getData() + offset_));
iov[0].A2IOVEC_LEN = firstlen;
for(num = 1; num < A2_IOV_MAX && num < bufq_.size() && amount > 0; ++num) {
const SharedHandle<BufEntry>& buf = bufq_[num];
ssize_t len = buf->getLength();
if(amount >= len) {
amount -= len;
iov[num].A2IOVEC_BASE =
reinterpret_cast<char*>(const_cast<unsigned char*>(buf->getData()));
iov[num].A2IOVEC_LEN = len;
} else {
break;
}
}
ssize_t slen = socket_->writeVector(iov, num);
if(slen == 0 && !socket_->wantRead() && !socket_->wantWrite()) { if(slen == 0 && !socket_->wantRead() && !socket_->wantWrite()) {
throw DL_ABORT_EX(fmt(EX_SOCKET_SEND, "Connection closed.")); throw DL_ABORT_EX(fmt(EX_SOCKET_SEND, "Connection closed."));
} }
//A2_LOG_NOTICE(fmt("SEND=%d", slen));
totalslen += slen; totalslen += slen;
offset_ += slen;
if(buf->final(offset_)) { if(firstlen > slen) {
offset_ += slen;
bufq_[0]->progressUpdate(slen, false);
} else {
slen -= firstlen;
bufq_[0]->progressUpdate(firstlen, true);
bufq_.pop_front(); bufq_.pop_front();
offset_ = 0; offset_ = 0;
} else { for(size_t i = 1; i < num; ++i) {
break; const SharedHandle<BufEntry>& buf = bufq_[0];
ssize_t len = buf->getLength();
if(len > slen) {
offset_ = slen;
bufq_[0]->progressUpdate(slen, false);
goto fin;
break;
} else {
slen -= len;
bufq_[0]->progressUpdate(len, true);
bufq_.pop_front();
}
}
} }
} }
fin:
return totalslen; return totalslen;
} }

View File

@ -46,23 +46,46 @@ namespace aria2 {
class SocketCore; class SocketCore;
struct ProgressUpdate {
virtual ~ProgressUpdate() {}
virtual void update(size_t length, bool complete) = 0;
};
class SocketBuffer { class SocketBuffer {
private: private:
class BufEntry { class BufEntry {
public: public:
virtual ~BufEntry() {} BufEntry(ProgressUpdate* progressUpdate)
: progressUpdate_(progressUpdate) {}
virtual ~BufEntry()
{
delete progressUpdate_;
}
virtual ssize_t send virtual ssize_t send
(const SharedHandle<SocketCore>& socket, size_t offset) = 0; (const SharedHandle<SocketCore>& socket, size_t offset) = 0;
virtual bool final(size_t offset) const = 0; virtual bool final(size_t offset) const = 0;
virtual size_t getLength() const = 0;
virtual const unsigned char* getData() const = 0;
void progressUpdate(size_t length, bool complete)
{
if(progressUpdate_) {
progressUpdate_->update(length, complete);
}
}
private:
ProgressUpdate* progressUpdate_;
}; };
class ByteArrayBufEntry:public BufEntry { class ByteArrayBufEntry:public BufEntry {
public: public:
ByteArrayBufEntry(unsigned char* bytes, size_t length); ByteArrayBufEntry(unsigned char* bytes, size_t length,
ProgressUpdate* progressUpdate);
virtual ~ByteArrayBufEntry(); virtual ~ByteArrayBufEntry();
virtual ssize_t send virtual ssize_t send
(const SharedHandle<SocketCore>& socket, size_t offset); (const SharedHandle<SocketCore>& socket, size_t offset);
virtual bool final(size_t offset) const; virtual bool final(size_t offset) const;
virtual size_t getLength() const;
virtual const unsigned char* getData() const;
private: private:
unsigned char* bytes_; unsigned char* bytes_;
size_t length_; size_t length_;
@ -70,11 +93,14 @@ private:
class StringBufEntry:public BufEntry { class StringBufEntry:public BufEntry {
public: public:
StringBufEntry(const std::string& s); StringBufEntry(const std::string& s,
ProgressUpdate* progressUpdate);
StringBufEntry(); StringBufEntry();
virtual ssize_t send virtual ssize_t send
(const SharedHandle<SocketCore>& socket, size_t offset); (const SharedHandle<SocketCore>& socket, size_t offset);
virtual bool final(size_t offset) const; virtual bool final(size_t offset) const;
virtual size_t getLength() const;
virtual const unsigned char* getData() const;
void swap(std::string& s); void swap(std::string& s);
private: private:
std::string str_; std::string str_;
@ -99,11 +125,18 @@ public:
// Feeds data pointered by bytes with length len into queue. This // Feeds data pointered by bytes with length len into queue. This
// object gets ownership of bytes, so caller must not delete or // object gets ownership of bytes, so caller must not delete or
// later bytes after this call. This function doesn't send data. // later bytes after this call. This function doesn't send data. If
void pushBytes(unsigned char* bytes, size_t len); // progressUpdate is not null, its update() function will be called
// each time the data is sent. It will be deleted by this object. It
// can be null.
void pushBytes(unsigned char* bytes, size_t len,
ProgressUpdate* progressUpdate = 0);
// Feeds data into queue. This function doesn't send data. // Feeds data into queue. This function doesn't send data. If
void pushStr(const std::string& data); // progressUpdate is not null, its update() function will be called
// each time the data is sent. It will be deleted by this object. It
// can be null.
void pushStr(const std::string& data, ProgressUpdate* progressUpdate = 0);
// Sends data in queue. Returns the number of bytes sent. // Sends data in queue. Returns the number of bytes sent.
ssize_t send(); ssize_t send();

View File

@ -738,14 +738,57 @@ void SocketCore::gnutlsRecordCheckDirection()
} }
#endif // HAVE_LIBGNUTLS #endif // HAVE_LIBGNUTLS
ssize_t SocketCore::writeData(const char* data, size_t len) ssize_t SocketCore::writeVector(a2iovec *iov, size_t iovcnt)
{
ssize_t ret = 0;
wantRead_ = false;
wantWrite_ = false;
if(!secure_) {
#ifdef __MINGW32__
DWORD nsent;
int rv = WSASend(sockfd_, iov, iovcnt, &nsent, 0, 0, 0);
if(rv == 0) {
ret = nsent;
} else {
ret = -1;
}
#else // !__MINGW32__
while((ret = writev(sockfd_, iov, iovcnt)) == -1 &&
SOCKET_ERRNO == A2_EINTR);
#endif // !__MINGW32__
int errNum = SOCKET_ERRNO;
if(ret == -1) {
if(A2_WOULDBLOCK(errNum)) {
wantWrite_ = true;
ret = 0;
} else {
throw DL_RETRY_EX(fmt(EX_SOCKET_SEND, errorMsg(errNum).c_str()));
}
}
} else {
// For SSL/TLS, we could not use writev, so just iterate vector
// and write the data in normal way.
for(size_t i = 0; i < iovcnt; ++i) {
ssize_t rv = writeData(iov[i].A2IOVEC_BASE, iov[i].A2IOVEC_LEN);
if(rv == 0) {
break;
}
ret += rv;
}
}
return ret;
}
ssize_t SocketCore::writeData(const void* data, size_t len)
{ {
ssize_t ret = 0; ssize_t ret = 0;
wantRead_ = false; wantRead_ = false;
wantWrite_ = false; wantWrite_ = false;
if(!secure_) { if(!secure_) {
while((ret = send(sockfd_, data, len, 0)) == -1 && SOCKET_ERRNO == A2_EINTR); // Cast for Windows send()
while((ret = send(sockfd_, reinterpret_cast<const char*>(data),
len, 0)) == -1 && SOCKET_ERRNO == A2_EINTR);
int errNum = SOCKET_ERRNO; int errNum = SOCKET_ERRNO;
if(ret == -1) { if(ret == -1) {
if(A2_WOULDBLOCK(errNum)) { if(A2_WOULDBLOCK(errNum)) {
@ -1166,7 +1209,7 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
#endif // ENABLE_SSL #endif // ENABLE_SSL
ssize_t SocketCore::writeData(const char* data, size_t len, ssize_t SocketCore::writeData(const void* data, size_t len,
const std::string& host, uint16_t port) const std::string& host, uint16_t port)
{ {
wantRead_ = false; wantRead_ = false;
@ -1184,7 +1227,9 @@ ssize_t SocketCore::writeData(const char* data, size_t len,
ssize_t r = -1; ssize_t r = -1;
int errNum = 0; int errNum = 0;
for(rp = res; rp; rp = rp->ai_next) { for(rp = res; rp; rp = rp->ai_next) {
while((r = sendto(sockfd_, data, len, 0, rp->ai_addr, rp->ai_addrlen)) == -1 // Cast for Windows sendto()
while((r = sendto(sockfd_, reinterpret_cast<const char*>(data), len, 0,
rp->ai_addr, rp->ai_addrlen)) == -1
&& A2_EINTR == SOCKET_ERRNO); && A2_EINTR == SOCKET_ERRNO);
errNum = SOCKET_ERRNO; errNum = SOCKET_ERRNO;
if(r == static_cast<ssize_t>(len)) { if(r == static_cast<ssize_t>(len)) {

View File

@ -252,25 +252,16 @@ public:
* @param data data to write * @param data data to write
* @param len length of data * @param len length of data
*/ */
ssize_t writeData(const char* data, size_t len); ssize_t writeData(const void* data, size_t len);
ssize_t writeData(const std::string& msg) ssize_t writeData(const std::string& msg)
{ {
return writeData(msg.c_str(), msg.size()); return writeData(msg.c_str(), msg.size());
} }
ssize_t writeData(const unsigned char* data, size_t len)
{
return writeData(reinterpret_cast<const char*>(data), len);
}
ssize_t writeData(const char* data, size_t len, ssize_t writeData(const void* data, size_t len,
const std::string& host, uint16_t port); const std::string& host, uint16_t port);
ssize_t writeData(const unsigned char* data, size_t len, ssize_t writeVector(a2iovec *iov, size_t iovcnt);
const std::string& host,
uint16_t port)
{
return writeData(reinterpret_cast<const char*>(data), len, host, port);
}
/** /**
* Reads up to len bytes from this socket. * Reads up to len bytes from this socket.

View File

@ -87,6 +87,10 @@
# include <netinet/in.h> # include <netinet/in.h>
#endif // HAVE_NETINET_IN_H #endif // HAVE_NETINET_IN_H
#ifdef HAVE_SYS_UIO_H
# include <sys/uio.h>
#endif // HAVE_SYS_UIO_H
#ifndef HAVE_GETADDRINFO #ifndef HAVE_GETADDRINFO
# include "getaddrinfo.h" # include "getaddrinfo.h"
# define HAVE_GAI_STRERROR # define HAVE_GAI_STRERROR
@ -141,4 +145,22 @@ union sockaddr_union {
sockaddr_in in; sockaddr_in in;
}; };
#define A2_DEFAULT_IOV_MAX 128
#if defined(IOV_MAX) && IOV_MAX < A2_DEFAULT_IOV_MAX
# define A2_IOV_MAX IOV_MAX
#else
# define A2_IOV_MAX A2_DEFAULT_IOV_MAX
#endif
#ifdef __MINGW32__
typedef WSABUF a2iovec;
# define A2IOVEC_BASE buf
# define A2IOVEC_LEN len
#else // !__MINGW32__
typedef struct iovec a2iovec;
# define A2IOVEC_BASE iov_base
# define A2IOVEC_LEN iov_len
#endif // !__MINGW32__
#endif // D_A2NETCOMPAT_H #endif // D_A2NETCOMPAT_H

View File

@ -6,6 +6,7 @@
#include "bittorrent_helper.h" #include "bittorrent_helper.h"
#include "util.h" #include "util.h"
#include "Peer.h" #include "Peer.h"
#include "SocketBuffer.h"
namespace aria2 { namespace aria2 {
@ -98,7 +99,8 @@ void BtAllowedFastMessageTest::testOnSendComplete() {
peer->setFastExtensionEnabled(true); peer->setFastExtensionEnabled(true);
msg.setPeer(peer); msg.setPeer(peer);
CPPUNIT_ASSERT(!peer->isInAmAllowedIndexSet(1)); CPPUNIT_ASSERT(!peer->isInAmAllowedIndexSet(1));
msg.onSendComplete(); SharedHandle<ProgressUpdate> pu(msg.getProgressUpdate());
pu->update(0, true);
CPPUNIT_ASSERT(peer->isInAmAllowedIndexSet(1)); CPPUNIT_ASSERT(peer->isInAmAllowedIndexSet(1));
} }

View File

@ -9,6 +9,7 @@
#include "MockBtRequestFactory.h" #include "MockBtRequestFactory.h"
#include "Peer.h" #include "Peer.h"
#include "FileEntry.h" #include "FileEntry.h"
#include "SocketBuffer.h"
namespace aria2 { namespace aria2 {
@ -123,7 +124,8 @@ void BtChokeMessageTest::testOnSendComplete() {
SharedHandle<MockBtMessageDispatcher2> dispatcher(new MockBtMessageDispatcher2()); SharedHandle<MockBtMessageDispatcher2> dispatcher(new MockBtMessageDispatcher2());
msg.setBtMessageDispatcher(dispatcher.get()); msg.setBtMessageDispatcher(dispatcher.get());
msg.onSendComplete(); SharedHandle<ProgressUpdate> pu(msg.getProgressUpdate());
pu->update(0, true);
CPPUNIT_ASSERT(dispatcher->doChokingActionCalled); CPPUNIT_ASSERT(dispatcher->doChokingActionCalled);
CPPUNIT_ASSERT(peer->amChoking()); CPPUNIT_ASSERT(peer->amChoking());

View File

@ -7,6 +7,7 @@
#include "bittorrent_helper.h" #include "bittorrent_helper.h"
#include "Peer.h" #include "Peer.h"
#include "MockPeerStorage.h" #include "MockPeerStorage.h"
#include "SocketBuffer.h"
namespace aria2 { namespace aria2 {
@ -90,7 +91,8 @@ void BtInterestedMessageTest::testOnSendComplete() {
peer->allocateSessionResource(1024, 1024*1024); peer->allocateSessionResource(1024, 1024*1024);
msg.setPeer(peer); msg.setPeer(peer);
CPPUNIT_ASSERT(!peer->amInterested()); CPPUNIT_ASSERT(!peer->amInterested());
msg.onSendComplete(); SharedHandle<ProgressUpdate> pu(msg.getProgressUpdate());
pu->update(0, true);
CPPUNIT_ASSERT(peer->amInterested()); CPPUNIT_ASSERT(peer->amInterested());
} }

View File

@ -7,6 +7,7 @@
#include "bittorrent_helper.h" #include "bittorrent_helper.h"
#include "Peer.h" #include "Peer.h"
#include "MockPeerStorage.h" #include "MockPeerStorage.h"
#include "SocketBuffer.h"
namespace aria2 { namespace aria2 {
@ -92,7 +93,8 @@ void BtNotInterestedMessageTest::testOnSendComplete() {
BtNotInterestedMessage msg; BtNotInterestedMessage msg;
msg.setPeer(peer); msg.setPeer(peer);
CPPUNIT_ASSERT(peer->amInterested()); CPPUNIT_ASSERT(peer->amInterested());
msg.onSendComplete(); SharedHandle<ProgressUpdate> pu(msg.getProgressUpdate());
pu->update(0, true);
CPPUNIT_ASSERT(!peer->amInterested()); CPPUNIT_ASSERT(!peer->amInterested());
} }

View File

@ -1,9 +1,12 @@
#include "BtUnchokeMessage.h" #include "BtUnchokeMessage.h"
#include "bittorrent_helper.h"
#include "Peer.h"
#include <cstring> #include <cstring>
#include <cppunit/extensions/HelperMacros.h> #include <cppunit/extensions/HelperMacros.h>
#include "bittorrent_helper.h"
#include "Peer.h"
#include "SocketBuffer.h"
namespace aria2 { namespace aria2 {
class BtUnchokeMessageTest:public CppUnit::TestFixture { class BtUnchokeMessageTest:public CppUnit::TestFixture {
@ -83,7 +86,8 @@ void BtUnchokeMessageTest::testOnSendComplete() {
msg.setPeer(peer); msg.setPeer(peer);
CPPUNIT_ASSERT(peer->amChoking()); CPPUNIT_ASSERT(peer->amChoking());
msg.onSendComplete(); SharedHandle<ProgressUpdate> pu(msg.getProgressUpdate());
pu->update(0, true);
CPPUNIT_ASSERT(!peer->amChoking()); CPPUNIT_ASSERT(!peer->amChoking());
} }

View File

@ -19,6 +19,7 @@
#include "RequestGroup.h" #include "RequestGroup.h"
#include "DownloadContext.h" #include "DownloadContext.h"
#include "bittorrent_helper.h" #include "bittorrent_helper.h"
#include "PeerConnection.h"
namespace aria2 { namespace aria2 {
@ -30,7 +31,6 @@ class DefaultBtMessageDispatcherTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testSendMessages_underUploadLimit); CPPUNIT_TEST(testSendMessages_underUploadLimit);
// See the comment on the definition // See the comment on the definition
//CPPUNIT_TEST(testSendMessages_overUploadLimit); //CPPUNIT_TEST(testSendMessages_overUploadLimit);
CPPUNIT_TEST(testSendMessages_sendingInProgress);
CPPUNIT_TEST(testDoCancelSendingPieceAction); CPPUNIT_TEST(testDoCancelSendingPieceAction);
CPPUNIT_TEST(testCheckRequestSlotAndDoNecessaryThing); CPPUNIT_TEST(testCheckRequestSlotAndDoNecessaryThing);
CPPUNIT_TEST(testCheckRequestSlotAndDoNecessaryThing_timeout); CPPUNIT_TEST(testCheckRequestSlotAndDoNecessaryThing_timeout);
@ -58,7 +58,6 @@ public:
void testSendMessages(); void testSendMessages();
void testSendMessages_underUploadLimit(); void testSendMessages_underUploadLimit();
void testSendMessages_overUploadLimit(); void testSendMessages_overUploadLimit();
void testSendMessages_sendingInProgress();
void testDoCancelSendingPieceAction(); void testDoCancelSendingPieceAction();
void testCheckRequestSlotAndDoNecessaryThing(); void testCheckRequestSlotAndDoNecessaryThing();
void testCheckRequestSlotAndDoNecessaryThing_timeout(); void testCheckRequestSlotAndDoNecessaryThing_timeout();
@ -188,7 +187,7 @@ void DefaultBtMessageDispatcherTest::testSendMessages() {
msg2->setUploading(false); msg2->setUploading(false);
btMessageDispatcher->addMessageToQueue(msg1); btMessageDispatcher->addMessageToQueue(msg1);
btMessageDispatcher->addMessageToQueue(msg2); btMessageDispatcher->addMessageToQueue(msg2);
btMessageDispatcher->sendMessages(); btMessageDispatcher->sendMessagesInternal();
CPPUNIT_ASSERT(msg1->isSendCalled()); CPPUNIT_ASSERT(msg1->isSendCalled());
CPPUNIT_ASSERT(msg2->isSendCalled()); CPPUNIT_ASSERT(msg2->isSendCalled());
@ -203,7 +202,7 @@ void DefaultBtMessageDispatcherTest::testSendMessages_underUploadLimit() {
msg2->setUploading(true); msg2->setUploading(true);
btMessageDispatcher->addMessageToQueue(msg1); btMessageDispatcher->addMessageToQueue(msg1);
btMessageDispatcher->addMessageToQueue(msg2); btMessageDispatcher->addMessageToQueue(msg2);
btMessageDispatcher->sendMessages(); btMessageDispatcher->sendMessagesInternal();
CPPUNIT_ASSERT(msg1->isSendCalled()); CPPUNIT_ASSERT(msg1->isSendCalled());
CPPUNIT_ASSERT(msg2->isSendCalled()); CPPUNIT_ASSERT(msg2->isSendCalled());
@ -232,7 +231,7 @@ void DefaultBtMessageDispatcherTest::testSendMessages_underUploadLimit() {
// btMessageDispatcher->addMessageToQueue(msg1); // btMessageDispatcher->addMessageToQueue(msg1);
// btMessageDispatcher->addMessageToQueue(msg2); // btMessageDispatcher->addMessageToQueue(msg2);
// btMessageDispatcher->addMessageToQueue(msg3); // btMessageDispatcher->addMessageToQueue(msg3);
// btMessageDispatcher->sendMessages(); // btMessageDispatcher->sendMessagesInternal();
// CPPUNIT_ASSERT(!msg1->isSendCalled()); // CPPUNIT_ASSERT(!msg1->isSendCalled());
// CPPUNIT_ASSERT(!msg2->isSendCalled()); // CPPUNIT_ASSERT(!msg2->isSendCalled());
@ -242,31 +241,6 @@ void DefaultBtMessageDispatcherTest::testSendMessages_underUploadLimit() {
// btMessageDispatcher->getMessageQueue().size()); // btMessageDispatcher->getMessageQueue().size());
// } // }
void DefaultBtMessageDispatcherTest::testSendMessages_sendingInProgress() {
SharedHandle<MockBtMessage2> msg1(new MockBtMessage2());
msg1->setSendingInProgress(false);
msg1->setUploading(false);
SharedHandle<MockBtMessage2> msg2(new MockBtMessage2());
msg2->setSendingInProgress(true);
msg2->setUploading(false);
SharedHandle<MockBtMessage2> msg3(new MockBtMessage2());
msg3->setSendingInProgress(false);
msg3->setUploading(false);
btMessageDispatcher->addMessageToQueue(msg1);
btMessageDispatcher->addMessageToQueue(msg2);
btMessageDispatcher->addMessageToQueue(msg3);
btMessageDispatcher->sendMessages();
CPPUNIT_ASSERT(msg1->isSendCalled());
CPPUNIT_ASSERT(msg2->isSendCalled());
CPPUNIT_ASSERT(!msg3->isSendCalled());
CPPUNIT_ASSERT_EQUAL((size_t)2,
btMessageDispatcher->getMessageQueue().size());
}
void DefaultBtMessageDispatcherTest::testDoCancelSendingPieceAction() { void DefaultBtMessageDispatcherTest::testDoCancelSendingPieceAction() {
SharedHandle<MockBtMessage2> msg1(new MockBtMessage2()); SharedHandle<MockBtMessage2> msg1(new MockBtMessage2());
SharedHandle<MockBtMessage2> msg2(new MockBtMessage2()); SharedHandle<MockBtMessage2> msg2(new MockBtMessage2());