Wrap BtMessage objects by std::unique_ptr instead of std::shared_ptr

pull/103/head
Tatsuhiro Tsujikawa 2013-06-30 16:55:15 +09:00
parent abcb0745ed
commit 098f1571be
71 changed files with 816 additions and 852 deletions

View File

@ -45,7 +45,7 @@ const char BtAllowedFastMessage::NAME[] = "allowed fast";
BtAllowedFastMessage::BtAllowedFastMessage(size_t index): BtAllowedFastMessage::BtAllowedFastMessage(size_t index):
IndexBtMessage(ID, NAME, index) {} IndexBtMessage(ID, NAME, index) {}
BtAllowedFastMessage* BtAllowedFastMessage::create std::unique_ptr<BtAllowedFastMessage> BtAllowedFastMessage::create
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
return IndexBtMessage::create<BtAllowedFastMessage>(data, dataLength); return IndexBtMessage::create<BtAllowedFastMessage>(data, dataLength);

View File

@ -47,7 +47,7 @@ public:
static const char NAME[]; static const char NAME[];
static BtAllowedFastMessage* create static std::unique_ptr<BtAllowedFastMessage> create
(const unsigned char* data, size_t dataLength); (const unsigned char* data, size_t dataLength);
virtual void doReceivedAction(); virtual void doReceivedAction();

View File

@ -78,12 +78,12 @@ void BtBitfieldMessage::setBitfield
memcpy(bitfield_, bitfield, bitfieldLength_); memcpy(bitfield_, bitfield, bitfieldLength_);
} }
BtBitfieldMessage* std::unique_ptr<BtBitfieldMessage>
BtBitfieldMessage::create(const unsigned char* data, size_t dataLength) BtBitfieldMessage::create(const unsigned char* data, size_t dataLength)
{ {
bittorrent::assertPayloadLengthGreater(1,dataLength, NAME); bittorrent::assertPayloadLengthGreater(1,dataLength, NAME);
bittorrent::assertID(ID, data, NAME); bittorrent::assertID(ID, data, NAME);
BtBitfieldMessage* message(new BtBitfieldMessage()); auto message = make_unique<BtBitfieldMessage>();
message->setBitfield(data+1, dataLength-1); message->setBitfield(data+1, dataLength-1);
return message; return message;
} }

View File

@ -60,7 +60,7 @@ public:
size_t getBitfieldLength() const { return bitfieldLength_; } size_t getBitfieldLength() const { return bitfieldLength_; }
static BtBitfieldMessage* create static std::unique_ptr<BtBitfieldMessage> create
(const unsigned char* data, size_t dataLength); (const unsigned char* data, size_t dataLength);
virtual void doReceivedAction(); virtual void doReceivedAction();

View File

@ -43,7 +43,7 @@ BtCancelMessage::BtCancelMessage
(size_t index, int32_t begin, int32_t length) (size_t index, int32_t begin, int32_t length)
:RangeBtMessage(ID, NAME, index, begin, length) {} :RangeBtMessage(ID, NAME, index, begin, length) {}
BtCancelMessage* BtCancelMessage::create std::unique_ptr<BtCancelMessage> BtCancelMessage::create
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
return RangeBtMessage::create<BtCancelMessage>(data, dataLength); return RangeBtMessage::create<BtCancelMessage>(data, dataLength);

View File

@ -47,7 +47,8 @@ public:
static const char NAME[]; static const char NAME[];
static BtCancelMessage* create(const unsigned char* data, size_t dataLength); static std::unique_ptr<BtCancelMessage> create
(const unsigned char* data, size_t dataLength);
virtual void doReceivedAction(); virtual void doReceivedAction();
}; };

View File

@ -42,9 +42,9 @@ namespace aria2 {
const char BtChokeMessage::NAME[] = "choke"; const char BtChokeMessage::NAME[] = "choke";
BtChokeMessage::BtChokeMessage():ZeroBtMessage(ID, NAME) {} BtChokeMessage::BtChokeMessage():ZeroBtMessage{ID, NAME} {}
BtChokeMessage* BtChokeMessage::create std::unique_ptr<BtChokeMessage> BtChokeMessage::create
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
return ZeroBtMessage::create<BtChokeMessage>(data, dataLength); return ZeroBtMessage::create<BtChokeMessage>(data, dataLength);

View File

@ -49,7 +49,8 @@ public:
virtual void doReceivedAction(); virtual void doReceivedAction();
static BtChokeMessage* create(const unsigned char* data, size_t dataLength); static std::unique_ptr<BtChokeMessage> create
(const unsigned char* data, size_t dataLength);
virtual bool sendPredicate() const; virtual bool sendPredicate() const;

View File

@ -96,7 +96,7 @@ std::string BtExtendedMessage::toString() const {
return s; return s;
} }
BtExtendedMessage* std::unique_ptr<BtExtendedMessage>
BtExtendedMessage::create(const std::shared_ptr<ExtensionMessageFactory>& factory, BtExtendedMessage::create(const std::shared_ptr<ExtensionMessageFactory>& factory,
const std::shared_ptr<Peer>& peer, const std::shared_ptr<Peer>& peer,
const unsigned char* data, size_t dataLength) const unsigned char* data, size_t dataLength)
@ -106,8 +106,7 @@ BtExtendedMessage::create(const std::shared_ptr<ExtensionMessageFactory>& factor
assert(factory); assert(factory);
std::shared_ptr<ExtensionMessage> extmsg = factory->createMessage(data+1, std::shared_ptr<ExtensionMessage> extmsg = factory->createMessage(data+1,
dataLength-1); dataLength-1);
BtExtendedMessage* message(new BtExtendedMessage(extmsg)); return make_unique<BtExtendedMessage>(extmsg);
return message;
} }
void BtExtendedMessage::doReceivedAction() void BtExtendedMessage::doReceivedAction()

View File

@ -56,7 +56,7 @@ public:
static const char NAME[]; static const char NAME[];
static BtExtendedMessage* create static std::unique_ptr<BtExtendedMessage> create
(const std::shared_ptr<ExtensionMessageFactory>& factory, (const std::shared_ptr<ExtensionMessageFactory>& factory,
const std::shared_ptr<Peer>& peer, const std::shared_ptr<Peer>& peer,
const unsigned char* data, const unsigned char* data,

View File

@ -75,10 +75,10 @@ void BtHandshakeMessage::init() {
reserved_[5] |= 0x10u; reserved_[5] |= 0x10u;
} }
std::shared_ptr<BtHandshakeMessage> std::unique_ptr<BtHandshakeMessage>
BtHandshakeMessage::create(const unsigned char* data, size_t dataLength) BtHandshakeMessage::create(const unsigned char* data, size_t dataLength)
{ {
std::shared_ptr<BtHandshakeMessage> message(new BtHandshakeMessage()); auto message = make_unique<BtHandshakeMessage>();
message->pstrlen_ = data[0]; message->pstrlen_ = data[0];
memcpy(message->pstr_, &data[1], PSTR_LENGTH); memcpy(message->pstr_, &data[1], PSTR_LENGTH);
memcpy(message->reserved_, &data[20], RESERVED_LENGTH); memcpy(message->reserved_, &data[20], RESERVED_LENGTH);

View File

@ -60,7 +60,7 @@ public:
*/ */
BtHandshakeMessage(const unsigned char* infoHash, const unsigned char* peerId); BtHandshakeMessage(const unsigned char* infoHash, const unsigned char* peerId);
static std::shared_ptr<BtHandshakeMessage> static std::unique_ptr<BtHandshakeMessage>
create(const unsigned char* data, size_t dataLength); create(const unsigned char* data, size_t dataLength);
virtual ~BtHandshakeMessage() { virtual ~BtHandshakeMessage() {

View File

@ -45,7 +45,7 @@ const char BtHaveAllMessage::NAME[] = "have all";
BtHaveAllMessage::BtHaveAllMessage():ZeroBtMessage(ID, NAME) {} BtHaveAllMessage::BtHaveAllMessage():ZeroBtMessage(ID, NAME) {}
BtHaveAllMessage* BtHaveAllMessage::create std::unique_ptr<BtHaveAllMessage> BtHaveAllMessage::create
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
return ZeroBtMessage::create<BtHaveAllMessage>(data, dataLength); return ZeroBtMessage::create<BtHaveAllMessage>(data, dataLength);

View File

@ -47,7 +47,7 @@ public:
static const char NAME[]; static const char NAME[];
static BtHaveAllMessage* create static std::unique_ptr<BtHaveAllMessage> create
(const unsigned char* data, size_t dataLength); (const unsigned char* data, size_t dataLength);
virtual void doReceivedAction(); virtual void doReceivedAction();

View File

@ -44,7 +44,7 @@ const char BtHaveMessage::NAME[] = "have";
BtHaveMessage::BtHaveMessage(size_t index):IndexBtMessage(ID, NAME, index) {} BtHaveMessage::BtHaveMessage(size_t index):IndexBtMessage(ID, NAME, index) {}
BtHaveMessage* BtHaveMessage::create std::unique_ptr<BtHaveMessage> BtHaveMessage::create
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
return IndexBtMessage::create<BtHaveMessage>(data, dataLength); return IndexBtMessage::create<BtHaveMessage>(data, dataLength);

View File

@ -47,7 +47,8 @@ public:
static const char NAME[]; static const char NAME[];
static BtHaveMessage* create(const unsigned char* data, size_t dataLength); static std::unique_ptr<BtHaveMessage> create
(const unsigned char* data, size_t dataLength);
virtual void doReceivedAction(); virtual void doReceivedAction();
}; };

View File

@ -43,7 +43,7 @@ const char BtHaveNoneMessage::NAME[] = "have none";
BtHaveNoneMessage::BtHaveNoneMessage():ZeroBtMessage(ID, NAME) {} BtHaveNoneMessage::BtHaveNoneMessage():ZeroBtMessage(ID, NAME) {}
BtHaveNoneMessage* BtHaveNoneMessage::create std::unique_ptr<BtHaveNoneMessage> BtHaveNoneMessage::create
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
return ZeroBtMessage::create<BtHaveNoneMessage>(data, dataLength); return ZeroBtMessage::create<BtHaveNoneMessage>(data, dataLength);

View File

@ -47,7 +47,7 @@ public:
static const char NAME[]; static const char NAME[];
static BtHaveNoneMessage* create static std::unique_ptr<BtHaveNoneMessage> create
(const unsigned char* data, size_t dataLength); (const unsigned char* data, size_t dataLength);
virtual void doReceivedAction(); virtual void doReceivedAction();

View File

@ -48,7 +48,7 @@ BtInterestedMessage::BtInterestedMessage()
BtInterestedMessage::~BtInterestedMessage() {} BtInterestedMessage::~BtInterestedMessage() {}
BtInterestedMessage* BtInterestedMessage::create std::unique_ptr<BtInterestedMessage> BtInterestedMessage::create
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
return ZeroBtMessage::create<BtInterestedMessage>(data, dataLength); return ZeroBtMessage::create<BtInterestedMessage>(data, dataLength);

View File

@ -53,7 +53,7 @@ public:
static const char NAME[]; static const char NAME[];
static BtInterestedMessage* create static std::unique_ptr<BtInterestedMessage> create
(const unsigned char* data, size_t dataLength); (const unsigned char* data, size_t dataLength);
virtual void doReceivedAction(); virtual void doReceivedAction();

View File

@ -51,10 +51,7 @@ class BtMessageDispatcher {
public: public:
virtual ~BtMessageDispatcher() {} virtual ~BtMessageDispatcher() {}
virtual void addMessageToQueue(const std::shared_ptr<BtMessage>& btMessage) = 0; virtual void addMessageToQueue(std::unique_ptr<BtMessage> btMessage) = 0;
virtual void
addMessageToQueue(const std::vector<std::shared_ptr<BtMessage> >& btMessages) =0;
virtual void sendMessages() = 0; virtual void sendMessages() = 0;

View File

@ -43,58 +43,77 @@ namespace aria2 {
class BtMessage; class BtMessage;
class BtHandshakeMessage; class BtHandshakeMessage;
class Piece; class BtAllowedFastMessage;
class BtBitfieldMessage;
class BtCancelMessage;
class BtChokeMessage;
class BtHaveAllMessage;
class BtHaveMessage;
class BtHaveNoneMessage;
class BtInterestedMessage;
class BtKeepAliveMessage;
class BtNotInterestedMessage;
class BtPieceMessage;
class BtPortMessage;
class BtRejectMessage;
class BtRequestMessage;
class BtUnchokeMessage;
class BtExtendedMessage;
class ExtensionMessage; class ExtensionMessage;
class Piece;
class BtMessageFactory { class BtMessageFactory {
public: public:
virtual ~BtMessageFactory() {} virtual ~BtMessageFactory() {}
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtMessage>
createBtMessage(const unsigned char* msg, size_t msgLength) = 0; createBtMessage(const unsigned char* msg, size_t msgLength) = 0;
virtual std::shared_ptr<BtHandshakeMessage> virtual std::unique_ptr<BtHandshakeMessage>
createHandshakeMessage(const unsigned char* msg, size_t msgLength) = 0; createHandshakeMessage(const unsigned char* msg, size_t msgLength) = 0;
virtual std::shared_ptr<BtHandshakeMessage> virtual std::unique_ptr<BtHandshakeMessage>
createHandshakeMessage(const unsigned char* infoHash, createHandshakeMessage(const unsigned char* infoHash,
const unsigned char* peerId) = 0; const unsigned char* peerId) = 0;
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtRequestMessage>
createRequestMessage(const std::shared_ptr<Piece>& piece, size_t blockIndex) = 0; createRequestMessage(const std::shared_ptr<Piece>& piece,
size_t blockIndex) = 0;
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtCancelMessage>
createCancelMessage(size_t index, int32_t begin, int32_t length) = 0; createCancelMessage(size_t index, int32_t begin, int32_t length) = 0;
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtPieceMessage>
createPieceMessage(size_t index, int32_t begin, int32_t length) = 0; createPieceMessage(size_t index, int32_t begin, int32_t length) = 0;
virtual std::shared_ptr<BtMessage> createHaveMessage(size_t index) = 0; virtual std::unique_ptr<BtHaveMessage> createHaveMessage(size_t index) = 0;
virtual std::shared_ptr<BtMessage> createChokeMessage() = 0; virtual std::unique_ptr<BtChokeMessage> createChokeMessage() = 0;
virtual std::shared_ptr<BtMessage> createUnchokeMessage() = 0; virtual std::unique_ptr<BtUnchokeMessage> createUnchokeMessage() = 0;
virtual std::shared_ptr<BtMessage> createInterestedMessage() = 0; virtual std::unique_ptr<BtInterestedMessage> createInterestedMessage() = 0;
virtual std::shared_ptr<BtMessage> createNotInterestedMessage() = 0; virtual std::unique_ptr<BtNotInterestedMessage>
createNotInterestedMessage() = 0;
virtual std::shared_ptr<BtMessage> createBitfieldMessage() = 0; virtual std::unique_ptr<BtBitfieldMessage> createBitfieldMessage() = 0;
virtual std::shared_ptr<BtMessage> createKeepAliveMessage() = 0; virtual std::unique_ptr<BtKeepAliveMessage> createKeepAliveMessage() = 0;
virtual std::shared_ptr<BtMessage> createHaveAllMessage() = 0; virtual std::unique_ptr<BtHaveAllMessage> createHaveAllMessage() = 0;
virtual std::shared_ptr<BtMessage> createHaveNoneMessage() = 0; virtual std::unique_ptr<BtHaveNoneMessage> createHaveNoneMessage() = 0;
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtRejectMessage>
createRejectMessage(size_t index, int32_t begin, int32_t length) = 0; createRejectMessage(size_t index, int32_t begin, int32_t length) = 0;
virtual std::shared_ptr<BtMessage> createAllowedFastMessage(size_t index) = 0; virtual std::unique_ptr<BtAllowedFastMessage>
createAllowedFastMessage(size_t index) = 0;
virtual std::shared_ptr<BtMessage> createPortMessage(uint16_t port) = 0; virtual std::unique_ptr<BtPortMessage> createPortMessage(uint16_t port) = 0;
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtExtendedMessage>
createBtExtendedMessage(const std::shared_ptr<ExtensionMessage>& msg) = 0; createBtExtendedMessage(const std::shared_ptr<ExtensionMessage>& msg) = 0;
}; };

View File

@ -48,7 +48,7 @@ BtNotInterestedMessage::BtNotInterestedMessage()
BtNotInterestedMessage::~BtNotInterestedMessage() {} BtNotInterestedMessage::~BtNotInterestedMessage() {}
BtNotInterestedMessage* BtNotInterestedMessage::create std::unique_ptr<BtNotInterestedMessage> BtNotInterestedMessage::create
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
return ZeroBtMessage::create<BtNotInterestedMessage>(data, dataLength); return ZeroBtMessage::create<BtNotInterestedMessage>(data, dataLength);

View File

@ -53,7 +53,7 @@ public:
static const char NAME[]; static const char NAME[];
static BtNotInterestedMessage* create static std::unique_ptr<BtNotInterestedMessage> create
(const unsigned char* data, size_t dataLength); (const unsigned char* data, size_t dataLength);
virtual void doReceivedAction(); virtual void doReceivedAction();

View File

@ -60,6 +60,7 @@
#include "WrDiskCache.h" #include "WrDiskCache.h"
#include "WrDiskCacheEntry.h" #include "WrDiskCacheEntry.h"
#include "DownloadFailureException.h" #include "DownloadFailureException.h"
#include "BtRejectMessage.h"
namespace aria2 { namespace aria2 {
@ -86,16 +87,14 @@ void BtPieceMessage::setMsgPayload(const unsigned char* data)
data_ = data; data_ = data;
} }
BtPieceMessage* BtPieceMessage::create std::unique_ptr<BtPieceMessage> BtPieceMessage::create
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
bittorrent::assertPayloadLengthGreater(9, dataLength, NAME); bittorrent::assertPayloadLengthGreater(9, dataLength, NAME);
bittorrent::assertID(ID, data, NAME); bittorrent::assertID(ID, data, NAME);
BtPieceMessage* message(new BtPieceMessage()); return make_unique<BtPieceMessage>(bittorrent::getIntParam(data, 1),
message->setIndex(bittorrent::getIntParam(data, 1)); bittorrent::getIntParam(data, 5),
message->setBegin(bittorrent::getIntParam(data, 5)); dataLength-9);
message->setBlockLength(dataLength-9);
return message;
} }
void BtPieceMessage::doReceivedAction() void BtPieceMessage::doReceivedAction()
@ -305,10 +304,9 @@ void BtPieceMessage::onChokingEvent(const BtChokingEvent& event)
begin_, begin_,
blockLength_)); blockLength_));
if(getPeer()->isFastExtensionEnabled()) { if(getPeer()->isFastExtensionEnabled()) {
std::shared_ptr<BtMessage> rej = getBtMessageDispatcher()->addMessageToQueue
getBtMessageFactory()->createRejectMessage (getBtMessageFactory()->createRejectMessage
(index_, begin_, blockLength_); (index_, begin_, blockLength_));
getBtMessageDispatcher()->addMessageToQueue(rej);
} }
setInvalidate(true); setInvalidate(true);
} }
@ -327,10 +325,9 @@ void BtPieceMessage::onCancelSendingPieceEvent
begin_, begin_,
blockLength_)); blockLength_));
if(getPeer()->isFastExtensionEnabled()) { if(getPeer()->isFastExtensionEnabled()) {
std::shared_ptr<BtMessage> rej = getBtMessageDispatcher()->addMessageToQueue
getBtMessageFactory()->createRejectMessage (getBtMessageFactory()->createRejectMessage
(index_, begin_, blockLength_); (index_, begin_, blockLength_));
getBtMessageDispatcher()->addMessageToQueue(rej);
} }
setInvalidate(true); setInvalidate(true);
} }

