mirror of https://github.com/aria2/aria2
Return ExtensionMessage subclass create return raw pointer
parent
3258614033
commit
4b94ede268
|
@ -78,11 +78,11 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
|
||||||
uint8_t extensionMessageID = *data;
|
uint8_t extensionMessageID = *data;
|
||||||
if(extensionMessageID == 0) {
|
if(extensionMessageID == 0) {
|
||||||
// handshake
|
// handshake
|
||||||
SharedHandle<HandshakeExtensionMessage> m =
|
HandshakeExtensionMessage* m =
|
||||||
HandshakeExtensionMessage::create(data, length);
|
HandshakeExtensionMessage::create(data, length);
|
||||||
m->setPeer(peer_);
|
m->setPeer(peer_);
|
||||||
m->setDownloadContext(dctx_);
|
m->setDownloadContext(dctx_);
|
||||||
return m;
|
return SharedHandle<ExtensionMessage>(m);
|
||||||
} else {
|
} else {
|
||||||
const char* extensionName = registry_->getExtensionName(extensionMessageID);
|
const char* extensionName = registry_->getExtensionName(extensionMessageID);
|
||||||
if(!extensionName) {
|
if(!extensionName) {
|
||||||
|
@ -92,10 +92,9 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
|
||||||
}
|
}
|
||||||
if(strcmp(extensionName, "ut_pex") == 0) {
|
if(strcmp(extensionName, "ut_pex") == 0) {
|
||||||
// uTorrent compatible Peer-Exchange
|
// uTorrent compatible Peer-Exchange
|
||||||
SharedHandle<UTPexExtensionMessage> m =
|
UTPexExtensionMessage* m = UTPexExtensionMessage::create(data, length);
|
||||||
UTPexExtensionMessage::create(data, length);
|
|
||||||
m->setPeerStorage(peerStorage_);
|
m->setPeerStorage(peerStorage_);
|
||||||
return m;
|
return SharedHandle<ExtensionMessage>(m);
|
||||||
} else if(strcmp(extensionName, "ut_metadata") == 0) {
|
} else if(strcmp(extensionName, "ut_metadata") == 0) {
|
||||||
if(length == 0) {
|
if(length == 0) {
|
||||||
throw DL_ABORT_EX
|
throw DL_ABORT_EX
|
||||||
|
@ -120,14 +119,14 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
|
||||||
}
|
}
|
||||||
switch(msgType->i()) {
|
switch(msgType->i()) {
|
||||||
case 0: {
|
case 0: {
|
||||||
SharedHandle<UTMetadataRequestExtensionMessage> m
|
UTMetadataRequestExtensionMessage* m
|
||||||
(new UTMetadataRequestExtensionMessage(extensionMessageID));
|
(new UTMetadataRequestExtensionMessage(extensionMessageID));
|
||||||
m->setIndex(index->i());
|
m->setIndex(index->i());
|
||||||
m->setDownloadContext(dctx_);
|
m->setDownloadContext(dctx_);
|
||||||
m->setPeer(peer_);
|
m->setPeer(peer_);
|
||||||
m->setBtMessageFactory(messageFactory_);
|
m->setBtMessageFactory(messageFactory_);
|
||||||
m->setBtMessageDispatcher(dispatcher_);
|
m->setBtMessageDispatcher(dispatcher_);
|
||||||
return m;
|
return SharedHandle<ExtensionMessage>(m);
|
||||||
}
|
}
|
||||||
case 1: {
|
case 1: {
|
||||||
if(end == length) {
|
if(end == length) {
|
||||||
|
@ -137,7 +136,7 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
|
||||||
if(!totalSize) {
|
if(!totalSize) {
|
||||||
throw DL_ABORT_EX("Bad ut_metadata data: total_size not found");
|
throw DL_ABORT_EX("Bad ut_metadata data: total_size not found");
|
||||||
}
|
}
|
||||||
SharedHandle<UTMetadataDataExtensionMessage> m
|
UTMetadataDataExtensionMessage* m
|
||||||
(new UTMetadataDataExtensionMessage(extensionMessageID));
|
(new UTMetadataDataExtensionMessage(extensionMessageID));
|
||||||
m->setIndex(index->i());
|
m->setIndex(index->i());
|
||||||
m->setTotalSize(totalSize->i());
|
m->setTotalSize(totalSize->i());
|
||||||
|
@ -145,14 +144,14 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
|
||||||
m->setUTMetadataRequestTracker(tracker_);
|
m->setUTMetadataRequestTracker(tracker_);
|
||||||
m->setPieceStorage(dctx_->getOwnerRequestGroup()->getPieceStorage());
|
m->setPieceStorage(dctx_->getOwnerRequestGroup()->getPieceStorage());
|
||||||
m->setDownloadContext(dctx_);
|
m->setDownloadContext(dctx_);
|
||||||
return m;
|
return SharedHandle<ExtensionMessage>(m);
|
||||||
}
|
}
|
||||||
case 2: {
|
case 2: {
|
||||||
SharedHandle<UTMetadataRejectExtensionMessage> m
|
UTMetadataRejectExtensionMessage* m
|
||||||
(new UTMetadataRejectExtensionMessage(extensionMessageID));
|
(new UTMetadataRejectExtensionMessage(extensionMessageID));
|
||||||
m->setIndex(index->i());
|
m->setIndex(index->i());
|
||||||
// No need to inject tracker because peer will be disconnected.
|
// No need to inject tracker because peer will be disconnected.
|
||||||
return m;
|
return SharedHandle<ExtensionMessage>(m);
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
throw DL_ABORT_EX
|
throw DL_ABORT_EX
|
||||||
|
|
|
@ -164,7 +164,7 @@ uint8_t HandshakeExtensionMessage::getExtensionMessageID(int key) const
|
||||||
return extreg_.getExtensionMessageID(key);
|
return extreg_.getExtensionMessageID(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
SharedHandle<HandshakeExtensionMessage>
|
HandshakeExtensionMessage*
|
||||||
HandshakeExtensionMessage::create(const unsigned char* data, size_t length)
|
HandshakeExtensionMessage::create(const unsigned char* data, size_t length)
|
||||||
{
|
{
|
||||||
if(length < 1) {
|
if(length < 1) {
|
||||||
|
@ -172,7 +172,6 @@ HandshakeExtensionMessage::create(const unsigned char* data, size_t length)
|
||||||
(fmt(MSG_TOO_SMALL_PAYLOAD_SIZE,
|
(fmt(MSG_TOO_SMALL_PAYLOAD_SIZE,
|
||||||
EXTENSION_NAME, static_cast<unsigned long>(length)));
|
EXTENSION_NAME, static_cast<unsigned long>(length)));
|
||||||
}
|
}
|
||||||
SharedHandle<HandshakeExtensionMessage> msg(new HandshakeExtensionMessage());
|
|
||||||
A2_LOG_DEBUG(fmt("Creating HandshakeExtensionMessage from %s",
|
A2_LOG_DEBUG(fmt("Creating HandshakeExtensionMessage from %s",
|
||||||
util::percentEncode(data, length).c_str()));
|
util::percentEncode(data, length).c_str()));
|
||||||
SharedHandle<ValueBase> decoded = bencode2::decode(data+1, length - 1);
|
SharedHandle<ValueBase> decoded = bencode2::decode(data+1, length - 1);
|
||||||
|
@ -181,6 +180,7 @@ HandshakeExtensionMessage::create(const unsigned char* data, size_t length)
|
||||||
throw DL_ABORT_EX
|
throw DL_ABORT_EX
|
||||||
("Unexpected payload format for extended message handshake");
|
("Unexpected payload format for extended message handshake");
|
||||||
}
|
}
|
||||||
|
HandshakeExtensionMessage* msg(new HandshakeExtensionMessage());
|
||||||
const Integer* port = downcast<Integer>(dict->get("p"));
|
const Integer* port = downcast<Integer>(dict->get("p"));
|
||||||
if(port && 0 < port->i() && port->i() < 65536) {
|
if(port && 0 < port->i() && port->i() < 65536) {
|
||||||
msg->tcpPort_ = port->i();
|
msg->tcpPort_ = port->i();
|
||||||
|
|
|
@ -125,7 +125,7 @@ public:
|
||||||
|
|
||||||
void setPeer(const SharedHandle<Peer>& peer);
|
void setPeer(const SharedHandle<Peer>& peer);
|
||||||
|
|
||||||
static SharedHandle<HandshakeExtensionMessage>
|
static HandshakeExtensionMessage*
|
||||||
create(const unsigned char* data, size_t dataLength);
|
create(const unsigned char* data, size_t dataLength);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -181,7 +181,7 @@ void UTPexExtensionMessage::setPeerStorage
|
||||||
peerStorage_ = peerStorage;
|
peerStorage_ = peerStorage;
|
||||||
}
|
}
|
||||||
|
|
||||||
SharedHandle<UTPexExtensionMessage>
|
UTPexExtensionMessage*
|
||||||
UTPexExtensionMessage::create(const unsigned char* data, size_t len)
|
UTPexExtensionMessage::create(const unsigned char* data, size_t len)
|
||||||
{
|
{
|
||||||
if(len < 1) {
|
if(len < 1) {
|
||||||
|
@ -189,7 +189,7 @@ UTPexExtensionMessage::create(const unsigned char* data, size_t len)
|
||||||
EXTENSION_NAME,
|
EXTENSION_NAME,
|
||||||
static_cast<unsigned long>(len)));
|
static_cast<unsigned long>(len)));
|
||||||
}
|
}
|
||||||
SharedHandle<UTPexExtensionMessage> msg(new UTPexExtensionMessage(*data));
|
UTPexExtensionMessage* msg(new UTPexExtensionMessage(*data));
|
||||||
|
|
||||||
SharedHandle<ValueBase> decoded = bencode2::decode(data+1, len - 1);
|
SharedHandle<ValueBase> decoded = bencode2::decode(data+1, len - 1);
|
||||||
const Dict* dict = downcast<Dict>(decoded);
|
const Dict* dict = downcast<Dict>(decoded);
|
||||||
|
|
|
@ -111,7 +111,7 @@ public:
|
||||||
|
|
||||||
void setPeerStorage(const SharedHandle<PeerStorage>& peerStorage);
|
void setPeerStorage(const SharedHandle<PeerStorage>& peerStorage);
|
||||||
|
|
||||||
static SharedHandle<UTPexExtensionMessage>
|
static UTPexExtensionMessage*
|
||||||
create(const unsigned char* data, size_t len);
|
create(const unsigned char* data, size_t len);
|
||||||
|
|
||||||
void setMaxFreshPeer(size_t maxFreshPeer);
|
void setMaxFreshPeer(size_t maxFreshPeer);
|
||||||
|
|
|
@ -138,9 +138,10 @@ void HandshakeExtensionMessageTest::testCreate()
|
||||||
{
|
{
|
||||||
std::string in =
|
std::string in =
|
||||||
"0d1:pi6881e1:v5:aria21:md5:a2dhti2e6:ut_pexi1ee13:metadata_sizei1024ee";
|
"0d1:pi6881e1:v5:aria21:md5:a2dhti2e6:ut_pexi1ee13:metadata_sizei1024ee";
|
||||||
SharedHandle<HandshakeExtensionMessage> m =
|
SharedHandle<HandshakeExtensionMessage> m
|
||||||
HandshakeExtensionMessage::create(reinterpret_cast<const unsigned char*>(in.c_str()),
|
(HandshakeExtensionMessage::create
|
||||||
in.size());
|
(reinterpret_cast<const unsigned char*>(in.c_str()),
|
||||||
|
in.size()));
|
||||||
CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion());
|
CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion());
|
||||||
CPPUNIT_ASSERT_EQUAL((uint16_t)6881, m->getTCPPort());
|
CPPUNIT_ASSERT_EQUAL((uint16_t)6881, m->getTCPPort());
|
||||||
CPPUNIT_ASSERT_EQUAL((uint8_t)1,
|
CPPUNIT_ASSERT_EQUAL((uint8_t)1,
|
||||||
|
@ -179,10 +180,10 @@ void HandshakeExtensionMessageTest::testCreate()
|
||||||
void HandshakeExtensionMessageTest::testCreate_stringnum()
|
void HandshakeExtensionMessageTest::testCreate_stringnum()
|
||||||
{
|
{
|
||||||
std::string in = "0d1:p4:68811:v5:aria21:md6:ut_pex1:1ee";
|
std::string in = "0d1:p4:68811:v5:aria21:md6:ut_pex1:1ee";
|
||||||
SharedHandle<HandshakeExtensionMessage> m =
|
SharedHandle<HandshakeExtensionMessage> m
|
||||||
HandshakeExtensionMessage::create
|
(HandshakeExtensionMessage::create
|
||||||
(reinterpret_cast<const unsigned char*>(in.c_str()),
|
(reinterpret_cast<const unsigned char*>(in.c_str()),
|
||||||
in.size());
|
in.size()));
|
||||||
CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion());
|
CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion());
|
||||||
// port number in string is not allowed
|
// port number in string is not allowed
|
||||||
CPPUNIT_ASSERT_EQUAL((uint16_t)0, m->getTCPPort());
|
CPPUNIT_ASSERT_EQUAL((uint16_t)0, m->getTCPPort());
|
||||||
|
|
|
@ -203,9 +203,9 @@ void UTPexExtensionMessageTest::testCreate()
|
||||||
"8:dropped618:"+std::string(&c6[0], &c6[COMPACT_LEN_IPV6])+
|
"8:dropped618:"+std::string(&c6[0], &c6[COMPACT_LEN_IPV6])+
|
||||||
"e";
|
"e";
|
||||||
|
|
||||||
SharedHandle<UTPexExtensionMessage> msg =
|
SharedHandle<UTPexExtensionMessage> msg
|
||||||
UTPexExtensionMessage::create
|
(UTPexExtensionMessage::create
|
||||||
(reinterpret_cast<const unsigned char*>(data.c_str()), data.size());
|
(reinterpret_cast<const unsigned char*>(data.c_str()), data.size()));
|
||||||
CPPUNIT_ASSERT_EQUAL((uint8_t)1, msg->getExtensionMessageID());
|
CPPUNIT_ASSERT_EQUAL((uint8_t)1, msg->getExtensionMessageID());
|
||||||
CPPUNIT_ASSERT_EQUAL((size_t)3, msg->getFreshPeers().size());
|
CPPUNIT_ASSERT_EQUAL((size_t)3, msg->getFreshPeers().size());
|
||||||
CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.1"),
|
CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.1"),
|
||||||
|
|
Loading…
Reference in New Issue