More code cleanups

pull/119/head
Nils Maier 2013-08-21 06:06:10 +02:00
parent 8526ceeb45
commit cf6f58ceec
9 changed files with 311 additions and 322 deletions

View File

@ -37,6 +37,7 @@
#include <CommonCrypto/CommonDigest.h> #include <CommonCrypto/CommonDigest.h>
#include "array_fun.h" #include "array_fun.h"
#include "a2functional.h"
#include "HashFuncEntry.h" #include "HashFuncEntry.h"
namespace aria2 { namespace aria2 {
@ -117,29 +118,31 @@ std::unique_ptr<MessageDigestImpl> MessageDigestImpl::create
(const std::string& hashType) (const std::string& hashType)
{ {
if (hashType == "sha-1") { if (hashType == "sha-1") {
return std::unique_ptr<MessageDigestImpl>(new MessageDigestSHA1()); return make_unique<MessageDigestSHA1>();
} }
if (hashType == "sha-224") { if (hashType == "sha-224") {
return std::unique_ptr<MessageDigestImpl>(new MessageDigestSHA224()); return make_unique<MessageDigestSHA224>();
} }
if (hashType == "sha-256") { if (hashType == "sha-256") {
return std::unique_ptr<MessageDigestImpl>(new MessageDigestSHA256()); return make_unique<MessageDigestSHA256>();
} }
if (hashType == "sha-384") { if (hashType == "sha-384") {
return std::unique_ptr<MessageDigestImpl>(new MessageDigestSHA384()); return make_unique<MessageDigestSHA384>();
} }
if (hashType == "sha-512") { if (hashType == "sha-512") {
return std::unique_ptr<MessageDigestImpl>(new MessageDigestSHA512()); return make_unique<MessageDigestSHA512>();
} }
if (hashType == "md5") { if (hashType == "md5") {
return std::unique_ptr<MessageDigestImpl>(new MessageDigestMD5()); return make_unique<MessageDigestMD5>();
} }
return nullptr; return nullptr;
} }
bool MessageDigestImpl::supports(const std::string& hashType) bool MessageDigestImpl::supports(const std::string& hashType)
{ {
return hashType == "sha-1" || hashType == "sha-224" || hashType == "sha-256" || hashType == "sha-384" || hashType == "sha-512" || hashType == "md5"; return hashType == "sha-1" || hashType == "sha-224" ||
hashType == "sha-256" || hashType == "sha-384" ||
hashType == "sha-512" || hashType == "md5";
} }
size_t MessageDigestImpl::getDigestLength(const std::string& hashType) size_t MessageDigestImpl::getDigestLength(const std::string& hashType)

View File

@ -58,11 +58,11 @@ namespace {
}; };
#endif // defined(__MAC_10_7) #endif // defined(__MAC_10_7)
class cfrelease { class CFReleaser {
const void *ptr_; const void *ptr_;
public: public:
inline cfrelease(const void *ptr) : ptr_(ptr) {} inline CFReleaser(const void *ptr) : ptr_(ptr) {}
inline ~cfrelease() { if (ptr_) CFRelease(ptr_); } inline ~CFReleaser() { if (ptr_) CFRelease(ptr_); }
}; };
static inline bool isWhitespace(char c) static inline bool isWhitespace(char c)
@ -123,13 +123,13 @@ namespace {
A2_LOG_ERROR("Failed to get a certref!"); A2_LOG_ERROR("Failed to get a certref!");
return false; return false;
} }
cfrelease del_ref(ref); CFReleaser del_ref(ref);
CFDataRef data = SecCertificateCopyData(ref); CFDataRef data = SecCertificateCopyData(ref);
if (!data) { if (!data) {
A2_LOG_ERROR("Failed to get a data!"); A2_LOG_ERROR("Failed to get a data!");
return false; return false;
} }
cfrelease del_data(data); CFReleaser del_data(data);
// Do try all supported hash algorithms. // Do try all supported hash algorithms.
// Usually the fingerprint would be sha1 or md5, however this is more // Usually the fingerprint would be sha1 or md5, however this is more
@ -183,12 +183,12 @@ SecIdentityRef AppleTLSContext::getCredentials()
bool AppleTLSContext::tryAsFingerprint(const std::string& fingerprint) bool AppleTLSContext::tryAsFingerprint(const std::string& fingerprint)
{ {
std::string fp = stripWhitespace(fingerprint); auto fp = stripWhitespace(fingerprint);
// Verify this is a valid hex representation and normalize. // Verify this is a valid hex representation and normalize.
fp = util::toHex(util::fromHex(fp.begin(), fp.end())); fp = util::toHex(util::fromHex(fp.begin(), fp.end()));
// Verify this can represent a hash // Verify this can represent a hash
std::vector<std::string> ht = MessageDigest::getSupportedHashTypes(); auto ht = MessageDigest::getSupportedHashTypes();
if (std::find_if(ht.begin(), ht.end(), hash_validator(fp)) == ht.end()) { if (std::find_if(ht.begin(), ht.end(), hash_validator(fp)) == ht.end()) {
A2_LOG_INFO(fmt("%s is not a fingerprint, invalid hash representation", fingerprint.c_str())); A2_LOG_INFO(fmt("%s is not a fingerprint, invalid hash representation", fingerprint.c_str()));
return false; return false;
@ -198,25 +198,25 @@ bool AppleTLSContext::tryAsFingerprint(const std::string& fingerprint)
A2_LOG_DEBUG(fmt("Looking for cert with fingerprint %s", fp.c_str())); A2_LOG_DEBUG(fmt("Looking for cert with fingerprint %s", fp.c_str()));
// Build and run the KeyChain the query. // Build and run the KeyChain the query.
SecPolicyRef policy = SecPolicyCreateSSL(true, nullptr); auto policy = SecPolicyCreateSSL(true, nullptr);
if (!policy) { if (!policy) {
A2_LOG_ERROR("Failed to create SecPolicy"); A2_LOG_ERROR("Failed to create SecPolicy");
return false; return false;
} }
cfrelease del_policy(policy); CFReleaser del_policy(policy);
const void *query_values[] = { const void *query_values[] = {
kSecClassIdentity, kSecClassIdentity,
kCFBooleanTrue, kCFBooleanTrue,
policy, policy,
kSecMatchLimitAll kSecMatchLimitAll
}; };
CFDictionaryRef query = CFDictionaryCreate(nullptr, query_keys, query_values, auto query = CFDictionaryCreate(nullptr, query_keys, query_values, 4,
4, nullptr, nullptr); nullptr, nullptr);
if (!query) { if (!query) {
A2_LOG_ERROR("Failed to create identity query"); A2_LOG_ERROR("Failed to create identity query");
return false; return false;
} }
cfrelease del_query(query); CFReleaser del_query(query);
CFArrayRef identities; CFArrayRef identities;
OSStatus err = SecItemCopyMatching(query, (CFTypeRef*)&identities); OSStatus err = SecItemCopyMatching(query, (CFTypeRef*)&identities);
if (err != errSecSuccess) { if (err != errSecSuccess) {
@ -254,7 +254,7 @@ bool AppleTLSContext::tryAsFingerprint(const std::string& fingerprint)
if (err != errSecSuccess) { if (err != errSecSuccess) {
A2_LOG_ERROR("Certificate search failed: " + errToString(err)); A2_LOG_ERROR("Certificate search failed: " + errToString(err));
} }
cfrelease del_search(search); CFReleaser del_search(search);
SecIdentityRef id; SecIdentityRef id;
while (SecIdentitySearchCopyNext(search, &id) == errSecSuccess) { while (SecIdentitySearchCopyNext(search, &id) == errSecSuccess) {

View File

@ -326,8 +326,8 @@ AppleTLSSession::AppleTLSSession(AppleTLSContext* ctx)
state_ = st_error; state_ = st_error;
return; return;
} }
for (SSLCipherSuiteList::iterator i = enabled.begin(), e = enabled.end(); i != e; ++i) { for (const auto& suite: enabled) {
A2_LOG_INFO(fmt("AppleTLS: Enabled suite %s", suiteToString(*i))); A2_LOG_INFO(fmt("AppleTLS: Enabled suite %s", suiteToString(suite)));
} }
if (SSLSetEnabledCiphers(sslCtx_, &enabled[0], enabled.size()) != noErr) { if (SSLSetEnabledCiphers(sslCtx_, &enabled[0], enabled.size()) != noErr) {
A2_LOG_ERROR("AppleTLS: Failed to set enabled ciphers list"); A2_LOG_ERROR("AppleTLS: Failed to set enabled ciphers list");
@ -336,7 +336,11 @@ AppleTLSSession::AppleTLSSession(AppleTLSContext* ctx)
} }
#endif #endif
if (ctx->getSide() == TLS_SERVER) { if (ctx->getSide() != TLS_SERVER) {
// Done with client-only initialization
return;
}
SecIdentityRef creds = ctx->getCredentials(); SecIdentityRef creds = ctx->getCredentials();
if (!creds) { if (!creds) {
A2_LOG_ERROR("AppleTLS: No credentials"); A2_LOG_ERROR("AppleTLS: No credentials");
@ -365,8 +369,6 @@ AppleTLSSession::AppleTLSSession(AppleTLSContext* ctx)
// it will take longer. // it will take longer.
} }
#endif // CIPHER_NO_DHPARAM #endif // CIPHER_NO_DHPARAM
}
} }
AppleTLSSession::~AppleTLSSession() AppleTLSSession::~AppleTLSSession()

View File

@ -124,27 +124,28 @@ void executeCommand(std::deque<std::unique_ptr<Command>>& commands,
for(size_t i = 0; i < max; ++i) { for(size_t i = 0; i < max; ++i) {
auto com = std::move(commands.front()); auto com = std::move(commands.front());
commands.pop_front(); commands.pop_front();
if(com->statusMatch(statusFilter)) { if (!com->statusMatch(statusFilter)) {
com->transitStatus();
if(com->execute()) {
com.reset();
} else {
com->clearIOEvents();
com.release();
}
} else {
com->clearIOEvents(); com->clearIOEvents();
commands.push_back(std::move(com)); commands.push_back(std::move(com));
continue;
}
com->transitStatus();
if (com->execute()) {
com.reset();
}
else {
com->clearIOEvents();
com.release();
} }
} }
} }
} // namespace } // namespace
namespace { namespace {
class GHR { class GlobalHaltRequestedFinalizer {
public: public:
GHR() {} GlobalHaltRequestedFinalizer() {}
~GHR() ~GlobalHaltRequestedFinalizer()
{ {
global::globalHaltRequested = 5; global::globalHaltRequested = 5;
} }
@ -153,7 +154,7 @@ public:
int DownloadEngine::run(bool oneshot) int DownloadEngine::run(bool oneshot)
{ {
GHR ghr; GlobalHaltRequestedFinalizer ghrf;
while(!commands_.empty() || !routineCommands_.empty()) { while(!commands_.empty() || !routineCommands_.empty()) {
if(!commands_.empty()) { if(!commands_.empty()) {
waitData(); waitData();
@ -243,12 +244,16 @@ void DownloadEngine::afterEachIteration()
global::globalHaltRequested = 2; global::globalHaltRequested = 2;
setNoWait(true); setNoWait(true);
setRefreshInterval(0); setRefreshInterval(0);
} else if(global::globalHaltRequested == 3) { return;
}
if(global::globalHaltRequested == 3) {
A2_LOG_NOTICE(_("Emergency shutdown sequence commencing...")); A2_LOG_NOTICE(_("Emergency shutdown sequence commencing..."));
requestForceHalt(); requestForceHalt();
global::globalHaltRequested = 4; global::globalHaltRequested = 4;
setNoWait(true); setNoWait(true);
setRefreshInterval(0); setRefreshInterval(0);
return;
} }
} }
@ -300,7 +305,9 @@ void DownloadEngine::poolSocket(const std::string& key,
std::multimap<std::string, SocketPoolEntry>::value_type p(key, entry); std::multimap<std::string, SocketPoolEntry>::value_type p(key, entry);
socketPool_.insert(p); socketPool_.insert(p);
if(lastSocketPoolScan_.difference(global::wallclock()) >= 60) { if(lastSocketPoolScan_.difference(global::wallclock()) < 60) {
return;
}
std::multimap<std::string, SocketPoolEntry> newPool; std::multimap<std::string, SocketPoolEntry> newPool;
A2_LOG_DEBUG("Scaning SocketPool and erasing timed out entry."); A2_LOG_DEBUG("Scaning SocketPool and erasing timed out entry.");
lastSocketPoolScan_ = global::wallclock(); lastSocketPoolScan_ = global::wallclock();
@ -313,7 +320,6 @@ void DownloadEngine::poolSocket(const std::string& key,
static_cast<unsigned long> static_cast<unsigned long>
(socketPool_.size()-newPool.size()))); (socketPool_.size()-newPool.size())));
socketPool_ = newPool; socketPool_ = newPool;
}
} }
namespace { namespace {
@ -382,17 +388,18 @@ void DownloadEngine::poolSocket(const std::shared_ptr<Request>& request,
const std::shared_ptr<SocketCore>& socket, const std::shared_ptr<SocketCore>& socket,
time_t timeout) time_t timeout)
{ {
if(!proxyRequest) { if(proxyRequest) {
std::pair<std::string, uint16_t> peerInfo;
if(getPeerInfo(peerInfo, socket)) {
poolSocket(peerInfo.first, peerInfo.second,
A2STR::NIL, 0, socket, timeout);
}
} else {
// If proxy is defined, then pool socket with its hostname. // If proxy is defined, then pool socket with its hostname.
poolSocket(request->getHost(), request->getPort(), poolSocket(request->getHost(), request->getPort(),
proxyRequest->getHost(), proxyRequest->getPort(), proxyRequest->getHost(), proxyRequest->getPort(),
socket, timeout); socket, timeout);
return;
}
std::pair<std::string, uint16_t> peerInfo;
if(getPeerInfo(peerInfo, socket)) {
poolSocket(peerInfo.first, peerInfo.second,
A2STR::NIL, 0, socket, timeout);
} }
} }
@ -404,17 +411,18 @@ void DownloadEngine::poolSocket
const std::string& options, const std::string& options,
time_t timeout) time_t timeout)
{ {
if(!proxyRequest) { if(proxyRequest) {
std::pair<std::string, uint16_t> peerInfo;
if(getPeerInfo(peerInfo, socket)) {
poolSocket(peerInfo.first, peerInfo.second, username,
A2STR::NIL, 0, socket, options, timeout);
}
} else {
// If proxy is defined, then pool socket with its hostname. // If proxy is defined, then pool socket with its hostname.
poolSocket(request->getHost(), request->getPort(), username, poolSocket(request->getHost(), request->getPort(), username,
proxyRequest->getHost(), proxyRequest->getPort(), proxyRequest->getHost(), proxyRequest->getPort(),
socket, options, timeout); socket, options, timeout);
return;
}
std::pair<std::string, uint16_t> peerInfo;
if(getPeerInfo(peerInfo, socket)) {
poolSocket(peerInfo.first, peerInfo.second, username,
A2STR::NIL, 0, socket, options, timeout);
} }
} }

View File

@ -91,46 +91,42 @@ std::unique_ptr<EventPoll> createEventPoll(Option* op)
#ifdef HAVE_LIBUV #ifdef HAVE_LIBUV
if (pollMethod == V_LIBUV) { if (pollMethod == V_LIBUV) {
auto ep = make_unique<LibuvEventPoll>(); auto ep = make_unique<LibuvEventPoll>();
if(ep->good()) { if(!ep->good()) {
return std::move(ep);
} else {
throw DL_ABORT_EX("Initializing LibuvEventPoll failed." throw DL_ABORT_EX("Initializing LibuvEventPoll failed."
" Try --event-poll=select"); " Try --event-poll=select");
} }
return std::move(ep);
} }
else else
#endif // HAVE_LIBUV #endif // HAVE_LIBUV
#ifdef HAVE_EPOLL #ifdef HAVE_EPOLL
if(pollMethod == V_EPOLL) { if(pollMethod == V_EPOLL) {
auto ep = make_unique<EpollEventPoll>(); auto ep = make_unique<EpollEventPoll>();
if(ep->good()) { if(!ep->good()) {
return std::move(ep);
} else {
throw DL_ABORT_EX("Initializing EpollEventPoll failed." throw DL_ABORT_EX("Initializing EpollEventPoll failed."
" Try --event-poll=select"); " Try --event-poll=select");
} }
return std::move(ep);
} else } else
#endif // HAVE_EPLL #endif // HAVE_EPLL
#ifdef HAVE_KQUEUE #ifdef HAVE_KQUEUE
if(pollMethod == V_KQUEUE) { if(pollMethod == V_KQUEUE) {
auto kp = make_unique<KqueueEventPoll>(); auto kp = make_unique<KqueueEventPoll>();
if(kp->good()) { if(!kp->good()) {
return std::move(kp);
} else {
throw DL_ABORT_EX("Initializing KqueueEventPoll failed." throw DL_ABORT_EX("Initializing KqueueEventPoll failed."
" Try --event-poll=select"); " Try --event-poll=select");
} }
return std::move(kp);
} else } else
#endif // HAVE_KQUEUE #endif // HAVE_KQUEUE
#ifdef HAVE_PORT_ASSOCIATE #ifdef HAVE_PORT_ASSOCIATE
if(pollMethod == V_PORT) { if(pollMethod == V_PORT) {
auto pp = make_unique<PortEventPoll>(); auto pp = make_unique<PortEventPoll>();
if(pp->good()) { if(!pp->good()) {
return std::move(pp);
} else {
throw DL_ABORT_EX("Initializing PortEventPoll failed." throw DL_ABORT_EX("Initializing PortEventPoll failed."
" Try --event-poll=select"); " Try --event-poll=select");
} }
return std::move(pp);
} else } else
#endif // HAVE_PORT_ASSOCIATE #endif // HAVE_PORT_ASSOCIATE
#ifdef HAVE_POLL #ifdef HAVE_POLL
@ -140,9 +136,9 @@ std::unique_ptr<EventPoll> createEventPoll(Option* op)
#endif // HAVE_POLL #endif // HAVE_POLL
if(pollMethod == V_SELECT) { if(pollMethod == V_SELECT) {
return make_unique<SelectEventPoll>(); return make_unique<SelectEventPoll>();
} else {
assert(0);
} }
assert(0);
return nullptr;
} }
} // namespace } // namespace

