Return ExtensionMessage subclass create return raw pointer

pull/28/head
Tatsuhiro Tsujikawa 2012-09-28 23:40:44 +09:00
parent 3258614033
commit 4b94ede268
7 changed files with 31 additions and 31 deletions

View File

@ -78,11 +78,11 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
uint8_t extensionMessageID = *data; uint8_t extensionMessageID = *data;
if(extensionMessageID == 0) { if(extensionMessageID == 0) {
// handshake // handshake
SharedHandle<HandshakeExtensionMessage> m = HandshakeExtensionMessage* m =
HandshakeExtensionMessage::create(data, length); HandshakeExtensionMessage::create(data, length);
m->setPeer(peer_); m->setPeer(peer_);
m->setDownloadContext(dctx_); m->setDownloadContext(dctx_);
return m; return SharedHandle<ExtensionMessage>(m);
} else { } else {
const char* extensionName = registry_->getExtensionName(extensionMessageID); const char* extensionName = registry_->getExtensionName(extensionMessageID);
if(!extensionName) { if(!extensionName) {
@ -92,10 +92,9 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
} }
if(strcmp(extensionName, "ut_pex") == 0) { if(strcmp(extensionName, "ut_pex") == 0) {
// uTorrent compatible Peer-Exchange // uTorrent compatible Peer-Exchange
SharedHandle<UTPexExtensionMessage> m = UTPexExtensionMessage* m = UTPexExtensionMessage::create(data, length);
UTPexExtensionMessage::create(data, length);
m->setPeerStorage(peerStorage_); m->setPeerStorage(peerStorage_);
return m; return SharedHandle<ExtensionMessage>(m);
} else if(strcmp(extensionName, "ut_metadata") == 0) { } else if(strcmp(extensionName, "ut_metadata") == 0) {
if(length == 0) { if(length == 0) {
throw DL_ABORT_EX throw DL_ABORT_EX
@ -120,14 +119,14 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
} }
switch(msgType->i()) { switch(msgType->i()) {
case 0: { case 0: {
SharedHandle<UTMetadataRequestExtensionMessage> m UTMetadataRequestExtensionMessage* m
(new UTMetadataRequestExtensionMessage(extensionMessageID)); (new UTMetadataRequestExtensionMessage(extensionMessageID));
m->setIndex(index->i()); m->setIndex(index->i());
m->setDownloadContext(dctx_); m->setDownloadContext(dctx_);
m->setPeer(peer_); m->setPeer(peer_);
m->setBtMessageFactory(messageFactory_); m->setBtMessageFactory(messageFactory_);
m->setBtMessageDispatcher(dispatcher_); m->setBtMessageDispatcher(dispatcher_);
return m; return SharedHandle<ExtensionMessage>(m);
} }
case 1: { case 1: {
if(end == length) { if(end == length) {
@ -137,7 +136,7 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
if(!totalSize) { if(!totalSize) {
throw DL_ABORT_EX("Bad ut_metadata data: total_size not found"); throw DL_ABORT_EX("Bad ut_metadata data: total_size not found");
} }
SharedHandle<UTMetadataDataExtensionMessage> m UTMetadataDataExtensionMessage* m
(new UTMetadataDataExtensionMessage(extensionMessageID)); (new UTMetadataDataExtensionMessage(extensionMessageID));
m->setIndex(index->i()); m->setIndex(index->i());
m->setTotalSize(totalSize->i()); m->setTotalSize(totalSize->i());
@ -145,14 +144,14 @@ DefaultExtensionMessageFactory::createMessage(const unsigned char* data, size_t
m->setUTMetadataRequestTracker(tracker_); m->setUTMetadataRequestTracker(tracker_);
m->setPieceStorage(dctx_->getOwnerRequestGroup()->getPieceStorage()); m->setPieceStorage(dctx_->getOwnerRequestGroup()->getPieceStorage());
m->setDownloadContext(dctx_); m->setDownloadContext(dctx_);
return m; return SharedHandle<ExtensionMessage>(m);
} }
case 2: { case 2: {
SharedHandle<UTMetadataRejectExtensionMessage> m UTMetadataRejectExtensionMessage* m
(new UTMetadataRejectExtensionMessage(extensionMessageID)); (new UTMetadataRejectExtensionMessage(extensionMessageID));
m->setIndex(index->i()); m->setIndex(index->i());
// No need to inject tracker because peer will be disconnected. // No need to inject tracker because peer will be disconnected.
return m; return SharedHandle<ExtensionMessage>(m);
} }
default: default:
throw DL_ABORT_EX throw DL_ABORT_EX

View File

@ -164,7 +164,7 @@ uint8_t HandshakeExtensionMessage::getExtensionMessageID(int key) const
return extreg_.getExtensionMessageID(key); return extreg_.getExtensionMessageID(key);
} }
SharedHandle<HandshakeExtensionMessage> HandshakeExtensionMessage*
HandshakeExtensionMessage::create(const unsigned char* data, size_t length) HandshakeExtensionMessage::create(const unsigned char* data, size_t length)
{ {
if(length < 1) { if(length < 1) {
@ -172,7 +172,6 @@ HandshakeExtensionMessage::create(const unsigned char* data, size_t length)
(fmt(MSG_TOO_SMALL_PAYLOAD_SIZE, (fmt(MSG_TOO_SMALL_PAYLOAD_SIZE,
EXTENSION_NAME, static_cast<unsigned long>(length))); EXTENSION_NAME, static_cast<unsigned long>(length)));
} }
SharedHandle<HandshakeExtensionMessage> msg(new HandshakeExtensionMessage());
A2_LOG_DEBUG(fmt("Creating HandshakeExtensionMessage from %s", A2_LOG_DEBUG(fmt("Creating HandshakeExtensionMessage from %s",
util::percentEncode(data, length).c_str())); util::percentEncode(data, length).c_str()));
SharedHandle<ValueBase> decoded = bencode2::decode(data+1, length - 1); SharedHandle<ValueBase> decoded = bencode2::decode(data+1, length - 1);
@ -181,6 +180,7 @@ HandshakeExtensionMessage::create(const unsigned char* data, size_t length)
throw DL_ABORT_EX throw DL_ABORT_EX
("Unexpected payload format for extended message handshake"); ("Unexpected payload format for extended message handshake");
} }
HandshakeExtensionMessage* msg(new HandshakeExtensionMessage());
const Integer* port = downcast<Integer>(dict->get("p")); const Integer* port = downcast<Integer>(dict->get("p"));
if(port && 0 < port->i() && port->i() < 65536) { if(port && 0 < port->i() && port->i() < 65536) {
msg->tcpPort_ = port->i(); msg->tcpPort_ = port->i();

View File

@ -125,7 +125,7 @@ public:
void setPeer(const SharedHandle<Peer>& peer); void setPeer(const SharedHandle<Peer>& peer);
static SharedHandle<HandshakeExtensionMessage> static HandshakeExtensionMessage*
create(const unsigned char* data, size_t dataLength); create(const unsigned char* data, size_t dataLength);
}; };

View File

@ -181,7 +181,7 @@ void UTPexExtensionMessage::setPeerStorage
peerStorage_ = peerStorage; peerStorage_ = peerStorage;
} }
SharedHandle<UTPexExtensionMessage> UTPexExtensionMessage*
UTPexExtensionMessage::create(const unsigned char* data, size_t len) UTPexExtensionMessage::create(const unsigned char* data, size_t len)
{ {
if(len < 1) { if(len < 1) {
@ -189,7 +189,7 @@ UTPexExtensionMessage::create(const unsigned char* data, size_t len)
EXTENSION_NAME, EXTENSION_NAME,
static_cast<unsigned long>(len))); static_cast<unsigned long>(len)));
} }
SharedHandle<UTPexExtensionMessage> msg(new UTPexExtensionMessage(*data)); UTPexExtensionMessage* msg(new UTPexExtensionMessage(*data));
SharedHandle<ValueBase> decoded = bencode2::decode(data+1, len - 1); SharedHandle<ValueBase> decoded = bencode2::decode(data+1, len - 1);
const Dict* dict = downcast<Dict>(decoded); const Dict* dict = downcast<Dict>(decoded);

View File

@ -111,7 +111,7 @@ public:
void setPeerStorage(const SharedHandle<PeerStorage>& peerStorage); void setPeerStorage(const SharedHandle<PeerStorage>& peerStorage);
static SharedHandle<UTPexExtensionMessage> static UTPexExtensionMessage*
create(const unsigned char* data, size_t len); create(const unsigned char* data, size_t len);
void setMaxFreshPeer(size_t maxFreshPeer); void setMaxFreshPeer(size_t maxFreshPeer);

View File

@ -138,9 +138,10 @@ void HandshakeExtensionMessageTest::testCreate()
{ {
std::string in = std::string in =
"0d1:pi6881e1:v5:aria21:md5:a2dhti2e6:ut_pexi1ee13:metadata_sizei1024ee"; "0d1:pi6881e1:v5:aria21:md5:a2dhti2e6:ut_pexi1ee13:metadata_sizei1024ee";
SharedHandle<HandshakeExtensionMessage> m = SharedHandle<HandshakeExtensionMessage> m
HandshakeExtensionMessage::create(reinterpret_cast<const unsigned char*>(in.c_str()), (HandshakeExtensionMessage::create
in.size()); (reinterpret_cast<const unsigned char*>(in.c_str()),
in.size()));
CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion()); CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion());
CPPUNIT_ASSERT_EQUAL((uint16_t)6881, m->getTCPPort()); CPPUNIT_ASSERT_EQUAL((uint16_t)6881, m->getTCPPort());
CPPUNIT_ASSERT_EQUAL((uint8_t)1, CPPUNIT_ASSERT_EQUAL((uint8_t)1,
@ -179,10 +180,10 @@ void HandshakeExtensionMessageTest::testCreate()
void HandshakeExtensionMessageTest::testCreate_stringnum() void HandshakeExtensionMessageTest::testCreate_stringnum()
{ {
std::string in = "0d1:p4:68811:v5:aria21:md6:ut_pex1:1ee"; std::string in = "0d1:p4:68811:v5:aria21:md6:ut_pex1:1ee";
SharedHandle<HandshakeExtensionMessage> m = SharedHandle<HandshakeExtensionMessage> m
HandshakeExtensionMessage::create (HandshakeExtensionMessage::create
(reinterpret_cast<const unsigned char*>(in.c_str()), (reinterpret_cast<const unsigned char*>(in.c_str()),
in.size()); in.size()));
CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion()); CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion());
// port number in string is not allowed // port number in string is not allowed
CPPUNIT_ASSERT_EQUAL((uint16_t)0, m->getTCPPort()); CPPUNIT_ASSERT_EQUAL((uint16_t)0, m->getTCPPort());

View File

@ -203,9 +203,9 @@ void UTPexExtensionMessageTest::testCreate()
"8:dropped618:"+std::string(&c6[0], &c6[COMPACT_LEN_IPV6])+ "8:dropped618:"+std::string(&c6[0], &c6[COMPACT_LEN_IPV6])+
"e"; "e";
SharedHandle<UTPexExtensionMessage> msg = SharedHandle<UTPexExtensionMessage> msg
UTPexExtensionMessage::create (UTPexExtensionMessage::create
(reinterpret_cast<const unsigned char*>(data.c_str()), data.size()); (reinterpret_cast<const unsigned char*>(data.c_str()), data.size()));
CPPUNIT_ASSERT_EQUAL((uint8_t)1, msg->getExtensionMessageID()); CPPUNIT_ASSERT_EQUAL((uint8_t)1, msg->getExtensionMessageID());
CPPUNIT_ASSERT_EQUAL((size_t)3, msg->getFreshPeers().size()); CPPUNIT_ASSERT_EQUAL((size_t)3, msg->getFreshPeers().size());
CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.1"), CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.1"),