diff --git a/configure.ac b/configure.ac index 4001c8b1..16d1324a 100644 --- a/configure.ac +++ b/configure.ac @@ -375,6 +375,7 @@ AC_CHECK_HEADERS([argz.h \ sys/socket.h \ sys/time.h \ sys/types.h \ + sys/uio.h \ termios.h \ unistd.h \ utime.h \ diff --git a/src/BtAllowedFastMessage.cc b/src/BtAllowedFastMessage.cc index 14b53ce0..e0fa3b9b 100644 --- a/src/BtAllowedFastMessage.cc +++ b/src/BtAllowedFastMessage.cc @@ -36,6 +36,7 @@ #include "DlAbortEx.h" #include "Peer.h" #include "fmt.h" +#include "SocketBuffer.h" namespace aria2 { @@ -62,8 +63,24 @@ void BtAllowedFastMessage::doReceivedAction() { getPeer()->addPeerAllowedIndex(getIndex()); } -void BtAllowedFastMessage::onSendComplete() { - getPeer()->addAmAllowedIndex(getIndex()); +namespace { +struct ThisProgressUpdate : public ProgressUpdate { + ThisProgressUpdate(const SharedHandle& peer, size_t index) + : peer(peer), index(index) {} + virtual void update(size_t length, bool complete) + { + if(complete) { + peer->addAmAllowedIndex(index); + } + } + SharedHandle peer; + size_t index; +}; +} // namespace + +ProgressUpdate* BtAllowedFastMessage::getProgressUpdate() +{ + return new ThisProgressUpdate(getPeer(), getIndex()); } } // namespace aria2 diff --git a/src/BtAllowedFastMessage.h b/src/BtAllowedFastMessage.h index c32aa0d7..079ca84e 100644 --- a/src/BtAllowedFastMessage.h +++ b/src/BtAllowedFastMessage.h @@ -52,8 +52,7 @@ public: virtual void doReceivedAction(); - virtual void onSendComplete(); - + virtual ProgressUpdate* getProgressUpdate(); }; } // namespace aria2 diff --git a/src/BtChokeMessage.cc b/src/BtChokeMessage.cc index ed230f5d..ac82bd0a 100644 --- a/src/BtChokeMessage.cc +++ b/src/BtChokeMessage.cc @@ -36,6 +36,7 @@ #include "Peer.h" #include "BtMessageDispatcher.h" #include "BtRequestFactory.h" +#include "SocketBuffer.h" namespace aria2 { @@ -64,10 +65,26 @@ bool BtChokeMessage::sendPredicate() const return !getPeer()->amChoking(); } -void BtChokeMessage::onSendComplete() +namespace { +struct ThisProgressUpdate : public ProgressUpdate { + ThisProgressUpdate(const SharedHandle& peer, + BtMessageDispatcher* disp) + : peer(peer), disp(disp) {} + virtual void update(size_t length, bool complete) + { + if(complete) { + peer->amChoking(true); + disp->doChokingAction(); + } + } + SharedHandle peer; + BtMessageDispatcher* disp; +}; +} // namespace + +ProgressUpdate* BtChokeMessage::getProgressUpdate() { - getPeer()->amChoking(true); - getBtMessageDispatcher()->doChokingAction(); + return new ThisProgressUpdate(getPeer(), getBtMessageDispatcher()); } } // namespace aria2 diff --git a/src/BtChokeMessage.h b/src/BtChokeMessage.h index 73c357d7..ee02e415 100644 --- a/src/BtChokeMessage.h +++ b/src/BtChokeMessage.h @@ -53,7 +53,7 @@ public: virtual bool sendPredicate() const; - virtual void onSendComplete(); + virtual ProgressUpdate* getProgressUpdate(); }; } // namespace aria2 diff --git a/src/BtInterestedMessage.cc b/src/BtInterestedMessage.cc index e2391bc1..8fafb399 100644 --- a/src/BtInterestedMessage.cc +++ b/src/BtInterestedMessage.cc @@ -35,6 +35,7 @@ #include "BtInterestedMessage.h" #include "Peer.h" #include "PeerStorage.h" +#include "SocketBuffer.h" namespace aria2 { @@ -66,8 +67,23 @@ bool BtInterestedMessage::sendPredicate() const return !getPeer()->amInterested(); } -void BtInterestedMessage::onSendComplete() { - getPeer()->amInterested(true); +namespace { +struct ThisProgressUpdate : public ProgressUpdate { + ThisProgressUpdate(const SharedHandle& peer) + : peer(peer) {} + virtual void update(size_t length, bool complete) + { + if(complete) { + peer->amInterested(true); + } + } + SharedHandle peer; +}; +} // namespace + +ProgressUpdate* BtInterestedMessage::getProgressUpdate() +{ + return new ThisProgressUpdate(getPeer()); } void BtInterestedMessage::setPeerStorage diff --git a/src/BtInterestedMessage.h b/src/BtInterestedMessage.h index 1630db93..bc10b208 100644 --- a/src/BtInterestedMessage.h +++ b/src/BtInterestedMessage.h @@ -60,7 +60,7 @@ public: virtual bool sendPredicate() const; - virtual void onSendComplete(); + virtual ProgressUpdate* getProgressUpdate(); void setPeerStorage(const SharedHandle& peerStorage); }; diff --git a/src/BtNotInterestedMessage.cc b/src/BtNotInterestedMessage.cc index 33f3ec3c..60b810a4 100644 --- a/src/BtNotInterestedMessage.cc +++ b/src/BtNotInterestedMessage.cc @@ -35,6 +35,7 @@ #include "BtNotInterestedMessage.h" #include "Peer.h" #include "PeerStorage.h" +#include "SocketBuffer.h" namespace aria2 { @@ -66,8 +67,23 @@ bool BtNotInterestedMessage::sendPredicate() const return getPeer()->amInterested(); } -void BtNotInterestedMessage::onSendComplete() { - getPeer()->amInterested(false); +namespace { +struct ThisProgressUpdate : public ProgressUpdate { + ThisProgressUpdate(const SharedHandle& peer) + : peer(peer) {} + virtual void update(size_t length, bool complete) + { + if(complete) { + peer->amInterested(false); + } + } + SharedHandle peer; +}; +} // namespace + +ProgressUpdate* BtNotInterestedMessage::getProgressUpdate() +{ + return new ThisProgressUpdate(getPeer()); } void BtNotInterestedMessage::setPeerStorage diff --git a/src/BtNotInterestedMessage.h b/src/BtNotInterestedMessage.h index c4e47847..fd9a5844 100644 --- a/src/BtNotInterestedMessage.h +++ b/src/BtNotInterestedMessage.h @@ -60,7 +60,7 @@ public: virtual bool sendPredicate() const; - virtual void onSendComplete(); + virtual ProgressUpdate* getProgressUpdate(); void setPeerStorage(const SharedHandle& peerStorage); }; diff --git a/src/BtPieceMessage.cc b/src/BtPieceMessage.cc index 92f8193b..845493eb 100644 --- a/src/BtPieceMessage.cc +++ b/src/BtPieceMessage.cc @@ -71,7 +71,6 @@ BtPieceMessage::BtPieceMessage index_(index), begin_(begin), blockLength_(blockLength), - msgHdrLen_(0), data_(0) { setUploading(true); @@ -180,39 +179,33 @@ size_t BtPieceMessage::getMessageHeaderLength() return MESSAGE_HEADER_LENGTH; } +namespace { +struct PieceSendUpdate : public ProgressUpdate { + PieceSendUpdate(const SharedHandle& peer) + : peer(peer) {} + virtual void update(size_t length, bool complete) + { + peer->updateUploadLength(length); + } + SharedHandle peer; +}; +} // namespace + void BtPieceMessage::send() { if(isInvalidate()) { return; } - size_t writtenLength; - if(!isSendingInProgress()) { - A2_LOG_INFO(fmt(MSG_SEND_PEER_MESSAGE, - getCuid(), - getPeer()->getIPAddress().c_str(), - getPeer()->getPort(), - toString().c_str())); - unsigned char* msgHdr = createMessageHeader(); - msgHdrLen_ = getMessageHeaderLength(); - A2_LOG_DEBUG(fmt("msglength = %lu bytes", - static_cast(msgHdrLen_+blockLength_))); - getPeerConnection()->pushBytes(msgHdr, msgHdrLen_); - int64_t pieceDataOffset = - static_cast(index_)*downloadContext_->getPieceLength()+begin_; - pushPieceData(pieceDataOffset, blockLength_); - } - writtenLength = getPeerConnection()->sendPendingData(); - // Subtract msgHdrLen_ from writtenLength to get the uploaded data - // size. - if(writtenLength > msgHdrLen_) { - writtenLength -= msgHdrLen_; - msgHdrLen_ = 0; - getPeer()->updateUploadLength(writtenLength); - downloadContext_->updateUploadLength(writtenLength); - } else { - msgHdrLen_ -= writtenLength; - } - setSendingInProgress(!getPeerConnection()->sendBufferIsEmpty()); + A2_LOG_INFO(fmt(MSG_SEND_PEER_MESSAGE, + getCuid(), + getPeer()->getIPAddress().c_str(), + getPeer()->getPort(), + toString().c_str())); + getPeerConnection()->pushBytes(createMessageHeader(), + getMessageHeaderLength()); + int64_t pieceDataOffset = + static_cast(index_)*downloadContext_->getPieceLength()+begin_; + pushPieceData(pieceDataOffset, blockLength_); } void BtPieceMessage::pushPieceData(int64_t offset, int32_t length) const @@ -224,7 +217,11 @@ void BtPieceMessage::pushPieceData(int64_t offset, int32_t length) const if(r == length) { unsigned char* dbuf = buf; buf.reset(0); - getPeerConnection()->pushBytes(dbuf, length); + getPeerConnection()->pushBytes(dbuf, length, + new PieceSendUpdate(getPeer())); + // To avoid upload rate overflow, we update the length here at + // once. + downloadContext_->updateUploadLength(length); } else { throw DL_ABORT_EX(EX_DATA_READ); } diff --git a/src/BtPieceMessage.h b/src/BtPieceMessage.h index adcb20ec..26e01ea5 100644 --- a/src/BtPieceMessage.h +++ b/src/BtPieceMessage.h @@ -48,7 +48,6 @@ private: size_t index_; int32_t begin_; int32_t blockLength_; - size_t msgHdrLen_; const unsigned char* data_; SharedHandle downloadContext_; SharedHandle peerStorage_; diff --git a/src/BtUnchokeMessage.cc b/src/BtUnchokeMessage.cc index 74100689..9dfd4a8e 100644 --- a/src/BtUnchokeMessage.cc +++ b/src/BtUnchokeMessage.cc @@ -34,6 +34,7 @@ /* copyright --> */ #include "BtUnchokeMessage.h" #include "Peer.h" +#include "SocketBuffer.h" namespace aria2 { @@ -60,8 +61,23 @@ bool BtUnchokeMessage::sendPredicate() const return getPeer()->amChoking(); } -void BtUnchokeMessage::onSendComplete() { - getPeer()->amChoking(false); +namespace { +struct ThisProgressUpdate : public ProgressUpdate { + ThisProgressUpdate(const SharedHandle& peer) + : peer(peer) {} + virtual void update(size_t length, bool complete) + { + if(complete) { + peer->amChoking(false); + } + } + SharedHandle peer; +}; +} // namespace + +ProgressUpdate* BtUnchokeMessage::getProgressUpdate() +{ + return new ThisProgressUpdate(getPeer()); } } // namespace aria2 diff --git a/src/BtUnchokeMessage.h b/src/BtUnchokeMessage.h index b5474a84..710e0a23 100644 --- a/src/BtUnchokeMessage.h +++ b/src/BtUnchokeMessage.h @@ -56,7 +56,7 @@ public: virtual bool sendPredicate() const; - virtual void onSendComplete(); + virtual ProgressUpdate* getProgressUpdate(); }; } // namespace aria2 diff --git a/src/DefaultBtMessageDispatcher.cc b/src/DefaultBtMessageDispatcher.cc index 1d1dee50..56b1c6c7 100644 --- a/src/DefaultBtMessageDispatcher.cc +++ b/src/DefaultBtMessageDispatcher.cc @@ -56,6 +56,7 @@ #include "RequestGroup.h" #include "util.h" #include "fmt.h" +#include "PeerConnection.h" namespace aria2 { @@ -87,7 +88,8 @@ void DefaultBtMessageDispatcher::addMessageToQueue } } -void DefaultBtMessageDispatcher::sendMessages() { +void DefaultBtMessageDispatcher::sendMessagesInternal() +{ std::vector > tempQueue; while(!messageQueue_.empty()) { SharedHandle msg = messageQueue_.front(); @@ -100,10 +102,6 @@ void DefaultBtMessageDispatcher::sendMessages() { } } msg->send(); - if(msg->isSendingInProgress()) { - messageQueue_.push_front(msg); - break; - } } if(!tempQueue.empty()) { // Insert pending message to the front, so that message is likely sent in @@ -118,6 +116,16 @@ void DefaultBtMessageDispatcher::sendMessages() { } } +void DefaultBtMessageDispatcher::sendMessages() { + // First flush any pending data in the buffer. + peerConnection_->sendPendingData(); + if(!peerConnection_->sendBufferIsEmpty()) { + return; + } + sendMessagesInternal(); + peerConnection_->sendPendingData(); +} + // Cancel sending piece message to peer. void DefaultBtMessageDispatcher::doCancelSendingPieceAction (size_t index, int32_t begin, int32_t length) diff --git a/src/DefaultBtMessageDispatcher.h b/src/DefaultBtMessageDispatcher.h index 00069fdc..7a8c44e1 100644 --- a/src/DefaultBtMessageDispatcher.h +++ b/src/DefaultBtMessageDispatcher.h @@ -52,6 +52,7 @@ class BtMessageFactory; class Peer; class Piece; class RequestGroupMan; +class PeerConnection; class DefaultBtMessageDispatcher : public BtMessageDispatcher { private: @@ -61,6 +62,7 @@ private: SharedHandle downloadContext_; SharedHandle peerStorage_; SharedHandle pieceStorage_; + SharedHandle peerConnection_; BtMessageFactory* messageFactory_; SharedHandle peer_; RequestGroupMan* requestGroupMan_; @@ -77,6 +79,9 @@ public: virtual void sendMessages(); + // For unit tests without PeerConnection + void sendMessagesInternal(); + virtual void doCancelSendingPieceAction (size_t index, int32_t begin, int32_t length); @@ -143,6 +148,11 @@ public: { requestTimeout_ = requestTimeout; } + + void setPeerConnection(const SharedHandle& peerConnection) + { + peerConnection_ = peerConnection; + } }; } // namespace aria2 diff --git a/src/DefaultPeerStorage.cc b/src/DefaultPeerStorage.cc index 469f044e..ff596076 100644 --- a/src/DefaultPeerStorage.cc +++ b/src/DefaultPeerStorage.cc @@ -322,7 +322,7 @@ void DefaultPeerStorage::returnPeer(const SharedHandle& peer) bool DefaultPeerStorage::chokeRoundIntervalElapsed() { - const time_t CHOKE_ROUND_INTERVAL = 10; + const time_t CHOKE_ROUND_INTERVAL = 1;//10; if(pieceStorage_->downloadFinished()) { return seederStateChoke_->getLastRound(). difference(global::wallclock()) >= CHOKE_ROUND_INTERVAL; diff --git a/src/PeerConnection.cc b/src/PeerConnection.cc index 1039f9ec..c0c6c999 100644 --- a/src/PeerConnection.cc +++ b/src/PeerConnection.cc @@ -85,12 +85,13 @@ PeerConnection::~PeerConnection() delete [] resbuf_; } -void PeerConnection::pushBytes(unsigned char* data, size_t len) +void PeerConnection::pushBytes(unsigned char* data, size_t len, + ProgressUpdate* progressUpdate) { if(encryptionEnabled_) { encryptor_->encrypt(len, data, data); } - socketBuffer_.pushBytes(data, len); + socketBuffer_.pushBytes(data, len, progressUpdate); } bool PeerConnection::receiveMessage(unsigned char* data, size_t& dataLength) diff --git a/src/PeerConnection.h b/src/PeerConnection.h index beb7f3ff..c852db5b 100644 --- a/src/PeerConnection.h +++ b/src/PeerConnection.h @@ -96,7 +96,8 @@ public: // Pushes data into send buffer. After this call, this object gets // ownership of data, so caller must not delete or alter it. - void pushBytes(unsigned char* data, size_t len); + void pushBytes(unsigned char* data, size_t len, + ProgressUpdate* progressUpdate = 0); void pushStr(const std::string& data); diff --git a/src/PeerInteractionCommand.cc b/src/PeerInteractionCommand.cc index 2a84d38d..5fcc8faf 100644 --- a/src/PeerInteractionCommand.cc +++ b/src/PeerInteractionCommand.cc @@ -200,6 +200,7 @@ PeerInteractionCommand::PeerInteractionCommand dispatcherPtr->setBtMessageFactory(factoryPtr); dispatcherPtr->setRequestGroupMan (getDownloadEngine()->getRequestGroupMan().get()); + dispatcherPtr->setPeerConnection(peerConnection); SharedHandle dispatcher(dispatcherPtr); DefaultBtMessageReceiver* receiverPtr(new DefaultBtMessageReceiver()); diff --git a/src/SimpleBtMessage.cc b/src/SimpleBtMessage.cc index 565eab29..c7aa670a 100644 --- a/src/SimpleBtMessage.cc +++ b/src/SimpleBtMessage.cc @@ -48,29 +48,19 @@ SimpleBtMessage::SimpleBtMessage(uint8_t id, const char* name) {} void SimpleBtMessage::send() { - if(isInvalidate()) { + if(isInvalidate() || !sendPredicate()) { return; } - if(!sendPredicate() && !isSendingInProgress()) { - return; - } - if(!isSendingInProgress()) { - A2_LOG_INFO(fmt(MSG_SEND_PEER_MESSAGE, - getCuid(), - getPeer()->getIPAddress().c_str(), - getPeer()->getPort(), - toString().c_str())); - unsigned char* msg = createMessage(); - size_t msgLength = getMessageLength(); - A2_LOG_DEBUG(fmt("msglength = %lu bytes", - static_cast(msgLength))); - getPeerConnection()->pushBytes(msg, msgLength); - } - getPeerConnection()->sendPendingData(); - setSendingInProgress(!getPeerConnection()->sendBufferIsEmpty()); - if(!isSendingInProgress()) { - onSendComplete(); - } + A2_LOG_INFO(fmt(MSG_SEND_PEER_MESSAGE, + getCuid(), + getPeer()->getIPAddress().c_str(), + getPeer()->getPort(), + toString().c_str())); + unsigned char* msg = createMessage(); + size_t msgLength = getMessageLength(); + A2_LOG_DEBUG(fmt("msglength = %lu bytes", + static_cast(msgLength))); + getPeerConnection()->pushBytes(msg, msgLength, getProgressUpdate()); } } // namespace aria2 diff --git a/src/SimpleBtMessage.h b/src/SimpleBtMessage.h index c3dc6419..eefc9466 100644 --- a/src/SimpleBtMessage.h +++ b/src/SimpleBtMessage.h @@ -39,6 +39,8 @@ namespace aria2 { +class ProgressUpdate; + class SimpleBtMessage : public AbstractBtMessage { public: SimpleBtMessage(uint8_t id, const char* name); @@ -49,7 +51,7 @@ public: virtual size_t getMessageLength() = 0; - virtual void onSendComplete() {}; + virtual ProgressUpdate* getProgressUpdate() { return 0; }; virtual bool sendPredicate() const { return true; }; diff --git a/src/SocketBuffer.cc b/src/SocketBuffer.cc index 35cd2419..ae3972a4 100644 --- a/src/SocketBuffer.cc +++ b/src/SocketBuffer.cc @@ -41,12 +41,13 @@ #include "DlAbortEx.h" #include "message.h" #include "fmt.h" +#include "LogFactory.h" namespace aria2 { SocketBuffer::ByteArrayBufEntry::ByteArrayBufEntry -(unsigned char* bytes, size_t length) - : bytes_(bytes), length_(length) +(unsigned char* bytes, size_t length, ProgressUpdate* progressUpdate) + : BufEntry(progressUpdate), bytes_(bytes), length_(length) {} SocketBuffer::ByteArrayBufEntry::~ByteArrayBufEntry() @@ -65,11 +66,22 @@ bool SocketBuffer::ByteArrayBufEntry::final(size_t offset) const return length_ <= offset; } -SocketBuffer::StringBufEntry::StringBufEntry(const std::string& s) - : str_(s) +size_t SocketBuffer::ByteArrayBufEntry::getLength() const +{ + return length_; +} + +const unsigned char* SocketBuffer::ByteArrayBufEntry::getData() const +{ + return bytes_; +} + +SocketBuffer::StringBufEntry::StringBufEntry(const std::string& s, + ProgressUpdate* progressUpdate) + : BufEntry(progressUpdate), str_(s) {} -SocketBuffer::StringBufEntry::StringBufEntry() {} +// SocketBuffer::StringBufEntry::StringBufEntry() {} ssize_t SocketBuffer::StringBufEntry::send (const SharedHandle& socket, size_t offset) @@ -82,6 +94,16 @@ bool SocketBuffer::StringBufEntry::final(size_t offset) const return str_.size() <= offset; } +size_t SocketBuffer::StringBufEntry::getLength() const +{ + return str_.size(); +} + +const unsigned char* SocketBuffer::StringBufEntry::getData() const +{ + return reinterpret_cast(str_.c_str()); +} + void SocketBuffer::StringBufEntry::swap(std::string& s) { str_.swap(s); @@ -92,38 +114,82 @@ SocketBuffer::SocketBuffer(const SharedHandle& socket): SocketBuffer::~SocketBuffer() {} -void SocketBuffer::pushBytes(unsigned char* bytes, size_t len) +void SocketBuffer::pushBytes(unsigned char* bytes, size_t len, + ProgressUpdate* progressUpdate) { if(len > 0) { - bufq_.push_back(SharedHandle(new ByteArrayBufEntry(bytes, len))); + bufq_.push_back(SharedHandle + (new ByteArrayBufEntry(bytes, len, progressUpdate))); } } -void SocketBuffer::pushStr(const std::string& data) +void SocketBuffer::pushStr(const std::string& data, + ProgressUpdate* progressUpdate) { if(data.size() > 0) { - bufq_.push_back(SharedHandle(new StringBufEntry(data))); + bufq_.push_back(SharedHandle + (new StringBufEntry(data, progressUpdate))); } } ssize_t SocketBuffer::send() { + a2iovec iov[A2_IOV_MAX]; size_t totalslen = 0; while(!bufq_.empty()) { - const SharedHandle& buf = bufq_[0]; - ssize_t slen = buf->send(socket_, offset_); + size_t num; + ssize_t amount = 16*1024; + ssize_t firstlen = bufq_[0]->getLength() - offset_; + amount -= firstlen; + iov[0].A2IOVEC_BASE = + reinterpret_cast(const_cast + (bufq_[0]->getData() + offset_)); + iov[0].A2IOVEC_LEN = firstlen; + + for(num = 1; num < A2_IOV_MAX && num < bufq_.size() && amount > 0; ++num) { + const SharedHandle& buf = bufq_[num]; + ssize_t len = buf->getLength(); + if(amount >= len) { + amount -= len; + iov[num].A2IOVEC_BASE = + reinterpret_cast(const_cast(buf->getData())); + iov[num].A2IOVEC_LEN = len; + } else { + break; + } + } + ssize_t slen = socket_->writeVector(iov, num); if(slen == 0 && !socket_->wantRead() && !socket_->wantWrite()) { throw DL_ABORT_EX(fmt(EX_SOCKET_SEND, "Connection closed.")); } + //A2_LOG_NOTICE(fmt("SEND=%d", slen)); totalslen += slen; - offset_ += slen; - if(buf->final(offset_)) { + + if(firstlen > slen) { + offset_ += slen; + bufq_[0]->progressUpdate(slen, false); + } else { + slen -= firstlen; + bufq_[0]->progressUpdate(firstlen, true); bufq_.pop_front(); offset_ = 0; - } else { - break; + for(size_t i = 1; i < num; ++i) { + const SharedHandle& buf = bufq_[0]; + ssize_t len = buf->getLength(); + if(len > slen) { + offset_ = slen; + bufq_[0]->progressUpdate(slen, false); + goto fin; + break; + } else { + slen -= len; + bufq_[0]->progressUpdate(len, true); + bufq_.pop_front(); + } + } } } + fin: return totalslen; } diff --git a/src/SocketBuffer.h b/src/SocketBuffer.h index c433d8d0..1f0a32df 100644 --- a/src/SocketBuffer.h +++ b/src/SocketBuffer.h @@ -46,23 +46,46 @@ namespace aria2 { class SocketCore; +struct ProgressUpdate { + virtual ~ProgressUpdate() {} + virtual void update(size_t length, bool complete) = 0; +}; + class SocketBuffer { private: class BufEntry { public: - virtual ~BufEntry() {} + BufEntry(ProgressUpdate* progressUpdate) + : progressUpdate_(progressUpdate) {} + virtual ~BufEntry() + { + delete progressUpdate_; + } virtual ssize_t send (const SharedHandle& socket, size_t offset) = 0; virtual bool final(size_t offset) const = 0; + virtual size_t getLength() const = 0; + virtual const unsigned char* getData() const = 0; + void progressUpdate(size_t length, bool complete) + { + if(progressUpdate_) { + progressUpdate_->update(length, complete); + } + } + private: + ProgressUpdate* progressUpdate_; }; class ByteArrayBufEntry:public BufEntry { public: - ByteArrayBufEntry(unsigned char* bytes, size_t length); + ByteArrayBufEntry(unsigned char* bytes, size_t length, + ProgressUpdate* progressUpdate); virtual ~ByteArrayBufEntry(); virtual ssize_t send (const SharedHandle& socket, size_t offset); virtual bool final(size_t offset) const; + virtual size_t getLength() const; + virtual const unsigned char* getData() const; private: unsigned char* bytes_; size_t length_; @@ -70,11 +93,14 @@ private: class StringBufEntry:public BufEntry { public: - StringBufEntry(const std::string& s); + StringBufEntry(const std::string& s, + ProgressUpdate* progressUpdate); StringBufEntry(); virtual ssize_t send (const SharedHandle& socket, size_t offset); virtual bool final(size_t offset) const; + virtual size_t getLength() const; + virtual const unsigned char* getData() const; void swap(std::string& s); private: std::string str_; @@ -99,11 +125,18 @@ public: // Feeds data pointered by bytes with length len into queue. This // object gets ownership of bytes, so caller must not delete or - // later bytes after this call. This function doesn't send data. - void pushBytes(unsigned char* bytes, size_t len); + // later bytes after this call. This function doesn't send data. If + // progressUpdate is not null, its update() function will be called + // each time the data is sent. It will be deleted by this object. It + // can be null. + void pushBytes(unsigned char* bytes, size_t len, + ProgressUpdate* progressUpdate = 0); - // Feeds data into queue. This function doesn't send data. - void pushStr(const std::string& data); + // Feeds data into queue. This function doesn't send data. If + // progressUpdate is not null, its update() function will be called + // each time the data is sent. It will be deleted by this object. It + // can be null. + void pushStr(const std::string& data, ProgressUpdate* progressUpdate = 0); // Sends data in queue. Returns the number of bytes sent. ssize_t send(); diff --git a/src/SocketCore.cc b/src/SocketCore.cc index 1be9010b..d8596009 100644 --- a/src/SocketCore.cc +++ b/src/SocketCore.cc @@ -738,14 +738,57 @@ void SocketCore::gnutlsRecordCheckDirection() } #endif // HAVE_LIBGNUTLS -ssize_t SocketCore::writeData(const char* data, size_t len) +ssize_t SocketCore::writeVector(a2iovec *iov, size_t iovcnt) +{ + ssize_t ret = 0; + wantRead_ = false; + wantWrite_ = false; + if(!secure_) { +#ifdef __MINGW32__ + DWORD nsent; + int rv = WSASend(sockfd_, iov, iovcnt, &nsent, 0, 0, 0); + if(rv == 0) { + ret = nsent; + } else { + ret = -1; + } +#else // !__MINGW32__ + while((ret = writev(sockfd_, iov, iovcnt)) == -1 && + SOCKET_ERRNO == A2_EINTR); +#endif // !__MINGW32__ + int errNum = SOCKET_ERRNO; + if(ret == -1) { + if(A2_WOULDBLOCK(errNum)) { + wantWrite_ = true; + ret = 0; + } else { + throw DL_RETRY_EX(fmt(EX_SOCKET_SEND, errorMsg(errNum).c_str())); + } + } + } else { + // For SSL/TLS, we could not use writev, so just iterate vector + // and write the data in normal way. + for(size_t i = 0; i < iovcnt; ++i) { + ssize_t rv = writeData(iov[i].A2IOVEC_BASE, iov[i].A2IOVEC_LEN); + if(rv == 0) { + break; + } + ret += rv; + } + } + return ret; +} + +ssize_t SocketCore::writeData(const void* data, size_t len) { ssize_t ret = 0; wantRead_ = false; wantWrite_ = false; if(!secure_) { - while((ret = send(sockfd_, data, len, 0)) == -1 && SOCKET_ERRNO == A2_EINTR); + // Cast for Windows send() + while((ret = send(sockfd_, reinterpret_cast(data), + len, 0)) == -1 && SOCKET_ERRNO == A2_EINTR); int errNum = SOCKET_ERRNO; if(ret == -1) { if(A2_WOULDBLOCK(errNum)) { @@ -1166,7 +1209,7 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname) #endif // ENABLE_SSL -ssize_t SocketCore::writeData(const char* data, size_t len, +ssize_t SocketCore::writeData(const void* data, size_t len, const std::string& host, uint16_t port) { wantRead_ = false; @@ -1184,7 +1227,9 @@ ssize_t SocketCore::writeData(const char* data, size_t len, ssize_t r = -1; int errNum = 0; for(rp = res; rp; rp = rp->ai_next) { - while((r = sendto(sockfd_, data, len, 0, rp->ai_addr, rp->ai_addrlen)) == -1 + // Cast for Windows sendto() + while((r = sendto(sockfd_, reinterpret_cast(data), len, 0, + rp->ai_addr, rp->ai_addrlen)) == -1 && A2_EINTR == SOCKET_ERRNO); errNum = SOCKET_ERRNO; if(r == static_cast(len)) { diff --git a/src/SocketCore.h b/src/SocketCore.h index 2c943c3b..470921a6 100644 --- a/src/SocketCore.h +++ b/src/SocketCore.h @@ -252,25 +252,16 @@ public: * @param data data to write * @param len length of data */ - ssize_t writeData(const char* data, size_t len); + ssize_t writeData(const void* data, size_t len); ssize_t writeData(const std::string& msg) { return writeData(msg.c_str(), msg.size()); } - ssize_t writeData(const unsigned char* data, size_t len) - { - return writeData(reinterpret_cast(data), len); - } - ssize_t writeData(const char* data, size_t len, + ssize_t writeData(const void* data, size_t len, const std::string& host, uint16_t port); - ssize_t writeData(const unsigned char* data, size_t len, - const std::string& host, - uint16_t port) - { - return writeData(reinterpret_cast(data), len, host, port); - } + ssize_t writeVector(a2iovec *iov, size_t iovcnt); /** * Reads up to len bytes from this socket. diff --git a/src/a2netcompat.h b/src/a2netcompat.h index 41378307..46af621b 100644 --- a/src/a2netcompat.h +++ b/src/a2netcompat.h @@ -87,6 +87,10 @@ # include #endif // HAVE_NETINET_IN_H +#ifdef HAVE_SYS_UIO_H +# include +#endif // HAVE_SYS_UIO_H + #ifndef HAVE_GETADDRINFO # include "getaddrinfo.h" # define HAVE_GAI_STRERROR @@ -141,4 +145,22 @@ union sockaddr_union { sockaddr_in in; }; +#define A2_DEFAULT_IOV_MAX 128 + +#if defined(IOV_MAX) && IOV_MAX < A2_DEFAULT_IOV_MAX +# define A2_IOV_MAX IOV_MAX +#else +# define A2_IOV_MAX A2_DEFAULT_IOV_MAX +#endif + +#ifdef __MINGW32__ +typedef WSABUF a2iovec; +# define A2IOVEC_BASE buf +# define A2IOVEC_LEN len +#else // !__MINGW32__ +typedef struct iovec a2iovec; +# define A2IOVEC_BASE iov_base +# define A2IOVEC_LEN iov_len +#endif // !__MINGW32__ + #endif // D_A2NETCOMPAT_H diff --git a/test/BtAllowedFastMessageTest.cc b/test/BtAllowedFastMessageTest.cc index 5e996660..e8d1fa4d 100644 --- a/test/BtAllowedFastMessageTest.cc +++ b/test/BtAllowedFastMessageTest.cc @@ -6,6 +6,7 @@ #include "bittorrent_helper.h" #include "util.h" #include "Peer.h" +#include "SocketBuffer.h" namespace aria2 { @@ -98,7 +99,8 @@ void BtAllowedFastMessageTest::testOnSendComplete() { peer->setFastExtensionEnabled(true); msg.setPeer(peer); CPPUNIT_ASSERT(!peer->isInAmAllowedIndexSet(1)); - msg.onSendComplete(); + SharedHandle pu(msg.getProgressUpdate()); + pu->update(0, true); CPPUNIT_ASSERT(peer->isInAmAllowedIndexSet(1)); } diff --git a/test/BtChokeMessageTest.cc b/test/BtChokeMessageTest.cc index 83465df5..31e8bfe8 100644 --- a/test/BtChokeMessageTest.cc +++ b/test/BtChokeMessageTest.cc @@ -9,6 +9,7 @@ #include "MockBtRequestFactory.h" #include "Peer.h" #include "FileEntry.h" +#include "SocketBuffer.h" namespace aria2 { @@ -123,7 +124,8 @@ void BtChokeMessageTest::testOnSendComplete() { SharedHandle dispatcher(new MockBtMessageDispatcher2()); msg.setBtMessageDispatcher(dispatcher.get()); - msg.onSendComplete(); + SharedHandle pu(msg.getProgressUpdate()); + pu->update(0, true); CPPUNIT_ASSERT(dispatcher->doChokingActionCalled); CPPUNIT_ASSERT(peer->amChoking()); diff --git a/test/BtInterestedMessageTest.cc b/test/BtInterestedMessageTest.cc index e0accffe..ccb6b8b2 100644 --- a/test/BtInterestedMessageTest.cc +++ b/test/BtInterestedMessageTest.cc @@ -7,6 +7,7 @@ #include "bittorrent_helper.h" #include "Peer.h" #include "MockPeerStorage.h" +#include "SocketBuffer.h" namespace aria2 { @@ -90,7 +91,8 @@ void BtInterestedMessageTest::testOnSendComplete() { peer->allocateSessionResource(1024, 1024*1024); msg.setPeer(peer); CPPUNIT_ASSERT(!peer->amInterested()); - msg.onSendComplete(); + SharedHandle pu(msg.getProgressUpdate()); + pu->update(0, true); CPPUNIT_ASSERT(peer->amInterested()); } diff --git a/test/BtNotInterestedMessageTest.cc b/test/BtNotInterestedMessageTest.cc index d9a8e5f1..1bd204d0 100644 --- a/test/BtNotInterestedMessageTest.cc +++ b/test/BtNotInterestedMessageTest.cc @@ -7,6 +7,7 @@ #include "bittorrent_helper.h" #include "Peer.h" #include "MockPeerStorage.h" +#include "SocketBuffer.h" namespace aria2 { @@ -92,7 +93,8 @@ void BtNotInterestedMessageTest::testOnSendComplete() { BtNotInterestedMessage msg; msg.setPeer(peer); CPPUNIT_ASSERT(peer->amInterested()); - msg.onSendComplete(); + SharedHandle pu(msg.getProgressUpdate()); + pu->update(0, true); CPPUNIT_ASSERT(!peer->amInterested()); } diff --git a/test/BtUnchokeMessageTest.cc b/test/BtUnchokeMessageTest.cc index 639a8406..17b7fe50 100644 --- a/test/BtUnchokeMessageTest.cc +++ b/test/BtUnchokeMessageTest.cc @@ -1,9 +1,12 @@ #include "BtUnchokeMessage.h" -#include "bittorrent_helper.h" -#include "Peer.h" + #include #include +#include "bittorrent_helper.h" +#include "Peer.h" +#include "SocketBuffer.h" + namespace aria2 { class BtUnchokeMessageTest:public CppUnit::TestFixture { @@ -83,7 +86,8 @@ void BtUnchokeMessageTest::testOnSendComplete() { msg.setPeer(peer); CPPUNIT_ASSERT(peer->amChoking()); - msg.onSendComplete(); + SharedHandle pu(msg.getProgressUpdate()); + pu->update(0, true); CPPUNIT_ASSERT(!peer->amChoking()); } diff --git a/test/DefaultBtMessageDispatcherTest.cc b/test/DefaultBtMessageDispatcherTest.cc index 9ca5a434..71cb3ed8 100644 --- a/test/DefaultBtMessageDispatcherTest.cc +++ b/test/DefaultBtMessageDispatcherTest.cc @@ -19,6 +19,7 @@ #include "RequestGroup.h" #include "DownloadContext.h" #include "bittorrent_helper.h" +#include "PeerConnection.h" namespace aria2 { @@ -30,7 +31,6 @@ class DefaultBtMessageDispatcherTest:public CppUnit::TestFixture { CPPUNIT_TEST(testSendMessages_underUploadLimit); // See the comment on the definition //CPPUNIT_TEST(testSendMessages_overUploadLimit); - CPPUNIT_TEST(testSendMessages_sendingInProgress); CPPUNIT_TEST(testDoCancelSendingPieceAction); CPPUNIT_TEST(testCheckRequestSlotAndDoNecessaryThing); CPPUNIT_TEST(testCheckRequestSlotAndDoNecessaryThing_timeout); @@ -58,7 +58,6 @@ public: void testSendMessages(); void testSendMessages_underUploadLimit(); void testSendMessages_overUploadLimit(); - void testSendMessages_sendingInProgress(); void testDoCancelSendingPieceAction(); void testCheckRequestSlotAndDoNecessaryThing(); void testCheckRequestSlotAndDoNecessaryThing_timeout(); @@ -188,7 +187,7 @@ void DefaultBtMessageDispatcherTest::testSendMessages() { msg2->setUploading(false); btMessageDispatcher->addMessageToQueue(msg1); btMessageDispatcher->addMessageToQueue(msg2); - btMessageDispatcher->sendMessages(); + btMessageDispatcher->sendMessagesInternal(); CPPUNIT_ASSERT(msg1->isSendCalled()); CPPUNIT_ASSERT(msg2->isSendCalled()); @@ -203,7 +202,7 @@ void DefaultBtMessageDispatcherTest::testSendMessages_underUploadLimit() { msg2->setUploading(true); btMessageDispatcher->addMessageToQueue(msg1); btMessageDispatcher->addMessageToQueue(msg2); - btMessageDispatcher->sendMessages(); + btMessageDispatcher->sendMessagesInternal(); CPPUNIT_ASSERT(msg1->isSendCalled()); CPPUNIT_ASSERT(msg2->isSendCalled()); @@ -232,7 +231,7 @@ void DefaultBtMessageDispatcherTest::testSendMessages_underUploadLimit() { // btMessageDispatcher->addMessageToQueue(msg1); // btMessageDispatcher->addMessageToQueue(msg2); // btMessageDispatcher->addMessageToQueue(msg3); -// btMessageDispatcher->sendMessages(); +// btMessageDispatcher->sendMessagesInternal(); // CPPUNIT_ASSERT(!msg1->isSendCalled()); // CPPUNIT_ASSERT(!msg2->isSendCalled()); @@ -242,31 +241,6 @@ void DefaultBtMessageDispatcherTest::testSendMessages_underUploadLimit() { // btMessageDispatcher->getMessageQueue().size()); // } -void DefaultBtMessageDispatcherTest::testSendMessages_sendingInProgress() { - SharedHandle msg1(new MockBtMessage2()); - msg1->setSendingInProgress(false); - msg1->setUploading(false); - SharedHandle msg2(new MockBtMessage2()); - msg2->setSendingInProgress(true); - msg2->setUploading(false); - SharedHandle msg3(new MockBtMessage2()); - msg3->setSendingInProgress(false); - msg3->setUploading(false); - - btMessageDispatcher->addMessageToQueue(msg1); - btMessageDispatcher->addMessageToQueue(msg2); - btMessageDispatcher->addMessageToQueue(msg3); - - btMessageDispatcher->sendMessages(); - - CPPUNIT_ASSERT(msg1->isSendCalled()); - CPPUNIT_ASSERT(msg2->isSendCalled()); - CPPUNIT_ASSERT(!msg3->isSendCalled()); - - CPPUNIT_ASSERT_EQUAL((size_t)2, - btMessageDispatcher->getMessageQueue().size()); -} - void DefaultBtMessageDispatcherTest::testDoCancelSendingPieceAction() { SharedHandle msg1(new MockBtMessage2()); SharedHandle msg2(new MockBtMessage2());