View File

@ -92,7 +92,8 @@ public:
void setPeerStorage(PeerStorage* peerStorage); void setPeerStorage(PeerStorage* peerStorage);
static BtPieceMessage* create(const unsigned char* data, size_t dataLength); static std::unique_ptr<BtPieceMessage> create
(const unsigned char* data, size_t dataLength);
virtual void doReceivedAction(); virtual void doReceivedAction();

View File

@ -61,14 +61,13 @@ BtPortMessage::BtPortMessage(uint16_t port)
taskFactory_(0) taskFactory_(0)
{} {}
BtPortMessage* BtPortMessage::create std::unique_ptr<BtPortMessage> BtPortMessage::create
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
bittorrent::assertPayloadLengthEqual(3, dataLength, NAME); bittorrent::assertPayloadLengthEqual(3, dataLength, NAME);
bittorrent::assertID(ID, data, NAME); bittorrent::assertID(ID, data, NAME);
uint16_t port = bittorrent::getShortIntParam(data, 1); uint16_t port = bittorrent::getShortIntParam(data, 1);
BtPortMessage* message(new BtPortMessage(port)); return make_unique<BtPortMessage>(port);
return message;
} }
void BtPortMessage::doReceivedAction() void BtPortMessage::doReceivedAction()

View File

@ -65,7 +65,8 @@ public:
uint16_t getPort() const { return port_; } uint16_t getPort() const { return port_; }
static BtPortMessage* create(const unsigned char* data, size_t dataLength); static std::unique_ptr<BtPortMessage> create
(const unsigned char* data, size_t dataLength);
virtual void doReceivedAction(); virtual void doReceivedAction();

View File

@ -47,7 +47,7 @@ BtRejectMessage::BtRejectMessage
(size_t index, int32_t begin, int32_t length): (size_t index, int32_t begin, int32_t length):
RangeBtMessage(ID, NAME, index, begin, length) {} RangeBtMessage(ID, NAME, index, begin, length) {}
BtRejectMessage* BtRejectMessage::create std::unique_ptr<BtRejectMessage> BtRejectMessage::create
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
return RangeBtMessage::create<BtRejectMessage>(data, dataLength); return RangeBtMessage::create<BtRejectMessage>(data, dataLength);

View File

@ -47,7 +47,8 @@ public:
static const char NAME[]; static const char NAME[];
static BtRejectMessage* create(const unsigned char* data, size_t dataLength); static std::unique_ptr<BtRejectMessage> create
(const unsigned char* data, size_t dataLength);
virtual void doReceivedAction(); virtual void doReceivedAction();
}; };

View File

@ -43,7 +43,7 @@
namespace aria2 { namespace aria2 {
class Piece; class Piece;
class BtMessage; class BtRequestMessage;
class BtRequestFactory { class BtRequestFactory {
public: public:
@ -65,24 +65,18 @@ public:
/** /**
* Creates RequestMessage objects associated to the pieces added by * Creates RequestMessage objects associated to the pieces added by
* addTargetPiece() and returns them. * addTargetPiece() and returns them. The number of objects
* The number of objects returned is capped by max. * returned is capped by max. If |endGame| is true, returns
* requests in end game mode.
*/ */
virtual void createRequestMessages virtual std::vector<std::unique_ptr<BtRequestMessage>> createRequestMessages
(std::vector<std::shared_ptr<BtMessage> >& requests, size_t max) = 0; (size_t max, bool endGame) = 0;
/** /**
* Use this method in end game mode. * Returns the list of index of pieces added using addTargetPiece()
* * into indexes.
*/ */
virtual void createRequestMessagesOnEndGame virtual std::vector<size_t> getTargetPieceIndexes() const = 0;
(std::vector<std::shared_ptr<BtMessage> >& requests, size_t max) = 0;
/**
* Stores the list of index of pieces added using addTargetPiece() into
* indexes.
*/
virtual void getTargetPieceIndexes(std::vector<size_t>& indexes) const = 0;
}; };

View File

@ -38,6 +38,8 @@
#include "PieceStorage.h" #include "PieceStorage.h"
#include "BtMessageDispatcher.h" #include "BtMessageDispatcher.h"
#include "BtMessageFactory.h" #include "BtMessageFactory.h"
#include "BtPieceMessage.h"
#include "BtRejectMessage.h"
namespace aria2 { namespace aria2 {
@ -48,7 +50,7 @@ BtRequestMessage::BtRequestMessage
RangeBtMessage(ID, NAME, index, begin, length), RangeBtMessage(ID, NAME, index, begin, length),
blockIndex_(blockIndex) {} blockIndex_(blockIndex) {}
BtRequestMessage* BtRequestMessage::create std::unique_ptr<BtRequestMessage> BtRequestMessage::create
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
return RangeBtMessage::create<BtRequestMessage>(data, dataLength); return RangeBtMessage::create<BtRequestMessage>(data, dataLength);
@ -63,16 +65,14 @@ void BtRequestMessage::doReceivedAction()
(!getPeer()->amChoking() || (!getPeer()->amChoking() ||
(getPeer()->amChoking() && (getPeer()->amChoking() &&
getPeer()->isInAmAllowedIndexSet(getIndex())))) { getPeer()->isInAmAllowedIndexSet(getIndex())))) {
std::shared_ptr<BtMessage> msg = getBtMessageDispatcher()->addMessageToQueue
getBtMessageFactory()->createPieceMessage (getBtMessageFactory()->createPieceMessage
(getIndex(), getBegin(), getLength()); (getIndex(), getBegin(), getLength()));
getBtMessageDispatcher()->addMessageToQueue(msg);
} else { } else {
if(getPeer()->isFastExtensionEnabled()) { if(getPeer()->isFastExtensionEnabled()) {
std::shared_ptr<BtMessage> msg = getBtMessageDispatcher()->addMessageToQueue
getBtMessageFactory()->createRejectMessage (getBtMessageFactory()->createRejectMessage
(getIndex(), getBegin(), getLength()); (getIndex(), getBegin(), getLength()));
getBtMessageDispatcher()->addMessageToQueue(msg);
} }
} }
} }

View File

@ -55,7 +55,7 @@ public:
size_t getBlockIndex() const { return blockIndex_; } size_t getBlockIndex() const { return blockIndex_; }
void setBlockIndex(size_t blockIndex) { blockIndex_ = blockIndex; } void setBlockIndex(size_t blockIndex) { blockIndex_ = blockIndex; }
static BtRequestMessage* create static std::unique_ptr<BtRequestMessage> create
(const unsigned char* data, size_t dataLength); (const unsigned char* data, size_t dataLength);
virtual void doReceivedAction(); virtual void doReceivedAction();

View File

