Use std::unique_ptr to receive BtMessage

pull/103/head
Tatsuhiro Tsujikawa 2013-07-04 22:35:41 +09:00
parent 0cdeaa8177
commit c6a733378f
7 changed files with 51 additions and 46 deletions

View File

@ -41,7 +41,8 @@
namespace aria2 { namespace aria2 {
class BtMessage;
class BtHandshakeMessage;
class BtInteractive { class BtInteractive {
public: public:
@ -49,9 +50,10 @@ public:
virtual void initiateHandshake() = 0; virtual void initiateHandshake() = 0;
virtual std::shared_ptr<BtMessage> receiveHandshake(bool quickReply = false) = 0; virtual std::unique_ptr<BtHandshakeMessage> receiveHandshake
(bool quickReply = false) = 0;
virtual std::shared_ptr<BtMessage> receiveAndSendHandshake() = 0; virtual std::unique_ptr<BtHandshakeMessage> receiveAndSendHandshake() = 0;
virtual void doPostHandshakeProcessing() = 0; virtual void doPostHandshakeProcessing() = 0;

View File

@ -48,11 +48,12 @@ class BtMessageReceiver {
public: public:
virtual ~BtMessageReceiver() {} virtual ~BtMessageReceiver() {}
virtual std::shared_ptr<BtHandshakeMessage> receiveHandshake(bool quickReply = false) = 0; virtual std::unique_ptr<BtHandshakeMessage> receiveHandshake
(bool quickReply = false) = 0;
virtual std::shared_ptr<BtHandshakeMessage> receiveAndSendHandshake() = 0; virtual std::unique_ptr<BtHandshakeMessage> receiveAndSendHandshake() = 0;
virtual std::shared_ptr<BtMessage> receiveMessage() = 0; virtual std::unique_ptr<BtMessage> receiveMessage() = 0;
}; };
} // namespace aria2 } // namespace aria2

View File

@ -119,9 +119,9 @@ void DefaultBtInteractive::initiateHandshake() {
dispatcher_->sendMessages(); dispatcher_->sendMessages();
} }
std::shared_ptr<BtMessage> DefaultBtInteractive::receiveHandshake(bool quickReply) { std::unique_ptr<BtHandshakeMessage> DefaultBtInteractive::receiveHandshake
std::shared_ptr<BtHandshakeMessage> message = (bool quickReply) {
btMessageReceiver_->receiveHandshake(quickReply); auto message = btMessageReceiver_->receiveHandshake(quickReply);
if(!message) { if(!message) {
return nullptr; return nullptr;
} }
@ -131,11 +131,9 @@ std::shared_ptr<BtMessage> DefaultBtInteractive::receiveHandshake(bool quickRepl
(fmt("CUID#%" PRId64 " - Drop connection from the same Peer ID", (fmt("CUID#%" PRId64 " - Drop connection from the same Peer ID",
cuid_)); cuid_));
} }
const PeerSet& usedPeers = peerStorage_->getUsedPeers(); for(auto& peer : peerStorage_->getUsedPeers()) {
for(PeerSet::const_iterator i = usedPeers.begin(), eoi = usedPeers.end(); if(peer->isActive() &&
i != eoi; ++i) { memcmp(peer->getPeerId(), message->getPeerId(), PEER_ID_LENGTH) == 0) {
if((*i)->isActive() &&
memcmp((*i)->getPeerId(), message->getPeerId(), PEER_ID_LENGTH) == 0) {
throw DL_ABORT_EX throw DL_ABORT_EX
(fmt("CUID#%" PRId64 " - Same Peer ID has been already seen.", (fmt("CUID#%" PRId64 " - Same Peer ID has been already seen.",
cuid_)); cuid_));
@ -166,7 +164,9 @@ std::shared_ptr<BtMessage> DefaultBtInteractive::receiveHandshake(bool quickRepl
return message; return message;
} }
std::shared_ptr<BtMessage> DefaultBtInteractive::receiveAndSendHandshake() { std::unique_ptr<BtHandshakeMessage>
DefaultBtInteractive::receiveAndSendHandshake()
{
return receiveHandshake(true); return receiveHandshake(true);
} }
@ -297,7 +297,7 @@ size_t DefaultBtInteractive::receiveMessages() {
downloadContext_->getOwnerRequestGroup()->doesDownloadSpeedExceed()) { downloadContext_->getOwnerRequestGroup()->doesDownloadSpeedExceed()) {
break; break;
} }
std::shared_ptr<BtMessage> message = btMessageReceiver_->receiveMessage(); auto message = btMessageReceiver_->receiveMessage();
if(!message) { if(!message) {
break; break;
} }

View File

@ -173,9 +173,10 @@ public:
virtual void initiateHandshake(); virtual void initiateHandshake();
virtual std::shared_ptr<BtMessage> receiveHandshake(bool quickReply = false); virtual std::unique_ptr<BtHandshakeMessage> receiveHandshake
(bool quickReply = false);
virtual std::shared_ptr<BtMessage> receiveAndSendHandshake(); virtual std::unique_ptr<BtHandshakeMessage> receiveAndSendHandshake();
virtual void doPostHandshakeProcessing(); virtual void doPostHandshakeProcessing();

View File

@ -54,14 +54,14 @@
namespace aria2 { namespace aria2 {
DefaultBtMessageReceiver::DefaultBtMessageReceiver(): DefaultBtMessageReceiver::DefaultBtMessageReceiver():
handshakeSent_(false), handshakeSent_{false},
downloadContext_{0}, downloadContext_{nullptr},
peerConnection_(0), peerConnection_{nullptr},
dispatcher_(0), dispatcher_{nullptr},
messageFactory_(0) messageFactory_{nullptr}
{} {}
std::shared_ptr<BtHandshakeMessage> std::unique_ptr<BtHandshakeMessage>
DefaultBtMessageReceiver::receiveHandshake(bool quickReply) DefaultBtMessageReceiver::receiveHandshake(bool quickReply)
{ {
A2_LOG_DEBUG A2_LOG_DEBUG
@ -69,15 +69,14 @@ DefaultBtMessageReceiver::receiveHandshake(bool quickReply)
static_cast<unsigned long>(peerConnection_->getBufferLength()))); static_cast<unsigned long>(peerConnection_->getBufferLength())));
unsigned char data[BtHandshakeMessage::MESSAGE_LENGTH]; unsigned char data[BtHandshakeMessage::MESSAGE_LENGTH];
size_t dataLength = BtHandshakeMessage::MESSAGE_LENGTH; size_t dataLength = BtHandshakeMessage::MESSAGE_LENGTH;
std::shared_ptr<BtHandshakeMessage> msg;
if(handshakeSent_ || !quickReply || peerConnection_->getBufferLength() < 48) { if(handshakeSent_ || !quickReply || peerConnection_->getBufferLength() < 48) {
if(peerConnection_->receiveHandshake(data, dataLength)) { if(peerConnection_->receiveHandshake(data, dataLength)) {
msg = messageFactory_->createHandshakeMessage(data, dataLength); auto msg = messageFactory_->createHandshakeMessage(data, dataLength);
msg->validate(); msg->validate();
return msg;
} }
} } else {
// Handle tracker's NAT-checking feature // Handle tracker's NAT-checking feature
if(!handshakeSent_ && quickReply && peerConnection_->getBufferLength() >= 48){
handshakeSent_ = true; handshakeSent_ = true;
// check info_hash // check info_hash
if(memcmp(bittorrent::getInfoHash(downloadContext_), if(memcmp(bittorrent::getInfoHash(downloadContext_),
@ -90,24 +89,25 @@ DefaultBtMessageReceiver::receiveHandshake(bool quickReply)
util::toHex(peerConnection_->getBuffer()+28, util::toHex(peerConnection_->getBuffer()+28,
INFO_HASH_LENGTH).c_str())); INFO_HASH_LENGTH).c_str()));
} }
if(!msg && if(peerConnection_->getBufferLength() ==
peerConnection_->getBufferLength() ==
BtHandshakeMessage::MESSAGE_LENGTH && BtHandshakeMessage::MESSAGE_LENGTH &&
peerConnection_->receiveHandshake(data, dataLength)) { peerConnection_->receiveHandshake(data, dataLength)) {
msg = messageFactory_->createHandshakeMessage(data, dataLength); auto msg = messageFactory_->createHandshakeMessage(data, dataLength);
msg->validate(); msg->validate();
return msg;
} }
} }
return msg; return nullptr;
} }
std::shared_ptr<BtHandshakeMessage> std::unique_ptr<BtHandshakeMessage>
DefaultBtMessageReceiver::receiveAndSendHandshake() DefaultBtMessageReceiver::receiveAndSendHandshake()
{ {
return receiveHandshake(true); return receiveHandshake(true);
} }
void DefaultBtMessageReceiver::sendHandshake() { void DefaultBtMessageReceiver::sendHandshake()
{
dispatcher_->addMessageToQueue dispatcher_->addMessageToQueue
(messageFactory_->createHandshakeMessage (messageFactory_->createHandshakeMessage
(bittorrent::getInfoHash(downloadContext_), (bittorrent::getInfoHash(downloadContext_),
@ -115,18 +115,19 @@ void DefaultBtMessageReceiver::sendHandshake() {
dispatcher_->sendMessages(); dispatcher_->sendMessages();
} }
std::shared_ptr<BtMessage> DefaultBtMessageReceiver::receiveMessage() { std::unique_ptr<BtMessage> DefaultBtMessageReceiver::receiveMessage()
{
size_t dataLength = 0; size_t dataLength = 0;
// Give 0 to PeerConnection::receiveMessage() to prevent memcpy. // Give 0 to PeerConnection::receiveMessage() to prevent memcpy.
if(!peerConnection_->receiveMessage(0, dataLength)) { if(!peerConnection_->receiveMessage(0, dataLength)) {
return nullptr; return nullptr;
} }
std::shared_ptr<BtMessage> msg = auto msg =
messageFactory_->createBtMessage(peerConnection_->getMsgPayloadBuffer(), messageFactory_->createBtMessage(peerConnection_->getMsgPayloadBuffer(),
dataLength); dataLength);
msg->validate(); msg->validate();
if(msg->getId() == BtPieceMessage::ID) { if(msg->getId() == BtPieceMessage::ID) {
auto piecemsg = std::static_pointer_cast<BtPieceMessage>(msg); auto piecemsg = static_cast<BtPieceMessage*>(msg.get());
piecemsg->setMsgPayload(peerConnection_->getMsgPayloadBuffer()); piecemsg->setMsgPayload(peerConnection_->getMsgPayloadBuffer());
} }
return msg; return msg;
@ -138,7 +139,8 @@ void DefaultBtMessageReceiver::setDownloadContext
downloadContext_ = downloadContext; downloadContext_ = downloadContext;
} }
void DefaultBtMessageReceiver::setPeerConnection(PeerConnection* peerConnection) void DefaultBtMessageReceiver::setPeerConnection
(PeerConnection* peerConnection)
{ {
peerConnection_ = peerConnection; peerConnection_ = peerConnection;
} }

View File

@ -58,12 +58,12 @@ private:
public: public:
DefaultBtMessageReceiver(); DefaultBtMessageReceiver();
virtual std::shared_ptr<BtHandshakeMessage> receiveHandshake virtual std::unique_ptr<BtHandshakeMessage> receiveHandshake
(bool quickReply = false); (bool quickReply = false);
virtual std::shared_ptr<BtHandshakeMessage> receiveAndSendHandshake(); virtual std::unique_ptr<BtHandshakeMessage> receiveAndSendHandshake();
virtual std::shared_ptr<BtMessage> receiveMessage(); virtual std::unique_ptr<BtMessage> receiveMessage();
void setDownloadContext(DownloadContext* downloadContext); void setDownloadContext(DownloadContext* downloadContext);

View File

@ -47,6 +47,7 @@
#include "DownloadContext.h" #include "DownloadContext.h"
#include "Peer.h" #include "Peer.h"
#include "BtMessage.h" #include "BtMessage.h"
#include "BtHandshakeMessage.h"
#include "BtRuntime.h" #include "BtRuntime.h"
#include "PeerStorage.h" #include "PeerStorage.h"
#include "DefaultBtMessageDispatcher.h" #include "DefaultBtMessageDispatcher.h"
@ -327,8 +328,7 @@ bool PeerInteractionCommand::executeInternal() {
break; break;
} }
} }
std::shared_ptr<BtMessage> handshakeMessage = auto handshakeMessage = btInteractive_->receiveHandshake();
btInteractive_->receiveHandshake();
if(!handshakeMessage) { if(!handshakeMessage) {
done = true; done = true;
break; break;
@ -338,8 +338,7 @@ bool PeerInteractionCommand::executeInternal() {
break; break;
} }
case RECEIVER_WAIT_HANDSHAKE: { case RECEIVER_WAIT_HANDSHAKE: {
std::shared_ptr<BtMessage> handshakeMessage = auto handshakeMessage = btInteractive_->receiveAndSendHandshake();
btInteractive_->receiveAndSendHandshake();
if(!handshakeMessage) { if(!handshakeMessage) {
done = true; done = true;
break; break;