diff --git a/ChangeLog b/ChangeLog index d3e8d740..4dc68a8a 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,20 @@ +2008-10-05 Tatsuhiro Tsujikawa + + Made socket for dht connections non-blocking + * src/DHTAbstractMessage.cc + * src/DHTAbstractMessage.h + * src/DHTConnection.h + * src/DHTConnectionImpl.cc + * src/DHTConnectionImpl.h + * src/DHTMessage.h + * src/DHTMessageDispatcherImpl.cc + * src/DHTMessageDispatcherImpl.h + * src/DHTUnknownMessage.cc + * src/DHTUnknownMessage.h + * src/SocketCore.cc + * src/SocketCore.h + * test/MockDHTMessage.h + 2008-10-05 Tatsuhiro Tsujikawa Changed the type of offset to int. diff --git a/src/DHTAbstractMessage.cc b/src/DHTAbstractMessage.cc index fe6b46b9..3a25ef63 100644 --- a/src/DHTAbstractMessage.cc +++ b/src/DHTAbstractMessage.cc @@ -33,6 +33,9 @@ */ /* copyright --> */ #include "DHTAbstractMessage.h" + +#include + #include "DHTNode.h" #include "BencodeVisitor.h" #include "DHTConnection.h" @@ -64,13 +67,16 @@ std::string DHTAbstractMessage::getBencodedMessage() return v.getBencodedData(); } -void DHTAbstractMessage::send() +bool DHTAbstractMessage::send() { std::string message = getBencodedMessage(); - _connection->sendMessage(reinterpret_cast(message.c_str()), - message.size(), - _remoteNode->getIPAddress(), - _remoteNode->getPort()); + ssize_t r = _connection->sendMessage + (reinterpret_cast(message.c_str()), + message.size(), + _remoteNode->getIPAddress(), + _remoteNode->getPort()); + assert(r >= 0); + return r == static_cast(message.size()); } void DHTAbstractMessage::setConnection(const WeakHandle& connection) diff --git a/src/DHTAbstractMessage.h b/src/DHTAbstractMessage.h index 16643a57..fc0203bb 100644 --- a/src/DHTAbstractMessage.h +++ b/src/DHTAbstractMessage.h @@ -62,7 +62,7 @@ public: virtual ~DHTAbstractMessage(); - virtual void send(); + virtual bool send(); virtual std::string getType() const = 0; diff --git a/src/DHTConnection.h b/src/DHTConnection.h index c9773ab4..c2a8e8f6 100644 --- a/src/DHTConnection.h +++ b/src/DHTConnection.h @@ -44,9 +44,11 @@ class DHTConnection { public: virtual ~DHTConnection() {} - virtual ssize_t receiveMessage(unsigned char* data, size_t len, std::string& host, uint16_t& port) = 0; + virtual ssize_t receiveMessage(unsigned char* data, size_t len, + std::string& host, uint16_t& port) = 0; - virtual void sendMessage(const unsigned char* data, size_t len, const std::string& host, uint16_t port) = 0; + virtual ssize_t sendMessage(const unsigned char* data, size_t len, + const std::string& host, uint16_t port) = 0; }; } // namespace aria2 diff --git a/src/DHTConnectionImpl.cc b/src/DHTConnectionImpl.cc index 6b285356..aa7638f8 100644 --- a/src/DHTConnectionImpl.cc +++ b/src/DHTConnectionImpl.cc @@ -33,12 +33,14 @@ */ /* copyright --> */ #include "DHTConnectionImpl.h" + +#include + #include "LogFactory.h" #include "Logger.h" #include "RecoverableException.h" #include "Util.h" #include "Socket.h" -#include namespace aria2 { @@ -66,6 +68,7 @@ bool DHTConnectionImpl::bind(uint16_t& port) { try { _socket->bind(port); + _socket->setNonBlockingMode(); std::pair svaddr; _socket->getAddrInfo(svaddr); port = svaddr.second; @@ -77,22 +80,24 @@ bool DHTConnectionImpl::bind(uint16_t& port) return false; } -ssize_t DHTConnectionImpl::receiveMessage(unsigned char* data, size_t len, std::string& host, uint16_t& port) +ssize_t DHTConnectionImpl::receiveMessage(unsigned char* data, size_t len, + std::string& host, uint16_t& port) { - if(_socket->isReadable(0)) { - std::pair remoteHost; - ssize_t length = _socket->readDataFrom(data, len, remoteHost); + std::pair remoteHost; + ssize_t length = _socket->readDataFrom(data, len, remoteHost); + if(length == 0) { + return length; + } else { host = remoteHost.first; port = remoteHost.second; return length; - } else { - return -1; } } -void DHTConnectionImpl::sendMessage(const unsigned char* data, size_t len, const std::string& host, uint16_t port) +ssize_t DHTConnectionImpl::sendMessage(const unsigned char* data, size_t len, + const std::string& host, uint16_t port) { - _socket->writeData(data, len, host, port); + return _socket->writeData(data, len, host, port); } SharedHandle DHTConnectionImpl::getSocket() const diff --git a/src/DHTConnectionImpl.h b/src/DHTConnectionImpl.h index 8806d6cd..d34c17b7 100644 --- a/src/DHTConnectionImpl.h +++ b/src/DHTConnectionImpl.h @@ -71,9 +71,11 @@ public: */ bool bind(uint16_t& port); - virtual ssize_t receiveMessage(unsigned char* data, size_t len, std::string& host, uint16_t& port); + virtual ssize_t receiveMessage(unsigned char* data, size_t len, + std::string& host, uint16_t& port); - virtual void sendMessage(const unsigned char* data, size_t len, const std::string& host, uint16_t port); + virtual ssize_t sendMessage(const unsigned char* data, size_t len, + const std::string& host, uint16_t port); SharedHandle getSocket() const; }; diff --git a/src/DHTMessage.h b/src/DHTMessage.h index bc018fdd..ae7bfd28 100644 --- a/src/DHTMessage.h +++ b/src/DHTMessage.h @@ -36,9 +36,11 @@ #define _D_DHT_MESSAGE_H_ #include "common.h" + +#include + #include "SharedHandle.h" #include "A2STR.h" -#include namespace aria2 { @@ -71,7 +73,7 @@ public: virtual void doReceivedAction() = 0; - virtual void send() = 0; + virtual bool send() = 0; virtual bool isReply() const = 0; diff --git a/src/DHTMessageDispatcherImpl.cc b/src/DHTMessageDispatcherImpl.cc index 1abe373e..e06456a1 100644 --- a/src/DHTMessageDispatcherImpl.cc +++ b/src/DHTMessageDispatcherImpl.cc @@ -67,28 +67,38 @@ DHTMessageDispatcherImpl::addMessageToQueue(const SharedHandle& mess addMessageToQueue(message, DHT_MESSAGE_TIMEOUT, callback); } -void +bool DHTMessageDispatcherImpl::sendMessage(const SharedHandle& entry) { try { - entry->_message->send(); - if(!entry->_message->isReply()) { - _tracker->addMessage(entry->_message, entry->_timeout, entry->_callback); + if(entry->_message->send()) { + if(!entry->_message->isReply()) { + _tracker->addMessage(entry->_message, entry->_timeout, entry->_callback); + } + _logger->info("Message sent: %s", entry->_message->toString().c_str()); + } else { + return false; } - _logger->info("Message sent: %s", entry->_message->toString().c_str()); } catch(RecoverableException& e) { _logger->error("Failed to send message: %s", e, entry->_message->toString().c_str()); } + return true; } void DHTMessageDispatcherImpl::sendMessages() { // TODO I can't use bind1st and mem_fun here because bind1st cannot bind a // function which takes a reference as an argument.. - for(std::deque >::iterator itr = _messageQueue.begin(); itr != _messageQueue.end(); ++itr) { - sendMessage(*itr); + std::deque >::iterator itr = + _messageQueue.begin(); + for(; itr != _messageQueue.end(); ++itr) { + if(!sendMessage(*itr)) { + break; + } } - _messageQueue.clear(); + _messageQueue.erase(_messageQueue.begin(), itr); + _logger->debug("%lu dht messages remaining in the queue.", + static_cast(_messageQueue.size())); } size_t DHTMessageDispatcherImpl::countMessageInQueue() const diff --git a/src/DHTMessageDispatcherImpl.h b/src/DHTMessageDispatcherImpl.h index d18d04a3..428998b9 100644 --- a/src/DHTMessageDispatcherImpl.h +++ b/src/DHTMessageDispatcherImpl.h @@ -52,7 +52,7 @@ private: Logger* _logger; - void sendMessage(const SharedHandle& msg); + bool sendMessage(const SharedHandle& msg); public: DHTMessageDispatcherImpl(const SharedHandle& tracker); diff --git a/src/DHTUnknownMessage.cc b/src/DHTUnknownMessage.cc index ab2b0fa2..fdc1aca0 100644 --- a/src/DHTUnknownMessage.cc +++ b/src/DHTUnknownMessage.cc @@ -33,9 +33,12 @@ */ /* copyright --> */ #include "DHTUnknownMessage.h" + +#include +#include + #include "DHTNode.h" #include "Util.h" -#include namespace aria2 { @@ -66,7 +69,7 @@ DHTUnknownMessage::~DHTUnknownMessage() void DHTUnknownMessage::doReceivedAction() {} -void DHTUnknownMessage::send() {} +bool DHTUnknownMessage::send() { return true; } bool DHTUnknownMessage::isReply() const { diff --git a/src/DHTUnknownMessage.h b/src/DHTUnknownMessage.h index 47c37584..08ec0ce8 100644 --- a/src/DHTUnknownMessage.h +++ b/src/DHTUnknownMessage.h @@ -57,7 +57,7 @@ public: virtual void doReceivedAction(); // do nothing; we don't use this message as outgoing message. - virtual void send(); + virtual bool send(); // always return false virtual bool isReply() const; diff --git a/src/SocketCore.cc b/src/SocketCore.cc index c7025fd9..0ba6ed20 100644 --- a/src/SocketCore.cc +++ b/src/SocketCore.cc @@ -839,8 +839,11 @@ bool SocketCore::initiateSecureConnection() #endif // __MINGW32__ } -void SocketCore::writeData(const char* data, size_t len, const std::string& host, uint16_t port) +ssize_t SocketCore::writeData(const char* data, size_t len, + const std::string& host, uint16_t port) { + _wantRead = false; + _wantWrite = false; struct addrinfo hints; struct addrinfo* res; @@ -861,17 +864,25 @@ void SocketCore::writeData(const char* data, size_t len, const std::string& host if(r == static_cast(len)) { break; } + if(r == -1 && errno == EAGAIN) { + _wantWrite = true; + r = 0; + break; + } } freeaddrinfo(res); if(r == -1) { throw DlAbortEx(StringFormat(EX_SOCKET_SEND, errorMsg()).str()); } + return r; } ssize_t SocketCore::readDataFrom(char* data, size_t len, std::pair& sender) { + _wantRead = false; + _wantWrite = false; struct sockaddr_storage sockaddr; socklen_t sockaddrlen = sizeof(struct sockaddr_storage); struct sockaddr* addrp = reinterpret_cast(&sockaddr); @@ -879,9 +890,15 @@ ssize_t SocketCore::readDataFrom(char* data, size_t len, while((r = recvfrom(sockfd, data, len, 0, addrp, &sockaddrlen)) == -1 && EINTR == errno); if(r == -1) { - throw DlAbortEx(StringFormat(EX_SOCKET_RECV, errorMsg()).str()); + if(errno == EAGAIN) { + _wantRead = true; + r = 0; + } else { + throw DlRetryEx(StringFormat(EX_SOCKET_RECV, errorMsg()).str()); + } + } else { + sender = Util::getNumericNameInfo(addrp, sockaddrlen); } - sender = Util::getNumericNameInfo(addrp, sockaddrlen); return r; } diff --git a/src/SocketCore.h b/src/SocketCore.h index 1392e456..fd1a3396 100644 --- a/src/SocketCore.h +++ b/src/SocketCore.h @@ -222,12 +222,14 @@ public: return writeData(reinterpret_cast(data), len); } - void writeData(const char* data, size_t len, const std::string& host, uint16_t port); + ssize_t writeData(const char* data, size_t len, + const std::string& host, uint16_t port); - void writeData(const unsigned char* data, size_t len, const std::string& host, - uint16_t port) + ssize_t writeData(const unsigned char* data, size_t len, + const std::string& host, + uint16_t port) { - writeData(reinterpret_cast(data), len, host, port); + return writeData(reinterpret_cast(data), len, host, port); } /** diff --git a/test/MockDHTMessage.h b/test/MockDHTMessage.h index 32ed9647..3c6b0f50 100644 --- a/test/MockDHTMessage.h +++ b/test/MockDHTMessage.h @@ -2,9 +2,11 @@ #define _D_MOCK_DHT_MESSAGE_H_ #include "DHTMessage.h" + +#include + #include "DHTNode.h" #include "Peer.h" -#include namespace aria2 { @@ -30,7 +32,7 @@ public: virtual void doReceivedAction() {} - virtual void send() {} + virtual bool send() { return true; } virtual bool isReply() const { return _isReply; }