Use std::unique_ptr for ExtensionMessage

pull/103/head
Tatsuhiro Tsujikawa 2013-07-01 21:42:51 +09:00
parent 9e35750bb8
commit 529b9fdceb
34 changed files with 383 additions and 398 deletions

View File

@ -51,14 +51,12 @@ namespace aria2 {
const char BtExtendedMessage::NAME[] = "extended";
BtExtendedMessage::BtExtendedMessage
(const std::shared_ptr<ExtensionMessage>& extensionMessage):
(std::unique_ptr<ExtensionMessage> extensionMessage):
SimpleBtMessage(ID, NAME),
extensionMessage_(extensionMessage),
extensionMessage_(std::move(extensionMessage)),
msgLength_(0)
{}
BtExtendedMessage::~BtExtendedMessage() {}
unsigned char* BtExtendedMessage::createMessage()
{
/**
@ -97,16 +95,15 @@ std::string BtExtendedMessage::toString() const {
}
std::unique_ptr<BtExtendedMessage>
BtExtendedMessage::create(const std::shared_ptr<ExtensionMessageFactory>& factory,
BtExtendedMessage::create(ExtensionMessageFactory* factory,
const std::shared_ptr<Peer>& peer,
const unsigned char* data, size_t dataLength)
{
bittorrent::assertPayloadLengthGreater(1, dataLength, NAME);
bittorrent::assertID(ID, data, NAME);
assert(factory);
std::shared_ptr<ExtensionMessage> extmsg = factory->createMessage(data+1,
dataLength-1);
return make_unique<BtExtendedMessage>(extmsg);
return make_unique<BtExtendedMessage>
(factory->createMessage(data+1, dataLength-1));
}
void BtExtendedMessage::doReceivedAction()
@ -116,4 +113,10 @@ void BtExtendedMessage::doReceivedAction()
}
}
const std::unique_ptr<ExtensionMessage>&
BtExtendedMessage::getExtensionMessage() const
{
return extensionMessage_;
}
} // namespace aria2

View File

@ -44,20 +44,19 @@ class ExtensionMessageFactory;
class BtExtendedMessage:public SimpleBtMessage
{
private:
std::shared_ptr<ExtensionMessage> extensionMessage_;
std::unique_ptr<ExtensionMessage> extensionMessage_;
size_t msgLength_;
public:
BtExtendedMessage(const std::shared_ptr<ExtensionMessage>& extensionMessage =
std::shared_ptr<ExtensionMessage>());
virtual ~BtExtendedMessage();
BtExtendedMessage(std::unique_ptr<ExtensionMessage> extensionMessage =
std::unique_ptr<ExtensionMessage>{});
static const uint8_t ID = 20;
static const char NAME[];
static std::unique_ptr<BtExtendedMessage> create
(const std::shared_ptr<ExtensionMessageFactory>& factory,
(ExtensionMessageFactory* factory,
const std::shared_ptr<Peer>& peer,
const unsigned char* data,
size_t dataLength);
@ -72,11 +71,7 @@ public:
virtual std::string toString() const;
const std::shared_ptr<ExtensionMessage>& getExtensionMessage() const
{
return extensionMessage_;
}
const std::unique_ptr<ExtensionMessage>& getExtensionMessage() const;
};
} // namespace aria2

View File

@ -114,7 +114,7 @@ public:
virtual std::unique_ptr<BtPortMessage> createPortMessage(uint16_t port) = 0;
virtual std::unique_ptr<BtExtendedMessage>
createBtExtendedMessage(const std::shared_ptr<ExtensionMessage>& msg) = 0;
createBtExtendedMessage(std::unique_ptr<ExtensionMessage> msg) = 0;
};
} // namespace aria2

View File

@ -199,7 +199,7 @@ void DefaultBtInteractive::addPortMessageToQueue()
void DefaultBtInteractive::addHandshakeExtendedMessageToQueue()
{
std::shared_ptr<HandshakeExtensionMessage> m(new HandshakeExtensionMessage());
auto m = make_unique<HandshakeExtensionMessage>();
m->setClientVersion("aria2/" PACKAGE_VERSION);
m->setTCPPort(tcpPort_);
m->setExtensions(extensionMessageRegistry_->getExtensions());
@ -207,7 +207,8 @@ void DefaultBtInteractive::addHandshakeExtendedMessageToQueue()
if(!attrs->metadata.empty()) {
m->setMetadataSize(attrs->metadataSize);
}
dispatcher_->addMessageToQueue(messageFactory_->createBtExtendedMessage(m));
dispatcher_->addMessageToQueue
(messageFactory_->createBtExtendedMessage(std::move(m)));
}
void DefaultBtInteractive::addBitfieldMessageToQueue() {
@ -479,32 +480,26 @@ void DefaultBtInteractive::checkActiveInteraction()
void DefaultBtInteractive::addPeerExchangeMessage()
{
if(pexTimer_.
difference(global::wallclock()) >= UTPexExtensionMessage::DEFAULT_INTERVAL) {
std::shared_ptr<UTPexExtensionMessage> m
(new UTPexExtensionMessage(peer_->getExtensionMessageID
(ExtensionMessageRegistry::UT_PEX)));
const PeerSet& usedPeers = peerStorage_->getUsedPeers();
for(PeerSet::const_iterator i = usedPeers.begin(), eoi = usedPeers.end();
if(pexTimer_.difference(global::wallclock()) >=
UTPexExtensionMessage::DEFAULT_INTERVAL) {
auto m = make_unique<UTPexExtensionMessage>
(peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_PEX));
auto& usedPeers = peerStorage_->getUsedPeers();
for(auto i = std::begin(usedPeers), eoi = std::end(usedPeers);
i != eoi && !m->freshPeersAreFull(); ++i) {
if((*i)->isActive() && peer_->getIPAddress() != (*i)->getIPAddress()) {
m->addFreshPeer(*i);
}
}
const std::deque<std::shared_ptr<Peer> >& droppedPeers =
peerStorage_->getDroppedPeers();
for(std::deque<std::shared_ptr<Peer> >::const_iterator i =
droppedPeers.begin(), eoi = droppedPeers.end();
i != eoi && !m->droppedPeersAreFull();
++i) {
auto& droppedPeers = peerStorage_->getDroppedPeers();
for(auto i = std::begin(droppedPeers), eoi = std::end(droppedPeers);
i != eoi && !m->droppedPeersAreFull(); ++i) {
if(peer_->getIPAddress() != (*i)->getIPAddress()) {
m->addDroppedPeer(*i);
}
}
dispatcher_->addMessageToQueue
(messageFactory_->createBtExtendedMessage(m));
(messageFactory_->createBtExtendedMessage(std::move(m)));
pexTimer_ = global::wallclock();
}
}
@ -522,7 +517,7 @@ void DefaultBtInteractive::doInteractionProcessing() {
size_t num = utMetadataRequestTracker_->avail();
if(num > 0) {
auto requests =
utMetadataRequestFactory_->create(num, pieceStorage_);
utMetadataRequestFactory_->create(num, pieceStorage_.get());
for(auto& i : requests) {
dispatcher_->addMessageToQueue(std::move(i));
}
@ -531,11 +526,9 @@ void DefaultBtInteractive::doInteractionProcessing() {
perSecTimer_ = global::wallclock();
// Drop timeout request after queuing message to give a chance
// to other connection to request piece.
std::vector<size_t> indexes =
utMetadataRequestTracker_->removeTimeoutEntry();
for(std::vector<size_t>::const_iterator i = indexes.begin(),
eoi = indexes.end(); i != eoi; ++i) {
pieceStorage_->cancelPiece(pieceStorage_->getPiece(*i), cuid_);
auto indexes = utMetadataRequestTracker_->removeTimeoutEntry();
for(auto idx : indexes) {
pieceStorage_->cancelPiece(pieceStorage_->getPiece(idx), cuid_);
}
}
if(pieceStorage_->downloadFinished()) {
@ -645,9 +638,9 @@ void DefaultBtInteractive::setPeerConnection
}
void DefaultBtInteractive::setExtensionMessageFactory
(const std::shared_ptr<ExtensionMessageFactory>& factory)
(std::unique_ptr<ExtensionMessageFactory> factory)
{
extensionMessageFactory_ = factory;
extensionMessageFactory_ = std::move(factory);
}
void DefaultBtInteractive::setBtMessageFactory

View File

@ -117,7 +117,7 @@ private:
// holds the reference so that peerConnection_ is not deleted.
std::shared_ptr<PeerConnection> peerConnection_;
std::shared_ptr<BtMessageFactory> messageFactory_;
std::shared_ptr<ExtensionMessageFactory> extensionMessageFactory_;
std::unique_ptr<ExtensionMessageFactory> extensionMessageFactory_;
std::shared_ptr<ExtensionMessageRegistry> extensionMessageRegistry_;
std::shared_ptr<UTMetadataRequestFactory> utMetadataRequestFactory_;
std::shared_ptr<UTMetadataRequestTracker> utMetadataRequestTracker_;
@ -219,7 +219,7 @@ public:
void setBtMessageFactory(const std::shared_ptr<BtMessageFactory>& factory);
void setExtensionMessageFactory
(const std::shared_ptr<ExtensionMessageFactory>& factory);
(std::unique_ptr<ExtensionMessageFactory> factory);
void setExtensionMessageRegistry
(const std::shared_ptr<ExtensionMessageRegistry>& registry)

View File

@ -79,6 +79,7 @@ DefaultBtMessageFactory::DefaultBtMessageFactory()
dispatcher_{nullptr},
requestFactory_{nullptr},
peerConnection_{nullptr},
extensionMessageFactory_{nullptr},
localNode_{nullptr},
routingTable_{nullptr},
taskQueue_{nullptr},
@ -402,9 +403,9 @@ DefaultBtMessageFactory::createPortMessage(uint16_t port)
std::unique_ptr<BtExtendedMessage>
DefaultBtMessageFactory::createBtExtendedMessage
(const std::shared_ptr<ExtensionMessage>& exmsg)
(std::unique_ptr<ExtensionMessage> exmsg)
{
auto msg = make_unique<BtExtendedMessage>(exmsg);
auto msg = make_unique<BtExtendedMessage>(std::move(exmsg));
setCommonProperty(msg.get());
return msg;
}
@ -447,7 +448,7 @@ void DefaultBtMessageFactory::setBtMessageDispatcher
}
void DefaultBtMessageFactory::setExtensionMessageFactory
(const std::shared_ptr<ExtensionMessageFactory>& factory)
(ExtensionMessageFactory* factory)
{
extensionMessageFactory_ = factory;
}

View File

@ -70,7 +70,7 @@ private:
PeerConnection* peerConnection_;
std::shared_ptr<ExtensionMessageFactory> extensionMessageFactory_;
ExtensionMessageFactory* extensionMessageFactory_;
DHTNode* localNode_;
@ -133,7 +133,7 @@ public:
virtual std::unique_ptr<BtPortMessage> createPortMessage(uint16_t port);
virtual std::unique_ptr<BtExtendedMessage>
createBtExtendedMessage(const std::shared_ptr<ExtensionMessage>& msg);
createBtExtendedMessage(std::unique_ptr<ExtensionMessage> msg);
void setPeer(const std::shared_ptr<Peer>& peer);
@ -159,8 +159,7 @@ public:
void setPeerConnection(PeerConnection* connection);
void setExtensionMessageFactory
(const std::shared_ptr<ExtensionMessageFactory>& factory);
void setExtensionMessageFactory(ExtensionMessageFactory* factory);
void setLocalNode(DHTNode* localNode);

View File

@ -58,34 +58,31 @@
namespace aria2 {
DefaultExtensionMessageFactory::DefaultExtensionMessageFactory()
: messageFactory_(0),
dispatcher_(0),
tracker_(0)
: DefaultExtensionMessageFactory{std::shared_ptr<Peer>{}, nullptr}
{}
DefaultExtensionMessageFactory::DefaultExtensionMessageFactory
(const std::shared_ptr<Peer>& peer,
const std::shared_ptr<ExtensionMessageRegistry>& registry)
: peer_(peer),
registry_(registry),
messageFactory_(0),
dispatcher_(0),
tracker_(0)
(const std::shared_ptr<Peer>& peer, ExtensionMessageRegistry* registry)
: peerStorage_{nullptr},
peer_{peer},
registry_{registry},
dctx_{nullptr},
messageFactory_{nullptr},
dispatcher_{nullptr},
tracker_{nullptr}
{}
DefaultExtensionMessageFactory::~DefaultExtensionMessageFactory() {}
std::shared_ptr<ExtensionMessage>
DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t length)
std::unique_ptr<ExtensionMessage>
DefaultExtensionMessageFactory::createMessage
(const unsigned char* data, size_t length)
{
uint8_t extensionMessageID = *data;
if(extensionMessageID == 0) {
// handshake
HandshakeExtensionMessage* m =
HandshakeExtensionMessage::create(data, length);
auto m = HandshakeExtensionMessage::create(data, length);
m->setPeer(peer_);
m->setDownloadContext(dctx_);
return std::shared_ptr<ExtensionMessage>(m);
return std::move(m);
} else {
const char* extensionName = registry_->getExtensionName(extensionMessageID);
if(!extensionName) {
@ -95,9 +92,9 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
}
if(strcmp(extensionName, "ut_pex") == 0) {
// uTorrent compatible Peer-Exchange
UTPexExtensionMessage* m = UTPexExtensionMessage::create(data, length);
auto m = UTPexExtensionMessage::create(data, length);
m->setPeerStorage(peerStorage_);
return std::shared_ptr<ExtensionMessage>(m);
return std::move(m);
} else if(strcmp(extensionName, "ut_metadata") == 0) {
if(length == 0) {
throw DL_ABORT_EX
@ -106,8 +103,7 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
static_cast<unsigned long>(length)));
}
size_t end;
std::shared_ptr<ValueBase> decoded =
bencode2::decode(data+1, length - 1, end);
auto decoded = bencode2::decode(data+1, length - 1, end);
const Dict* dict = downcast<Dict>(decoded);
if(!dict) {
throw DL_ABORT_EX("Bad ut_metadata: dictionary not found");
@ -122,14 +118,14 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
}
switch(msgType->i()) {
case 0: {
UTMetadataRequestExtensionMessage* m
(new UTMetadataRequestExtensionMessage(extensionMessageID));
auto m = make_unique<UTMetadataRequestExtensionMessage>
(extensionMessageID);
m->setIndex(index->i());
m->setDownloadContext(dctx_);
m->setPeer(peer_);
m->setBtMessageFactory(messageFactory_);
m->setBtMessageDispatcher(dispatcher_);
return std::shared_ptr<ExtensionMessage>(m);
return std::move(m);
}
case 1: {
if(end == length) {
@ -139,22 +135,23 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
if(!totalSize) {
throw DL_ABORT_EX("Bad ut_metadata data: total_size not found");
}
UTMetadataDataExtensionMessage* m
(new UTMetadataDataExtensionMessage(extensionMessageID));
auto m = make_unique<UTMetadataDataExtensionMessage>
(extensionMessageID);
m->setIndex(index->i());
m->setTotalSize(totalSize->i());
m->setData(&data[1+end], &data[length]);
m->setUTMetadataRequestTracker(tracker_);
m->setPieceStorage(dctx_->getOwnerRequestGroup()->getPieceStorage());
m->setPieceStorage
(dctx_->getOwnerRequestGroup()->getPieceStorage().get());
m->setDownloadContext(dctx_);
return std::shared_ptr<ExtensionMessage>(m);
return std::move(m);
}
case 2: {
UTMetadataRejectExtensionMessage* m
(new UTMetadataRejectExtensionMessage(extensionMessageID));
auto m = make_unique<UTMetadataRejectExtensionMessage>
(extensionMessageID);
m->setIndex(index->i());
// No need to inject tracker because peer will be disconnected.
return std::shared_ptr<ExtensionMessage>(m);
return std::move(m);
}
default:
throw DL_ABORT_EX
@ -170,8 +167,7 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
}
}
void DefaultExtensionMessageFactory::setPeerStorage
(const std::shared_ptr<PeerStorage>& peerStorage)
void DefaultExtensionMessageFactory::setPeerStorage(PeerStorage* peerStorage)
{
peerStorage_ = peerStorage;
}
@ -181,4 +177,33 @@ void DefaultExtensionMessageFactory::setPeer(const std::shared_ptr<Peer>& peer)
peer_ = peer;
}
void DefaultExtensionMessageFactory::setExtensionMessageRegistry
(ExtensionMessageRegistry* registry)
{
registry_ = registry;
}
void DefaultExtensionMessageFactory::setDownloadContext(DownloadContext* dctx)
{
dctx_ = dctx;
}
void DefaultExtensionMessageFactory::setBtMessageFactory
(BtMessageFactory* factory)
{
messageFactory_ = factory;
}
void DefaultExtensionMessageFactory::setBtMessageDispatcher
(BtMessageDispatcher* disp)
{
dispatcher_ = disp;
}
void DefaultExtensionMessageFactory::setUTMetadataRequestTracker
(UTMetadataRequestTracker* tracker)
{
tracker_ = tracker;
}
} // namespace aria2

View File

@ -49,13 +49,13 @@ class UTMetadataRequestTracker;
class DefaultExtensionMessageFactory:public ExtensionMessageFactory {
private:
std::shared_ptr<PeerStorage> peerStorage_;
PeerStorage* peerStorage_;
std::shared_ptr<Peer> peer_;
std::shared_ptr<ExtensionMessageRegistry> registry_;
ExtensionMessageRegistry* registry_;
std::shared_ptr<DownloadContext> dctx_;
DownloadContext* dctx_;
BtMessageFactory* messageFactory_;
@ -66,43 +66,24 @@ public:
DefaultExtensionMessageFactory();
DefaultExtensionMessageFactory
(const std::shared_ptr<Peer>& peer,
const std::shared_ptr<ExtensionMessageRegistry>& registry);
(const std::shared_ptr<Peer>& peer, ExtensionMessageRegistry* registry);
virtual ~DefaultExtensionMessageFactory();
virtual std::shared_ptr<ExtensionMessage>
virtual std::unique_ptr<ExtensionMessage>
createMessage(const unsigned char* data, size_t length);
void setPeerStorage(const std::shared_ptr<PeerStorage>& peerStorage);
void setPeerStorage(PeerStorage* peerStorage);
void setPeer(const std::shared_ptr<Peer>& peer);
void setExtensionMessageRegistry
(const std::shared_ptr<ExtensionMessageRegistry>& registry)
{
registry_ = registry;
}
void setExtensionMessageRegistry(ExtensionMessageRegistry* registry);
void setDownloadContext(const std::shared_ptr<DownloadContext>& dctx)
{
dctx_ = dctx;
}
void setDownloadContext(DownloadContext* dctx);
void setBtMessageFactory(BtMessageFactory* factory)
{
messageFactory_ = factory;
}
void setBtMessageFactory(BtMessageFactory* factory);
void setBtMessageDispatcher(BtMessageDispatcher* disp)
{
dispatcher_ = disp;
}
void setBtMessageDispatcher(BtMessageDispatcher* disp);
void setUTMetadataRequestTracker(UTMetadataRequestTracker* tracker)
{
tracker_ = tracker;
}
void setUTMetadataRequestTracker(UTMetadataRequestTracker* tracker);
};
} // namespace aria2

View File

@ -47,7 +47,7 @@ class ExtensionMessageFactory {
public:
virtual ~ExtensionMessageFactory() {}
virtual std::shared_ptr<ExtensionMessage>
virtual std::unique_ptr<ExtensionMessage>
createMessage(const unsigned char* data, size_t length) = 0;
};

View File

@ -52,12 +52,9 @@ namespace aria2 {
const char HandshakeExtensionMessage::EXTENSION_NAME[] = "handshake";
HandshakeExtensionMessage::HandshakeExtensionMessage()
: tcpPort_(0),
metadataSize_(0)
: tcpPort_{0}, metadataSize_{0}, dctx_{nullptr}
{}
HandshakeExtensionMessage::~HandshakeExtensionMessage() {}
std::string HandshakeExtensionMessage::getPayload()
{
Dict dict;
@ -127,12 +124,9 @@ void HandshakeExtensionMessage::doReceivedAction()
dctx_->getFirstFileEntry()->setLength(metadataSize_);
dctx_->markTotalLengthIsKnown();
dctx_->getOwnerRequestGroup()->initPieceStorage();
std::shared_ptr<PieceStorage> pieceStorage =
dctx_->getOwnerRequestGroup()->getPieceStorage();
// We enter 'end game' mode from the start to get metadata
// quickly.
pieceStorage->enterEndGame();
dctx_->getOwnerRequestGroup()->getPieceStorage()->enterEndGame();
}
peer_->reconfigureSessionResource(dctx_->getPieceLength(),
dctx_->getTotalLength());
@ -164,7 +158,7 @@ uint8_t HandshakeExtensionMessage::getExtensionMessageID(int key) const
return extreg_.getExtensionMessageID(key);
}
HandshakeExtensionMessage*
std::unique_ptr<HandshakeExtensionMessage>
HandshakeExtensionMessage::create(const unsigned char* data, size_t length)
{
if(length < 1) {
@ -174,13 +168,13 @@ HandshakeExtensionMessage::create(const unsigned char* data, size_t length)
}
A2_LOG_DEBUG(fmt("Creating HandshakeExtensionMessage from %s",
util::percentEncode(data, length).c_str()));
std::shared_ptr<ValueBase> decoded = bencode2::decode(data+1, length - 1);
auto decoded = bencode2::decode(data+1, length - 1);
const Dict* dict = downcast<Dict>(decoded);
if(!dict) {
throw DL_ABORT_EX
("Unexpected payload format for extended message handshake");
}
HandshakeExtensionMessage* msg(new HandshakeExtensionMessage());
auto msg = make_unique<HandshakeExtensionMessage>();
const Integer* port = downcast<Integer>(dict->get("p"));
if(port && 0 < port->i() && port->i() < 65536) {
msg->tcpPort_ = port->i();

View File

@ -57,14 +57,12 @@ private:
ExtensionMessageRegistry extreg_;
std::shared_ptr<DownloadContext> dctx_;
DownloadContext* dctx_;
std::shared_ptr<Peer> peer_;
public:
HandshakeExtensionMessage();
virtual ~HandshakeExtensionMessage();
virtual std::string getPayload();
virtual uint8_t getExtensionMessageID() const
@ -113,7 +111,7 @@ public:
metadataSize_ = size;
}
void setDownloadContext(const std::shared_ptr<DownloadContext>& dctx)
void setDownloadContext(DownloadContext* dctx)
{
dctx_ = dctx;
}
@ -126,7 +124,7 @@ public:
void setPeer(const std::shared_ptr<Peer>& peer);
static HandshakeExtensionMessage*
static std::unique_ptr<HandshakeExtensionMessage>
create(const unsigned char* data, size_t dataLength);
};

View File

@ -132,25 +132,22 @@ PeerInteractionCommand::PeerInteractionCommand
utMetadataRequestTracker.reset(new UTMetadataRequestTracker());
}
DefaultExtensionMessageFactory* extensionMessageFactoryPtr
(new DefaultExtensionMessageFactory(getPeer(), exMsgRegistry));
extensionMessageFactoryPtr->setPeerStorage(peerStorage);
extensionMessageFactoryPtr->setDownloadContext
(requestGroup_->getDownloadContext());
extensionMessageFactoryPtr->setUTMetadataRequestTracker
auto extensionMessageFactory =
make_unique<DefaultExtensionMessageFactory>(getPeer(),
exMsgRegistry.get());
extensionMessageFactory->setPeerStorage(peerStorage.get());
extensionMessageFactory->setDownloadContext
(requestGroup_->getDownloadContext().get());
extensionMessageFactory->setUTMetadataRequestTracker
(utMetadataRequestTracker.get());
// PieceStorage will be set later.
std::shared_ptr<ExtensionMessageFactory> extensionMessageFactory
(extensionMessageFactoryPtr);
DefaultBtMessageFactory* factoryPtr(new DefaultBtMessageFactory());
factoryPtr->setCuid(cuid);
factoryPtr->setDownloadContext(requestGroup_->getDownloadContext().get());
factoryPtr->setPieceStorage(pieceStorage.get());
factoryPtr->setPeerStorage(peerStorage.get());
factoryPtr->setExtensionMessageFactory(extensionMessageFactory);
factoryPtr->setExtensionMessageFactory(extensionMessageFactory.get());
factoryPtr->setPeer(getPeer());
if(family == AF_INET) {
factoryPtr->setLocalNode(DHTRegistry::getData().localNode.get());
@ -218,6 +215,30 @@ PeerInteractionCommand::PeerInteractionCommand
reqFactoryPtr->setCuid(cuid);
std::shared_ptr<BtRequestFactory> reqFactory(reqFactoryPtr);
// reverse depends
factoryPtr->setBtMessageDispatcher(dispatcherPtr);
factoryPtr->setBtRequestFactory(reqFactoryPtr);
factoryPtr->setPeerConnection(peerConnection.get());
extensionMessageFactory->setBtMessageDispatcher(dispatcherPtr);
extensionMessageFactory->setBtMessageFactory(factoryPtr);
if(metadataGetMode) {
utMetadataRequestFactory->setCuid(cuid);
utMetadataRequestFactory->setDownloadContext
(requestGroup_->getDownloadContext().get());
utMetadataRequestFactory->setBtMessageDispatcher(dispatcherPtr);
utMetadataRequestFactory->setBtMessageFactory(factoryPtr);
utMetadataRequestFactory->setPeer(getPeer());
utMetadataRequestFactory->setUTMetadataRequestTracker
(utMetadataRequestTracker.get());
}
getPeer()->allocateSessionResource
(requestGroup_->getDownloadContext()->getPieceLength(),
requestGroup_->getDownloadContext()->getTotalLength());
getPeer()->setBtMessageDispatcher(dispatcherPtr);
DefaultBtInteractive* btInteractivePtr
(new DefaultBtInteractive(requestGroup_->getDownloadContext(), getPeer()));
btInteractivePtr->setBtRuntime(btRuntime_);
@ -228,7 +249,8 @@ PeerInteractionCommand::PeerInteractionCommand
btInteractivePtr->setDispatcher(dispatcher);
btInteractivePtr->setBtRequestFactory(reqFactory);
btInteractivePtr->setPeerConnection(peerConnection);
btInteractivePtr->setExtensionMessageFactory(extensionMessageFactory);
btInteractivePtr->setExtensionMessageFactory
(std::move(extensionMessageFactory));
btInteractivePtr->setExtensionMessageRegistry(exMsgRegistry);
btInteractivePtr->setKeepAliveInterval
(getOption()->getAsInt(PREF_BT_KEEP_ALIVE_INTERVAL));
@ -264,29 +286,6 @@ PeerInteractionCommand::PeerInteractionCommand
btInteractive_ = btInteractive;
// reverse depends
factoryPtr->setBtMessageDispatcher(dispatcherPtr);
factoryPtr->setBtRequestFactory(reqFactoryPtr);
factoryPtr->setPeerConnection(peerConnection.get());
extensionMessageFactoryPtr->setBtMessageDispatcher(dispatcherPtr);
extensionMessageFactoryPtr->setBtMessageFactory(factoryPtr);
if(metadataGetMode) {
utMetadataRequestFactory->setCuid(cuid);
utMetadataRequestFactory->setDownloadContext
(requestGroup_->getDownloadContext());
utMetadataRequestFactory->setBtMessageDispatcher(dispatcherPtr);
utMetadataRequestFactory->setBtMessageFactory(factoryPtr);
utMetadataRequestFactory->setPeer(getPeer());
utMetadataRequestFactory->setUTMetadataRequestTracker
(utMetadataRequestTracker.get());
}
getPeer()->allocateSessionResource
(requestGroup_->getDownloadContext()->getPieceLength(),
requestGroup_->getDownloadContext()->getTotalLength());
getPeer()->setBtMessageDispatcher(dispatcherPtr);
btRuntime_->increaseConnections();
requestGroup_->increaseNumCommand();

View File

@ -54,13 +54,13 @@ namespace aria2 {
UTMetadataDataExtensionMessage::UTMetadataDataExtensionMessage
(uint8_t extensionMessageID)
: UTMetadataExtensionMessage(extensionMessageID),
totalSize_(0),
tracker_(0)
: UTMetadataExtensionMessage{extensionMessageID},
totalSize_{0},
dctx_{nullptr},
pieceStorage_{nullptr},
tracker_{nullptr}
{}
UTMetadataDataExtensionMessage::~UTMetadataDataExtensionMessage() {}
std::string UTMetadataDataExtensionMessage::getPayload()
{
Dict dict;
@ -109,19 +109,39 @@ void UTMetadataDataExtensionMessage::doReceivedAction()
}
}
void UTMetadataDataExtensionMessage::setTotalSize(size_t totalSize)
{
totalSize_ = totalSize;
}
size_t UTMetadataDataExtensionMessage::getTotalSize() const
{
return totalSize_;
}
void UTMetadataDataExtensionMessage::setData(const std::string& data)
{
data_ = data;
}
const std::string& UTMetadataDataExtensionMessage::getData() const
{
return data_;
}
void UTMetadataDataExtensionMessage::setPieceStorage
(const std::shared_ptr<PieceStorage>& pieceStorage)
(PieceStorage* pieceStorage)
{
pieceStorage_ = pieceStorage;
}
void UTMetadataDataExtensionMessage::setDownloadContext
(const std::shared_ptr<DownloadContext>& dctx)
void UTMetadataDataExtensionMessage::setUTMetadataRequestTracker
(UTMetadataRequestTracker* tracker)
{
tracker_ = tracker;
}
void UTMetadataDataExtensionMessage::setDownloadContext(DownloadContext* dctx)
{
dctx_ = dctx;
}

View File

@ -51,31 +51,23 @@ private:
std::string data_;
std::shared_ptr<DownloadContext> dctx_;
DownloadContext* dctx_;
std::shared_ptr<PieceStorage> pieceStorage_;
PieceStorage* pieceStorage_;
UTMetadataRequestTracker* tracker_;
public:
UTMetadataDataExtensionMessage(uint8_t extensionMessageID);
~UTMetadataDataExtensionMessage();
virtual std::string getPayload();
virtual std::string toString() const;
virtual void doReceivedAction();
void setTotalSize(size_t totalSize)
{
totalSize_ = totalSize;
}
void setTotalSize(size_t totalSize);
size_t getTotalSize() const
{
return totalSize_;
}
size_t getTotalSize() const;
void setData(const std::string& data);
@ -85,19 +77,13 @@ public:
data_.assign(first, last);
}
const std::string& getData() const
{
return data_;
}
const std::string& getData() const;
void setPieceStorage(const std::shared_ptr<PieceStorage>& pieceStorage);
void setPieceStorage(PieceStorage* pieceStorage);
void setUTMetadataRequestTracker(UTMetadataRequestTracker* tracker)
{
tracker_ = tracker;
}
void setUTMetadataRequestTracker(UTMetadataRequestTracker* tracker);
void setDownloadContext(const std::shared_ptr<DownloadContext>& dctx);
void setDownloadContext(DownloadContext* dctx);
};
} // namespace aria2

View File

@ -41,8 +41,9 @@
namespace aria2 {
UTMetadataRejectExtensionMessage::UTMetadataRejectExtensionMessage
(uint8_t extensionMessageID):
UTMetadataExtensionMessage(extensionMessageID) {}
(uint8_t extensionMessageID)
: UTMetadataExtensionMessage{extensionMessageID}
{}
std::string UTMetadataRejectExtensionMessage::getPayload()
{

View File

@ -54,13 +54,13 @@
namespace aria2 {
UTMetadataRequestExtensionMessage::UTMetadataRequestExtensionMessage
(uint8_t extensionMessageID):UTMetadataExtensionMessage(extensionMessageID),
dispatcher_(0),
messageFactory_(0)
(uint8_t extensionMessageID)
: UTMetadataExtensionMessage{extensionMessageID},
dctx_{nullptr},
dispatcher_{nullptr},
messageFactory_{nullptr}
{}
UTMetadataRequestExtensionMessage::~UTMetadataRequestExtensionMessage() {}
std::string UTMetadataRequestExtensionMessage::getPayload()
{
Dict dict;
@ -81,25 +81,22 @@ void UTMetadataRequestExtensionMessage::doReceivedAction()
uint8_t id = peer_->getExtensionMessageID
(ExtensionMessageRegistry::UT_METADATA);
if(attrs->metadata.empty()) {
std::shared_ptr<UTMetadataRejectExtensionMessage> m
(new UTMetadataRejectExtensionMessage(id));
auto m = make_unique<UTMetadataRejectExtensionMessage>(id);
m->setIndex(getIndex());
dispatcher_->addMessageToQueue
(messageFactory_->createBtExtendedMessage(m));
(messageFactory_->createBtExtendedMessage(std::move(m)));
}else if(getIndex()*METADATA_PIECE_SIZE < attrs->metadataSize) {
std::shared_ptr<UTMetadataDataExtensionMessage> m
(new UTMetadataDataExtensionMessage(id));
auto m = make_unique<UTMetadataDataExtensionMessage>(id);
m->setIndex(getIndex());
m->setTotalSize(attrs->metadataSize);
std::string::const_iterator begin =
attrs->metadata.begin()+getIndex()*METADATA_PIECE_SIZE;
std::string::const_iterator end =
auto begin = std::begin(attrs->metadata)+getIndex()*METADATA_PIECE_SIZE;
auto end =
(getIndex()+1)*METADATA_PIECE_SIZE <= attrs->metadata.size()?
attrs->metadata.begin()+(getIndex()+1)*METADATA_PIECE_SIZE:
attrs->metadata.end();
std::begin(attrs->metadata)+(getIndex()+1)*METADATA_PIECE_SIZE:
std::end(attrs->metadata);
m->setData(begin, end);
dispatcher_->addMessageToQueue
(messageFactory_->createBtExtendedMessage(m));
(messageFactory_->createBtExtendedMessage(std::move(m)));
} else {
throw DL_ABORT_EX
(fmt("Metadata piece index is too big. piece=%lu",
@ -108,14 +105,27 @@ void UTMetadataRequestExtensionMessage::doReceivedAction()
}
void UTMetadataRequestExtensionMessage::setDownloadContext
(const std::shared_ptr<DownloadContext>& dctx)
(DownloadContext* dctx)
{
dctx_ = dctx;
}
void UTMetadataRequestExtensionMessage::setPeer(const std::shared_ptr<Peer>& peer)
void UTMetadataRequestExtensionMessage::setPeer
(const std::shared_ptr<Peer>& peer)
{
peer_ = peer;
}
void UTMetadataRequestExtensionMessage::setBtMessageDispatcher
(BtMessageDispatcher* disp)
{
dispatcher_ = disp;
}
void UTMetadataRequestExtensionMessage::setBtMessageFactory
(BtMessageFactory* factory)
{
messageFactory_ = factory;
}
} // namespace aria2

View File

@ -48,7 +48,7 @@ class Peer;
class UTMetadataRequestExtensionMessage:public UTMetadataExtensionMessage {
private:
std::shared_ptr<DownloadContext> dctx_;
DownloadContext* dctx_;
std::shared_ptr<Peer> peer_;
@ -58,25 +58,17 @@ private:
public:
UTMetadataRequestExtensionMessage(uint8_t extensionMessageID);
~UTMetadataRequestExtensionMessage();
virtual std::string getPayload();
virtual std::string toString() const;
virtual void doReceivedAction();
void setDownloadContext(const std::shared_ptr<DownloadContext>& dctx);
void setDownloadContext(DownloadContext* dctx);
void setBtMessageDispatcher(BtMessageDispatcher* disp)
{
dispatcher_ = disp;
}
void setBtMessageDispatcher(BtMessageDispatcher* disp);
void setBtMessageFactory(BtMessageFactory* factory)
{
messageFactory_ = factory;
}
void setBtMessageFactory(BtMessageFactory* factory);
void setPeer(const std::shared_ptr<Peer>& peer);
};

View File

@ -50,14 +50,15 @@
namespace aria2 {
UTMetadataRequestFactory::UTMetadataRequestFactory()
: dispatcher_{nullptr},
: dctx_{nullptr},
dispatcher_{nullptr},
messageFactory_{nullptr},
tracker_{nullptr},
cuid_(0)
cuid_{0}
{}
std::vector<std::unique_ptr<BtMessage>> UTMetadataRequestFactory::create
(size_t num, const std::shared_ptr<PieceStorage>& pieceStorage)
(size_t num, PieceStorage* pieceStorage)
{
auto msgs = std::vector<std::unique_ptr<BtMessage>>{};
while(num) {
@ -70,16 +71,15 @@ std::vector<std::unique_ptr<BtMessage>> UTMetadataRequestFactory::create
--num;
A2_LOG_DEBUG(fmt("Creating ut_metadata request index=%lu",
static_cast<unsigned long>(p->getIndex())));
std::shared_ptr<UTMetadataRequestExtensionMessage> m
(new UTMetadataRequestExtensionMessage
(peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_METADATA)));
auto m = make_unique<UTMetadataRequestExtensionMessage>
(peer_->getExtensionMessageID(ExtensionMessageRegistry::UT_METADATA));
m->setIndex(p->getIndex());
m->setDownloadContext(dctx_);
m->setBtMessageDispatcher(dispatcher_);
m->setBtMessageFactory(messageFactory_);
m->setPeer(peer_);
msgs.push_back(messageFactory_->createBtExtendedMessage(m));
msgs.push_back(messageFactory_->createBtExtendedMessage(std::move(m)));
tracker_->add(p->getIndex());
}
return msgs;

View File

@ -54,7 +54,7 @@ class BtMessage;
class UTMetadataRequestFactory {
private:
std::shared_ptr<DownloadContext> dctx_;
DownloadContext* dctx_;
std::shared_ptr<Peer> peer_;
@ -70,9 +70,9 @@ public:
// Creates and returns at most num of ut_metadata request
// message. pieceStorage is used to identify missing piece.
std::vector<std::unique_ptr<BtMessage>> create
(size_t num, const std::shared_ptr<PieceStorage>& pieceStorage);
(size_t num, PieceStorage* pieceStorage);
void setDownloadContext(const std::shared_ptr<DownloadContext>& dctx)
void setDownloadContext(DownloadContext* dctx)
{
dctx_ = dctx;
}

View File

@ -57,23 +57,18 @@ const size_t DEFAULT_MAX_DROPPED_PEER = 50;
const char UTPexExtensionMessage::EXTENSION_NAME[] = "ut_pex";
UTPexExtensionMessage::UTPexExtensionMessage(uint8_t extensionMessageID):
extensionMessageID_(extensionMessageID),
interval_(DEFAULT_INTERVAL),
maxFreshPeer_(DEFAULT_MAX_FRESH_PEER),
maxDroppedPeer_(DEFAULT_MAX_DROPPED_PEER) {}
UTPexExtensionMessage::~UTPexExtensionMessage() {}
UTPexExtensionMessage::UTPexExtensionMessage(uint8_t extensionMessageID)
: extensionMessageID_{extensionMessageID},
peerStorage_{nullptr},
interval_{DEFAULT_INTERVAL},
maxFreshPeer_{DEFAULT_MAX_FRESH_PEER},
maxDroppedPeer_{DEFAULT_MAX_DROPPED_PEER}
{}
std::string UTPexExtensionMessage::getPayload()
{
std::pair<std::pair<std::string, std::string>,
std::pair<std::string, std::string> > freshPeerPair =
createCompactPeerListAndFlag(freshPeers_);
std::pair<std::pair<std::string, std::string>,
std::pair<std::string, std::string> > droppedPeerPair =
createCompactPeerListAndFlag(droppedPeers_);
auto freshPeerPair = createCompactPeerListAndFlag(freshPeers_);
auto droppedPeerPair = createCompactPeerListAndFlag(droppedPeers_);
Dict dict;
if(!freshPeerPair.first.first.empty()) {
dict.put("added", freshPeerPair.first.first);
@ -102,8 +97,7 @@ UTPexExtensionMessage::createCompactPeerListAndFlag
std::string flagstring;
std::string addrstring6;
std::string flagstring6;
for(std::vector<std::shared_ptr<Peer> >::const_iterator itr = peers.begin(),
eoi = peers.end(); itr != eoi; ++itr) {
for(auto itr = std::begin(peers), eoi = std::end(peers); itr != eoi; ++itr) {
unsigned char compact[COMPACT_LEN_IPV6];
int compactlen = bittorrent::packcompact
(compact, (*itr)->getIPAddress(), (*itr)->getPort());
@ -115,8 +109,10 @@ UTPexExtensionMessage::createCompactPeerListAndFlag
flagstring6 += (*itr)->isSeeder() ? 0x02u : 0x00u;
}
}
return std::make_pair(std::make_pair(addrstring, flagstring),
std::make_pair(addrstring6, flagstring6));
return std::make_pair(std::make_pair(std::move(addrstring),
std::move(flagstring)),
std::make_pair(std::move(addrstring6),
std::move(flagstring6)));
}
std::string UTPexExtensionMessage::toString() const
@ -143,6 +139,12 @@ bool UTPexExtensionMessage::addFreshPeer(const std::shared_ptr<Peer>& peer)
}
}
const std::vector<std::shared_ptr<Peer>>&
UTPexExtensionMessage::getFreshPeers() const
{
return freshPeers_;
}
bool UTPexExtensionMessage::freshPeersAreFull() const
{
return freshPeers_.size() >= maxFreshPeer_;
@ -159,6 +161,12 @@ bool UTPexExtensionMessage::addDroppedPeer(const std::shared_ptr<Peer>& peer)
}
}
const std::vector<std::shared_ptr<Peer>>&
UTPexExtensionMessage::getDroppedPeers() const
{
return droppedPeers_;
}
bool UTPexExtensionMessage::droppedPeersAreFull() const
{
return droppedPeers_.size() >= maxDroppedPeer_;
@ -174,13 +182,12 @@ void UTPexExtensionMessage::setMaxDroppedPeer(size_t maxDroppedPeer)
maxDroppedPeer_ = maxDroppedPeer;
}
void UTPexExtensionMessage::setPeerStorage
(const std::shared_ptr<PeerStorage>& peerStorage)
void UTPexExtensionMessage::setPeerStorage(PeerStorage* peerStorage)
{
peerStorage_ = peerStorage;
}
UTPexExtensionMessage*
std::unique_ptr<UTPexExtensionMessage>
UTPexExtensionMessage::create(const unsigned char* data, size_t len)
{
if(len < 1) {
@ -188,9 +195,9 @@ UTPexExtensionMessage::create(const unsigned char* data, size_t len)
EXTENSION_NAME,
static_cast<unsigned long>(len)));
}
UTPexExtensionMessage* msg(new UTPexExtensionMessage(*data));
auto msg = make_unique<UTPexExtensionMessage>(*data);
std::shared_ptr<ValueBase> decoded = bencode2::decode(data+1, len - 1);
auto decoded = bencode2::decode(data+1, len - 1);
const Dict* dict = downcast<Dict>(decoded);
if(dict) {
const String* added = downcast<String>(dict->get("added"));

View File

@ -56,7 +56,7 @@ private:
std::vector<std::shared_ptr<Peer> > droppedPeers_;
std::shared_ptr<PeerStorage> peerStorage_;
PeerStorage* peerStorage_;
time_t interval_;
@ -65,14 +65,12 @@ private:
size_t maxDroppedPeer_;
std::pair<std::pair<std::string, std::string>,
std::pair<std::string, std::string> >
createCompactPeerListAndFlag(const std::vector<std::shared_ptr<Peer> >& peers);
std::pair<std::string, std::string>>
createCompactPeerListAndFlag(const std::vector<std::shared_ptr<Peer>>& peers);
public:
UTPexExtensionMessage(uint8_t extensionMessageID);
virtual ~UTPexExtensionMessage();
virtual std::string getPayload();
virtual uint8_t getExtensionMessageID() const
@ -93,25 +91,19 @@ public:
bool addFreshPeer(const std::shared_ptr<Peer>& peer);
const std::vector<std::shared_ptr<Peer> >& getFreshPeers() const
{
return freshPeers_;
}
const std::vector<std::shared_ptr<Peer>>& getFreshPeers() const;
bool freshPeersAreFull() const;
bool addDroppedPeer(const std::shared_ptr<Peer>& peer);
const std::vector<std::shared_ptr<Peer> >& getDroppedPeers() const
{
return droppedPeers_;
}
const std::vector<std::shared_ptr<Peer>>& getDroppedPeers() const;
bool droppedPeersAreFull() const;
void setPeerStorage(const std::shared_ptr<PeerStorage>& peerStorage);
void setPeerStorage(PeerStorage* peerStorage);
static UTPexExtensionMessage*
static std::unique_ptr<UTPexExtensionMessage>
create(const unsigned char* data, size_t len);
void setMaxFreshPeer(size_t maxFreshPeer);

View File

@ -33,11 +33,9 @@ public:
CPPUNIT_TEST_SUITE_REGISTRATION(BtExtendedMessageTest);
void BtExtendedMessageTest::testCreate() {
std::shared_ptr<Peer> peer(new Peer("192.168.0.1", 6969));
auto peer = std::make_shared<Peer>("192.168.0.1", 6969);
peer->allocateSessionResource(1024, 1024*1024);
std::shared_ptr<MockExtensionMessageFactory> exmsgFactory
(new MockExtensionMessageFactory());
auto exmsgFactory = MockExtensionMessageFactory{};
// payload:{4:name3:foo}->11bytes
std::string payload = "4:name3:foo";
@ -45,16 +43,14 @@ void BtExtendedMessageTest::testCreate() {
bittorrent::createPeerMessageString((unsigned char*)msg, sizeof(msg), 13, 20);
msg[5] = 1; // Set dummy extended message ID 1
memcpy(msg+6, payload.c_str(), payload.size());
std::shared_ptr<BtExtendedMessage> pm(BtExtendedMessage::create(exmsgFactory,
peer,
&msg[4], 13));
auto pm = BtExtendedMessage::create(&exmsgFactory, peer, &msg[4], 13);
CPPUNIT_ASSERT_EQUAL((uint8_t)20, pm->getId());
// case: payload size is wrong
try {
unsigned char msg[5];
bittorrent::createPeerMessageString(msg, sizeof(msg), 1, 20);
BtExtendedMessage::create(exmsgFactory, peer, &msg[4], 1);
BtExtendedMessage::create(&exmsgFactory, peer, &msg[4], 1);
CPPUNIT_FAIL("exception must be thrown.");
} catch(Exception& e) {
std::cerr << e.stackTrace() << std::endl;
@ -63,7 +59,7 @@ void BtExtendedMessageTest::testCreate() {
try {
unsigned char msg[6];
bittorrent::createPeerMessageString(msg, sizeof(msg), 2, 21);
BtExtendedMessage::create(exmsgFactory, peer, &msg[4], 2);
BtExtendedMessage::create(&exmsgFactory, peer, &msg[4], 2);
CPPUNIT_FAIL("exception must be thrown.");
} catch(Exception& e) {
std::cerr << e.stackTrace() << std::endl;
@ -73,10 +69,8 @@ void BtExtendedMessageTest::testCreate() {
void BtExtendedMessageTest::testCreateMessage() {
std::string payload = "4:name3:foo";
uint8_t extendedMessageID = 1;
std::shared_ptr<MockExtensionMessage> exmsg
(new MockExtensionMessage("charlie", extendedMessageID, payload));
BtExtendedMessage msg(exmsg);
BtExtendedMessage msg{make_unique<MockExtensionMessage>
("charlie", extendedMessageID, payload, nullptr)};
unsigned char data[17];
bittorrent::createPeerMessageString(data, sizeof(data), 13, 20);
*(data+5) = extendedMessageID;
@ -88,19 +82,18 @@ void BtExtendedMessageTest::testCreateMessage() {
}
void BtExtendedMessageTest::testDoReceivedAction() {
std::shared_ptr<MockExtensionMessage> exmsg
(new MockExtensionMessage("charlie", 1, ""));
BtExtendedMessage msg(exmsg);
auto evcheck = MockExtensionMessageEventCheck{};
BtExtendedMessage msg{make_unique<MockExtensionMessage>
("charlie", 1, "", &evcheck)};
msg.doReceivedAction();
CPPUNIT_ASSERT(exmsg->doReceivedActionCalled_);
CPPUNIT_ASSERT(evcheck.doReceivedActionCalled);
}
void BtExtendedMessageTest::testToString() {
std::string payload = "4:name3:foo";
uint8_t extendedMessageID = 1;
std::shared_ptr<MockExtensionMessage> exmsg
(new MockExtensionMessage("charlie", extendedMessageID, payload));
BtExtendedMessage msg(exmsg);
BtExtendedMessage msg{make_unique<MockExtensionMessage>
("charlie", extendedMessageID, payload, nullptr)};
CPPUNIT_ASSERT_EQUAL(std::string("extended charlie"), msg.toString());
}

View File

@ -42,7 +42,7 @@ public:
factory_ = make_unique<DefaultBtMessageFactory>();
factory_->setDownloadContext(dctx_.get());
factory_->setPeer(peer_);
factory_->setExtensionMessageFactory(exmsgFactory_);
factory_->setExtensionMessageFactory(exmsgFactory_.get());
}
void testCreateBtMessage_BtExtendedMessage();

View File

@ -38,42 +38,38 @@ class DefaultExtensionMessageFactoryTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testCreateMessage_UTMetadataReject);
CPPUNIT_TEST_SUITE_END();
private:
std::shared_ptr<MockPeerStorage> peerStorage_;
std::unique_ptr<MockPeerStorage> peerStorage_;
std::shared_ptr<Peer> peer_;
std::shared_ptr<DefaultExtensionMessageFactory> factory_;
std::shared_ptr<ExtensionMessageRegistry> registry_;
std::shared_ptr<MockBtMessageDispatcher> dispatcher_;
std::shared_ptr<MockBtMessageFactory> messageFactory_;
std::unique_ptr<DefaultExtensionMessageFactory> factory_;
std::unique_ptr<ExtensionMessageRegistry> registry_;
std::unique_ptr<MockBtMessageDispatcher> dispatcher_;
std::unique_ptr<MockBtMessageFactory> messageFactory_;
std::shared_ptr<DownloadContext> dctx_;
std::shared_ptr<RequestGroup> requestGroup_;
std::unique_ptr<RequestGroup> requestGroup_;
public:
void setUp()
{
peerStorage_.reset(new MockPeerStorage());
peerStorage_ = make_unique<MockPeerStorage>();
peer_.reset(new Peer("192.168.0.1", 6969));
peer_ = std::make_shared<Peer>("192.168.0.1", 6969);
peer_->allocateSessionResource(1024, 1024*1024);
peer_->setExtension(ExtensionMessageRegistry::UT_PEX, 1);
registry_.reset(new ExtensionMessageRegistry());
dispatcher_.reset(new MockBtMessageDispatcher());
messageFactory_.reset(new MockBtMessageFactory());
dctx_.reset(new DownloadContext());
std::shared_ptr<Option> option(new Option());
requestGroup_.reset(new RequestGroup(GroupId::create(), option));
registry_ = make_unique<ExtensionMessageRegistry>();
dispatcher_ = make_unique<MockBtMessageDispatcher>();
messageFactory_ = make_unique<MockBtMessageFactory>();
dctx_ = std::make_shared<DownloadContext>();
auto option = std::make_shared<Option>();
requestGroup_ = make_unique<RequestGroup>(GroupId::create(), option);
requestGroup_->setDownloadContext(dctx_);
factory_.reset(new DefaultExtensionMessageFactory());
factory_->setPeerStorage(peerStorage_);
factory_ = make_unique<DefaultExtensionMessageFactory>();
factory_->setPeerStorage(peerStorage_.get());
factory_->setPeer(peer_);
factory_->setExtensionMessageRegistry(registry_);
factory_->setExtensionMessageRegistry(registry_.get());
factory_->setBtMessageDispatcher(dispatcher_.get());
factory_->setBtMessageFactory(messageFactory_.get());
factory_->setDownloadContext(dctx_);
factory_->setDownloadContext(dctx_.get());
}
std::string getExtensionMessageID(int key)
@ -85,9 +81,10 @@ public:
template<typename T>
std::shared_ptr<T> createMessage(const std::string& data)
{
return std::dynamic_pointer_cast<T>
(factory_->createMessage
(reinterpret_cast<const unsigned char*>(data.c_str()), data.size()));
auto m = factory_->createMessage
(reinterpret_cast<const unsigned char*>(data.c_str()), data.size());
return std::dynamic_pointer_cast<T>(std::shared_ptr<T>
{static_cast<T*>(m.release())});
}
void testCreateMessage_unknown();
@ -98,7 +95,6 @@ public:
void testCreateMessage_UTMetadataReject();
};
CPPUNIT_TEST_SUITE_REGISTRATION(DefaultExtensionMessageFactoryTest);
void DefaultExtensionMessageFactoryTest::testCreateMessage_unknown()
@ -123,8 +119,7 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_Handshake()
char id[1] = { 0 };
std::string data = std::string(&id[0], &id[1])+"d1:v5:aria2e";
std::shared_ptr<HandshakeExtensionMessage> m =
createMessage<HandshakeExtensionMessage>(data);
auto m = createMessage<HandshakeExtensionMessage>(data);
CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion());
}
@ -148,8 +143,7 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTPex()
std::string(&c3[0], &c3[6])+std::string(&c4[0], &c4[6])+
"e";
std::shared_ptr<UTPexExtensionMessage> m =
createMessage<UTPexExtensionMessage>(data);
auto m = createMessage<UTPexExtensionMessage>(data);
CPPUNIT_ASSERT_EQUAL(registry_->getExtensionMessageID
(ExtensionMessageRegistry::UT_PEX),
m->getExtensionMessageID());
@ -162,8 +156,7 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataRequest()
std::string data = getExtensionMessageID
(ExtensionMessageRegistry::UT_METADATA)+
"d8:msg_typei0e5:piecei1ee";
std::shared_ptr<UTMetadataRequestExtensionMessage> m =
createMessage<UTMetadataRequestExtensionMessage>(data);
auto m = createMessage<UTMetadataRequestExtensionMessage>(data);
CPPUNIT_ASSERT_EQUAL((size_t)1, m->getIndex());
}
@ -174,8 +167,7 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataData()
std::string data = getExtensionMessageID
(ExtensionMessageRegistry::UT_METADATA)+
"d8:msg_typei1e5:piecei1e10:total_sizei300ee0000000000";
std::shared_ptr<UTMetadataDataExtensionMessage> m =
createMessage<UTMetadataDataExtensionMessage>(data);
auto m = createMessage<UTMetadataDataExtensionMessage>(data);
CPPUNIT_ASSERT_EQUAL((size_t)1, m->getIndex());
CPPUNIT_ASSERT_EQUAL((size_t)300, m->getTotalSize());
CPPUNIT_ASSERT_EQUAL(std::string(10, '0'), m->getData());
@ -188,8 +180,7 @@ void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataReject()
std::string data = getExtensionMessageID
(ExtensionMessageRegistry::UT_METADATA)+
"d8:msg_typei2e5:piecei1ee";
std::shared_ptr<UTMetadataRejectExtensionMessage> m =
createMessage<UTMetadataRejectExtensionMessage>(data);
auto m = createMessage<UTMetadataRejectExtensionMessage>(data);
CPPUNIT_ASSERT_EQUAL((size_t)1, m->getIndex());
}

View File

@ -91,16 +91,15 @@ void HandshakeExtensionMessageTest::testToString()
void HandshakeExtensionMessageTest::testDoReceivedAction()
{
std::shared_ptr<DownloadContext> dctx
(new DownloadContext(METADATA_PIECE_SIZE, 0));
std::shared_ptr<Option> op(new Option());
auto dctx = std::make_shared<DownloadContext>(METADATA_PIECE_SIZE, 0);
auto op = std::make_shared<Option>();
RequestGroup rg(GroupId::create(), op);
rg.setDownloadContext(dctx);
dctx->setAttribute(CTX_ATTR_BT, make_unique<TorrentAttribute>());
dctx->markTotalLengthIsUnknown();
std::shared_ptr<Peer> peer(new Peer("192.168.0.1", 0));
auto peer = std::make_shared<Peer>("192.168.0.1", 0);
peer->allocateSessionResource(1024, 1024*1024);
HandshakeExtensionMessage msg;
msg.setClientVersion("aria2");
@ -109,7 +108,7 @@ void HandshakeExtensionMessageTest::testDoReceivedAction()
msg.setExtension(ExtensionMessageRegistry::UT_METADATA, 3);
msg.setMetadataSize(1024);
msg.setPeer(peer);
msg.setDownloadContext(dctx);
msg.setDownloadContext(dctx.get());
msg.doReceivedAction();

View File

@ -20,6 +20,7 @@
#include "BtAllowedFastMessage.h"
#include "BtPortMessage.h"
#include "BtExtendedMessage.h"
#include "ExtensionMessage.h"
namespace aria2 {
@ -114,7 +115,7 @@ public:
}
virtual std::unique_ptr<BtExtendedMessage>
createBtExtendedMessage(const std::shared_ptr<ExtensionMessage>& extmsg)
createBtExtendedMessage(std::unique_ptr<ExtensionMessage> extmsg)
{
return std::unique_ptr<BtExtendedMessage>{};
}

View File

@ -5,30 +5,38 @@
namespace aria2 {
struct MockExtensionMessageEventCheck {
MockExtensionMessageEventCheck() : doReceivedActionCalled{false}
{}
bool doReceivedActionCalled;
};
class MockExtensionMessage:public ExtensionMessage {
public:
std::string extensionName_;
uint8_t extensionMessageID_;
std::string data_;
bool doReceivedActionCalled_;
public:
MockExtensionMessageEventCheck* evcheck_;
MockExtensionMessage(const std::string& extensionName,
uint8_t extensionMessageID,
const unsigned char* data,
size_t length)
size_t length,
MockExtensionMessageEventCheck* evcheck)
: extensionName_{extensionName},
extensionMessageID_{extensionMessageID},
data_{&data[0], &data[length]},
doReceivedActionCalled_{false}
evcheck_{evcheck}
{}
MockExtensionMessage(const std::string& extensionName,
uint8_t extensionMessageID,
const std::string& data)
const std::string& data,
MockExtensionMessageEventCheck* evcheck)
: extensionName_{extensionName},
extensionMessageID_{extensionMessageID},
data_{data},
doReceivedActionCalled_{false}
evcheck_{evcheck}
{}
virtual std::string getPayload()
@ -53,7 +61,9 @@ public:
virtual void doReceivedAction()
{
doReceivedActionCalled_ = true;
if(evcheck_) {
evcheck_->doReceivedActionCalled = true;
}
}
};

View File

@ -10,11 +10,11 @@ class MockExtensionMessageFactory:public ExtensionMessageFactory {
public:
virtual ~MockExtensionMessageFactory() {}
virtual std::shared_ptr<ExtensionMessage> createMessage(const unsigned char* data,
size_t length)
virtual std::unique_ptr<ExtensionMessage>
createMessage(const unsigned char* data, size_t length)
{
return std::shared_ptr<ExtensionMessage>
(new MockExtensionMessage("a2_mock", *data, data+1, length-1));
return make_unique<MockExtensionMessage>
("a2_mock", *data, data+1, length-1, nullptr);
}
};

View File

@ -65,14 +65,13 @@ void UTMetadataDataExtensionMessageTest::testToString()
void UTMetadataDataExtensionMessageTest::testDoReceivedAction()
{
std::shared_ptr<DirectDiskAdaptor> diskAdaptor(new DirectDiskAdaptor());
std::shared_ptr<ByteArrayDiskWriter> diskWriter(new ByteArrayDiskWriter());
auto diskAdaptor = std::make_shared<DirectDiskAdaptor>();
auto diskWriter = std::make_shared<ByteArrayDiskWriter>();
diskAdaptor->setDiskWriter(diskWriter);
std::shared_ptr<MockPieceStorage> pieceStorage(new MockPieceStorage());
auto pieceStorage = make_unique<MockPieceStorage>();
pieceStorage->setDiskAdaptor(diskAdaptor);
std::shared_ptr<UTMetadataRequestTracker> tracker
(new UTMetadataRequestTracker());
std::shared_ptr<DownloadContext> dctx(new DownloadContext());
auto tracker = make_unique<UTMetadataRequestTracker>();
auto dctx = make_unique<DownloadContext>();
std::string piece0 = std::string(METADATA_PIECE_SIZE, '0');
std::string piece1 = std::string(METADATA_PIECE_SIZE, '1');
@ -88,9 +87,9 @@ void UTMetadataDataExtensionMessageTest::testDoReceivedAction()
dctx->setAttribute(CTX_ATTR_BT, std::move(attrs));
}
UTMetadataDataExtensionMessage m(1);
m.setPieceStorage(pieceStorage);
m.setPieceStorage(pieceStorage.get());
m.setUTMetadataRequestTracker(tracker.get());
m.setDownloadContext(dctx);
m.setDownloadContext(dctx.get());
m.setIndex(1);
m.setData(piece1);

View File

@ -31,7 +31,7 @@ class UTMetadataRequestExtensionMessageTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testDoReceivedAction_data);
CPPUNIT_TEST_SUITE_END();
public:
std::shared_ptr<DownloadContext> dctx_;
std::unique_ptr<DownloadContext> dctx_;
std::unique_ptr<WrapExtBtMessageFactory> messageFactory_;
std::unique_ptr<MockBtMessageDispatcher> dispatcher_;
std::shared_ptr<Peer> peer_;
@ -40,7 +40,7 @@ public:
{
messageFactory_ = make_unique<WrapExtBtMessageFactory>();
dispatcher_ = make_unique<MockBtMessageDispatcher>();
dctx_ = std::make_shared<DownloadContext>();
dctx_ = make_unique<DownloadContext>();
dctx_->setAttribute(CTX_ATTR_BT, make_unique<TorrentAttribute>());
peer_ = std::make_shared<Peer>("host", 6880);
peer_->allocateSessionResource(0, 0);
@ -101,7 +101,7 @@ void UTMetadataRequestExtensionMessageTest::testDoReceivedAction_reject()
{
UTMetadataRequestExtensionMessage msg(1);
msg.setIndex(10);
msg.setDownloadContext(dctx_);
msg.setDownloadContext(dctx_.get());
msg.setPeer(peer_);
msg.setBtMessageFactory(messageFactory_.get());
msg.setBtMessageDispatcher(dispatcher_.get());
@ -118,13 +118,13 @@ void UTMetadataRequestExtensionMessageTest::testDoReceivedAction_data()
{
UTMetadataRequestExtensionMessage msg(1);
msg.setIndex(1);
msg.setDownloadContext(dctx_);
msg.setDownloadContext(dctx_.get());
msg.setPeer(peer_);
msg.setBtMessageFactory(messageFactory_.get());
msg.setBtMessageDispatcher(dispatcher_.get());
size_t metadataSize = METADATA_PIECE_SIZE*2;
auto attrs = bittorrent::getTorrentAttrs(dctx_);
auto attrs = bittorrent::getTorrentAttrs(dctx_.get());
std::string first(METADATA_PIECE_SIZE, '0');
std::string second(METADATA_PIECE_SIZE, '1');
attrs->metadata = first+second;

View File

@ -38,7 +38,7 @@ public:
} else {
size_t index = missingIndexes.front();
missingIndexes.pop_front();
return std::shared_ptr<Piece>(new Piece(index, 0));
return std::make_shared<Piece>(index, 0);
}
}
};
@ -50,29 +50,26 @@ CPPUNIT_TEST_SUITE_REGISTRATION(UTMetadataRequestFactoryTest);
void UTMetadataRequestFactoryTest::testCreate()
{
UTMetadataRequestFactory factory;
std::shared_ptr<DownloadContext> dctx
(new DownloadContext(METADATA_PIECE_SIZE, METADATA_PIECE_SIZE*2));
factory.setDownloadContext(dctx);
std::shared_ptr<MockPieceStorage2> ps(new MockPieceStorage2());
ps->missingIndexes.push_back(0);
ps->missingIndexes.push_back(1);
std::shared_ptr<WrapExtBtMessageFactory> messageFactory
(new WrapExtBtMessageFactory());
factory.setBtMessageFactory(messageFactory.get());
std::shared_ptr<Peer> peer(new Peer("peer", 6880));
DownloadContext dctx{METADATA_PIECE_SIZE, METADATA_PIECE_SIZE*2};
factory.setDownloadContext(&dctx);
MockPieceStorage2 ps;
ps.missingIndexes.push_back(0);
ps.missingIndexes.push_back(1);
WrapExtBtMessageFactory messageFactory;
factory.setBtMessageFactory(&messageFactory);
auto peer = std::make_shared<Peer>("peer", 6880);
peer->allocateSessionResource(0, 0);
factory.setPeer(peer);
std::shared_ptr<UTMetadataRequestTracker> tracker
(new UTMetadataRequestTracker());
factory.setUTMetadataRequestTracker(tracker.get());
UTMetadataRequestTracker tracker;
factory.setUTMetadataRequestTracker(&tracker);
auto msgs = factory.create(1, ps);
auto msgs = factory.create(1, &ps);
CPPUNIT_ASSERT_EQUAL((size_t)1, msgs.size());
msgs = factory.create(1, ps);
msgs = factory.create(1, &ps);
CPPUNIT_ASSERT_EQUAL((size_t)1, msgs.size());
msgs = factory.create(1, ps);
msgs = factory.create(1, &ps);
CPPUNIT_ASSERT_EQUAL((size_t)0, msgs.size());
}

View File

@ -31,11 +31,11 @@ class UTPexExtensionMessageTest:public CppUnit::TestFixture {
CPPUNIT_TEST(testDroppedPeersAreFull);
CPPUNIT_TEST_SUITE_END();
private:
std::shared_ptr<MockPeerStorage> peerStorage_;
std::unique_ptr<MockPeerStorage> peerStorage_;
public:
void setUp()
{
peerStorage_.reset(new MockPeerStorage());
peerStorage_ = make_unique<MockPeerStorage>();
global::wallclock().reset();
}
@ -70,23 +70,23 @@ void UTPexExtensionMessageTest::testGetExtensionName()
void UTPexExtensionMessageTest::testGetBencodedData()
{
UTPexExtensionMessage msg(1);
std::shared_ptr<Peer> p1(new Peer("192.168.0.1", 6881));
auto p1 = std::make_shared<Peer>("192.168.0.1", 6881);
p1->allocateSessionResource(256*1024, 1024*1024);
p1->setAllBitfield();
CPPUNIT_ASSERT(msg.addFreshPeer(p1));// added seeder, check add.f flag
std::shared_ptr<Peer> p2(new Peer("10.1.1.2", 9999));
auto p2 = std::make_shared<Peer>("10.1.1.2", 9999);
CPPUNIT_ASSERT(msg.addFreshPeer(p2));
std::shared_ptr<Peer> p3(new Peer("192.168.0.2", 6882));
auto p3 = std::make_shared<Peer>("192.168.0.2", 6882);
p3->startDrop();
CPPUNIT_ASSERT(msg.addDroppedPeer(p3));
std::shared_ptr<Peer> p4(new Peer("10.1.1.3", 10000));
auto p4 = std::make_shared<Peer>("10.1.1.3", 10000);
p4->startDrop();
CPPUNIT_ASSERT(msg.addDroppedPeer(p4));
std::shared_ptr<Peer> p5(new Peer("1002:1035:4527:3546:7854:1237:3247:3217",
6881));
auto p5 = std::make_shared<Peer>("1002:1035:4527:3546:7854:1237:3247:3217",
6881);
CPPUNIT_ASSERT(msg.addFreshPeer(p5));
std::shared_ptr<Peer> p6(new Peer("2001:db8:bd05:1d2:288a:1fc0:1:10ee", 6882));
auto p6 = std::make_shared<Peer>("2001:db8:bd05:1d2:288a:1fc0:1:10ee", 6882);
p6->startDrop();
CPPUNIT_ASSERT(msg.addDroppedPeer(p6));
@ -150,7 +150,7 @@ void UTPexExtensionMessageTest::testDoReceivedAction()
std::shared_ptr<Peer> p4(new Peer("2001:db8:bd05:1d2:288a:1fc0:1:10ee", 10000));
p4->startDrop();
msg.addDroppedPeer(p4);
msg.setPeerStorage(peerStorage_);
msg.setPeerStorage(peerStorage_.get());
msg.doReceivedAction();
@ -203,9 +203,8 @@ void UTPexExtensionMessageTest::testCreate()
"8:dropped618:"+std::string(&c6[0], &c6[COMPACT_LEN_IPV6])+
"e";
std::shared_ptr<UTPexExtensionMessage> msg
(UTPexExtensionMessage::create
(reinterpret_cast<const unsigned char*>(data.c_str()), data.size()));
auto msg = UTPexExtensionMessage::create
(reinterpret_cast<const unsigned char*>(data.c_str()), data.size());
CPPUNIT_ASSERT_EQUAL((uint8_t)1, msg->getExtensionMessageID());
CPPUNIT_ASSERT_EQUAL((size_t)3, msg->getFreshPeers().size());
CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.1"),

View File

@ -8,10 +8,10 @@ namespace aria2 {
class WrapExtBtMessageFactory:public MockBtMessageFactory {
public:
virtual std::unique_ptr<BtExtendedMessage>
createBtExtendedMessage(const std::shared_ptr<ExtensionMessage>& extmsg)
createBtExtendedMessage(std::unique_ptr<ExtensionMessage> extmsg)
override
{
return make_unique<BtExtendedMessage>(extmsg);
return make_unique<BtExtendedMessage>(std::move(extmsg));
}
};