@ -36,9 +36,13 @@
namespace aria2 { namespace aria2 {
BtSuggestPieceMessage::BtSuggestPieceMessage(size_t index)
: IndexBtMessage{ID, NAME, index}
{}
const char BtSuggestPieceMessage::NAME[] = "suggest piece"; const char BtSuggestPieceMessage::NAME[] = "suggest piece";
BtSuggestPieceMessage* BtSuggestPieceMessage::create std::unique_ptr<BtSuggestPieceMessage> BtSuggestPieceMessage::create
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
return IndexBtMessage::create<BtSuggestPieceMessage>(data, dataLength); return IndexBtMessage::create<BtSuggestPieceMessage>(data, dataLength);

View File

@ -41,13 +41,13 @@ namespace aria2 {
class BtSuggestPieceMessage : public IndexBtMessage { class BtSuggestPieceMessage : public IndexBtMessage {
public: public:
BtSuggestPieceMessage():IndexBtMessage(ID, NAME, 0) {} BtSuggestPieceMessage(size_t index = 0);
static const uint8_t ID = 13; static const uint8_t ID = 13;
static const char NAME[]; static const char NAME[];
static BtSuggestPieceMessage* create static std::unique_ptr<BtSuggestPieceMessage> create
(const unsigned char* data, size_t dataLength); (const unsigned char* data, size_t dataLength);
virtual void doReceivedAction() { virtual void doReceivedAction() {

View File

@ -42,7 +42,7 @@ const char BtUnchokeMessage::NAME[] = "unchoke";
BtUnchokeMessage::BtUnchokeMessage():ZeroBtMessage(ID, NAME) {} BtUnchokeMessage::BtUnchokeMessage():ZeroBtMessage(ID, NAME) {}
BtUnchokeMessage* BtUnchokeMessage::create std::unique_ptr<BtUnchokeMessage> BtUnchokeMessage::create
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
return ZeroBtMessage::create<BtUnchokeMessage>(data, dataLength); return ZeroBtMessage::create<BtUnchokeMessage>(data, dataLength);

View File

@ -49,7 +49,7 @@ public:
static const char NAME[]; static const char NAME[];
static BtUnchokeMessage* create static std::unique_ptr<BtUnchokeMessage> create
(const unsigned char* data, size_t dataLength); (const unsigned char* data, size_t dataLength);
virtual void doReceivedAction(); virtual void doReceivedAction();

View File

@ -46,6 +46,14 @@
#include "BtUnchokeMessage.h" #include "BtUnchokeMessage.h"
#include "BtRequestMessage.h" #include "BtRequestMessage.h"
#include "BtPieceMessage.h" #include "BtPieceMessage.h"
#include "BtPortMessage.h"
#include "BtInterestedMessage.h"
#include "BtNotInterestedMessage.h"
#include "BtHaveMessage.h"
#include "BtHaveAllMessage.h"
#include "BtBitfieldMessage.h"
#include "BtHaveNoneMessage.h"
#include "BtAllowedFastMessage.h"
#include "DlAbortEx.h" #include "DlAbortEx.h"
#include "BtExtendedMessage.h" #include "BtExtendedMessage.h"
#include "HandshakeExtensionMessage.h" #include "HandshakeExtensionMessage.h"
@ -104,10 +112,10 @@ DefaultBtInteractive::DefaultBtInteractive
DefaultBtInteractive::~DefaultBtInteractive() {} DefaultBtInteractive::~DefaultBtInteractive() {}
void DefaultBtInteractive::initiateHandshake() { void DefaultBtInteractive::initiateHandshake() {
std::shared_ptr<BtMessage> message = dispatcher_->addMessageToQueue
messageFactory_->createHandshakeMessage (messageFactory_->createHandshakeMessage
(bittorrent::getInfoHash(downloadContext_), bittorrent::getStaticPeerId()); (bittorrent::getInfoHash(downloadContext_),
dispatcher_->addMessageToQueue(message); bittorrent::getStaticPeerId()));
dispatcher_->sendMessages(); dispatcher_->sendMessages();
} }
@ -199,8 +207,7 @@ void DefaultBtInteractive::addHandshakeExtendedMessageToQueue()
if(!attrs->metadata.empty()) { if(!attrs->metadata.empty()) {
m->setMetadataSize(attrs->metadataSize); m->setMetadataSize(attrs->metadataSize);
} }
std::shared_ptr<BtMessage> msg = messageFactory_->createBtExtendedMessage(m); dispatcher_->addMessageToQueue(messageFactory_->createBtExtendedMessage(m));
dispatcher_->addMessageToQueue(msg);
} }
void DefaultBtInteractive::addBitfieldMessageToQueue() { void DefaultBtInteractive::addBitfieldMessageToQueue() {
@ -301,9 +308,6 @@ size_t DefaultBtInteractive::receiveMessages() {
message->doReceivedAction(); message->doReceivedAction();
switch(message->getId()) { switch(message->getId()) {
case BtKeepAliveMessage::ID:
floodingStat_.incKeepAliveCount();
break;
case BtChokeMessage::ID: case BtChokeMessage::ID:
if(!peer_->peerChoking()) { if(!peer_->peerChoking()) {
floodingStat_.incChokeUnchokeCount(); floodingStat_.incChokeUnchokeCount();
@ -314,10 +318,13 @@ size_t DefaultBtInteractive::receiveMessages() {
floodingStat_.incChokeUnchokeCount(); floodingStat_.incChokeUnchokeCount();
} }
break; break;
case BtPieceMessage::ID:
case BtRequestMessage::ID: case BtRequestMessage::ID:
case BtPieceMessage::ID:
inactiveTimer_ = global::wallclock(); inactiveTimer_ = global::wallclock();
break; break;
case BtKeepAliveMessage::ID:
floodingStat_.incKeepAliveCount();
break;
} }
} }
@ -359,11 +366,9 @@ void DefaultBtInteractive::fillPiece(size_t maxMissingBlock) {
if(peer_->peerChoking()) { if(peer_->peerChoking()) {
if(peer_->isFastExtensionEnabled()) { if(peer_->isFastExtensionEnabled()) {
if(pieceStorage_->isEndGame()) { if(pieceStorage_->isEndGame()) {
std::vector<size_t> excludedIndexes;
excludedIndexes.reserve(btRequestFactory_->countTargetPiece());
btRequestFactory_->getTargetPieceIndexes(excludedIndexes);
pieceStorage_->getMissingFastPiece pieceStorage_->getMissingFastPiece
(pieces, diffMissingBlock, peer_, excludedIndexes, cuid_); (pieces, diffMissingBlock, peer_,
btRequestFactory_->getTargetPieceIndexes(), cuid_);
} else { } else {
pieces.reserve(diffMissingBlock); pieces.reserve(diffMissingBlock);
pieceStorage_->getMissingFastPiece pieceStorage_->getMissingFastPiece
@ -372,11 +377,9 @@ void DefaultBtInteractive::fillPiece(size_t maxMissingBlock) {
} }
} else { } else {
if(pieceStorage_->isEndGame()) { if(pieceStorage_->isEndGame()) {
std::vector<size_t> excludedIndexes;
excludedIndexes.reserve(btRequestFactory_->countTargetPiece());
btRequestFactory_->getTargetPieceIndexes(excludedIndexes);
pieceStorage_->getMissingPiece pieceStorage_->getMissingPiece
(pieces, diffMissingBlock, peer_, excludedIndexes, cuid_); (pieces, diffMissingBlock, peer_,
btRequestFactory_->getTargetPieceIndexes(), cuid_);
} else { } else {
pieces.reserve(diffMissingBlock); pieces.reserve(diffMissingBlock);
pieceStorage_->getMissingPiece(pieces, diffMissingBlock, peer_, cuid_); pieceStorage_->getMissingPiece(pieces, diffMissingBlock, peer_, cuid_);
@ -399,14 +402,12 @@ void DefaultBtInteractive::addRequests() {
0 : maxOutstandingRequest_-dispatcher_->countOutstandingRequest(); 0 : maxOutstandingRequest_-dispatcher_->countOutstandingRequest();
if(reqNumToCreate > 0) { if(reqNumToCreate > 0) {
std::vector<std::shared_ptr<BtMessage> > requests; auto requests =
requests.reserve(reqNumToCreate); btRequestFactory_->createRequestMessages(reqNumToCreate,
if(pieceStorage_->isEndGame()) { pieceStorage_->isEndGame());
btRequestFactory_->createRequestMessagesOnEndGame(requests,reqNumToCreate); for(auto& i : requests) {
} else { dispatcher_->addMessageToQueue(std::move(i));
btRequestFactory_->createRequestMessages(requests, reqNumToCreate);
} }
dispatcher_->addMessageToQueue(requests);
} }
} }
@ -502,8 +503,8 @@ void DefaultBtInteractive::addPeerExchangeMessage()
} }
} }
std::shared_ptr<BtMessage> msg = messageFactory_->createBtExtendedMessage(m); dispatcher_->addMessageToQueue
dispatcher_->addMessageToQueue(msg); (messageFactory_->createBtExtendedMessage(m));
pexTimer_ = global::wallclock(); pexTimer_ = global::wallclock();
} }
} }
@ -520,9 +521,11 @@ void DefaultBtInteractive::doInteractionProcessing() {
downloadContext_->getTotalLength() > 0) { downloadContext_->getTotalLength() > 0) {
size_t num = utMetadataRequestTracker_->avail(); size_t num = utMetadataRequestTracker_->avail();
if(num > 0) { if(num > 0) {
std::vector<std::shared_ptr<BtMessage> > requests; auto requests =
utMetadataRequestFactory_->create(requests, num, pieceStorage_); utMetadataRequestFactory_->create(num, pieceStorage_);
dispatcher_->addMessageToQueue(requests); for(auto& i : requests) {
dispatcher_->addMessageToQueue(std::move(i));
}
} }
if(perSecTimer_.difference(global::wallclock()) >= 1) { if(perSecTimer_.difference(global::wallclock()) >= 1) {
perSecTimer_ = global::wallclock(); perSecTimer_ = global::wallclock();

View File

@ -57,18 +57,19 @@
#include "util.h" #include "util.h"
#include "fmt.h" #include "fmt.h"
#include "PeerConnection.h" #include "PeerConnection.h"
#include "BtCancelMessage.h"
namespace aria2 { namespace aria2 {
DefaultBtMessageDispatcher::DefaultBtMessageDispatcher() DefaultBtMessageDispatcher::DefaultBtMessageDispatcher()
: cuid_(0), : cuid_{0},
downloadContext_{0}, downloadContext_{nullptr},
peerStorage_{0}, peerStorage_{nullptr},
pieceStorage_{0}, pieceStorage_{nullptr},
peerConnection_{0}, peerConnection_{nullptr},
messageFactory_(0), messageFactory_{nullptr},
requestGroupMan_(0), requestGroupMan_{nullptr},
requestTimeout_(0) requestTimeout_{0}
{} {}
DefaultBtMessageDispatcher::~DefaultBtMessageDispatcher() DefaultBtMessageDispatcher::~DefaultBtMessageDispatcher()
@ -77,39 +78,31 @@ DefaultBtMessageDispatcher::~DefaultBtMessageDispatcher()
} }
void DefaultBtMessageDispatcher::addMessageToQueue void DefaultBtMessageDispatcher::addMessageToQueue
(const std::shared_ptr<BtMessage>& btMessage) (std::unique_ptr<BtMessage> btMessage)
{ {
btMessage->onQueued(); btMessage->onQueued();
messageQueue_.push_back(btMessage); messageQueue_.push_back(std::move(btMessage));
}
void DefaultBtMessageDispatcher::addMessageToQueue
(const std::vector<std::shared_ptr<BtMessage> >& btMessages)
{
for(std::vector<std::shared_ptr<BtMessage> >::const_iterator itr =
btMessages.begin(), eoi = btMessages.end(); itr != eoi; ++itr) {
addMessageToQueue(*itr);
}
} }
void DefaultBtMessageDispatcher::sendMessagesInternal() void DefaultBtMessageDispatcher::sendMessagesInternal()
{ {
std::vector<std::shared_ptr<BtMessage> > tempQueue; auto tempQueue = std::vector<std::unique_ptr<BtMessage>>{};
while(!messageQueue_.empty()) { while(!messageQueue_.empty()) {
std::shared_ptr<BtMessage> msg = messageQueue_.front(); auto msg = std::move(messageQueue_.front());
messageQueue_.pop_front(); messageQueue_.pop_front();
if(msg->isUploading()) { if(msg->isUploading()) {
if(requestGroupMan_->doesOverallUploadSpeedExceed() || if(requestGroupMan_->doesOverallUploadSpeedExceed() ||
downloadContext_->getOwnerRequestGroup()->doesUploadSpeedExceed()) { downloadContext_->getOwnerRequestGroup()->doesUploadSpeedExceed()) {
tempQueue.push_back(msg); tempQueue.push_back(std::move(msg));
continue; continue;
} }
} }
msg->send(); msg->send();
} }
if(!tempQueue.empty()) { if(!tempQueue.empty()) {
messageQueue_.insert(messageQueue_.begin(), messageQueue_.insert(std::begin(messageQueue_),
tempQueue.begin(), tempQueue.end()); std::make_move_iterator(std::begin(tempQueue)),
std::make_move_iterator(std::end(tempQueue)));
} }
} }
@ -121,15 +114,26 @@ void DefaultBtMessageDispatcher::sendMessages()
peerConnection_->sendPendingData(); peerConnection_->sendPendingData();
} }
namespace {
std::vector<BtMessage*> toRawPointers
(const std::deque<std::unique_ptr<BtMessage>>& v)
{
auto x = std::vector<BtMessage*>{};
x.reserve(v.size());
for(auto& i : v) {
x.push_back(i.get());
}
return x;
}
} // namespace
// Cancel sending piece message to peer. // Cancel sending piece message to peer.
void DefaultBtMessageDispatcher::doCancelSendingPieceAction void DefaultBtMessageDispatcher::doCancelSendingPieceAction
(size_t index, int32_t begin, int32_t length) (size_t index, int32_t begin, int32_t length)
{ {
BtCancelSendingPieceEvent event(index, begin, length); BtCancelSendingPieceEvent event(index, begin, length);
auto q = toRawPointers(messageQueue_);
std::vector<std::shared_ptr<BtMessage> > tempQueue for(auto i : q) {
(messageQueue_.begin(), messageQueue_.end());
for(const auto& i : tempQueue) {
i->onCancelSendingPieceEvent(event); i->onCancelSendingPieceEvent(event);
} }
} }
@ -173,9 +177,8 @@ void DefaultBtMessageDispatcher::doAbortOutstandingRequestAction
BtAbortOutstandingRequestEvent event(piece); BtAbortOutstandingRequestEvent event(piece);
std::vector<std::shared_ptr<BtMessage> > tempQueue auto tempQueue = toRawPointers(messageQueue_);
(messageQueue_.begin(), messageQueue_.end()); for(auto i : tempQueue) {
for(const auto& i : tempQueue) {
i->onAbortOutstandingRequestEvent(event); i->onAbortOutstandingRequestEvent(event);
} }
} }
@ -207,9 +210,8 @@ void DefaultBtMessageDispatcher::doChokingAction()
{ {
BtChokingEvent event; BtChokingEvent event;
std::vector<std::shared_ptr<BtMessage> > tempQueue auto tempQueue = toRawPointers(messageQueue_);
(messageQueue_.begin(), messageQueue_.end()); for(auto i : tempQueue) {
for(const auto& i : tempQueue) {
i->onChokingEvent(event); i->onChokingEvent(event);
} }
} }
@ -298,7 +300,7 @@ void DefaultBtMessageDispatcher::addOutstandingRequest
size_t DefaultBtMessageDispatcher::countOutstandingUpload() size_t DefaultBtMessageDispatcher::countOutstandingUpload()
{ {
return std::count_if(messageQueue_.begin(), messageQueue_.end(), return std::count_if(std::begin(messageQueue_), std::end(messageQueue_),
std::mem_fn(&BtMessage::isUploading)); std::mem_fn(&BtMessage::isUploading));
} }

View File

@ -57,7 +57,7 @@ class PeerConnection;
class DefaultBtMessageDispatcher : public BtMessageDispatcher { class DefaultBtMessageDispatcher : public BtMessageDispatcher {
private: private:
cuid_t cuid_; cuid_t cuid_;
std::deque<std::shared_ptr<BtMessage> > messageQueue_; std::deque<std::unique_ptr<BtMessage> > messageQueue_;
std::deque<std::unique_ptr<RequestSlot>> requestSlots_; std::deque<std::unique_ptr<RequestSlot>> requestSlots_;
DownloadContext* downloadContext_; DownloadContext* downloadContext_;
PeerStorage* peerStorage_; PeerStorage* peerStorage_;
@ -72,10 +72,7 @@ public:
virtual ~DefaultBtMessageDispatcher(); virtual ~DefaultBtMessageDispatcher();
virtual void addMessageToQueue(const std::shared_ptr<BtMessage>& btMessage); virtual void addMessageToQueue(std::unique_ptr<BtMessage> btMessage);
virtual void addMessageToQueue
(const std::vector<std::shared_ptr<BtMessage> >& btMessages);
virtual void sendMessages(); virtual void sendMessages();
@ -117,7 +114,7 @@ public:
virtual size_t countOutstandingUpload(); virtual size_t countOutstandingUpload();
const std::deque<std::shared_ptr<BtMessage> >& getMessageQueue() const const std::deque<std::unique_ptr<BtMessage>>& getMessageQueue() const
{ {
return messageQueue_; return messageQueue_;
} }

View File

@ -70,58 +70,53 @@
namespace aria2 { namespace aria2 {
DefaultBtMessageFactory::DefaultBtMessageFactory(): DefaultBtMessageFactory::DefaultBtMessageFactory()
cuid_(0), : cuid_{0},
downloadContext_(0), downloadContext_{nullptr},
pieceStorage_(0), pieceStorage_{nullptr},
peerStorage_(0), peerStorage_{nullptr},
dhtEnabled_(false), dhtEnabled_(false),
dispatcher_(0), dispatcher_{nullptr},
requestFactory_(0), requestFactory_{nullptr},
peerConnection_(0), peerConnection_{nullptr},
localNode_(0), localNode_{nullptr},
routingTable_(0), routingTable_{nullptr},
taskQueue_(0), taskQueue_{nullptr},
taskFactory_(0), taskFactory_{nullptr},
metadataGetMode_(false) metadataGetMode_(false)
{} {}
DefaultBtMessageFactory::~DefaultBtMessageFactory() {} std::unique_ptr<BtMessage>
std::shared_ptr<BtMessage>
DefaultBtMessageFactory::createBtMessage DefaultBtMessageFactory::createBtMessage
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
std::shared_ptr<AbstractBtMessage> msg; auto msg = std::unique_ptr<AbstractBtMessage>{};
if(dataLength == 0) { if(dataLength == 0) {
// keep-alive // keep-alive
msg.reset(new BtKeepAliveMessage()); msg = make_unique<BtKeepAliveMessage>();
} else { } else {
uint8_t id = bittorrent::getId(data); uint8_t id = bittorrent::getId(data);
switch(id) { switch(id) {
case BtChokeMessage::ID: case BtChokeMessage::ID:
msg.reset(BtChokeMessage::create(data, dataLength)); msg = BtChokeMessage::create(data, dataLength);
break; break;
case BtUnchokeMessage::ID: case BtUnchokeMessage::ID:
msg.reset(BtUnchokeMessage::create(data, dataLength)); msg = BtUnchokeMessage::create(data, dataLength);
break; break;
case BtInterestedMessage::ID: case BtInterestedMessage::ID: {
{ auto m = BtInterestedMessage::create(data, dataLength);
BtInterestedMessage* m = BtInterestedMessage::create(data, dataLength);
m->setPeerStorage(peerStorage_); m->setPeerStorage(peerStorage_);
msg.reset(m); msg = std::move(m);
}
break; break;
case BtNotInterestedMessage::ID: }
{ case BtNotInterestedMessage::ID: {
BtNotInterestedMessage* m = auto m = BtNotInterestedMessage::create(data, dataLength);
BtNotInterestedMessage::create(data, dataLength);
m->setPeerStorage(peerStorage_); m->setPeerStorage(peerStorage_);
msg.reset(m); msg = std::move(m);
}
break; break;
}
case BtHaveMessage::ID: case BtHaveMessage::ID:
msg.reset(BtHaveMessage::create(data, dataLength)); msg = BtHaveMessage::create(data, dataLength);
if(!metadataGetMode_) { if(!metadataGetMode_) {
msg->setBtMessageValidator(make_unique<IndexBtMessageValidator> msg->setBtMessageValidator(make_unique<IndexBtMessageValidator>
(static_cast<BtHaveMessage*>(msg.get()), (static_cast<BtHaveMessage*>(msg.get()),
@ -129,7 +124,7 @@ DefaultBtMessageFactory::createBtMessage
} }
break; break;
case BtBitfieldMessage::ID: case BtBitfieldMessage::ID:
msg.reset(BtBitfieldMessage::create(data, dataLength)); msg = BtBitfieldMessage::create(data, dataLength);
if(!metadataGetMode_) { if(!metadataGetMode_) {
msg->setBtMessageValidator(make_unique<BtBitfieldMessageValidator> msg->setBtMessageValidator(make_unique<BtBitfieldMessageValidator>
(static_cast<BtBitfieldMessage*>(msg.get()), (static_cast<BtBitfieldMessage*>(msg.get()),
@ -137,95 +132,94 @@ DefaultBtMessageFactory::createBtMessage
} }
break; break;
case BtRequestMessage::ID: { case BtRequestMessage::ID: {
BtRequestMessage* m = BtRequestMessage::create(data, dataLength); auto m = BtRequestMessage::create(data, dataLength);
if(!metadataGetMode_) { if(!metadataGetMode_) {
m->setBtMessageValidator m->setBtMessageValidator
(make_unique<RangeBtMessageValidator> (make_unique<RangeBtMessageValidator>
(m, (static_cast<BtRequestMessage*>(m.get()),
downloadContext_->getNumPieces(), downloadContext_->getNumPieces(),
pieceStorage_->getPieceLength(m->getIndex()))); pieceStorage_->getPieceLength(m->getIndex())));
} }
msg.reset(m); msg = std::move(m);
break;
}
case BtCancelMessage::ID: {
BtCancelMessage* m = BtCancelMessage::create(data, dataLength);
if(!metadataGetMode_) {
m->setBtMessageValidator
(make_unique<RangeBtMessageValidator>
(m,
downloadContext_->getNumPieces(),
pieceStorage_->getPieceLength(m->getIndex())));
}
msg.reset(m);
break; break;
} }
case BtPieceMessage::ID: { case BtPieceMessage::ID: {
BtPieceMessage* m = BtPieceMessage::create(data, dataLength); auto m = BtPieceMessage::create(data, dataLength);
if(!metadataGetMode_) { if(!metadataGetMode_) {
m->setBtMessageValidator m->setBtMessageValidator
(make_unique<BtPieceMessageValidator> (make_unique<BtPieceMessageValidator>
(m, (static_cast<BtPieceMessage*>(m.get()),
downloadContext_->getNumPieces(), downloadContext_->getNumPieces(),
pieceStorage_->getPieceLength(m->getIndex()))); pieceStorage_->getPieceLength(m->getIndex())));
} }
m->setDownloadContext(downloadContext_); m->setDownloadContext(downloadContext_);
m->setPeerStorage(peerStorage_); m->setPeerStorage(peerStorage_);
msg.reset(m); msg = std::move(m);
break; break;
} }
case BtHaveAllMessage::ID: case BtCancelMessage::ID: {
msg.reset(BtHaveAllMessage::create(data, dataLength)); auto m = BtCancelMessage::create(data, dataLength);
break;
case BtHaveNoneMessage::ID:
msg.reset(BtHaveNoneMessage::create(data, dataLength));
break;
case BtRejectMessage::ID: {
BtRejectMessage* m = BtRejectMessage::create(data, dataLength);
if(!metadataGetMode_) { if(!metadataGetMode_) {
m->setBtMessageValidator m->setBtMessageValidator
(make_unique<RangeBtMessageValidator> (make_unique<RangeBtMessageValidator>
(m, (static_cast<BtCancelMessage*>(m.get()),
downloadContext_->getNumPieces(), downloadContext_->getNumPieces(),
pieceStorage_->getPieceLength(m->getIndex()))); pieceStorage_->getPieceLength(m->getIndex())));
} }
msg.reset(m); msg = std::move(m);
break;
}
case BtSuggestPieceMessage::ID: {
BtSuggestPieceMessage* m =
BtSuggestPieceMessage::create(data, dataLength);
if(!metadataGetMode_) {
m->setBtMessageValidator(make_unique<IndexBtMessageValidator>
(m, downloadContext_->getNumPieces()));
}
msg.reset(m);
break;
}
case BtAllowedFastMessage::ID: {
BtAllowedFastMessage* m = BtAllowedFastMessage::create(data, dataLength);
if(!metadataGetMode_) {
std::shared_ptr<BtMessageValidator> validator
(new IndexBtMessageValidator(m, downloadContext_->getNumPieces()));
m->setBtMessageValidator(make_unique<IndexBtMessageValidator>
(m, downloadContext_->getNumPieces()));
}
msg.reset(m);
break; break;
} }
case BtPortMessage::ID: { case BtPortMessage::ID: {
BtPortMessage* m = BtPortMessage::create(data, dataLength); auto m = BtPortMessage::create(data, dataLength);
m->setLocalNode(localNode_); m->setLocalNode(localNode_);
m->setRoutingTable(routingTable_); m->setRoutingTable(routingTable_);
m->setTaskQueue(taskQueue_); m->setTaskQueue(taskQueue_);
m->setTaskFactory(taskFactory_); m->setTaskFactory(taskFactory_);
msg.reset(m); msg = std::move(m);
break;
}
case BtSuggestPieceMessage::ID: {
auto m = BtSuggestPieceMessage::create(data, dataLength);
if(!metadataGetMode_) {
m->setBtMessageValidator(make_unique<IndexBtMessageValidator>
(static_cast<BtSuggestPieceMessage*>(m.get()),
downloadContext_->getNumPieces()));
}
msg = std::move(m);
break;
}
case BtHaveAllMessage::ID:
msg = BtHaveAllMessage::create(data, dataLength);
break;
case BtHaveNoneMessage::ID:
msg = BtHaveNoneMessage::create(data, dataLength);
break;
case BtRejectMessage::ID: {
auto m = BtRejectMessage::create(data, dataLength);
if(!metadataGetMode_) {
m->setBtMessageValidator
(make_unique<RangeBtMessageValidator>
(static_cast<BtRejectMessage*>(m.get()),
downloadContext_->getNumPieces(),
pieceStorage_->getPieceLength(m->getIndex())));
}
msg = std::move(m);
break;
}
case BtAllowedFastMessage::ID: {
auto m = BtAllowedFastMessage::create(data, dataLength);
if(!metadataGetMode_) {
m->setBtMessageValidator(make_unique<IndexBtMessageValidator>
(static_cast<BtAllowedFastMessage*>(m.get()),
downloadContext_->getNumPieces()));
}
msg = std::move(m);
break; break;
} }
case BtExtendedMessage::ID: { case BtExtendedMessage::ID: {
if(peer_->isExtendedMessagingEnabled()) { if(peer_->isExtendedMessagingEnabled()) {
msg.reset(BtExtendedMessage::create(extensionMessageFactory_, msg = BtExtendedMessage::create(extensionMessageFactory_,
peer_, data, dataLength)); peer_, data, dataLength);
} else { } else {
throw DL_ABORT_EX("Received extended message from peer during" throw DL_ABORT_EX("Received extended message from peer during"
" a session with extended messaging disabled."); " a session with extended messaging disabled.");
@ -237,7 +231,7 @@ DefaultBtMessageFactory::createBtMessage
} }
} }
setCommonProperty(msg.get()); setCommonProperty(msg.get());
return msg; return std::move(msg);
} }
void DefaultBtMessageFactory::setCommonProperty(AbstractBtMessage* msg) void DefaultBtMessageFactory::setCommonProperty(AbstractBtMessage* msg)
@ -254,12 +248,11 @@ void DefaultBtMessageFactory::setCommonProperty(AbstractBtMessage* msg)
} }
} }
std::shared_ptr<BtHandshakeMessage> std::unique_ptr<BtHandshakeMessage>
DefaultBtMessageFactory::createHandshakeMessage DefaultBtMessageFactory::createHandshakeMessage
(const unsigned char* data, size_t dataLength) (const unsigned char* data, size_t dataLength)
{ {
std::shared_ptr<BtHandshakeMessage> msg = auto msg = BtHandshakeMessage::create(data, dataLength);
BtHandshakeMessage::create(data, dataLength);
msg->setBtMessageValidator(make_unique<BtHandshakeMessageValidator> msg->setBtMessageValidator(make_unique<BtHandshakeMessageValidator>
(msg.get(), (msg.get(),
bittorrent::getInfoHash(downloadContext_))); bittorrent::getInfoHash(downloadContext_)));
@ -267,155 +260,153 @@ DefaultBtMessageFactory::createHandshakeMessage
return msg; return msg;
} }
std::shared_ptr<BtHandshakeMessage> std::unique_ptr<BtHandshakeMessage>
DefaultBtMessageFactory::createHandshakeMessage(const unsigned char* infoHash, DefaultBtMessageFactory::createHandshakeMessage(const unsigned char* infoHash,
const unsigned char* peerId) const unsigned char* peerId)
{ {
std::shared_ptr<BtHandshakeMessage> msg auto msg = make_unique<BtHandshakeMessage>(infoHash, peerId);
(new BtHandshakeMessage(infoHash, peerId));
msg->setDHTEnabled(dhtEnabled_); msg->setDHTEnabled(dhtEnabled_);
setCommonProperty(msg.get()); setCommonProperty(msg.get());
return msg; return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtRequestMessage>
DefaultBtMessageFactory::createRequestMessage DefaultBtMessageFactory::createRequestMessage
(const std::shared_ptr<Piece>& piece, size_t blockIndex) (const std::shared_ptr<Piece>& piece, size_t blockIndex)
{ {
BtRequestMessage* msg auto msg = make_unique<BtRequestMessage>(piece->getIndex(),
(new BtRequestMessage(piece->getIndex(),
blockIndex*piece->getBlockLength(), blockIndex*piece->getBlockLength(),
piece->getBlockLength(blockIndex), piece->getBlockLength(blockIndex),
blockIndex)); blockIndex);
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtCancelMessage>
DefaultBtMessageFactory::createCancelMessage DefaultBtMessageFactory::createCancelMessage
(size_t index, int32_t begin, int32_t length) (size_t index, int32_t begin, int32_t length)
{ {
BtCancelMessage* msg(new BtCancelMessage(index, begin, length)); auto msg = make_unique<BtCancelMessage>(index, begin, length);
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtPieceMessage>
DefaultBtMessageFactory::createPieceMessage DefaultBtMessageFactory::createPieceMessage
(size_t index, int32_t begin, int32_t length) (size_t index, int32_t begin, int32_t length)
{ {
BtPieceMessage* msg(new BtPieceMessage(index, begin, length)); auto msg = make_unique<BtPieceMessage>(index, begin, length);
msg->setDownloadContext(downloadContext_); msg->setDownloadContext(downloadContext_);
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtHaveMessage>
DefaultBtMessageFactory::createHaveMessage(size_t index) DefaultBtMessageFactory::createHaveMessage(size_t index)
{ {
BtHaveMessage* msg(new BtHaveMessage(index)); auto msg = make_unique<BtHaveMessage>(index);
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtChokeMessage>
DefaultBtMessageFactory::createChokeMessage() DefaultBtMessageFactory::createChokeMessage()
{ {
BtChokeMessage* msg(new BtChokeMessage()); auto msg = make_unique<BtChokeMessage>();
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtUnchokeMessage>
DefaultBtMessageFactory::createUnchokeMessage() DefaultBtMessageFactory::createUnchokeMessage()
{ {
BtUnchokeMessage* msg(new BtUnchokeMessage()); auto msg = make_unique<BtUnchokeMessage>();
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtInterestedMessage>
DefaultBtMessageFactory::createInterestedMessage() DefaultBtMessageFactory::createInterestedMessage()
{ {
BtInterestedMessage* msg(new BtInterestedMessage()); auto msg = make_unique<BtInterestedMessage>();
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtNotInterestedMessage>
DefaultBtMessageFactory::createNotInterestedMessage() DefaultBtMessageFactory::createNotInterestedMessage()
{ {
BtNotInterestedMessage* msg(new BtNotInterestedMessage()); auto msg = make_unique<BtNotInterestedMessage>();
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtBitfieldMessage>
DefaultBtMessageFactory::createBitfieldMessage() DefaultBtMessageFactory::createBitfieldMessage()
{ {
BtBitfieldMessage* msg auto msg = make_unique<BtBitfieldMessage>
(new BtBitfieldMessage(pieceStorage_->getBitfield(), (pieceStorage_->getBitfield(),
pieceStorage_->getBitfieldLength())); pieceStorage_->getBitfieldLength());
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtKeepAliveMessage>
DefaultBtMessageFactory::createKeepAliveMessage() DefaultBtMessageFactory::createKeepAliveMessage()
{ {
BtKeepAliveMessage* msg(new BtKeepAliveMessage()); auto msg = make_unique<BtKeepAliveMessage>();
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtHaveAllMessage>
DefaultBtMessageFactory::createHaveAllMessage() DefaultBtMessageFactory::createHaveAllMessage()
{ {
BtHaveAllMessage* msg(new BtHaveAllMessage()); auto msg = make_unique<BtHaveAllMessage>();
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtHaveNoneMessage>
DefaultBtMessageFactory::createHaveNoneMessage() DefaultBtMessageFactory::createHaveNoneMessage()
{ {
BtHaveNoneMessage* msg(new BtHaveNoneMessage()); auto msg = make_unique<BtHaveNoneMessage>();
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtRejectMessage>
DefaultBtMessageFactory::createRejectMessage DefaultBtMessageFactory::createRejectMessage
(size_t index, int32_t begin, int32_t length) (size_t index, int32_t begin, int32_t length)
{ {
BtRejectMessage* msg(new BtRejectMessage(index, begin, length)); auto msg = make_unique<BtRejectMessage>(index, begin, length);
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtAllowedFastMessage>
DefaultBtMessageFactory::createAllowedFastMessage(size_t index) DefaultBtMessageFactory::createAllowedFastMessage(size_t index)
{ {
BtAllowedFastMessage* msg(new BtAllowedFastMessage(index)); auto msg = make_unique<BtAllowedFastMessage>(index);
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtPortMessage>
DefaultBtMessageFactory::createPortMessage(uint16_t port) DefaultBtMessageFactory::createPortMessage(uint16_t port)
{ {
BtPortMessage* msg(new BtPortMessage(port)); auto msg = make_unique<BtPortMessage>(port);
setCommonProperty(msg); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(msg); return msg;
} }
std::shared_ptr<BtMessage> std::unique_ptr<BtExtendedMessage>
DefaultBtMessageFactory::createBtExtendedMessage DefaultBtMessageFactory::createBtExtendedMessage
(const std::shared_ptr<ExtensionMessage>& msg) (const std::shared_ptr<ExtensionMessage>& exmsg)
{ {
BtExtendedMessage* m(new BtExtendedMessage(msg)); auto msg = make_unique<BtExtendedMessage>(exmsg);
setCommonProperty(m); setCommonProperty(msg.get());
return std::shared_ptr<BtMessage>(m); return msg;
} }
void DefaultBtMessageFactory::setTaskQueue(DHTTaskQueue* taskQueue) void DefaultBtMessageFactory::setTaskQueue(DHTTaskQueue* taskQueue)

View File

@ -86,53 +86,53 @@ private:
public: public:
DefaultBtMessageFactory(); DefaultBtMessageFactory();
virtual ~DefaultBtMessageFactory(); virtual std::unique_ptr<BtMessage>
virtual std::shared_ptr<BtMessage>
createBtMessage(const unsigned char* msg, size_t msgLength); createBtMessage(const unsigned char* msg, size_t msgLength);
virtual std::shared_ptr<BtHandshakeMessage> virtual std::unique_ptr<BtHandshakeMessage>
createHandshakeMessage(const unsigned char* msg, size_t msgLength); createHandshakeMessage(const unsigned char* msg, size_t msgLength);
virtual std::shared_ptr<BtHandshakeMessage> virtual std::unique_ptr<BtHandshakeMessage>
createHandshakeMessage(const unsigned char* infoHash, createHandshakeMessage(const unsigned char* infoHash,
const unsigned char* peerId); const unsigned char* peerId);
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtRequestMessage>
createRequestMessage(const std::shared_ptr<Piece>& piece, size_t blockIndex); createRequestMessage(const std::shared_ptr<Piece>& piece,
size_t blockIndex);
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtCancelMessage>
createCancelMessage(size_t index, int32_t begin, int32_t length); createCancelMessage(size_t index, int32_t begin, int32_t length);
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtPieceMessage>
createPieceMessage(size_t index, int32_t begin, int32_t length); createPieceMessage(size_t index, int32_t begin, int32_t length);
virtual std::shared_ptr<BtMessage> createHaveMessage(size_t index); virtual std::unique_ptr<BtHaveMessage> createHaveMessage(size_t index);
virtual std::shared_ptr<BtMessage> createChokeMessage(); virtual std::unique_ptr<BtChokeMessage> createChokeMessage();
virtual std::shared_ptr<BtMessage> createUnchokeMessage(); virtual std::unique_ptr<BtUnchokeMessage> createUnchokeMessage();
virtual std::shared_ptr<BtMessage> createInterestedMessage(); virtual std::unique_ptr<BtInterestedMessage> createInterestedMessage();
virtual std::shared_ptr<BtMessage> createNotInterestedMessage(); virtual std::unique_ptr<BtNotInterestedMessage> createNotInterestedMessage();
virtual std::shared_ptr<BtMessage> createBitfieldMessage(); virtual std::unique_ptr<BtBitfieldMessage> createBitfieldMessage();
virtual std::shared_ptr<BtMessage> createKeepAliveMessage(); virtual std::unique_ptr<BtKeepAliveMessage> createKeepAliveMessage();
virtual std::shared_ptr<BtMessage> createHaveAllMessage(); virtual std::unique_ptr<BtHaveAllMessage> createHaveAllMessage();
virtual std::shared_ptr<BtMessage> createHaveNoneMessage(); virtual std::unique_ptr<BtHaveNoneMessage> createHaveNoneMessage();
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtRejectMessage>
createRejectMessage(size_t index, int32_t begin, int32_t length); createRejectMessage(size_t index, int32_t begin, int32_t length);
virtual std::shared_ptr<BtMessage> createAllowedFastMessage(size_t index); virtual std::unique_ptr<BtAllowedFastMessage> createAllowedFastMessage
(size_t index);
virtual std::shared_ptr<BtMessage> createPortMessage(uint16_t port); virtual std::unique_ptr<BtPortMessage> createPortMessage(uint16_t port);
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtExtendedMessage>
createBtExtendedMessage(const std::shared_ptr<ExtensionMessage>& msg); createBtExtendedMessage(const std::shared_ptr<ExtensionMessage>& msg);
void setPeer(const std::shared_ptr<Peer>& peer); void setPeer(const std::shared_ptr<Peer>& peer);

View File

@ -108,10 +108,10 @@ DefaultBtMessageReceiver::receiveAndSendHandshake()
} }
void DefaultBtMessageReceiver::sendHandshake() { void DefaultBtMessageReceiver::sendHandshake() {
std::shared_ptr<BtMessage> msg = dispatcher_->addMessageToQueue
messageFactory_->createHandshakeMessage (messageFactory_->createHandshakeMessage
(bittorrent::getInfoHash(downloadContext_), bittorrent::getStaticPeerId()); (bittorrent::getInfoHash(downloadContext_),
dispatcher_->addMessageToQueue(msg); bittorrent::getStaticPeerId()));
dispatcher_->sendMessages(); dispatcher_->sendMessages();
} }

View File

@ -48,6 +48,7 @@
#include "SimpleRandomizer.h" #include "SimpleRandomizer.h"
#include "array_fun.h" #include "array_fun.h"
#include "fmt.h" #include "fmt.h"
#include "BtRequestMessage.h"
namespace aria2 { namespace aria2 {
@ -157,48 +158,50 @@ void DefaultBtRequestFactory::removeAllTargetPiece() {
pieces_.clear(); pieces_.clear();
} }
void DefaultBtRequestFactory::createRequestMessages std::vector<std::unique_ptr<BtRequestMessage>>
(std::vector<std::shared_ptr<BtMessage> >& requests, size_t max) DefaultBtRequestFactory::createRequestMessages(size_t max, bool endGame)
{ {
if(requests.size() >= max) { if(endGame) {
return; return createRequestMessagesOnEndGame(max);
} }
auto requests = std::vector<std::unique_ptr<BtRequestMessage>>{};
size_t getnum = max-requests.size(); size_t getnum = max-requests.size();
std::vector<size_t> blockIndexes; auto blockIndexes = std::vector<size_t>{};
blockIndexes.reserve(getnum); blockIndexes.reserve(getnum);
for(std::deque<std::shared_ptr<Piece> >::iterator itr = pieces_.begin(), for(auto itr = std::begin(pieces_), eoi = std::end(pieces_);
eoi = pieces_.end(); itr != eoi && getnum; ++itr) { itr != eoi && getnum; ++itr) {
std::shared_ptr<Piece>& piece = *itr; auto& piece = *itr;
if(piece->getMissingUnusedBlockIndex(blockIndexes, getnum)) { if(piece->getMissingUnusedBlockIndex(blockIndexes, getnum)) {
getnum -= blockIndexes.size(); getnum -= blockIndexes.size();
for(std::vector<size_t>::const_iterator i = blockIndexes.begin(), for(auto i = std::begin(blockIndexes), eoi2 = std::end(blockIndexes);
eoi2 = blockIndexes.end(); i != eoi2; ++i) { i != eoi2; ++i) {
A2_LOG_DEBUG A2_LOG_DEBUG
(fmt("Creating RequestMessage index=%lu, begin=%u," (fmt("Creating RequestMessage index=%lu, begin=%u,"
" blockIndex=%lu", " blockIndex=%lu",
static_cast<unsigned long>(piece->getIndex()), static_cast<unsigned long>(piece->getIndex()),
static_cast<unsigned int>((*i)*piece->getBlockLength()), static_cast<unsigned int>((*i)*piece->getBlockLength()),
static_cast<unsigned long>(*i))); static_cast<unsigned long>(*i)));
requests.push_back requests.push_back(messageFactory_->createRequestMessage(piece, *i));
(messageFactory_->createRequestMessage(piece, *i));
} }
blockIndexes.clear(); blockIndexes.clear();
} }
} }
return requests;
} }
void DefaultBtRequestFactory::createRequestMessagesOnEndGame std::vector<std::unique_ptr<BtRequestMessage>>
(std::vector<std::shared_ptr<BtMessage> >& requests, size_t max) DefaultBtRequestFactory::createRequestMessagesOnEndGame(size_t max)
{ {
for(std::deque<std::shared_ptr<Piece> >::iterator itr = pieces_.begin(), auto requests = std::vector<std::unique_ptr<BtRequestMessage>>{};
eoi = pieces_.end(); itr != eoi && requests.size() < max; ++itr) { for(auto itr = std::begin(pieces_), eoi = std::end(pieces_);
std::shared_ptr<Piece>& piece = *itr; itr != eoi && requests.size() < max; ++itr) {
auto& piece = *itr;
const size_t mislen = piece->getBitfieldLength(); const size_t mislen = piece->getBitfieldLength();
array_ptr<unsigned char> misbitfield(new unsigned char[mislen]); array_ptr<unsigned char> misbitfield(new unsigned char[mislen]);
piece->getAllMissingBlockIndexes(misbitfield, mislen); piece->getAllMissingBlockIndexes(misbitfield, mislen);
std::vector<size_t> missingBlockIndexes; auto missingBlockIndexes = std::vector<size_t>{};
size_t blockIndex = 0; size_t blockIndex = 0;
for(size_t i = 0; i < mislen; ++i) { for(size_t i = 0; i < mislen; ++i) {
unsigned char bits = misbitfield[i]; unsigned char bits = misbitfield[i];
@ -209,12 +212,13 @@ void DefaultBtRequestFactory::createRequestMessagesOnEndGame
} }
} }
} }
std::random_shuffle(missingBlockIndexes.begin(), missingBlockIndexes.end(), std::random_shuffle(std::begin(missingBlockIndexes),
std::end(missingBlockIndexes),
*SimpleRandomizer::getInstance()); *SimpleRandomizer::getInstance());
for(std::vector<size_t>::const_iterator bitr = missingBlockIndexes.begin(), for(auto bitr = std::begin(missingBlockIndexes),
eoi2 = missingBlockIndexes.end(); eoi2 = std::end(missingBlockIndexes);
bitr != eoi2 && requests.size() < max; ++bitr) { bitr != eoi2 && requests.size() < max; ++bitr) {
const size_t& blockIndex = *bitr; size_t blockIndex = *bitr;
if(!dispatcher_->isOutstandingRequest(piece->getIndex(), if(!dispatcher_->isOutstandingRequest(piece->getIndex(),
blockIndex)) { blockIndex)) {
A2_LOG_DEBUG A2_LOG_DEBUG
@ -228,6 +232,7 @@ void DefaultBtRequestFactory::createRequestMessagesOnEndGame
} }
} }
} }
return requests;
} }
namespace { namespace {
@ -256,11 +261,14 @@ size_t DefaultBtRequestFactory::countMissingBlock()
CountMissingBlock()).getNumMissingBlock(); CountMissingBlock()).getNumMissingBlock();
} }
void DefaultBtRequestFactory::getTargetPieceIndexes std::vector<size_t>
(std::vector<size_t>& indexes) const DefaultBtRequestFactory::getTargetPieceIndexes() const
{ {
std::transform(pieces_.begin(), pieces_.end(), std::back_inserter(indexes), auto res = std::vector<size_t>{};
std::mem_fn(&Piece::getIndex)); res.reserve(pieces_.size());
std::transform(std::begin(pieces_), std::end(pieces_),
std::back_inserter(res), std::mem_fn(&Piece::getIndex));
return res;
} }
void DefaultBtRequestFactory::setPieceStorage(PieceStorage* pieceStorage) void DefaultBtRequestFactory::setPieceStorage(PieceStorage* pieceStorage)

View File

@ -51,6 +51,9 @@ class Piece;
class DefaultBtRequestFactory : public BtRequestFactory { class DefaultBtRequestFactory : public BtRequestFactory {
private: private:
std::vector<std::unique_ptr<BtRequestMessage>> createRequestMessagesOnEndGame
(size_t max);
PieceStorage* pieceStorage_; PieceStorage* pieceStorage_;
std::shared_ptr<Peer> peer_; std::shared_ptr<Peer> peer_;
BtMessageDispatcher* dispatcher_; BtMessageDispatcher* dispatcher_;
@ -78,13 +81,10 @@ public:
virtual void doChokedAction(); virtual void doChokedAction();
virtual void createRequestMessages virtual std::vector<std::unique_ptr<BtRequestMessage>> createRequestMessages
(std::vector<std::shared_ptr<BtMessage> >& requests, size_t max); (size_t max, bool endGame);
virtual void createRequestMessagesOnEndGame virtual std::vector<size_t> getTargetPieceIndexes() const;
(std::vector<std::shared_ptr<BtMessage> >& requests, size_t max);
virtual void getTargetPieceIndexes(std::vector<size_t>& indexes) const;
std::deque<std::shared_ptr<Piece> >& getTargetPieces() std::deque<std::shared_ptr<Piece> >& getTargetPieces()
{ {

View File

@ -46,7 +46,7 @@ public:
virtual std::string getPayload() = 0; virtual std::string getPayload() = 0;
virtual uint8_t getExtensionMessageID() = 0; virtual uint8_t getExtensionMessageID() const = 0;
virtual const char* getExtensionName() const = 0; virtual const char* getExtensionName() const = 0;

View File

@ -67,7 +67,7 @@ public:
virtual std::string getPayload(); virtual std::string getPayload();
virtual uint8_t getExtensionMessageID() virtual uint8_t getExtensionMessageID() const
{ {
return 0; return 0;
} }

View File

@ -47,13 +47,12 @@ private:
static const size_t MESSAGE_LENGTH = 9; static const size_t MESSAGE_LENGTH = 9;
protected: protected:
template<typename T> template<typename T>
static T* create(const unsigned char* data, size_t dataLength) static std::unique_ptr<T> create(const unsigned char* data,
size_t dataLength)
{ {
bittorrent::assertPayloadLengthEqual(5, dataLength, T::NAME); bittorrent::assertPayloadLengthEqual(5, dataLength, T::NAME);
bittorrent::assertID(T::ID, data, T::NAME); bittorrent::assertID(T::ID, data, T::NAME);
T* message(new T()); return make_unique<T>(bittorrent::getIntParam(data, 1));
message->setIndex(bittorrent::getIntParam(data, 1));
return message;
} }
public: public:
IndexBtMessage(uint8_t id, const char* name, size_t index) IndexBtMessage(uint8_t id, const char* name, size_t index)

View File

@ -49,15 +49,14 @@ private:
static const size_t MESSAGE_LENGTH = 17; static const size_t MESSAGE_LENGTH = 17;
protected: protected:
template<typename T> template<typename T>
static T* create(const unsigned char* data, size_t dataLength) static std::unique_ptr<T> create(const unsigned char* data,
size_t dataLength)
{ {
bittorrent::assertPayloadLengthEqual(13, dataLength, T::NAME); bittorrent::assertPayloadLengthEqual(13, dataLength, T::NAME);
bittorrent::assertID(T::ID, data, T::NAME); bittorrent::assertID(T::ID, data, T::NAME);
T* message(new T()); return make_unique<T>(bittorrent::getIntParam(data, 1),
message->setIndex(bittorrent::getIntParam(data, 1)); bittorrent::getIntParam(data, 5),
message->setBegin(bittorrent::getIntParam(data, 5)); bittorrent::getIntParam(data, 9));
message->setLength(bittorrent::getIntParam(data, 9));
return message;
} }
public: public:
RangeBtMessage(uint8_t id, const char* name, RangeBtMessage(uint8_t id, const char* name,

View File

@ -47,7 +47,7 @@ private:
public: public:
UTMetadataExtensionMessage(uint8_t extensionMessageID); UTMetadataExtensionMessage(uint8_t extensionMessageID);
virtual uint8_t getExtensionMessageID() virtual uint8_t getExtensionMessageID() const
{ {
return extensionMessageID_; return extensionMessageID_;
} }

View File

@ -49,6 +49,7 @@
#include "BtMessage.h" #include "BtMessage.h"
#include "PieceStorage.h" #include "PieceStorage.h"
#include "ExtensionMessageRegistry.h" #include "ExtensionMessageRegistry.h"
#include "BtExtendedMessage.h"
namespace aria2 { namespace aria2 {
@ -83,8 +84,8 @@ void UTMetadataRequestExtensionMessage::doReceivedAction()
std::shared_ptr<UTMetadataRejectExtensionMessage> m std::shared_ptr<UTMetadataRejectExtensionMessage> m
(new UTMetadataRejectExtensionMessage(id)); (new UTMetadataRejectExtensionMessage(id));
m->setIndex(getIndex()); m->setIndex(getIndex());
std::shared_ptr<BtMessage> msg = messageFactory_->createBtExtendedMessage(m); dispatcher_->addMessageToQueue
dispatcher_->addMessageToQueue(msg); (messageFactory_->createBtExtendedMessage(m));
}else if(getIndex()*METADATA_PIECE_SIZE < attrs->metadataSize) { }else if(getIndex()*METADATA_PIECE_SIZE < attrs->metadataSize) {
std::shared_ptr<UTMetadataDataExtensionMessage> m std::shared_ptr<UTMetadataDataExtensionMessage> m
(new UTMetadataDataExtensionMessage(id)); (new UTMetadataDataExtensionMessage(id));
@ -97,8 +98,8 @@ void UTMetadataRequestExtensionMessage::doReceivedAction()
attrs->metadata.begin()+(getIndex()+1)*METADATA_PIECE_SIZE: attrs->metadata.begin()+(getIndex()+1)*METADATA_PIECE_SIZE:
attrs->metadata.end(); attrs->metadata.end();
m->setData(begin, end); m->setData(begin, end);
std::shared_ptr<BtMessage> msg = messageFactory_->createBtExtendedMessage(m); dispatcher_->addMessageToQueue
dispatcher_->addMessageToQueue(msg); (messageFactory_->createBtExtendedMessage(m));
} else { } else {
throw DL_ABORT_EX throw DL_ABORT_EX
(fmt("Metadata piece index is too big. piece=%lu", (fmt("Metadata piece index is too big. piece=%lu",

View File

@ -45,24 +45,24 @@
#include "LogFactory.h" #include "LogFactory.h"
#include "fmt.h" #include "fmt.h"
#include "ExtensionMessageRegistry.h" #include "ExtensionMessageRegistry.h"
#include "BtExtendedMessage.h"
namespace aria2 { namespace aria2 {
UTMetadataRequestFactory::UTMetadataRequestFactory() UTMetadataRequestFactory::UTMetadataRequestFactory()
: dispatcher_(0), : dispatcher_{nullptr},
messageFactory_(0), messageFactory_{nullptr},
tracker_(0), tracker_{nullptr},
cuid_(0) cuid_(0)
{} {}
void UTMetadataRequestFactory::create std::vector<std::unique_ptr<BtMessage>> UTMetadataRequestFactory::create
(std::vector<std::shared_ptr<BtMessage> >& msgs, size_t num, (size_t num, const std::shared_ptr<PieceStorage>& pieceStorage)
const std::shared_ptr<PieceStorage>& pieceStorage)
{ {
auto msgs = std::vector<std::unique_ptr<BtMessage>>{};
while(num) { while(num) {
std::vector<size_t> metadataRequests = tracker_->getAllTrackedIndex(); auto metadataRequests = tracker_->getAllTrackedIndex();
std::shared_ptr<Piece> p = auto p = pieceStorage->getMissingPiece(peer_, metadataRequests, cuid_);
pieceStorage->getMissingPiece(peer_, metadataRequests, cuid_);
if(!p) { if(!p) {
A2_LOG_DEBUG("No ut_metadata piece is available to download."); A2_LOG_DEBUG("No ut_metadata piece is available to download.");
break; break;
@ -79,10 +79,10 @@ void UTMetadataRequestFactory::create
m->setBtMessageFactory(messageFactory_); m->setBtMessageFactory(messageFactory_);
m->setPeer(peer_); m->setPeer(peer_);
std::shared_ptr<BtMessage> msg = messageFactory_->createBtExtendedMessage(m); msgs.push_back(messageFactory_->createBtExtendedMessage(m));
msgs.push_back(msg);
tracker_->add(p->getIndex()); tracker_->add(p->getIndex());
} }
return msgs;
} }
} // namespace aria2 } // namespace aria2

View File

@ -67,10 +67,10 @@ private:
public: public:
UTMetadataRequestFactory(); UTMetadataRequestFactory();
// Creates at most num of ut_metadata request message and appends // Creates and returns at most num of ut_metadata request
// them to msgs. pieceStorage is used to identify missing piece. // message. pieceStorage is used to identify missing piece.
void create(std::vector<std::shared_ptr<BtMessage> >& msgs, size_t num, std::vector<std::unique_ptr<BtMessage>> create
const std::shared_ptr<PieceStorage>& pieceStorage); (size_t num, const std::shared_ptr<PieceStorage>& pieceStorage);
void setDownloadContext(const std::shared_ptr<DownloadContext>& dctx) void setDownloadContext(const std::shared_ptr<DownloadContext>& dctx)
{ {

View File

@ -75,7 +75,7 @@ public:
virtual std::string getPayload(); virtual std::string getPayload();
virtual uint8_t getExtensionMessageID() virtual uint8_t getExtensionMessageID() const
{ {
return extensionMessageID_; return extensionMessageID_;
} }

View File

@ -37,6 +37,10 @@
namespace aria2 { namespace aria2 {
ZeroBtMessage::ZeroBtMessage(uint8_t id, const char* name)
: SimpleBtMessage{id, name}
{}
unsigned char* ZeroBtMessage::createMessage() unsigned char* ZeroBtMessage::createMessage()
{ {
/** /**

View File

@ -45,17 +45,16 @@ private:
static const size_t MESSAGE_LENGTH = 5; static const size_t MESSAGE_LENGTH = 5;
protected: protected:
template<typename T> template<typename T>
static T* create(const unsigned char* data, size_t dataLength) static std::unique_ptr<T> create(const unsigned char* data,
size_t dataLength)
{ {
bittorrent::assertPayloadLengthEqual(1, dataLength, T::NAME); bittorrent::assertPayloadLengthEqual(1, dataLength, T::NAME);
bittorrent::assertID(T::ID, data, T::NAME); bittorrent::assertID(T::ID, data, T::NAME);
T* message(new T()); return make_unique<T>();
return message;
} }
public: public:
ZeroBtMessage(uint8_t id, const char* name): ZeroBtMessage(uint8_t id, const char* name);
SimpleBtMessage(id, name) {}
virtual unsigned char* createMessage(); virtual unsigned char* createMessage();

View File

@ -41,7 +41,7 @@ public:
length(0) {} length(0) {}
virtual void doCancelSendingPieceAction virtual void doCancelSendingPieceAction
(size_t index, int32_t begin, int32_t length) { (size_t index, int32_t begin, int32_t length) override {
this->index = index; this->index = index;
this->begin = begin; this->begin = begin;
this->length = length; this->length = length;
@ -58,7 +58,7 @@ void BtCancelMessageTest::testCreate() {
bittorrent::setIntParam(&msg[5], 12345); bittorrent::setIntParam(&msg[5], 12345);
bittorrent::setIntParam(&msg[9], 256); bittorrent::setIntParam(&msg[9], 256);
bittorrent::setIntParam(&msg[13], 1024); bittorrent::setIntParam(&msg[13], 1024);
std::shared_ptr<BtCancelMessage> pm(BtCancelMessage::create(&msg[4], 13)); auto pm = BtCancelMessage::create(&msg[4], 13);
CPPUNIT_ASSERT_EQUAL((uint8_t)8, pm->getId()); CPPUNIT_ASSERT_EQUAL((uint8_t)8, pm->getId());
CPPUNIT_ASSERT_EQUAL((size_t)12345, pm->getIndex()); CPPUNIT_ASSERT_EQUAL((size_t)12345, pm->getIndex());
CPPUNIT_ASSERT_EQUAL(256, pm->getBegin()); CPPUNIT_ASSERT_EQUAL(256, pm->getBegin());
@ -103,8 +103,7 @@ void BtCancelMessageTest::testDoReceivedAction() {
msg.setBegin(2*16*1024); msg.setBegin(2*16*1024);
msg.setLength(16*1024); msg.setLength(16*1024);
msg.setPeer(peer); msg.setPeer(peer);
std::shared_ptr<MockBtMessageDispatcher2> dispatcher auto dispatcher = make_unique<MockBtMessageDispatcher2>();
(new MockBtMessageDispatcher2());
msg.setBtMessageDispatcher(dispatcher.get()); msg.setBtMessageDispatcher(dispatcher.get());
msg.doReceivedAction(); msg.doReceivedAction();

View File

@ -43,13 +43,18 @@ public:
bool doChokedActionCalled; bool doChokedActionCalled;
bool doChokingActionCalled; bool doChokingActionCalled;
public: public:
MockBtMessageDispatcher2():doChokedActionCalled(false), doChokingActionCalled(false) {} MockBtMessageDispatcher2()
: doChokedActionCalled{false},
doChokingActionCalled{false}
{}
virtual void doChokedAction() { virtual void doChokedAction() override
{
doChokedActionCalled = true; doChokedActionCalled = true;
} }
virtual void doChokingAction() { virtual void doChokingAction() override
{
doChokingActionCalled = true; doChokingActionCalled = true;
} }
}; };
@ -58,9 +63,10 @@ public:
public: public:
bool doChokedActionCalled; bool doChokedActionCalled;
public: public:
MockBtRequestFactory2():doChokedActionCalled(false) {} MockBtRequestFactory2():doChokedActionCalled{false} {}
virtual void doChokedAction() { virtual void doChokedAction() override
{
doChokedActionCalled = true; doChokedActionCalled = true;
} }
}; };
@ -72,7 +78,7 @@ CPPUNIT_TEST_SUITE_REGISTRATION(BtChokeMessageTest);
void BtChokeMessageTest::testCreate() { void BtChokeMessageTest::testCreate() {
unsigned char msg[5]; unsigned char msg[5];
bittorrent::createPeerMessageString(msg, sizeof(msg), 1, 0); bittorrent::createPeerMessageString(msg, sizeof(msg), 1, 0);
std::shared_ptr<BtChokeMessage> pm(BtChokeMessage::create(&msg[4], 1)); auto pm = BtChokeMessage::create(&msg[4], 1);
CPPUNIT_ASSERT_EQUAL((uint8_t)0, pm->getId()); CPPUNIT_ASSERT_EQUAL((uint8_t)0, pm->getId());
// case: payload size is wrong // case: payload size is wrong
@ -106,9 +112,9 @@ void BtChokeMessageTest::testDoReceivedAction() {
BtChokeMessage msg; BtChokeMessage msg;
msg.setPeer(peer); msg.setPeer(peer);
std::shared_ptr<MockBtMessageDispatcher2> dispatcher(new MockBtMessageDispatcher2()); auto dispatcher = make_unique<MockBtMessageDispatcher2>();
msg.setBtMessageDispatcher(dispatcher.get()); msg.setBtMessageDispatcher(dispatcher.get());
std::shared_ptr<MockBtRequestFactory2> requestFactory(new MockBtRequestFactory2()); auto requestFactory = make_unique<MockBtRequestFactory2>();
msg.setBtRequestFactory(requestFactory.get()); msg.setBtRequestFactory(requestFactory.get());
msg.doReceivedAction(); msg.doReceivedAction();
@ -121,10 +127,10 @@ void BtChokeMessageTest::testOnSendComplete() {
BtChokeMessage msg; BtChokeMessage msg;
msg.setPeer(peer); msg.setPeer(peer);
std::shared_ptr<MockBtMessageDispatcher2> dispatcher(new MockBtMessageDispatcher2()); auto dispatcher = make_unique<MockBtMessageDispatcher2>();
msg.setBtMessageDispatcher(dispatcher.get()); msg.setBtMessageDispatcher(dispatcher.get());
std::shared_ptr<ProgressUpdate> pu(msg.getProgressUpdate()); auto pu = std::unique_ptr<ProgressUpdate>{msg.getProgressUpdate()};
pu->update(0, true); pu->update(0, true);
CPPUNIT_ASSERT(dispatcher->doChokingActionCalled); CPPUNIT_ASSERT(dispatcher->doChokingActionCalled);

View File

@ -15,6 +15,7 @@
#include "Piece.h" #include "Piece.h"
#include "BtHandshakeMessage.h" #include "BtHandshakeMessage.h"
#include "DownloadContext.h" #include "DownloadContext.h"
#include "BtRejectMessage.h"
namespace aria2 { namespace aria2 {
@ -47,44 +48,33 @@ public:
void testCancelSendingPieceEvent_invalidate(); void testCancelSendingPieceEvent_invalidate();
void testToString(); void testToString();
class MockBtMessage2 : public MockBtMessage {
public:
size_t index;
uint32_t begin;
size_t length;
public:
MockBtMessage2(size_t index, uint32_t begin, size_t length):index(index), begin(begin), length(length) {}
};
class MockBtMessageFactory2 : public MockBtMessageFactory { class MockBtMessageFactory2 : public MockBtMessageFactory {
public: public:
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtRejectMessage>
createRejectMessage(size_t index, createRejectMessage(size_t index,
int32_t begin, int32_t begin,
int32_t length) { int32_t length) {
std::shared_ptr<MockBtMessage2> msg(new MockBtMessage2(index, begin, length)); return make_unique<BtRejectMessage>(index, begin, length);
return msg;
} }
}; };
std::shared_ptr<DownloadContext> dctx_; std::unique_ptr<DownloadContext> dctx_;
std::shared_ptr<MockBtMessageDispatcher> btMessageDispatcher; std::unique_ptr<MockBtMessageDispatcher> btMessageDispatcher;
std::shared_ptr<MockBtMessageFactory> btMessageFactory_; std::unique_ptr<MockBtMessageFactory> btMessageFactory_;
std::shared_ptr<Peer> peer; std::shared_ptr<Peer> peer;
std::shared_ptr<BtPieceMessage> msg; std::unique_ptr<BtPieceMessage> msg;
void setUp() { void setUp() {
dctx_.reset(new DownloadContext(16*1024, 256*1024, "/path/to/file")); dctx_ = make_unique<DownloadContext>(16*1024, 256*1024, "/path/to/file");
peer.reset(new Peer("host", 6969)); peer = std::make_shared<Peer>("host", 6969);
peer->allocateSessionResource(dctx_->getPieceLength(), peer->allocateSessionResource(dctx_->getPieceLength(),
dctx_->getTotalLength()); dctx_->getTotalLength());
btMessageDispatcher.reset(new MockBtMessageDispatcher()); btMessageDispatcher = make_unique<MockBtMessageDispatcher>();
btMessageFactory_.reset(new MockBtMessageFactory2()); btMessageFactory_ = make_unique<MockBtMessageFactory2>();
msg.reset(new BtPieceMessage()); msg = make_unique<BtPieceMessage>();
msg->setIndex(1); msg->setIndex(1);
msg->setBegin(1024); msg->setBegin(1024);
msg->setBlockLength(16*1024); msg->setBlockLength(16*1024);
@ -166,11 +156,11 @@ void BtPieceMessageTest::testChokingEvent_allowedFastEnabled() {
CPPUNIT_ASSERT(msg->isInvalidate()); CPPUNIT_ASSERT(msg->isInvalidate());
CPPUNIT_ASSERT_EQUAL((size_t)1, btMessageDispatcher->messageQueue.size()); CPPUNIT_ASSERT_EQUAL((size_t)1, btMessageDispatcher->messageQueue.size());
auto rej = std::dynamic_pointer_cast<MockBtMessage2> auto rej = static_cast<const BtRejectMessage*>
(btMessageDispatcher->messageQueue.front()); (btMessageDispatcher->messageQueue.front().get());
CPPUNIT_ASSERT_EQUAL((size_t)1, rej->index); CPPUNIT_ASSERT_EQUAL((size_t)1, rej->getIndex());
CPPUNIT_ASSERT_EQUAL((uint32_t)1024, rej->begin); CPPUNIT_ASSERT_EQUAL((int32_t)1024, rej->getBegin());
CPPUNIT_ASSERT_EQUAL((size_t)16*1024, rej->length); CPPUNIT_ASSERT_EQUAL((int32_t)16*1024, rej->getLength());
} }
void BtPieceMessageTest::testChokingEvent_inAmAllowedIndexSet() { void BtPieceMessageTest::testChokingEvent_inAmAllowedIndexSet() {
@ -234,11 +224,11 @@ void BtPieceMessageTest::testCancelSendingPieceEvent_allowedFastEnabled() {
CPPUNIT_ASSERT(msg->isInvalidate()); CPPUNIT_ASSERT(msg->isInvalidate());
CPPUNIT_ASSERT_EQUAL((size_t)1, btMessageDispatcher->messageQueue.size()); CPPUNIT_ASSERT_EQUAL((size_t)1, btMessageDispatcher->messageQueue.size());
auto rej = std::dynamic_pointer_cast<MockBtMessage2> auto rej = static_cast<const BtRejectMessage*>
(btMessageDispatcher->messageQueue.front()); (btMessageDispatcher->messageQueue.front().get());
CPPUNIT_ASSERT_EQUAL((size_t)1, rej->index); CPPUNIT_ASSERT_EQUAL((size_t)1, rej->getIndex());
CPPUNIT_ASSERT_EQUAL((uint32_t)1024, rej->begin); CPPUNIT_ASSERT_EQUAL((int32_t)1024, rej->getBegin());
CPPUNIT_ASSERT_EQUAL((size_t)16*1024, rej->length); CPPUNIT_ASSERT_EQUAL((int32_t)16*1024, rej->getLength());
} }
void BtPieceMessageTest::testCancelSendingPieceEvent_invalidate() { void BtPieceMessageTest::testCancelSendingPieceEvent_invalidate() {

View File

@ -54,62 +54,44 @@ public:
class MockPieceStorage2 : public MockPieceStorage { class MockPieceStorage2 : public MockPieceStorage {
public: public:
virtual bool hasPiece(size_t index) { virtual bool hasPiece(size_t index) override
{
return index == 1; return index == 1;
} }
}; };
class MockBtMessage2 : public MockBtMessage {
public:
std::string type;
size_t index;
uint32_t begin;
size_t length;
public:
MockBtMessage2(std::string type, size_t index, uint32_t begin,
size_t length)
:
type(type), index(index), begin(begin), length(length) {}
};
typedef std::shared_ptr<MockBtMessage2> MockBtMessage2Handle;
class MockBtMessageFactory2 : public MockBtMessageFactory { class MockBtMessageFactory2 : public MockBtMessageFactory {
public: public:
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtPieceMessage>
createPieceMessage(size_t index, int32_t begin, int32_t length) { createPieceMessage(size_t index, int32_t begin, int32_t length) override
std::shared_ptr<MockBtMessage2> btMsg {
(new MockBtMessage2("piece", index, begin, length)); return make_unique<BtPieceMessage>(index, begin, length);
return btMsg;
} }
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtRejectMessage>
createRejectMessage(size_t index, int32_t begin, int32_t length) { createRejectMessage(size_t index, int32_t begin, int32_t length) override
std::shared_ptr<MockBtMessage2> btMsg {
(new MockBtMessage2("reject", index, begin, length)); return make_unique<BtRejectMessage>(index, begin, length);
return btMsg;
} }
}; };
typedef std::shared_ptr<MockBtMessageFactory2> MockBtMessageFactory2Handle; std::unique_ptr<MockPieceStorage> pieceStorage_;
std::shared_ptr<MockPieceStorage> pieceStorage_;
std::shared_ptr<Peer> peer_; std::shared_ptr<Peer> peer_;
std::shared_ptr<MockBtMessageDispatcher> dispatcher_; std::unique_ptr<MockBtMessageDispatcher> dispatcher_;
std::shared_ptr<MockBtMessageFactory> messageFactory_; std::unique_ptr<MockBtMessageFactory> messageFactory_;
std::shared_ptr<BtRequestMessage> msg; std::unique_ptr<BtRequestMessage> msg;
void setUp() { void setUp() {
pieceStorage_.reset(new MockPieceStorage2()); pieceStorage_ = make_unique<MockPieceStorage2>();
peer_.reset(new Peer("host", 6969)); peer_ = std::make_shared<Peer>("host", 6969);
peer_->allocateSessionResource(16*1024, 256*1024); peer_->allocateSessionResource(16*1024, 256*1024);
dispatcher_.reset(new MockBtMessageDispatcher()); dispatcher_ = make_unique<MockBtMessageDispatcher>();
messageFactory_.reset(new MockBtMessageFactory2()); messageFactory_ = make_unique<MockBtMessageFactory2>();
msg.reset(new BtRequestMessage()); msg = make_unique<BtRequestMessage>();
msg->setPeer(peer_); msg->setPeer(peer_);
msg->setIndex(1); msg->setIndex(1);
msg->setBegin(16); msg->setBegin(16);
@ -130,7 +112,7 @@ void BtRequestMessageTest::testCreate() {
bittorrent::setIntParam(&msg[5], 12345); bittorrent::setIntParam(&msg[5], 12345);
bittorrent::setIntParam(&msg[9], 256); bittorrent::setIntParam(&msg[9], 256);
bittorrent::setIntParam(&msg[13], 1024); bittorrent::setIntParam(&msg[13], 1024);
std::shared_ptr<BtRequestMessage> pm(BtRequestMessage::create(&msg[4], 13)); auto pm = BtRequestMessage::create(&msg[4], 13);
CPPUNIT_ASSERT_EQUAL((uint8_t)6, pm->getId()); CPPUNIT_ASSERT_EQUAL((uint8_t)6, pm->getId());
CPPUNIT_ASSERT_EQUAL((size_t)12345, pm->getIndex()); CPPUNIT_ASSERT_EQUAL((size_t)12345, pm->getIndex());
CPPUNIT_ASSERT_EQUAL(256, pm->getBegin()); CPPUNIT_ASSERT_EQUAL(256, pm->getBegin());
@ -174,12 +156,13 @@ void BtRequestMessageTest::testDoReceivedAction_hasPieceAndAmNotChoking() {
msg->doReceivedAction(); msg->doReceivedAction();
CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher_->messageQueue.size()); CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher_->messageQueue.size());
auto pieceMsg = std::dynamic_pointer_cast<MockBtMessage2> CPPUNIT_ASSERT(BtPieceMessage::ID ==
(dispatcher_->messageQueue.front()); dispatcher_->messageQueue.front()->getId());
CPPUNIT_ASSERT_EQUAL(std::string("piece"), pieceMsg->type); auto pieceMsg = static_cast<const BtPieceMessage*>
CPPUNIT_ASSERT_EQUAL((size_t)1, pieceMsg->index); (dispatcher_->messageQueue.front().get());
CPPUNIT_ASSERT_EQUAL((uint32_t)16, pieceMsg->begin); CPPUNIT_ASSERT_EQUAL((size_t)1, pieceMsg->getIndex());
CPPUNIT_ASSERT_EQUAL((size_t)32, pieceMsg->length); CPPUNIT_ASSERT_EQUAL((int32_t)16, pieceMsg->getBegin());
CPPUNIT_ASSERT_EQUAL((int32_t)32, pieceMsg->getBlockLength());
} }
void BtRequestMessageTest::testDoReceivedAction_hasPieceAndAmChokingAndFastExtensionEnabled() { void BtRequestMessageTest::testDoReceivedAction_hasPieceAndAmChokingAndFastExtensionEnabled() {
@ -188,12 +171,13 @@ void BtRequestMessageTest::testDoReceivedAction_hasPieceAndAmChokingAndFastExten
msg->doReceivedAction(); msg->doReceivedAction();
CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher_->messageQueue.size()); CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher_->messageQueue.size());
auto pieceMsg = std::dynamic_pointer_cast<MockBtMessage2> CPPUNIT_ASSERT(BtRejectMessage::ID ==
(dispatcher_->messageQueue.front()); dispatcher_->messageQueue.front()->getId());
CPPUNIT_ASSERT_EQUAL(std::string("reject"), pieceMsg->type); auto rejectMsg = static_cast<const BtRejectMessage*>
CPPUNIT_ASSERT_EQUAL((size_t)1, pieceMsg->index); (dispatcher_->messageQueue.front().get());
CPPUNIT_ASSERT_EQUAL((uint32_t)16, pieceMsg->begin); CPPUNIT_ASSERT_EQUAL((size_t)1, rejectMsg->getIndex());
CPPUNIT_ASSERT_EQUAL((size_t)32, pieceMsg->length); CPPUNIT_ASSERT_EQUAL((int32_t)16, rejectMsg->getBegin());
CPPUNIT_ASSERT_EQUAL((int32_t)32, rejectMsg->getLength());
} }
void BtRequestMessageTest::testDoReceivedAction_hasPieceAndAmChokingAndFastExtensionDisabled() { void BtRequestMessageTest::testDoReceivedAction_hasPieceAndAmChokingAndFastExtensionDisabled() {
@ -210,12 +194,13 @@ void BtRequestMessageTest::testDoReceivedAction_doesntHavePieceAndFastExtensionE
msg->doReceivedAction(); msg->doReceivedAction();
CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher_->messageQueue.size()); CPPUNIT_ASSERT_EQUAL((size_t)1, dispatcher_->messageQueue.size());
auto pieceMsg = std::dynamic_pointer_cast<MockBtMessage2> CPPUNIT_ASSERT(BtRejectMessage::ID ==
(dispatcher_->messageQueue.front()); dispatcher_->messageQueue.front()->getId());
CPPUNIT_ASSERT_EQUAL(std::string("reject"), pieceMsg->type); auto rejectMsg = static_cast<const BtRejectMessage*>
CPPUNIT_ASSERT_EQUAL((size_t)2, pieceMsg->index); (dispatcher_->messageQueue.front().get());
CPPUNIT_ASSERT_EQUAL((uint32_t)16, pieceMsg->begin); CPPUNIT_ASSERT_EQUAL((size_t)2, rejectMsg->getIndex());
CPPUNIT_ASSERT_EQUAL((size_t)32, pieceMsg->length); CPPUNIT_ASSERT_EQUAL((int32_t)16, rejectMsg->getBegin());
CPPUNIT_ASSERT_EQUAL((int32_t)32, rejectMsg->getLength());
} }
void BtRequestMessageTest::testDoReceivedAction_doesntHavePieceAndFastExtensionDisabled() { void BtRequestMessageTest::testDoReceivedAction_doesntHavePieceAndFastExtensionDisabled() {
@ -227,21 +212,21 @@ void BtRequestMessageTest::testDoReceivedAction_doesntHavePieceAndFastExtensionD
} }
void BtRequestMessageTest::testHandleAbortRequestEvent() { void BtRequestMessageTest::testHandleAbortRequestEvent() {
std::shared_ptr<Piece> piece(new Piece(1, 16*1024)); auto piece = std::make_shared<Piece>(1, 16*1024);
CPPUNIT_ASSERT(!msg->isInvalidate()); CPPUNIT_ASSERT(!msg->isInvalidate());
msg->onAbortOutstandingRequestEvent(BtAbortOutstandingRequestEvent(piece)); msg->onAbortOutstandingRequestEvent(BtAbortOutstandingRequestEvent(piece));
CPPUNIT_ASSERT(msg->isInvalidate()); CPPUNIT_ASSERT(msg->isInvalidate());
} }
void BtRequestMessageTest::testHandleAbortRequestEvent_indexNoMatch() { void BtRequestMessageTest::testHandleAbortRequestEvent_indexNoMatch() {
std::shared_ptr<Piece> piece(new Piece(2, 16*1024)); auto piece = std::make_shared<Piece>(2, 16*1024);
CPPUNIT_ASSERT(!msg->isInvalidate()); CPPUNIT_ASSERT(!msg->isInvalidate());
msg->onAbortOutstandingRequestEvent(BtAbortOutstandingRequestEvent(piece)); msg->onAbortOutstandingRequestEvent(BtAbortOutstandingRequestEvent(piece));
CPPUNIT_ASSERT(!msg->isInvalidate()); CPPUNIT_ASSERT(!msg->isInvalidate());
} }
void BtRequestMessageTest::testHandleAbortRequestEvent_alreadyInvalidated() { void BtRequestMessageTest::testHandleAbortRequestEvent_alreadyInvalidated() {
std::shared_ptr<Piece> piece(new Piece(1, 16*1024)); auto piece = std::make_shared<Piece>(1, 16*1024);
msg->setInvalidate(true); msg->setInvalidate(true);
CPPUNIT_ASSERT(msg->isInvalidate()); CPPUNIT_ASSERT(msg->isInvalidate());
msg->onAbortOutstandingRequestEvent(BtAbortOutstandingRequestEvent(piece)); msg->onAbortOutstandingRequestEvent(BtAbortOutstandingRequestEvent(piece));

View File

@ -32,9 +32,8 @@ void BtSuggestPieceMessageTest::testCreate() {
unsigned char msg[9]; unsigned char msg[9];
bittorrent::createPeerMessageString(msg, sizeof(msg), 5, 13); bittorrent::createPeerMessageString(msg, sizeof(msg), 5, 13);
bittorrent::setIntParam(&msg[5], 12345); bittorrent::setIntParam(&msg[5], 12345);
std::shared_ptr<BtSuggestPieceMessage> pm auto pm = BtSuggestPieceMessage::create(&msg[4], 5);
(BtSuggestPieceMessage::create(&msg[4], 5)); CPPUNIT_ASSERT(BtSuggestPieceMessage::ID == pm->getId());
CPPUNIT_ASSERT_EQUAL((uint8_t)13, pm->getId());
CPPUNIT_ASSERT_EQUAL((size_t)12345, pm->getIndex()); CPPUNIT_ASSERT_EQUAL((size_t)12345, pm->getIndex());
// case: payload size is wrong // case: payload size is wrong

View File

@ -43,13 +43,13 @@ class DefaultBtMessageDispatcherTest:public CppUnit::TestFixture {
private: private:
std::shared_ptr<DownloadContext> dctx_; std::shared_ptr<DownloadContext> dctx_;
std::shared_ptr<Peer> peer; std::shared_ptr<Peer> peer;
std::shared_ptr<DefaultBtMessageDispatcher> btMessageDispatcher; std::unique_ptr<DefaultBtMessageDispatcher> btMessageDispatcher;
std::shared_ptr<MockPeerStorage> peerStorage; std::unique_ptr<MockPeerStorage> peerStorage;
std::shared_ptr<MockPieceStorage> pieceStorage; std::unique_ptr<MockPieceStorage> pieceStorage;
std::shared_ptr<MockBtMessageFactory> messageFactory_; std::unique_ptr<MockBtMessageFactory> messageFactory_;
std::shared_ptr<RequestGroupMan> rgman_; std::unique_ptr<RequestGroupMan> rgman_;
std::shared_ptr<Option> option_; std::shared_ptr<Option> option_;
std::shared_ptr<RequestGroup> rg_; std::unique_ptr<RequestGroup> rg_;
public: public:
void tearDown() {} void tearDown() {}
@ -66,45 +66,43 @@ public:
void testGetOutstandingRequest(); void testGetOutstandingRequest();
void testRemoveOutstandingRequest(); void testRemoveOutstandingRequest();
class MockBtMessage2 : public MockBtMessage { struct EventCheck {
private: EventCheck() : onQueuedCalled{false}, sendCalled{false},
doCancelActionCalled{false}
{}
bool onQueuedCalled; bool onQueuedCalled;
bool sendCalled; bool sendCalled;
bool doCancelActionCalled; bool doCancelActionCalled;
};
class MockBtMessage2 : public MockBtMessage {
public: public:
EventCheck* evcheck;
std::string type; std::string type;
public: MockBtMessage2(EventCheck* evcheck = nullptr)
MockBtMessage2():onQueuedCalled(false), : evcheck{evcheck}
sendCalled(false),
doCancelActionCalled(false)
{} {}
virtual ~MockBtMessage2() {} virtual void onQueued() override
{
virtual void onQueued() { if(evcheck){
onQueuedCalled = true; evcheck->onQueuedCalled = true;
}
} }
bool isOnQueuedCalled() const { virtual void send() override
return onQueuedCalled; {
if(evcheck) {
evcheck->sendCalled = true;
} }
virtual void send() {
sendCalled = true;
}
bool isSendCalled() const {
return sendCalled;
} }
virtual void onCancelSendingPieceEvent virtual void onCancelSendingPieceEvent
(const BtCancelSendingPieceEvent& event) (const BtCancelSendingPieceEvent& event) override
{ {
doCancelActionCalled = true; if(evcheck) {
evcheck->doCancelActionCalled = true;
} }
bool isDoCancelActionCalled() const {
return doCancelActionCalled;
} }
}; };
@ -112,48 +110,50 @@ public:
private: private:
std::shared_ptr<Piece> piece; std::shared_ptr<Piece> piece;
public: public:
virtual std::shared_ptr<Piece> getPiece(size_t index) { virtual std::shared_ptr<Piece> getPiece(size_t index) override
{
return piece; return piece;
} }
void setPiece(const std::shared_ptr<Piece>& piece) { void setPiece(const std::shared_ptr<Piece>& piece)
{
this->piece = piece; this->piece = piece;
} }
}; };
class MockBtMessageFactory2 : public MockBtMessageFactory { class MockBtMessageFactory2 : public MockBtMessageFactory {
public: public:
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtCancelMessage>
createCancelMessage(size_t index, int32_t begin, int32_t length) { createCancelMessage(size_t index, int32_t begin, int32_t length) override
std::shared_ptr<MockBtMessage2> btMsg(new MockBtMessage2()); {
btMsg->type = "cancel"; return make_unique<BtCancelMessage>(index, begin, length);
return btMsg;
} }
}; };
void setUp() { void setUp()
option_.reset(new Option()); {
option_ = std::make_shared<Option>();
option_->put(PREF_DIR, "."); option_->put(PREF_DIR, ".");
rg_.reset(new RequestGroup(GroupId::create(), option_)); rg_ = make_unique<RequestGroup>(GroupId::create(), option_);
dctx_.reset(new DownloadContext()); dctx_ = std::make_shared<DownloadContext>();
bittorrent::load(A2_TEST_DIR"/test.torrent", dctx_, option_); bittorrent::load(A2_TEST_DIR"/test.torrent", dctx_, option_);
rg_->setDownloadContext(dctx_); rg_->setDownloadContext(dctx_);
peer.reset(new Peer("192.168.0.1", 6969)); peer = std::make_shared<Peer>("192.168.0.1", 6969);
peer->allocateSessionResource peer->allocateSessionResource
(dctx_->getPieceLength(), dctx_->getTotalLength()); (dctx_->getPieceLength(), dctx_->getTotalLength());
peerStorage.reset(new MockPeerStorage()); peerStorage = make_unique<MockPeerStorage>();
pieceStorage.reset(new MockPieceStorage()); pieceStorage = make_unique<MockPieceStorage>();
messageFactory_.reset(new MockBtMessageFactory2()); messageFactory_ = make_unique<MockBtMessageFactory2>();
rgman_.reset(new RequestGroupMan(std::vector<std::shared_ptr<RequestGroup> >(), rgman_ = make_unique<RequestGroupMan>
0, option_.get())); (std::vector<std::shared_ptr<RequestGroup>>{}, 0, option_.get());
btMessageDispatcher.reset(new DefaultBtMessageDispatcher()); btMessageDispatcher = make_unique<DefaultBtMessageDispatcher>();
btMessageDispatcher->setPeer(peer); btMessageDispatcher->setPeer(peer);
btMessageDispatcher->setDownloadContext(dctx_.get()); btMessageDispatcher->setDownloadContext(dctx_.get());
btMessageDispatcher->setPieceStorage(pieceStorage.get()); btMessageDispatcher->setPieceStorage(pieceStorage.get());
@ -167,87 +167,66 @@ public:
CPPUNIT_TEST_SUITE_REGISTRATION(DefaultBtMessageDispatcherTest); CPPUNIT_TEST_SUITE_REGISTRATION(DefaultBtMessageDispatcherTest);
void DefaultBtMessageDispatcherTest::testAddMessage() { void DefaultBtMessageDispatcherTest::testAddMessage()
std::shared_ptr<MockBtMessage2> msg(new MockBtMessage2()); {
CPPUNIT_ASSERT_EQUAL(false, msg->isOnQueuedCalled()); auto evcheck = EventCheck{};
btMessageDispatcher->addMessageToQueue(msg); auto msg = make_unique<MockBtMessage2>(&evcheck);
CPPUNIT_ASSERT_EQUAL(true, msg->isOnQueuedCalled()); btMessageDispatcher->addMessageToQueue(std::move(msg));
CPPUNIT_ASSERT_EQUAL(true, evcheck.onQueuedCalled);
CPPUNIT_ASSERT_EQUAL((size_t)1, CPPUNIT_ASSERT_EQUAL((size_t)1,
btMessageDispatcher->getMessageQueue().size()); btMessageDispatcher->getMessageQueue().size());
} }
void DefaultBtMessageDispatcherTest::testSendMessages() { void DefaultBtMessageDispatcherTest::testSendMessages() {
std::shared_ptr<MockBtMessage2> msg1(new MockBtMessage2()); auto evcheck1 = EventCheck{};
auto msg1 = make_unique<MockBtMessage2>(&evcheck1);
msg1->setUploading(false); msg1->setUploading(false);
std::shared_ptr<MockBtMessage2> msg2(new MockBtMessage2()); auto evcheck2 = EventCheck{};
auto msg2 = make_unique<MockBtMessage2>(&evcheck2);
msg2->setUploading(false); msg2->setUploading(false);
btMessageDispatcher->addMessageToQueue(msg1); btMessageDispatcher->addMessageToQueue(std::move(msg1));
btMessageDispatcher->addMessageToQueue(msg2); btMessageDispatcher->addMessageToQueue(std::move(msg2));
btMessageDispatcher->sendMessagesInternal(); btMessageDispatcher->sendMessagesInternal();
CPPUNIT_ASSERT(msg1->isSendCalled()); CPPUNIT_ASSERT(evcheck1.sendCalled);
CPPUNIT_ASSERT(msg2->isSendCalled()); CPPUNIT_ASSERT(evcheck2.sendCalled);
} }
void DefaultBtMessageDispatcherTest::testSendMessages_underUploadLimit() { void DefaultBtMessageDispatcherTest::testSendMessages_underUploadLimit() {
std::shared_ptr<MockBtMessage2> msg1(new MockBtMessage2()); auto evcheck1 = EventCheck{};
auto msg1 = make_unique<MockBtMessage2>(&evcheck1);
msg1->setUploading(true); msg1->setUploading(true);
std::shared_ptr<MockBtMessage2> msg2(new MockBtMessage2()); auto evcheck2 = EventCheck{};
auto msg2 = make_unique<MockBtMessage2>(&evcheck2);
msg2->setUploading(true); msg2->setUploading(true);
btMessageDispatcher->addMessageToQueue(msg1); btMessageDispatcher->addMessageToQueue(std::move(msg1));
btMessageDispatcher->addMessageToQueue(msg2); btMessageDispatcher->addMessageToQueue(std::move(msg2));
btMessageDispatcher->sendMessagesInternal(); btMessageDispatcher->sendMessagesInternal();
CPPUNIT_ASSERT(msg1->isSendCalled()); CPPUNIT_ASSERT(evcheck1.sendCalled);
CPPUNIT_ASSERT(msg2->isSendCalled()); CPPUNIT_ASSERT(evcheck2.sendCalled);
} }
// TODO Because we no longer directly use PeerStorage::calculateStat() void DefaultBtMessageDispatcherTest::testDoCancelSendingPieceAction()
// and Neither RequestGroup nor RequestGroupMan can be stubbed, this {
// test is commented out for now. auto evcheck1 = EventCheck{};
// auto msg1 = make_unique<MockBtMessage2>(&evcheck1);
// void DefaultBtMessageDispatcherTest::testSendMessages_overUploadLimit() { auto evcheck2 = EventCheck{};
// btMessageDispatcher->setMaxUploadSpeedLimit(100); auto msg2 = make_unique<MockBtMessage2>(&evcheck2);
// TransferStat stat;
// stat.setUploadSpeed(150);
// peerStorage->setStat(stat);
// std::shared_ptr<MockBtMessage2> msg1(new MockBtMessage2()); btMessageDispatcher->addMessageToQueue(std::move(msg1));
// msg1->setUploading(true); btMessageDispatcher->addMessageToQueue(std::move(msg2));
// std::shared_ptr<MockBtMessage2> msg2(new MockBtMessage2());
// msg2->setUploading(true);
// std::shared_ptr<MockBtMessage2> msg3(new MockBtMessage2());
// msg3->setUploading(false);
// btMessageDispatcher->addMessageToQueue(msg1);
// btMessageDispatcher->addMessageToQueue(msg2);
// btMessageDispatcher->addMessageToQueue(msg3);
// btMessageDispatcher->sendMessagesInternal();
// CPPUNIT_ASSERT(!msg1->isSendCalled());
// CPPUNIT_ASSERT(!msg2->isSendCalled());
// CPPUNIT_ASSERT(msg3->isSendCalled());
// CPPUNIT_ASSERT_EQUAL((size_t)2,
// btMessageDispatcher->getMessageQueue().size());
// }
void DefaultBtMessageDispatcherTest::testDoCancelSendingPieceAction() {
std::shared_ptr<MockBtMessage2> msg1(new MockBtMessage2());
std::shared_ptr<MockBtMessage2> msg2(new MockBtMessage2());
btMessageDispatcher->addMessageToQueue(msg1);
btMessageDispatcher->addMessageToQueue(msg2);
btMessageDispatcher->doCancelSendingPieceAction(0, 0, 0); btMessageDispatcher->doCancelSendingPieceAction(0, 0, 0);
CPPUNIT_ASSERT_EQUAL(true, msg1->isDoCancelActionCalled()); CPPUNIT_ASSERT(evcheck1.doCancelActionCalled);
CPPUNIT_ASSERT_EQUAL(true, msg2->isDoCancelActionCalled()); CPPUNIT_ASSERT(evcheck2.doCancelActionCalled);
} }
int MY_PIECE_LENGTH = 16*1024; int MY_PIECE_LENGTH = 16*1024;
void DefaultBtMessageDispatcherTest::testCheckRequestSlotAndDoNecessaryThing() { void DefaultBtMessageDispatcherTest::testCheckRequestSlotAndDoNecessaryThing()
{
auto piece = std::make_shared<Piece>(0, MY_PIECE_LENGTH); auto piece = std::make_shared<Piece>(0, MY_PIECE_LENGTH);
size_t index; size_t index;
CPPUNIT_ASSERT(piece->getMissingUnusedBlockIndex(index)); CPPUNIT_ASSERT(piece->getMissingUnusedBlockIndex(index));

View File

@ -24,22 +24,22 @@ class DefaultBtMessageFactoryTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testCreatePortMessage); CPPUNIT_TEST(testCreatePortMessage);
CPPUNIT_TEST_SUITE_END(); CPPUNIT_TEST_SUITE_END();
private: private:
std::shared_ptr<DownloadContext> dctx_; std::unique_ptr<DownloadContext> dctx_;
std::shared_ptr<Peer> peer_; std::shared_ptr<Peer> peer_;
std::shared_ptr<MockExtensionMessageFactory> exmsgFactory_; std::shared_ptr<MockExtensionMessageFactory> exmsgFactory_;
std::shared_ptr<DefaultBtMessageFactory> factory_; std::unique_ptr<DefaultBtMessageFactory> factory_;
public: public:
void setUp() void setUp()
{ {
dctx_.reset(new DownloadContext()); dctx_ = make_unique<DownloadContext>();
peer_.reset(new Peer("192.168.0.1", 6969)); peer_ = std::make_shared<Peer>("192.168.0.1", 6969);
peer_->allocateSessionResource(1024, 1024*1024); peer_->allocateSessionResource(1024, 1024*1024);
peer_->setExtendedMessagingEnabled(true); peer_->setExtendedMessagingEnabled(true);
exmsgFactory_.reset(new MockExtensionMessageFactory()); exmsgFactory_ = std::make_shared<MockExtensionMessageFactory>();
factory_.reset(new DefaultBtMessageFactory()); factory_ = make_unique<DefaultBtMessageFactory>();
factory_->setDownloadContext(dctx_.get()); factory_->setDownloadContext(dctx_.get());
factory_->setPeer(peer_); factory_->setPeer(peer_);
factory_->setExtensionMessageFactory(exmsgFactory_); factory_->setExtensionMessageFactory(exmsgFactory_);
@ -62,9 +62,9 @@ void DefaultBtMessageFactoryTest::testCreateBtMessage_BtExtendedMessage()
msg[5] = 1; // Set dummy extended message ID 1 msg[5] = 1; // Set dummy extended message ID 1
memcpy(msg+6, payload.c_str(), payload.size()); memcpy(msg+6, payload.c_str(), payload.size());
auto m = std::dynamic_pointer_cast<BtExtendedMessage> auto m =
(factory_->createBtMessage((const unsigned char*)msg+4, sizeof(msg))); factory_->createBtMessage((const unsigned char*)msg+4, sizeof(msg));
CPPUNIT_ASSERT(BtExtendedMessage::ID == m->getId());
try { try {
// disable extended messaging // disable extended messaging
peer_->setExtendedMessagingEnabled(false); peer_->setExtendedMessagingEnabled(false);
@ -82,17 +82,16 @@ void DefaultBtMessageFactoryTest::testCreatePortMessage()
bittorrent::createPeerMessageString(data, sizeof(data), 3, 9); bittorrent::createPeerMessageString(data, sizeof(data), 3, 9);
bittorrent::setShortIntParam(&data[5], 6881); bittorrent::setShortIntParam(&data[5], 6881);
try { try {
auto m = std::dynamic_pointer_cast<BtPortMessage> auto r = factory_->createBtMessage(&data[4], sizeof(data)-4);
(factory_->createBtMessage(&data[4], sizeof(data)-4)); CPPUNIT_ASSERT(BtPortMessage::ID == r->getId());
CPPUNIT_ASSERT(m); auto m = static_cast<const BtPortMessage*>(r.get());
CPPUNIT_ASSERT_EQUAL((uint16_t)6881, m->getPort()); CPPUNIT_ASSERT_EQUAL((uint16_t)6881, m->getPort());
} catch(Exception& e) { } catch(Exception& e) {
CPPUNIT_FAIL(e.stackTrace()); CPPUNIT_FAIL(e.stackTrace());
} }
} }
{ {
auto m = std::dynamic_pointer_cast<BtPortMessage> auto m = factory_->createPortMessage(6881);
(factory_->createPortMessage(6881));
CPPUNIT_ASSERT_EQUAL((uint16_t)6881, m->getPort()); CPPUNIT_ASSERT_EQUAL((uint16_t)6881, m->getPort());
} }
} }

View File

@ -28,11 +28,10 @@ class DefaultBtRequestFactoryTest:public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE_END(); CPPUNIT_TEST_SUITE_END();
private: private:
std::shared_ptr<Peer> peer_; std::shared_ptr<Peer> peer_;
std::shared_ptr<DefaultBtRequestFactory> requestFactory_; std::unique_ptr<DefaultBtRequestFactory> requestFactory_;
std::shared_ptr<DownloadContext> dctx_; std::unique_ptr<MockPieceStorage> pieceStorage_;
std::shared_ptr<MockPieceStorage> pieceStorage_; std::unique_ptr<MockBtMessageFactory> messageFactory_;
std::shared_ptr<MockBtMessageFactory> messageFactory_; std::unique_ptr<MockBtMessageDispatcher> dispatcher_;
std::shared_ptr<MockBtMessageDispatcher> dispatcher_;
public: public:
void testAddTargetPiece(); void testAddTargetPiece();
void testRemoveCompletedPiece(); void testRemoveCompletedPiece();
@ -50,53 +49,43 @@ public:
index(index), blockIndex(blockIndex) {} index(index), blockIndex(blockIndex) {}
}; };
typedef std::shared_ptr<MockBtRequestMessage> MockBtRequestMessageHandle;
class MockBtMessageFactory2 : public MockBtMessageFactory { class MockBtMessageFactory2 : public MockBtMessageFactory {
public: public:
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtRequestMessage>
createRequestMessage(const std::shared_ptr<Piece>& piece, size_t blockIndex) { createRequestMessage(const std::shared_ptr<Piece>& piece,
return std::shared_ptr<BtMessage> size_t blockIndex) override
(new MockBtRequestMessage(piece->getIndex(), blockIndex)); {
return make_unique<BtRequestMessage>(piece->getIndex(), 0, 0,
blockIndex);
} }
}; };
class MockBtMessageDispatcher2 : public MockBtMessageDispatcher { class MockBtMessageDispatcher2 : public MockBtMessageDispatcher {
public: public:
virtual bool isOutstandingRequest(size_t index, size_t blockIndex) { virtual bool isOutstandingRequest(size_t index, size_t blockIndex) override
{
return index == 0 && blockIndex == 0; return index == 0 && blockIndex == 0;
} }
}; };
class SortMockBtRequestMessage { class BtRequestMessageSorter {
public: public:
bool operator()(const std::shared_ptr<MockBtRequestMessage>& a, bool operator()(const std::unique_ptr<BtRequestMessage>& a,
const std::shared_ptr<MockBtRequestMessage>& b) { const std::unique_ptr<BtRequestMessage>& b)
if(a->index < b->index) { {
return true; return a->getIndex() < b->getIndex() ||
} else if(b->index < a->index) { (a->getIndex() == b->getIndex() &&
return false; a->getBlockIndex() < b->getBlockIndex());
} else if(a->blockIndex < b->blockIndex) {
return true;
} else if(b->blockIndex < a->blockIndex) {
return false;
} else {
return true;
}
} }
}; };
void setUp() void setUp()
{ {
pieceStorage_.reset(new MockPieceStorage()); pieceStorage_ = make_unique<MockPieceStorage>();
peer_ = std::make_shared<Peer>("host", 6969);
peer_.reset(new Peer("host", 6969)); messageFactory_ = make_unique<MockBtMessageFactory2>();
dispatcher_ = make_unique<MockBtMessageDispatcher>();
messageFactory_.reset(new MockBtMessageFactory2()); requestFactory_ = make_unique<DefaultBtRequestFactory>();
dispatcher_.reset(new MockBtMessageDispatcher());
requestFactory_.reset(new DefaultBtRequestFactory());
requestFactory_->setPieceStorage(pieceStorage_.get()); requestFactory_->setPieceStorage(pieceStorage_.get());
requestFactory_->setPeer(peer_); requestFactory_->setPeer(peer_);
requestFactory_->setBtMessageDispatcher(dispatcher_.get()); requestFactory_->setBtMessageDispatcher(dispatcher_.get());
@ -107,14 +96,15 @@ public:
CPPUNIT_TEST_SUITE_REGISTRATION(DefaultBtRequestFactoryTest); CPPUNIT_TEST_SUITE_REGISTRATION(DefaultBtRequestFactoryTest);
void DefaultBtRequestFactoryTest::testAddTargetPiece() { void DefaultBtRequestFactoryTest::testAddTargetPiece()
{ {
std::shared_ptr<Piece> piece(new Piece(0, 16*1024*10)); {
auto piece = std::make_shared<Piece>(0, 16*1024*10);
requestFactory_->addTargetPiece(piece); requestFactory_->addTargetPiece(piece);
CPPUNIT_ASSERT_EQUAL((size_t)1, requestFactory_->countTargetPiece()); CPPUNIT_ASSERT_EQUAL((size_t)1, requestFactory_->countTargetPiece());
} }
{ {
std::shared_ptr<Piece> piece(new Piece(1, 16*1024*9)); auto piece = std::make_shared<Piece>(1, 16*1024*9);
piece->completeBlock(0); piece->completeBlock(0);
requestFactory_->addTargetPiece(piece); requestFactory_->addTargetPiece(piece);
CPPUNIT_ASSERT_EQUAL((size_t)2, requestFactory_->countTargetPiece()); CPPUNIT_ASSERT_EQUAL((size_t)2, requestFactory_->countTargetPiece());
@ -122,9 +112,10 @@ void DefaultBtRequestFactoryTest::testAddTargetPiece() {
CPPUNIT_ASSERT_EQUAL((size_t)18, requestFactory_->countMissingBlock()); CPPUNIT_ASSERT_EQUAL((size_t)18, requestFactory_->countMissingBlock());
} }
void DefaultBtRequestFactoryTest::testRemoveCompletedPiece() { void DefaultBtRequestFactoryTest::testRemoveCompletedPiece()
std::shared_ptr<Piece> piece1(new Piece(0, 16*1024)); {
std::shared_ptr<Piece> piece2(new Piece(1, 16*1024)); auto piece1 = std::make_shared<Piece>(0, 16*1024);
auto piece2 = std::make_shared<Piece>(1, 16*1024);
piece2->setAllBlock(); piece2->setAllBlock();
requestFactory_->addTargetPiece(piece1); requestFactory_->addTargetPiece(piece1);
requestFactory_->addTargetPiece(piece2); requestFactory_->addTargetPiece(piece2);
@ -135,77 +126,63 @@ void DefaultBtRequestFactoryTest::testRemoveCompletedPiece() {
requestFactory_->getTargetPieces().front()->getIndex()); requestFactory_->getTargetPieces().front()->getIndex());
} }
void DefaultBtRequestFactoryTest::testCreateRequestMessages() { void DefaultBtRequestFactoryTest::testCreateRequestMessages()
{
int PIECE_LENGTH = 16*1024*2; int PIECE_LENGTH = 16*1024*2;
std::shared_ptr<Piece> piece1(new Piece(0, PIECE_LENGTH)); auto piece1 = std::make_shared<Piece>(0, PIECE_LENGTH);
std::shared_ptr<Piece> piece2(new Piece(1, PIECE_LENGTH)); auto piece2 = std::make_shared<Piece>(1, PIECE_LENGTH);
requestFactory_->addTargetPiece(piece1); requestFactory_->addTargetPiece(piece1);
requestFactory_->addTargetPiece(piece2); requestFactory_->addTargetPiece(piece2);
std::vector<std::shared_ptr<BtMessage> > msgs; auto msgs = requestFactory_->createRequestMessages(3, false);
requestFactory_->createRequestMessages(msgs, 3);
CPPUNIT_ASSERT_EQUAL((size_t)3, msgs.size()); CPPUNIT_ASSERT_EQUAL((size_t)3, msgs.size());
std::vector<std::shared_ptr<BtMessage> >::iterator itr = msgs.begin(); auto msg = msgs[0].get();
auto msg = std::dynamic_pointer_cast<MockBtRequestMessage>(*itr); CPPUNIT_ASSERT_EQUAL((size_t)0, msg->getIndex());
CPPUNIT_ASSERT_EQUAL((size_t)0, msg->index); CPPUNIT_ASSERT_EQUAL((size_t)0, msg->getBlockIndex());
CPPUNIT_ASSERT_EQUAL((size_t)0, msg->blockIndex); msg = msgs[1].get();
++itr; CPPUNIT_ASSERT_EQUAL((size_t)0, msg->getIndex());
msg = std::dynamic_pointer_cast<MockBtRequestMessage>(*itr); CPPUNIT_ASSERT_EQUAL((size_t)1, msg->getBlockIndex());
CPPUNIT_ASSERT_EQUAL((size_t)0, msg->index); msg = msgs[2].get();
CPPUNIT_ASSERT_EQUAL((size_t)1, msg->blockIndex); CPPUNIT_ASSERT_EQUAL((size_t)1, msg->getIndex());
++itr; CPPUNIT_ASSERT_EQUAL((size_t)0, msg->getBlockIndex());
msg = std::dynamic_pointer_cast<MockBtRequestMessage>(*itr);
CPPUNIT_ASSERT_EQUAL((size_t)1, msg->index);
CPPUNIT_ASSERT_EQUAL((size_t)0, msg->blockIndex);
{ {
std::vector<std::shared_ptr<BtMessage> > msgs; auto msgs = requestFactory_->createRequestMessages(3, false);
requestFactory_->createRequestMessages(msgs, 3);
CPPUNIT_ASSERT_EQUAL((size_t)1, msgs.size()); CPPUNIT_ASSERT_EQUAL((size_t)1, msgs.size());
} }
} }
void DefaultBtRequestFactoryTest::testCreateRequestMessages_onEndGame() { void DefaultBtRequestFactoryTest::testCreateRequestMessages_onEndGame()
std::shared_ptr<MockBtMessageDispatcher2> dispatcher {
(new MockBtMessageDispatcher2()); auto dispatcher = make_unique<MockBtMessageDispatcher2>();
requestFactory_->setBtMessageDispatcher(dispatcher.get()); requestFactory_->setBtMessageDispatcher(dispatcher.get());
int PIECE_LENGTH = 16*1024*2; int PIECE_LENGTH = 16*1024*2;
std::shared_ptr<Piece> piece1(new Piece(0, PIECE_LENGTH)); auto piece1 = std::make_shared<Piece>(0, PIECE_LENGTH);
std::shared_ptr<Piece> piece2(new Piece(1, PIECE_LENGTH)); auto piece2 = std::make_shared<Piece>(1, PIECE_LENGTH);
requestFactory_->addTargetPiece(piece1); requestFactory_->addTargetPiece(piece1);
requestFactory_->addTargetPiece(piece2); requestFactory_->addTargetPiece(piece2);
std::vector<std::shared_ptr<BtMessage> > msgs; auto msgs = requestFactory_->createRequestMessages(3, true);
requestFactory_->createRequestMessagesOnEndGame(msgs, 3); std::sort(std::begin(msgs), std::end(msgs), BtRequestMessageSorter());
std::vector<std::shared_ptr<MockBtRequestMessage> > mmsgs; CPPUNIT_ASSERT_EQUAL((size_t)3, msgs.size());
for(std::vector<std::shared_ptr<BtMessage> >::iterator i = msgs.begin(); auto msg = msgs[0].get();
i != msgs.end(); ++i) { CPPUNIT_ASSERT_EQUAL((size_t)0, msg->getIndex());
mmsgs.push_back(std::dynamic_pointer_cast<MockBtRequestMessage>(*i)); CPPUNIT_ASSERT_EQUAL((size_t)1, msg->getBlockIndex());
msg = msgs[1].get();
CPPUNIT_ASSERT_EQUAL((size_t)1, msg->getIndex());
CPPUNIT_ASSERT_EQUAL((size_t)0, msg->getBlockIndex());
msg = msgs[2].get();
CPPUNIT_ASSERT_EQUAL((size_t)1, msg->getIndex());
CPPUNIT_ASSERT_EQUAL((size_t)1, msg->getBlockIndex());
} }
std::sort(mmsgs.begin(), mmsgs.end(), SortMockBtRequestMessage()); void DefaultBtRequestFactoryTest::testRemoveTargetPiece()
{
CPPUNIT_ASSERT_EQUAL((size_t)3, mmsgs.size()); auto piece1 = std::make_shared<Piece>(0, 16*1024);
std::vector<std::shared_ptr<MockBtRequestMessage> >::iterator itr =mmsgs.begin();
MockBtRequestMessage* msg = (*itr).get();
CPPUNIT_ASSERT_EQUAL((size_t)0, msg->index);
CPPUNIT_ASSERT_EQUAL((size_t)1, msg->blockIndex);
++itr;
msg = (*itr).get();
CPPUNIT_ASSERT_EQUAL((size_t)1, msg->index);
CPPUNIT_ASSERT_EQUAL((size_t)0, msg->blockIndex);
++itr;
msg = (*itr).get();
CPPUNIT_ASSERT_EQUAL((size_t)1, msg->index);
CPPUNIT_ASSERT_EQUAL((size_t)1, msg->blockIndex);
}
void DefaultBtRequestFactoryTest::testRemoveTargetPiece() {
std::shared_ptr<Piece> piece1(new Piece(0, 16*1024));
requestFactory_->addTargetPiece(piece1); requestFactory_->addTargetPiece(piece1);
@ -224,16 +201,15 @@ void DefaultBtRequestFactoryTest::testRemoveTargetPiece() {
void DefaultBtRequestFactoryTest::testGetTargetPieceIndexes() void DefaultBtRequestFactoryTest::testGetTargetPieceIndexes()
{ {
std::shared_ptr<Piece> piece1(new Piece(1, 16*1024)); auto piece1 = std::make_shared<Piece>(1, 16*1024);
std::shared_ptr<Piece> piece3(new Piece(3, 16*1024)); auto piece3 = std::make_shared<Piece>(3, 16*1024);
std::shared_ptr<Piece> piece5(new Piece(5, 16*1024)); auto piece5 = std::make_shared<Piece>(5, 16*1024);
requestFactory_->addTargetPiece(piece3); requestFactory_->addTargetPiece(piece3);
requestFactory_->addTargetPiece(piece1); requestFactory_->addTargetPiece(piece1);
requestFactory_->addTargetPiece(piece5); requestFactory_->addTargetPiece(piece5);
std::vector<size_t> indexes; auto indexes = requestFactory_->getTargetPieceIndexes();
requestFactory_->getTargetPieceIndexes(indexes);
CPPUNIT_ASSERT_EQUAL((size_t)3, indexes.size()); CPPUNIT_ASSERT_EQUAL((size_t)3, indexes.size());
CPPUNIT_ASSERT_EQUAL((size_t)3, indexes[0]); CPPUNIT_ASSERT_EQUAL((size_t)3, indexes[0]);
CPPUNIT_ASSERT_EQUAL((size_t)1, indexes[1]); CPPUNIT_ASSERT_EQUAL((size_t)1, indexes[1]);

View File

@ -12,18 +12,12 @@ namespace aria2 {
class MockBtMessageDispatcher : public BtMessageDispatcher { class MockBtMessageDispatcher : public BtMessageDispatcher {
public: public:
std::deque<std::shared_ptr<BtMessage> > messageQueue; std::deque<std::unique_ptr<BtMessage>> messageQueue;
virtual ~MockBtMessageDispatcher() {} virtual ~MockBtMessageDispatcher() {}
virtual void addMessageToQueue(const std::shared_ptr<BtMessage>& btMessage) { virtual void addMessageToQueue(std::unique_ptr<BtMessage> btMessage) {
messageQueue.push_back(btMessage); messageQueue.push_back(std::move(btMessage));
}
virtual void addMessageToQueue
(const std::vector<std::shared_ptr<BtMessage> >& btMessages)
{
std::copy(btMessages.begin(), btMessages.end(), back_inserter(messageQueue));
} }
virtual void sendMessages() {} virtual void sendMessages() {}

View File

@ -3,6 +3,24 @@
#include "BtMessageFactory.h" #include "BtMessageFactory.h"
#include "BtHandshakeMessage.h"
#include "BtRequestMessage.h"
#include "BtCancelMessage.h"
#include "BtPieceMessage.h"
#include "BtHaveMessage.h"
#include "BtChokeMessage.h"
#include "BtUnchokeMessage.h"
#include "BtInterestedMessage.h"
#include "BtNotInterestedMessage.h"
#include "BtBitfieldMessage.h"
#include "BtKeepAliveMessage.h"
#include "BtHaveAllMessage.h"
#include "BtHaveNoneMessage.h"
#include "BtRejectMessage.h"
#include "BtAllowedFastMessage.h"
#include "BtPortMessage.h"
#include "BtExtendedMessage.h"
namespace aria2 { namespace aria2 {
class ExtensionMessage; class ExtensionMessage;
@ -13,91 +31,92 @@ public:
virtual ~MockBtMessageFactory() {} virtual ~MockBtMessageFactory() {}
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtMessage>
createBtMessage(const unsigned char* msg, size_t msgLength) { createBtMessage(const unsigned char* msg, size_t msgLength) {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtMessage>{};
}; };
virtual std::shared_ptr<BtHandshakeMessage> virtual std::unique_ptr<BtHandshakeMessage>
createHandshakeMessage(const unsigned char* msg, size_t msgLength) { createHandshakeMessage(const unsigned char* msg, size_t msgLength) {
return std::shared_ptr<BtHandshakeMessage>(); return std::unique_ptr<BtHandshakeMessage>{};
} }
virtual std::shared_ptr<BtHandshakeMessage> virtual std::unique_ptr<BtHandshakeMessage>
createHandshakeMessage(const unsigned char* infoHash, createHandshakeMessage(const unsigned char* infoHash,
const unsigned char* peerId) { const unsigned char* peerId) {
return std::shared_ptr<BtHandshakeMessage>(); return std::unique_ptr<BtHandshakeMessage>{};
} }
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtRequestMessage>
createRequestMessage(const std::shared_ptr<Piece>& piece, size_t blockIndex) { createRequestMessage(const std::shared_ptr<Piece>& piece, size_t blockIndex) {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtRequestMessage>{};
} }
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtCancelMessage>
createCancelMessage(size_t index, int32_t begin, int32_t length) { createCancelMessage(size_t index, int32_t begin, int32_t length) {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtCancelMessage>{};
} }
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtPieceMessage>
createPieceMessage(size_t index, int32_t begin, int32_t length) { createPieceMessage(size_t index, int32_t begin, int32_t length) {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtPieceMessage>{};
} }
virtual std::shared_ptr<BtMessage> createHaveMessage(size_t index) { virtual std::unique_ptr<BtHaveMessage> createHaveMessage(size_t index) {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtHaveMessage>{};
} }
virtual std::shared_ptr<BtMessage> createChokeMessage() { virtual std::unique_ptr<BtChokeMessage> createChokeMessage() {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtChokeMessage>{};
} }
virtual std::shared_ptr<BtMessage> createUnchokeMessage() { virtual std::unique_ptr<BtUnchokeMessage> createUnchokeMessage() {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtUnchokeMessage>{};
} }
virtual std::shared_ptr<BtMessage> createInterestedMessage() { virtual std::unique_ptr<BtInterestedMessage> createInterestedMessage() {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtInterestedMessage>{};
} }
virtual std::shared_ptr<BtMessage> createNotInterestedMessage() { virtual std::unique_ptr<BtNotInterestedMessage> createNotInterestedMessage() {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtNotInterestedMessage>{};
} }
virtual std::shared_ptr<BtMessage> createBitfieldMessage() { virtual std::unique_ptr<BtBitfieldMessage> createBitfieldMessage() {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtBitfieldMessage>{};
} }
virtual std::shared_ptr<BtMessage> createKeepAliveMessage() { virtual std::unique_ptr<BtKeepAliveMessage> createKeepAliveMessage() {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtKeepAliveMessage>{};
} }
virtual std::shared_ptr<BtMessage> createHaveAllMessage() { virtual std::unique_ptr<BtHaveAllMessage> createHaveAllMessage() {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtHaveAllMessage>{};
} }
virtual std::shared_ptr<BtMessage> createHaveNoneMessage() { virtual std::unique_ptr<BtHaveNoneMessage> createHaveNoneMessage() {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtHaveNoneMessage>{};
} }
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtRejectMessage>
createRejectMessage(size_t index, int32_t begin, int32_t length) { createRejectMessage(size_t index, int32_t begin, int32_t length) {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtRejectMessage>{};
} }
virtual std::shared_ptr<BtMessage> createAllowedFastMessage(size_t index) { virtual std::unique_ptr<BtAllowedFastMessage> createAllowedFastMessage
return std::shared_ptr<BtMessage>(); (size_t index) {
return std::unique_ptr<BtAllowedFastMessage>{};
} }
virtual std::shared_ptr<BtMessage> createPortMessage(uint16_t port) virtual std::unique_ptr<BtPortMessage> createPortMessage(uint16_t port)
{ {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtPortMessage>{};
} }
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtExtendedMessage>
createBtExtendedMessage(const std::shared_ptr<ExtensionMessage>& extmsg) createBtExtendedMessage(const std::shared_ptr<ExtensionMessage>& extmsg)
{ {
return std::shared_ptr<BtMessage>(); return std::unique_ptr<BtExtendedMessage>{};
} }
}; };

View File

@ -2,6 +2,7 @@
#define D_MOCK_BT_REQUEST_FACTORY_H #define D_MOCK_BT_REQUEST_FACTORY_H
#include "BtRequestFactory.h" #include "BtRequestFactory.h"
#include "BtRequestMessage.h"
namespace aria2 { namespace aria2 {
@ -23,13 +24,16 @@ public:
virtual void doChokedAction() {} virtual void doChokedAction() {}
virtual void createRequestMessages virtual std::vector<std::unique_ptr<BtRequestMessage>> createRequestMessages
(std::vector<std::shared_ptr<BtMessage> >& requests, size_t max) {} (size_t max, bool endGame)
{
return std::vector<std::unique_ptr<BtRequestMessage>>{};
}
virtual void createRequestMessagesOnEndGame virtual std::vector<size_t> getTargetPieceIndexes() const
(std::vector<std::shared_ptr<BtMessage> >& requests, size_t max) {} {
return std::vector<size_t>{};
virtual void getTargetPieceIndexes(std::vector<size_t>& indexes) const {} }
}; };
} // namespace aria2 } // namespace aria2

View File

@ -15,27 +15,28 @@ public:
MockExtensionMessage(const std::string& extensionName, MockExtensionMessage(const std::string& extensionName,
uint8_t extensionMessageID, uint8_t extensionMessageID,
const unsigned char* data, const unsigned char* data,
size_t length):extensionName_(extensionName), size_t length)
extensionMessageID_(extensionMessageID), : extensionName_{extensionName},
data_(&data[0], &data[length]), extensionMessageID_{extensionMessageID},
doReceivedActionCalled_(false) {} data_{&data[0], &data[length]},
doReceivedActionCalled_{false}
{}
MockExtensionMessage(const std::string& extensionName, MockExtensionMessage(const std::string& extensionName,
uint8_t extensionMessageID, uint8_t extensionMessageID,
const std::string& data): const std::string& data)
extensionName_(extensionName), : extensionName_{extensionName},
extensionMessageID_(extensionMessageID), extensionMessageID_{extensionMessageID},
data_(data), data_{data},
doReceivedActionCalled_(false) {} doReceivedActionCalled_{false}
{}
virtual ~MockExtensionMessage() {}
virtual std::string getPayload() virtual std::string getPayload()
{ {
return data_; return data_;
} }
virtual uint8_t getExtensionMessageID() virtual uint8_t getExtensionMessageID() const
{ {
return extensionMessageID_; return extensionMessageID_;
} }

View File

@ -32,28 +32,29 @@ class UTMetadataRequestExtensionMessageTest:public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE_END(); CPPUNIT_TEST_SUITE_END();
public: public:
std::shared_ptr<DownloadContext> dctx_; std::shared_ptr<DownloadContext> dctx_;
std::shared_ptr<WrapExtBtMessageFactory> messageFactory_; std::unique_ptr<WrapExtBtMessageFactory> messageFactory_;
std::shared_ptr<MockBtMessageDispatcher> dispatcher_; std::unique_ptr<MockBtMessageDispatcher> dispatcher_;
std::shared_ptr<Peer> peer_; std::shared_ptr<Peer> peer_;
void setUp() void setUp()
{ {
messageFactory_.reset(new WrapExtBtMessageFactory()); messageFactory_ = make_unique<WrapExtBtMessageFactory>();
dispatcher_.reset(new MockBtMessageDispatcher()); dispatcher_ = make_unique<MockBtMessageDispatcher>();
dctx_.reset(new DownloadContext()); dctx_ = std::make_shared<DownloadContext>();
dctx_->setAttribute(CTX_ATTR_BT, make_unique<TorrentAttribute>()); dctx_->setAttribute(CTX_ATTR_BT, make_unique<TorrentAttribute>());
peer_.reset(new Peer("host", 6880)); peer_ = std::make_shared<Peer>("host", 6880);
peer_->allocateSessionResource(0, 0); peer_->allocateSessionResource(0, 0);
peer_->setExtension(ExtensionMessageRegistry::UT_METADATA, 1); peer_->setExtension(ExtensionMessageRegistry::UT_METADATA, 1);
} }
template<typename T> template<typename T>
std::shared_ptr<T> getFirstDispatchedMessage() const T* getFirstDispatchedMessage()
{ {
auto wrapmsg = std::dynamic_pointer_cast<WrapExtBtMessage> CPPUNIT_ASSERT(BtExtendedMessage::ID ==
(dispatcher_->messageQueue.front()); dispatcher_->messageQueue.front()->getId());
auto msg = static_cast<const BtExtendedMessage*>
return std::dynamic_pointer_cast<T>(wrapmsg->m_); (dispatcher_->messageQueue.front().get());
return dynamic_cast<const T*>(msg->getExtensionMessage().get());
} }
void testGetExtensionMessageID(); void testGetExtensionMessageID();
@ -106,8 +107,7 @@ void UTMetadataRequestExtensionMessageTest::testDoReceivedAction_reject()
msg.setBtMessageDispatcher(dispatcher_.get()); msg.setBtMessageDispatcher(dispatcher_.get());
msg.doReceivedAction(); msg.doReceivedAction();
std::shared_ptr<UTMetadataRejectExtensionMessage> m = auto m = getFirstDispatchedMessage<UTMetadataRejectExtensionMessage>();
getFirstDispatchedMessage<UTMetadataRejectExtensionMessage>();
CPPUNIT_ASSERT(m); CPPUNIT_ASSERT(m);
CPPUNIT_ASSERT_EQUAL((size_t)10, m->getIndex()); CPPUNIT_ASSERT_EQUAL((size_t)10, m->getIndex());
@ -132,8 +132,7 @@ void UTMetadataRequestExtensionMessageTest::testDoReceivedAction_data()
msg.doReceivedAction(); msg.doReceivedAction();
std::shared_ptr<UTMetadataDataExtensionMessage> m = auto m = getFirstDispatchedMessage<UTMetadataDataExtensionMessage>();
getFirstDispatchedMessage<UTMetadataDataExtensionMessage>();
CPPUNIT_ASSERT(m); CPPUNIT_ASSERT(m);
CPPUNIT_ASSERT_EQUAL((size_t)1, m->getIndex()); CPPUNIT_ASSERT_EQUAL((size_t)1, m->getIndex());

View File

@ -66,16 +66,14 @@ void UTMetadataRequestFactoryTest::testCreate()
(new UTMetadataRequestTracker()); (new UTMetadataRequestTracker());
factory.setUTMetadataRequestTracker(tracker.get()); factory.setUTMetadataRequestTracker(tracker.get());
std::vector<std::shared_ptr<BtMessage> > msgs; auto msgs = factory.create(1, ps);
factory.create(msgs, 1, ps);
CPPUNIT_ASSERT_EQUAL((size_t)1, msgs.size()); CPPUNIT_ASSERT_EQUAL((size_t)1, msgs.size());
factory.create(msgs, 1, ps); msgs = factory.create(1, ps);
CPPUNIT_ASSERT_EQUAL((size_t)2, msgs.size()); CPPUNIT_ASSERT_EQUAL((size_t)1, msgs.size());
factory.create(msgs, 1, ps); msgs = factory.create(1, ps);
CPPUNIT_ASSERT_EQUAL((size_t)2, msgs.size()); CPPUNIT_ASSERT_EQUAL((size_t)0, msgs.size());
} }
} // namespace aria2 } // namespace aria2

View File

@ -1,19 +1,17 @@
#ifndef D_EXTENSION_MESSAGE_TEST_HELPER_H #ifndef D_EXTENSION_MESSAGE_TEST_HELPER_H
#define D_EXTENSION_MESSAGE_TEST_HELPER_H #define D_EXTENSION_MESSAGE_TEST_HELPER_H
#include "MockBtMessage.h"
#include "MockBtMessageFactory.h" #include "MockBtMessageFactory.h"
namespace aria2 { namespace aria2 {
typedef WrapBtMessage<ExtensionMessage> WrapExtBtMessage;
class WrapExtBtMessageFactory:public MockBtMessageFactory { class WrapExtBtMessageFactory:public MockBtMessageFactory {
public: public:
virtual std::shared_ptr<BtMessage> virtual std::unique_ptr<BtExtendedMessage>
createBtExtendedMessage(const std::shared_ptr<ExtensionMessage>& extmsg) createBtExtendedMessage(const std::shared_ptr<ExtensionMessage>& extmsg)
override
{ {
return std::shared_ptr<BtMessage>(new WrapExtBtMessage(extmsg)); return make_unique<BtExtendedMessage>(extmsg);
} }
}; };