View File

@ -107,8 +107,8 @@ LibuvEventPoll::LibuvEventPoll()
LibuvEventPoll::~LibuvEventPoll() LibuvEventPoll::~LibuvEventPoll()
{ {
for (KPolls::iterator i = polls_.begin(), e = polls_.end(); i != e; ++i) { for (auto& p: polls_) {
i->second->close(); p.second->close();
} }
// Actually kill the polls, and timers, if any. // Actually kill the polls, and timers, if any.
uv_run(loop_, (uv_run_mode)(UV_RUN_ONCE | UV_RUN_NOWAIT)); uv_run(loop_, (uv_run_mode)(UV_RUN_ONCE | UV_RUN_NOWAIT));
@ -256,8 +256,8 @@ bool LibuvEventPoll::addEvents(sock_t socket, Command* command, int events,
bool LibuvEventPoll::deleteEvents(sock_t socket, bool LibuvEventPoll::deleteEvents(sock_t socket,
const LibuvEventPoll::KEvent& event) const LibuvEventPoll::KEvent& event)
{ {
std::shared_ptr<KSocketEntry> socketEntry(new KSocketEntry(socket)); auto socketEntry = std::make_shared<KSocketEntry>(socket);
KSocketEntrySet::iterator i = socketEntries_.find(socketEntry); auto i = socketEntries_.find(socketEntry);
if (i == socketEntries_.end()) { if (i == socketEntries_.end()) {
A2_LOG_DEBUG(fmt("Socket %d is not found in SocketEntries.", socket)); A2_LOG_DEBUG(fmt("Socket %d is not found in SocketEntries.", socket));
@ -266,7 +266,7 @@ bool LibuvEventPoll::deleteEvents(sock_t socket,
event.removeSelf(*i); event.removeSelf(*i);
KPolls::iterator poll = polls_.find(socket); auto poll = polls_.find(socket);
if (poll == polls_.end()) { if (poll == polls_.end()) {
return false; return false;
} }

View File

@ -76,7 +76,7 @@ private:
uv_poll_t handle_; uv_poll_t handle_;
static void poll_callback(uv_poll_t* handle, int status, int events) { static void poll_callback(uv_poll_t* handle, int status, int events) {
KPoll* poll = static_cast<KPoll*>(handle->data); auto poll = static_cast<KPoll*>(handle->data);
poll->eventer_->pollCallback(poll, status, events); poll->eventer_->pollCallback(poll, status, events);
} }
static void close_callback(uv_handle_t* handle) { static void close_callback(uv_handle_t* handle) {

View File

@ -134,14 +134,12 @@ std::unique_ptr<StatCalc> getStatCalc(const std::shared_ptr<Option>& op)
{ {
if(op->getAsBool(PREF_QUIET)) { if(op->getAsBool(PREF_QUIET)) {
return make_unique<NullStatCalc>(); return make_unique<NullStatCalc>();
} else { }
auto impl = make_unique<ConsoleStatCalc> auto impl = make_unique<ConsoleStatCalc>(op->getAsInt(PREF_SUMMARY_INTERVAL),
(op->getAsInt(PREF_SUMMARY_INTERVAL),
op->getAsBool(PREF_HUMAN_READABLE)); op->getAsBool(PREF_HUMAN_READABLE));
impl->setReadoutVisibility(op->getAsBool(PREF_SHOW_CONSOLE_READOUT)); impl->setReadoutVisibility(op->getAsBool(PREF_SHOW_CONSOLE_READOUT));
impl->setTruncate(op->getAsBool(PREF_TRUNCATE_CONSOLE_READOUT)); impl->setTruncate(op->getAsBool(PREF_TRUNCATE_CONSOLE_READOUT));
return std::move(impl); return std::move(impl);
}
} }
} // namespace } // namespace
@ -183,21 +181,20 @@ int MultiUrlRequestInfo::prepare()
#ifdef ENABLE_SSL #ifdef ENABLE_SSL
if(option_->getAsBool(PREF_ENABLE_RPC) && if(option_->getAsBool(PREF_ENABLE_RPC) &&
option_->getAsBool(PREF_RPC_SECURE)) { option_->getAsBool(PREF_RPC_SECURE)) {
if(!option_->blank(PREF_RPC_CERTIFICATE) if(option_->blank(PREF_RPC_CERTIFICATE)
#ifndef HAVE_APPLETLS #ifndef HAVE_APPLETLS
&& !option_->blank(PREF_RPC_PRIVATE_KEY) || option_->blank(PREF_RPC_PRIVATE_KEY)
#endif // HAVE_APPLETLS #endif // HAVE_APPLETLS
) { ) {
throw DL_ABORT_EX("Specify --rpc-certificate and --rpc-private-key "
"options in order to use secure RPC.");
}
// We set server TLS context to the SocketCore before creating // We set server TLS context to the SocketCore before creating
// DownloadEngine instance. // DownloadEngine instance.
std::shared_ptr<TLSContext> svTlsContext(TLSContext::make(TLS_SERVER)); std::shared_ptr<TLSContext> svTlsContext(TLSContext::make(TLS_SERVER));
svTlsContext->addCredentialFile(option_->get(PREF_RPC_CERTIFICATE), svTlsContext->addCredentialFile(option_->get(PREF_RPC_CERTIFICATE),
option_->get(PREF_RPC_PRIVATE_KEY)); option_->get(PREF_RPC_PRIVATE_KEY));
SocketCore::setServerTLSContext(svTlsContext); SocketCore::setServerTLSContext(svTlsContext);
} else {
throw DL_ABORT_EX("Specify --rpc-certificate and --rpc-private-key "
"options in order to use secure RPC.");
}
} }
#endif // ENABLE_SSL #endif // ENABLE_SSL

View File

@ -270,9 +270,8 @@ void SocketCore::bindWithFamily(uint16_t port, int family, int flags)
sock_t fd = bindTo(nullptr, port, family, sockType_, flags, error); sock_t fd = bindTo(nullptr, port, family, sockType_, flags, error);
if(fd == (sock_t) -1) { if(fd == (sock_t) -1) {
throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str())); throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str()));
} else {
sockfd_ = fd;
} }
sockfd_ = fd;
} }
void SocketCore::bind void SocketCore::bind
@ -288,16 +287,17 @@ void SocketCore::bind
} }
if(!(flags&AI_PASSIVE) || bindAddrs_.empty()) { if(!(flags&AI_PASSIVE) || bindAddrs_.empty()) {
sock_t fd = bindTo(addrp, port, family, sockType_, flags, error); sock_t fd = bindTo(addrp, port, family, sockType_, flags, error);
if(fd != (sock_t) -1) { if(fd == (sock_t) -1) {
sockfd_ = fd; throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str()));
} }
} else { sockfd_ = fd;
for(std::vector<std::pair<sockaddr_union, socklen_t> >:: return;
const_iterator i = bindAddrs_.begin(), eoi = bindAddrs_.end(); }
i != eoi; ++i) {
for (const auto& a : bindAddrs_) {
char host[NI_MAXHOST]; char host[NI_MAXHOST];
int s; int s;
s = getnameinfo(&(*i).first.sa, (*i).second, host, NI_MAXHOST, nullptr, 0, s = getnameinfo(&a.first.sa, a.second, host, NI_MAXHOST, nullptr, 0,
NI_NUMERICHOST); NI_NUMERICHOST);
if(s) { if(s) {
error = gai_strerror(s); error = gai_strerror(s);
@ -313,7 +313,6 @@ void SocketCore::bind
break; break;
} }
} }
}
if(sockfd_ == (sock_t) -1) { if(sockfd_ == (sock_t) -1) {
throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str())); throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str()));
} }
@ -329,11 +328,10 @@ void SocketCore::bind(const struct sockaddr* addr, socklen_t addrlen)
closeConnection(); closeConnection();
std::string error; std::string error;
sock_t fd = bindInternal(addr->sa_family, sockType_, 0, addr, addrlen, error); sock_t fd = bindInternal(addr->sa_family, sockType_, 0, addr, addrlen, error);
if(fd != (sock_t)-1) { if(fd == (sock_t)-1) {
sockfd_ = fd;
} else {
throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str())); throw DL_ABORT_EX(fmt(EX_SOCKET_BIND, error.c_str()));
} }
sockfd_ = fd;
} }
void SocketCore::beginListen() void SocketCore::beginListen()
@ -486,12 +484,11 @@ void SocketCore::setMulticastInterface(const std::string& localAddr)
in_addr addr; in_addr addr;
if(localAddr.empty()) { if(localAddr.empty()) {
addr.s_addr = htonl(INADDR_ANY); addr.s_addr = htonl(INADDR_ANY);
} else { }
if(inetPton(AF_INET, localAddr.c_str(), &addr) != 0) { else if(inetPton(AF_INET, localAddr.c_str(), &addr) != 0) {
throw DL_ABORT_EX(fmt("%s is not valid IPv4 numeric address", throw DL_ABORT_EX(fmt("%s is not valid IPv4 numeric address",
localAddr.c_str())); localAddr.c_str()));
} }
}
setSockOpt(IPPROTO_IP, IP_MULTICAST_IF, &addr, sizeof(addr)); setSockOpt(IPPROTO_IP, IP_MULTICAST_IF, &addr, sizeof(addr));
} }
@ -517,12 +514,11 @@ void SocketCore::joinMulticastGroup
in_addr ifAddr; in_addr ifAddr;
if(localAddr.empty()) { if(localAddr.empty()) {
ifAddr.s_addr = htonl(INADDR_ANY); ifAddr.s_addr = htonl(INADDR_ANY);
} else { }
if(inetPton(AF_INET, localAddr.c_str(), &ifAddr) != 0) { else if(inetPton(AF_INET, localAddr.c_str(), &ifAddr) != 0) {
throw DL_ABORT_EX(fmt("%s is not valid IPv4 numeric address", throw DL_ABORT_EX(fmt("%s is not valid IPv4 numeric address",
localAddr.c_str())); localAddr.c_str()));
} }
}
struct ip_mreq mreq; struct ip_mreq mreq;
memset(&mreq, 0, sizeof(mreq)); memset(&mreq, 0, sizeof(mreq));
mreq.imr_multiaddr = multiAddr; mreq.imr_multiaddr = multiAddr;
@ -605,11 +601,11 @@ bool SocketCore::isWritable(time_t timeout)
int errNum = SOCKET_ERRNO; int errNum = SOCKET_ERRNO;
if(r > 0) { if(r > 0) {
return p.revents&(POLLOUT|POLLHUP|POLLERR); return p.revents&(POLLOUT|POLLHUP|POLLERR);
} else if(r == 0) {
return false;
} else {
throw DL_RETRY_EX(fmt(EX_SOCKET_CHECK_WRITABLE, errorMsg(errNum).c_str()));
} }
if(r == 0) {
return false;
}
throw DL_RETRY_EX(fmt(EX_SOCKET_CHECK_WRITABLE, errorMsg(errNum).c_str()));
#else // !HAVE_POLL #else // !HAVE_POLL
# ifndef __MINGW32__ # ifndef __MINGW32__
CHECK_FD(sockfd_); CHECK_FD(sockfd_);
@ -626,17 +622,15 @@ bool SocketCore::isWritable(time_t timeout)
int errNum = SOCKET_ERRNO; int errNum = SOCKET_ERRNO;
if(r == 1) { if(r == 1) {
return true; return true;
} else if(r == 0) { }
if(r == 0) {
// time out // time out
return false; return false;
} else { }
if(errNum == A2_EINPROGRESS || errNum == A2_EINTR) { if(errNum == A2_EINPROGRESS || errNum == A2_EINTR) {
return false; return false;
} else {
throw DL_RETRY_EX
(fmt(EX_SOCKET_CHECK_WRITABLE, errorMsg(errNum).c_str()));
}
} }
throw DL_RETRY_EX(fmt(EX_SOCKET_CHECK_WRITABLE, errorMsg(errNum).c_str()));
#endif // !HAVE_POLL #endif // !HAVE_POLL
} }
@ -651,11 +645,11 @@ bool SocketCore::isReadable(time_t timeout)
int errNum = SOCKET_ERRNO; int errNum = SOCKET_ERRNO;
if(r > 0) { if(r > 0) {
return p.revents&(POLLIN|POLLHUP|POLLERR); return p.revents&(POLLIN|POLLHUP|POLLERR);
} else if(r == 0) {
return false;
} else {
throw DL_RETRY_EX(fmt(EX_SOCKET_CHECK_READABLE, errorMsg(errNum).c_str()));
} }
if(r == 0) {
return false;
}
throw DL_RETRY_EX(fmt(EX_SOCKET_CHECK_READABLE, errorMsg(errNum).c_str()));
#else // !HAVE_POLL #else // !HAVE_POLL
# ifndef __MINGW32__ # ifndef __MINGW32__
CHECK_FD(sockfd_); CHECK_FD(sockfd_);
@ -672,17 +666,15 @@ bool SocketCore::isReadable(time_t timeout)
int errNum = SOCKET_ERRNO; int errNum = SOCKET_ERRNO;
if(r == 1) { if(r == 1) {
return true; return true;
} else if(r == 0) { }
if(r == 0) {
// time out // time out
return false; return false;
} else { }
if(errNum == A2_EINPROGRESS || errNum == A2_EINTR) { if(errNum == A2_EINPROGRESS || errNum == A2_EINTR) {
return false; return false;
} else {
throw DL_RETRY_EX
(fmt(EX_SOCKET_CHECK_READABLE, errorMsg(errNum).c_str()));
}
} }
throw DL_RETRY_EX(fmt(EX_SOCKET_CHECK_READABLE, errorMsg(errNum).c_str()));
#endif // !HAVE_POLL #endif // !HAVE_POLL
} }
@ -706,12 +698,11 @@ ssize_t SocketCore::writeVector(a2iovec *iov, size_t iovcnt)
#endif // !__MINGW32__ #endif // !__MINGW32__
int errNum = SOCKET_ERRNO; int errNum = SOCKET_ERRNO;
if(ret == -1) { if(ret == -1) {
if(A2_WOULDBLOCK(errNum)) { if(!A2_WOULDBLOCK(errNum)) {
wantWrite_ = true;
ret = 0;
} else {
throw DL_RETRY_EX(fmt(EX_SOCKET_SEND, errorMsg(errNum).c_str())); throw DL_RETRY_EX(fmt(EX_SOCKET_SEND, errorMsg(errNum).c_str()));
} }
wantWrite_ = true;
ret = 0;
} }
} else { } else {
// For SSL/TLS, we could not use writev, so just iterate vector // For SSL/TLS, we could not use writev, so just iterate vector
@ -739,28 +730,26 @@ ssize_t SocketCore::writeData(const void* data, size_t len)
len, 0)) == -1 && SOCKET_ERRNO == A2_EINTR); len, 0)) == -1 && SOCKET_ERRNO == A2_EINTR);
int errNum = SOCKET_ERRNO; int errNum = SOCKET_ERRNO;
if(ret == -1) { if(ret == -1) {
if(A2_WOULDBLOCK(errNum)) { if(!A2_WOULDBLOCK(errNum)) {
wantWrite_ = true;
ret = 0;
} else {
throw DL_RETRY_EX(fmt(EX_SOCKET_SEND, errorMsg(errNum).c_str())); throw DL_RETRY_EX(fmt(EX_SOCKET_SEND, errorMsg(errNum).c_str()));
} }
wantWrite_ = true;
ret = 0;
} }
} else { } else {
#ifdef ENABLE_SSL #ifdef ENABLE_SSL
ret = tlsSession_->writeData(data, len); ret = tlsSession_->writeData(data, len);
if(ret < 0) { if(ret < 0) {
if(ret == TLS_ERR_WOULDBLOCK) { if(ret != TLS_ERR_WOULDBLOCK) {
throw DL_RETRY_EX(fmt(EX_SOCKET_SEND,
tlsSession_->getLastErrorString().c_str()));
}
if(tlsSession_->checkDirection() == TLS_WANT_READ) { if(tlsSession_->checkDirection() == TLS_WANT_READ) {
wantRead_ = true; wantRead_ = true;
} else { } else {
wantWrite_ = true; wantWrite_ = true;
} }
ret = 0; ret = 0;
} else {
throw DL_RETRY_EX(fmt(EX_SOCKET_SEND,
tlsSession_->getLastErrorString().c_str()));
}
} }
#endif // ENABLE_SSL #endif // ENABLE_SSL
} }
@ -779,28 +768,26 @@ void SocketCore::readData(void* data, size_t& len)
SOCKET_ERRNO == A2_EINTR); SOCKET_ERRNO == A2_EINTR);
int errNum = SOCKET_ERRNO; int errNum = SOCKET_ERRNO;
if(ret == -1) { if(ret == -1) {
if(A2_WOULDBLOCK(errNum)) { if(!A2_WOULDBLOCK(errNum)) {
wantRead_ = true;
ret = 0;
} else {
throw DL_RETRY_EX(fmt(EX_SOCKET_RECV, errorMsg(errNum).c_str())); throw DL_RETRY_EX(fmt(EX_SOCKET_RECV, errorMsg(errNum).c_str()));
} }
wantRead_ = true;
ret = 0;
} }
} else { } else {
#ifdef ENABLE_SSL #ifdef ENABLE_SSL
ret = tlsSession_->readData(data, len); ret = tlsSession_->readData(data, len);
if(ret < 0) { if(ret < 0) {
if(ret == TLS_ERR_WOULDBLOCK) { if(ret != TLS_ERR_WOULDBLOCK) {
throw DL_RETRY_EX(fmt(EX_SOCKET_SEND,
tlsSession_->getLastErrorString().c_str()));
}
if(tlsSession_->checkDirection() == TLS_WANT_READ) { if(tlsSession_->checkDirection() == TLS_WANT_READ) {
wantRead_ = true; wantRead_ = true;
} else { } else {
wantWrite_ = true; wantWrite_ = true;
} }
ret = 0; ret = 0;
} else {
throw DL_RETRY_EX(fmt(EX_SOCKET_SEND,
tlsSession_->getLastErrorString().c_str()));
}
} }
#endif // ENABLE_SSL #endif // ENABLE_SSL
} }
@ -855,20 +842,20 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
} }
if(rv == TLS_ERR_OK) { if(rv == TLS_ERR_OK) {
secure_ = A2_TLS_CONNECTED; secure_ = A2_TLS_CONNECTED;
} else if(rv == TLS_ERR_WOULDBLOCK) { break;
}
if(rv != TLS_ERR_WOULDBLOCK) {
throw DL_ABORT_EX(fmt("SSL/TLS handshake failure: %s",
handshakeError.empty() ?
tlsSession_->getLastErrorString().c_str() :
handshakeError.c_str()));
}
if(tlsSession_->checkDirection() == TLS_WANT_READ) { if(tlsSession_->checkDirection() == TLS_WANT_READ) {
wantRead_ = true; wantRead_ = true;
} else { } else {
wantWrite_ = true; wantWrite_ = true;
} }
return false; return false;
} else {
throw DL_ABORT_EX(fmt("SSL/TLS handshake failure: %s",
handshakeError.empty() ?
tlsSession_->getLastErrorString().c_str() :
handshakeError.c_str()));
}
break;
default: default:
break; break;
} }
@ -931,12 +918,11 @@ ssize_t SocketCore::readDataFrom(void* data, size_t len,
&& A2_EINTR == SOCKET_ERRNO); && A2_EINTR == SOCKET_ERRNO);
int errNum = SOCKET_ERRNO; int errNum = SOCKET_ERRNO;
if(r == -1) { if(r == -1) {
if(A2_WOULDBLOCK(errNum)) { if(!A2_WOULDBLOCK(errNum)) {
wantRead_ = true;
r = 0;
} else {
throw DL_RETRY_EX(fmt(EX_SOCKET_RECV, errorMsg(errNum).c_str())); throw DL_RETRY_EX(fmt(EX_SOCKET_RECV, errorMsg(errNum).c_str()));
} }
wantRead_ = true;
r = 0;
} else { } else {
sender = util::getNumericNameInfo(&sockaddr.sa, sockaddrlen); sender = util::getNumericNameInfo(&sockaddr.sa, sockaddrlen);
} }
@ -957,9 +943,8 @@ std::string SocketCore::getSocketError() const
} }
if(error != 0) { if(error != 0) {
return errorMsg(error); return errorMsg(error);
} else {
return "";
} }
return "";
} }
bool SocketCore::wantRead() const bool SocketCore::wantRead() const
@ -977,22 +962,19 @@ void SocketCore::bindAddress(const std::string& iface)
std::vector<std::pair<sockaddr_union, socklen_t> > bindAddrs; std::vector<std::pair<sockaddr_union, socklen_t> > bindAddrs;
getInterfaceAddress(bindAddrs, iface, protocolFamily_); getInterfaceAddress(bindAddrs, iface, protocolFamily_);
if(bindAddrs.empty()) { if(bindAddrs.empty()) {
throw DL_ABORT_EX throw DL_ABORT_EX(fmt(MSG_INTERFACE_NOT_FOUND, iface.c_str(),
(fmt(MSG_INTERFACE_NOT_FOUND, iface.c_str(), "not available")); "not available"));
} else { }
bindAddrs_.swap(bindAddrs); bindAddrs_.swap(bindAddrs);
for(std::vector<std::pair<sockaddr_union, socklen_t> >:: for (const auto& a: bindAddrs_) {
const_iterator i = bindAddrs_.begin(), eoi = bindAddrs_.end();
i != eoi; ++i) {
char host[NI_MAXHOST]; char host[NI_MAXHOST];
int s; int s;
s = getnameinfo(&(*i).first.sa, (*i).second, host, NI_MAXHOST, nullptr, 0, s = getnameinfo(&a.first.sa, a.second, host, NI_MAXHOST, nullptr, 0,
NI_NUMERICHOST); NI_NUMERICHOST);
if(s == 0) { if(s == 0) {
A2_LOG_DEBUG(fmt("Sockets will bind to %s", host)); A2_LOG_DEBUG(fmt("Sockets will bind to %s", host));
} }
} }
}
} }
void getInterfaceAddress void getInterfaceAddress
@ -1101,7 +1083,6 @@ int callGetaddrinfo
int inetNtop(int af, const void* src, char* dst, socklen_t size) int inetNtop(int af, const void* src, char* dst, socklen_t size)
{ {
int s;
sockaddr_union su; sockaddr_union su;
memset(&su, 0, sizeof(su)); memset(&su, 0, sizeof(su));
if(af == AF_INET) { if(af == AF_INET) {
@ -1110,20 +1091,19 @@ int inetNtop(int af, const void* src, char* dst, socklen_t size)
su.in.sin_len = sizeof(su.in); su.in.sin_len = sizeof(su.in);
#endif // HAVE_SOCKADDR_IN_SIN_LEN #endif // HAVE_SOCKADDR_IN_SIN_LEN
memcpy(&su.in.sin_addr, src, sizeof(su.in.sin_addr)); memcpy(&su.in.sin_addr, src, sizeof(su.in.sin_addr));
s = getnameinfo(&su.sa, sizeof(su.in), return getnameinfo(&su.sa, sizeof(su.in), dst, size, nullptr, 0,
dst, size, nullptr, 0, NI_NUMERICHOST); NI_NUMERICHOST);
} else if(af == AF_INET6) { }
if(af == AF_INET6) {
su.in6.sin6_family = AF_INET6; su.in6.sin6_family = AF_INET6;
#ifdef HAVE_SOCKADDR_IN6_SIN6_LEN #ifdef HAVE_SOCKADDR_IN6_SIN6_LEN
su.in6.sin6_len = sizeof(su.in6); su.in6.sin6_len = sizeof(su.in6);
#endif // HAVE_SOCKADDR_IN6_SIN6_LEN #endif // HAVE_SOCKADDR_IN6_SIN6_LEN
memcpy(&su.in6.sin6_addr, src, sizeof(su.in6.sin6_addr)); memcpy(&su.in6.sin6_addr, src, sizeof(su.in6.sin6_addr));
s = getnameinfo(&su.sa, sizeof(su.in6), return getnameinfo(&su.sa, sizeof(su.in6), dst, size, nullptr, 0,
dst, size, nullptr, 0, NI_NUMERICHOST); NI_NUMERICHOST);
} else {
s = EAI_FAMILY;
} }
return s; return EAI_FAMILY;
} }
int inetPton(int af, const char* src, void* dst) int inetPton(int af, const char* src, void* dst)
@ -1139,16 +1119,17 @@ int inetPton(int af, const char* src, void* dst)
} }
in_addr* addr = reinterpret_cast<in_addr*>(dst); in_addr* addr = reinterpret_cast<in_addr*>(dst);
addr->s_addr = binaddr.ipv4_addr; addr->s_addr = binaddr.ipv4_addr;
} else if(af == AF_INET6) { return 0;
}
if(af == AF_INET6) {
if(len != 16) { if(len != 16) {
return -1; return -1;
} }
in6_addr* addr = reinterpret_cast<in6_addr*>(dst); in6_addr* addr = reinterpret_cast<in6_addr*>(dst);
memcpy(addr->s6_addr, binaddr.ipv6_addr, sizeof(addr->s6_addr)); memcpy(addr->s6_addr, binaddr.ipv6_addr, sizeof(addr->s6_addr));
} else {
return -1;
}
return 0; return 0;
}
return -1;
} }
namespace net { namespace net {
@ -1200,7 +1181,9 @@ bool verifyHostname(const std::string& hostname,
return true; return true;
} }
} }
} else { return false;
}
if(dnsNames.empty()) { if(dnsNames.empty()) {
return util::tlsHostnameMatch(commonName, hostname); return util::tlsHostnameMatch(commonName, hostname);
} }
@ -1209,7 +1192,6 @@ bool verifyHostname(const std::string& hostname,
return true; return true;
} }
} }
}
return false; return false;
} }
@ -1237,12 +1219,11 @@ void checkAddrconfig()
do { do {
buf = reinterpret_cast<IP_ADAPTER_ADDRESSES*>(malloc(bufsize)); buf = reinterpret_cast<IP_ADAPTER_ADDRESSES*>(malloc(bufsize));
retval = GetAdaptersAddresses(AF_UNSPEC, 0, 0, buf, &bufsize); retval = GetAdaptersAddresses(AF_UNSPEC, 0, 0, buf, &bufsize);
if(retval == ERROR_BUFFER_OVERFLOW) { if(retval != ERROR_BUFFER_OVERFLOW) {
free(buf);
buf = 0;
} else {
break; break;
} }
free(buf);
buf = 0;
} while(retval == ERROR_BUFFER_OVERFLOW && numTry < MAX_TRY); } while(retval == ERROR_BUFFER_OVERFLOW && numTry < MAX_TRY);
if(retval != NO_ERROR) { if(retval != NO_ERROR) {
A2_LOG_INFO("GetAdaptersAddresses failed. Assume both IPv4 and IPv6 " A2_LOG_INFO("GetAdaptersAddresses failed. Assume both IPv4 and IPv6 "
@ -1261,7 +1242,9 @@ void checkAddrconfig()
continue; continue;
} }
PIP_ADAPTER_UNICAST_ADDRESS ucaddr = p->FirstUnicastAddress; PIP_ADAPTER_UNICAST_ADDRESS ucaddr = p->FirstUnicastAddress;
if(ucaddr) { if(!ucaddr) {
continue;
}
for(PIP_ADAPTER_UNICAST_ADDRESS i = ucaddr; i; i = i->Next) { for(PIP_ADAPTER_UNICAST_ADDRESS i = ucaddr; i; i = i->Next) {
bool found = false; bool found = false;
switch(i->Address.iSockaddrLength) { switch(i->Address.iSockaddrLength) {
@ -1297,8 +1280,8 @@ void checkAddrconfig()
} }
} }
} }
}
free(buf); free(buf);
A2_LOG_INFO(fmt("IPv4 configured=%d, IPv6 configured=%d", A2_LOG_INFO(fmt("IPv4 configured=%d, IPv6 configured=%d",
ipv4AddrConfigured, ipv6AddrConfigured)); ipv4AddrConfigured, ipv6AddrConfigured));
#elif defined(HAVE_GETIFADDRS) #elif defined(HAVE_GETIFADDRS)