Delay auth failures instead of PBKDF2

Closes GH-256
pull/257/merge
Nils Maier 2014-07-15 08:15:46 +02:00
parent 24ae459523
commit 8f2af33b6d
21 changed files with 250 additions and 162 deletions

75
src/DelayedCommand.h Normal file
View File

@ -0,0 +1,75 @@
/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2014 Nils Maier
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#ifndef D_DELAYED_COMMAND_H
#define D_DELAYED_COMMAND_H
#include "TimeBasedCommand.h"
namespace aria2
{
class DelayedCommand : public TimeBasedCommand
{
private:
std::unique_ptr<Command> command_;
bool noWait_;
public:
virtual void process() CXX11_OVERRIDE
{
auto e = getDownloadEngine();
e->addCommand(std::move(command_));
if (noWait_) {
e->setNoWait(true);
}
enableExit();
}
public:
DelayedCommand(cuid_t cuid, DownloadEngine* e, time_t delay,
std::unique_ptr<Command> command, bool noWait)
: TimeBasedCommand(cuid, e, delay),
command_{std::move(command)},
noWait_{noWait}
{
}
virtual ~DelayedCommand() {}
};
} // namespace aria2
#endif // D_DELAYED_COMMAND_H

View File

@ -76,13 +76,6 @@
#include "Option.h" #include "Option.h"
#include "util_security.h" #include "util_security.h"
// Lower time limit for PBKDF2 operations in validateToken.
static const double kTokenTimeLower = 0.025;
// Upper time limit for PBKDF2 operations in validateToken.
static const double kTokenTimeUpper = 0.5;
// Sweet spot time for PBKDF2 operations in validateToken.
static const double kTokenTimeSweetspot = 0.035;
namespace aria2 { namespace aria2 {
namespace global { namespace global {
@ -111,9 +104,7 @@ DownloadEngine::DownloadEngine(std::unique_ptr<EventPoll> eventPoll)
asyncDNSServers_(nullptr), asyncDNSServers_(nullptr),
#endif // HAVE_ARES_ADDR_NODE #endif // HAVE_ARES_ADDR_NODE
dnsCache_(make_unique<DNSCache>()), dnsCache_(make_unique<DNSCache>()),
option_(nullptr), option_(nullptr)
tokenIterations_(5000),
tokenAverageDuration_(0.0)
{ {
unsigned char sessionId[20]; unsigned char sessionId[20];
util::generateRandomKey(sessionId); util::generateRandomKey(sessionId);
@ -650,76 +641,17 @@ bool DownloadEngine::validateToken(const std::string& token)
return true; return true;
} }
if (!tokenHMAC_ || tokenAverageDuration_ > kTokenTimeUpper || if (!tokenHMAC_) {
tokenAverageDuration_ < kTokenTimeLower) {
// Setup our stuff.
if (tokenHMAC_) {
A2_LOG_INFO(fmt("Recalculating iterations because avg. duration is %.4f",
tokenAverageDuration_));
}
tokenHMAC_ = HMAC::createRandom(); tokenHMAC_ = HMAC::createRandom();
if (!tokenHMAC_) { if (!tokenHMAC_) {
A2_LOG_ERROR("Failed to create HMAC"); A2_LOG_ERROR("Failed to create HMAC");
return false; return false;
} }
tokenExpected_ = make_unique<HMACResult>(tokenHMAC_->getResult(
// This should still be pretty fast on a modern system... Well, too fast option_->get(PREF_RPC_SECRET)));
// with the initial 5000 iterations, and that is why we adjust it.
// XXX We should run this setup high priorty, so that other processes on the
// system don't mess up our results and let us underestimate the iterations.
std::deque<double> mm;
for (auto i = 0; i < 10; ++i) {
auto c = std::clock();
tokenExpected_ = make_unique<HMACResult>
(PBKDF2(tokenHMAC_.get(), option_->get(PREF_RPC_SECRET),
tokenIterations_));
mm.push_back((std::clock() - c) / (double)CLOCKS_PER_SEC);
}
std::sort(mm.begin(), mm.end());
// Pop outliers.
mm.pop_front();
mm.pop_back();
mm.pop_back();
auto duration = std::accumulate(mm.begin(), mm.end(), 0.0) / mm.size();
A2_LOG_INFO(fmt("Took us %.4f secs on average to perform PBKDF2 with %zu "
"iterations during setup",
duration, tokenIterations_));
// Adjust iterations so that an op takes about |kTokenTimeSpeetspot| sec,
// which would allow for a couple attempts per second (instead of
// potentially thousands without PBKDF2).
// We might overestimate the performance a bit, but should not perform
// worse than |kTokenTimeUpper| secs per attempt on a normally loaded system
// and no better than |kTokenTimeLower|. If this does not hold true anymore,
// the |tokenAverageDuration_| checks will force a re-calcuation.
tokenIterations_ *= kTokenTimeSweetspot / duration;
auto c = std::clock();
tokenExpected_ = make_unique<HMACResult>
(PBKDF2(tokenHMAC_.get(), option_->get(PREF_RPC_SECRET),
tokenIterations_));
duration = (std::clock() - c) / (double)CLOCKS_PER_SEC;
A2_LOG_INFO(fmt("Took us %.4f secs to perform PBKDF2 with %zu iterations",
duration, tokenIterations_));
// Seed average duration.
tokenAverageDuration_ = duration;
} }
auto c = std::clock(); return *tokenExpected_ == tokenHMAC_->getResult(token);
bool rv = *tokenExpected_ == PBKDF2(tokenHMAC_.get(), token,
tokenIterations_);
auto duration = (std::clock() - c) / (double)CLOCKS_PER_SEC;
A2_LOG_DEBUG(fmt("Took us %.4f secs to perform token compare with %zu "
"iterations",
duration, tokenIterations_));
// Update rolling hash.
tokenAverageDuration_ = tokenAverageDuration_ * 0.9 + duration * 0.1;
return rv;
} }
} // namespace aria2 } // namespace aria2

View File

@ -183,9 +183,6 @@ private:
std::unique_ptr<util::security::HMAC> tokenHMAC_; std::unique_ptr<util::security::HMAC> tokenHMAC_;
std::unique_ptr<util::security::HMACResult> tokenExpected_; std::unique_ptr<util::security::HMACResult> tokenExpected_;
size_t tokenIterations_;
double tokenAverageDuration_;
public: public:
DownloadEngine(std::unique_ptr<EventPoll> eventPoll); DownloadEngine(std::unique_ptr<EventPoll> eventPoll);

View File

@ -43,6 +43,7 @@
#include "RequestGroupMan.h" #include "RequestGroupMan.h"
#include "RecoverableException.h" #include "RecoverableException.h"
#include "HttpServerResponseCommand.h" #include "HttpServerResponseCommand.h"
#include "DelayedCommand.h"
#include "OptionParser.h" #include "OptionParser.h"
#include "OptionHandler.h" #include "OptionHandler.h"
#include "wallclock.h" #include "wallclock.h"
@ -104,6 +105,7 @@ void HttpServerBodyCommand::sendJsonRpcResponse
(const rpc::RpcResponse& res, (const rpc::RpcResponse& res,
const std::string& callback) const std::string& callback)
{ {
bool notauthorized = rpc::not_authorized(res);
bool gzip = httpServer_->supportsGZip(); bool gzip = httpServer_->supportsGZip();
std::string responseData = rpc::toJson(res, callback, gzip); std::string responseData = rpc::toJson(res, callback, gzip);
if(res.code == 0) { if(res.code == 0) {
@ -126,24 +128,32 @@ void HttpServerBodyCommand::sendJsonRpcResponse
std::move(responseData), std::move(responseData),
getJsonRpcContentType(!callback.empty())); getJsonRpcContentType(!callback.empty()));
} }
addHttpServerResponseCommand(); addHttpServerResponseCommand(notauthorized);
} }
void HttpServerBodyCommand::sendJsonRpcBatchResponse void HttpServerBodyCommand::sendJsonRpcBatchResponse
(const std::vector<rpc::RpcResponse>& results, (const std::vector<rpc::RpcResponse>& results,
const std::string& callback) const std::string& callback)
{ {
bool notauthorized = rpc::any_not_authorized(results.begin(), results.end());
bool gzip = httpServer_->supportsGZip(); bool gzip = httpServer_->supportsGZip();
std::string responseData = rpc::toJsonBatch(results, callback, gzip); std::string responseData = rpc::toJsonBatch(results, callback, gzip);
httpServer_->feedResponse(std::move(responseData), httpServer_->feedResponse(std::move(responseData),
getJsonRpcContentType(!callback.empty())); getJsonRpcContentType(!callback.empty()));
addHttpServerResponseCommand(); addHttpServerResponseCommand(notauthorized);
} }
void HttpServerBodyCommand::addHttpServerResponseCommand() void HttpServerBodyCommand::addHttpServerResponseCommand(bool delayed)
{ {
e_->addCommand(make_unique<HttpServerResponseCommand> auto resp =
(getCuid(), httpServer_, e_, socket_)); make_unique<HttpServerResponseCommand>(getCuid(), httpServer_, e_, socket_);
if (delayed) {
e_->addCommand(
make_unique<DelayedCommand>(getCuid(), e_, 1, std::move(resp), true));
return;
}
e_->addCommand(std::move(resp));
e_->setNoWait(true); e_->setNoWait(true);
} }
@ -201,7 +211,7 @@ bool HttpServerBodyCommand::execute()
} }
} }
httpServer_->feedResponse(200, accessControlHeaders); httpServer_->feedResponse(200, accessControlHeaders);
addHttpServerResponseCommand(); addHttpServerResponseCommand(false);
return true; return true;
} }
@ -223,7 +233,7 @@ bool HttpServerBodyCommand::execute()
(fmt("CUID#%" PRId64 " - Failed to parse XML-RPC request", (fmt("CUID#%" PRId64 " - Failed to parse XML-RPC request",
getCuid())); getCuid()));
httpServer_->feedResponse(400); httpServer_->feedResponse(400);
addHttpServerResponseCommand(); addHttpServerResponseCommand(false);
return true; return true;
} }
A2_LOG_INFO(fmt("Executing RPC method %s", req.methodName.c_str())); A2_LOG_INFO(fmt("Executing RPC method %s", req.methodName.c_str()));
@ -232,10 +242,10 @@ bool HttpServerBodyCommand::execute()
bool gzip = httpServer_->supportsGZip(); bool gzip = httpServer_->supportsGZip();
std::string responseData = rpc::toXml(res, gzip); std::string responseData = rpc::toXml(res, gzip);
httpServer_->feedResponse(std::move(responseData), "text/xml"); httpServer_->feedResponse(std::move(responseData), "text/xml");
addHttpServerResponseCommand(); addHttpServerResponseCommand(false);
#else // !ENABLE_XML_RPC #else // !ENABLE_XML_RPC
httpServer_->feedResponse(404); httpServer_->feedResponse(404);
addHttpServerResponseCommand(); addHttpServerResponseCommand(false);
#endif // !ENABLE_XML_RPC #endif // !ENABLE_XML_RPC
return true; return true;
} }
@ -274,8 +284,7 @@ bool HttpServerBodyCommand::execute()
} }
Dict* jsondict = downcast<Dict>(json); Dict* jsondict = downcast<Dict>(json);
if(jsondict) { if(jsondict) {
rpc::RpcResponse res = auto res = rpc::processJsonRpcRequest(jsondict, e_, preauthorized);
rpc::processJsonRpcRequest(jsondict, e_, preauthorized);
sendJsonRpcResponse(res, callback); sendJsonRpcResponse(res, callback);
} else { } else {
List* jsonlist = downcast<List>(json); List* jsonlist = downcast<List>(json);
@ -306,7 +315,7 @@ bool HttpServerBodyCommand::execute()
} }
default: default:
httpServer_->feedResponse(404); httpServer_->feedResponse(404);
addHttpServerResponseCommand(); addHttpServerResponseCommand(false);
return true; return true;
} }
} else { } else {

View File

@ -63,7 +63,7 @@ private:
void sendJsonRpcBatchResponse void sendJsonRpcBatchResponse
(const std::vector<rpc::RpcResponse>& results, (const std::vector<rpc::RpcResponse>& results,
const std::string& callback); const std::string& callback);
void addHttpServerResponseCommand(); void addHttpServerResponseCommand(bool delayed);
void updateWriteCheck(); void updateWriteCheck();
public: public:
HttpServerBodyCommand(cuid_t cuid, HttpServerBodyCommand(cuid_t cuid,

View File

@ -66,6 +66,7 @@ SRCS = \
DefaultDiskWriterFactory.cc DefaultDiskWriterFactory.h\ DefaultDiskWriterFactory.cc DefaultDiskWriterFactory.h\
DefaultPieceStorage.cc DefaultPieceStorage.h\ DefaultPieceStorage.cc DefaultPieceStorage.h\
DefaultStreamPieceSelector.cc DefaultStreamPieceSelector.h\ DefaultStreamPieceSelector.cc DefaultStreamPieceSelector.h\
DelayedCommand.h\
Dependency.h\ Dependency.h\
DirectDiskAdaptor.cc DirectDiskAdaptor.h\ DirectDiskAdaptor.cc DirectDiskAdaptor.h\
DiskAdaptor.cc DiskAdaptor.h\ DiskAdaptor.cc DiskAdaptor.h\

View File

@ -94,13 +94,16 @@ void RpcMethod::authorize(RpcRequest& req, DownloadEngine* e)
RpcResponse RpcMethod::execute(RpcRequest req, DownloadEngine* e) RpcResponse RpcMethod::execute(RpcRequest req, DownloadEngine* e)
{ {
auto authorized = RpcResponse::NOTAUTHORIZED;
try { try {
authorize(req, e); authorize(req, e);
authorized = RpcResponse::AUTHORIZED;
auto r = process(req, e); auto r = process(req, e);
return RpcResponse(0, std::move(r), std::move(req.id)); return RpcResponse(0, authorized, std::move(r), std::move(req.id));
} catch(RecoverableException& ex) { } catch(RecoverableException& ex) {
A2_LOG_DEBUG_EX(EX_EXCEPTION_CAUGHT, ex); A2_LOG_DEBUG_EX(EX_EXCEPTION_CAUGHT, ex);
return RpcResponse(1, createErrorResponse(ex, req), std::move(req.id)); return RpcResponse(1, authorized, createErrorResponse(ex, req),
std::move(req.id));
} }
} }

View File

@ -97,7 +97,7 @@ public:
// Do work to fulfill RpcRequest req and returns its result as // Do work to fulfill RpcRequest req and returns its result as
// RpcResponse. This method delegates to process() method. // RpcResponse. This method delegates to process() method.
RpcResponse execute(RpcRequest req, DownloadEngine* e); virtual RpcResponse execute(RpcRequest req, DownloadEngine* e);
}; };
} // namespace rpc } // namespace rpc

View File

@ -1360,54 +1360,73 @@ std::unique_ptr<ValueBase> SaveSessionRpcMethod::process
std::unique_ptr<ValueBase> SystemMulticallRpcMethod::process std::unique_ptr<ValueBase> SystemMulticallRpcMethod::process
(const RpcRequest& req, DownloadEngine* e) (const RpcRequest& req, DownloadEngine* e)
{ {
const List* methodSpecs = checkRequiredParam<List>(req, 0); // Should never get here, since SystemMulticallRpcMethod overrides execute().
auto list = List::g(); assert(false);
auto auth = RpcRequest::MUST_AUTHORIZE; return nullptr;
for(auto & methodSpec : *methodSpecs) { }
Dict* methodDict = downcast<Dict>(methodSpec);
if(!methodDict) { RpcResponse SystemMulticallRpcMethod::execute(RpcRequest req, DownloadEngine *e)
list->append(createErrorResponse {
(DL_ABORT_EX("system.multicall expected struct."), req)); auto preauthorized = RpcRequest::MUST_AUTHORIZE;
continue; auto authorized = RpcResponse::AUTHORIZED;
} try {
const String* methodName = downcast<String>(methodDict->get(KEY_METHOD_NAME)); const List* methodSpecs = checkRequiredParam<List>(req, 0);
if(!methodName) { auto list = List::g();
list->append(createErrorResponse for(auto & methodSpec : *methodSpecs) {
(DL_ABORT_EX("Missing methodName."), req)); Dict* methodDict = downcast<Dict>(methodSpec);
continue; if(!methodDict) {
} list->append(createErrorResponse
if(methodName->s() == getMethodName()) { (DL_ABORT_EX("system.multicall expected struct."), req));
list->append(createErrorResponse continue;
(DL_ABORT_EX("Recursive system.multicall forbidden."), req)); }
continue; const String* methodName = downcast<String>(methodDict->get(KEY_METHOD_NAME));
} if(!methodName) {
// TODO what if params missing? list->append(createErrorResponse
auto tempParamsList = methodDict->get(KEY_PARAMS); (DL_ABORT_EX("Missing methodName."), req));
std::unique_ptr<List> paramsList; continue;
if(downcast<List>(tempParamsList)) { }
paramsList.reset(static_cast<List*>(methodDict->popValue(KEY_PARAMS) if(methodName->s() == getMethodName()) {
.release())); list->append(createErrorResponse
} else { (DL_ABORT_EX("Recursive system.multicall forbidden."), req));
paramsList = List::g(); continue;
} }
RpcRequest r = { // TODO what if params missing?
methodName->s(), auto tempParamsList = methodDict->get(KEY_PARAMS);
std::move(paramsList), std::unique_ptr<List> paramsList;
nullptr, if(downcast<List>(tempParamsList)) {
auth, paramsList.reset(static_cast<List*>(methodDict->popValue(KEY_PARAMS)
req.jsonRpc .release()));
}; } else {
RpcResponse res = getMethod(methodName->s())->execute(std::move(r), e); paramsList = List::g();
if(res.code == 0) { }
auto l = List::g(); RpcRequest r = {
l->append(std::move(res.param)); methodName->s(),
list->append(std::move(l)); std::move(paramsList),
auth = RpcRequest::PREAUTHORIZED; nullptr,
} else { preauthorized,
list->append(std::move(res.param)); req.jsonRpc
};
RpcResponse res = getMethod(methodName->s())->execute(std::move(r), e);
if(rpc::not_authorized(res)) {
authorized = RpcResponse::NOTAUTHORIZED;
} else {
preauthorized = RpcRequest::PREAUTHORIZED;
}
if(res.code == 0) {
auto l = List::g();
l->append(std::move(res.param));
list->append(std::move(l));
} else {
list->append(std::move(res.param));
}
} }
return RpcResponse(0, authorized, std::move(list), std::move(req.id));
} catch(RecoverableException& ex) {
A2_LOG_DEBUG_EX(EX_EXCEPTION_CAUGHT, ex);
return RpcResponse(1, authorized, createErrorResponse(ex, req),
std::move(req.id));
} }
return std::move(list);
} }
std::unique_ptr<ValueBase> NoSuchMethodRpcMethod::process std::unique_ptr<ValueBase> NoSuchMethodRpcMethod::process

View File

@ -586,14 +586,9 @@ class SystemMulticallRpcMethod:public RpcMethod {
protected: protected:
virtual std::unique_ptr<ValueBase> process virtual std::unique_ptr<ValueBase> process
(const RpcRequest& req, DownloadEngine* e) CXX11_OVERRIDE; (const RpcRequest& req, DownloadEngine* e) CXX11_OVERRIDE;
public: public:
virtual void authorize(RpcRequest& req, DownloadEngine* e) CXX11_OVERRIDE virtual RpcResponse execute(RpcRequest req, DownloadEngine* e) CXX11_OVERRIDE;
{
// Batch calls (e.g., system.multicall) authorizes only nested
// methods. This is because XML-RPC system.multicall only accpets
// methods array and there is no room for us to insert token
// parameter.
}
static const char* getMethodName() static const char* getMethodName()
{ {

View File

@ -51,7 +51,7 @@ RpcRequest::RpcRequest(std::string methodName,
RpcRequest::RpcRequest(std::string methodName, RpcRequest::RpcRequest(std::string methodName,
std::unique_ptr<List> params, std::unique_ptr<List> params,
std::unique_ptr<ValueBase> id, std::unique_ptr<ValueBase> id,
RpcRequest::authorization_t authorization, RpcRequest::preauthorization_t authorization,
bool jsonRpc) bool jsonRpc)
: methodName{std::move(methodName)}, params{std::move(params)}, : methodName{std::move(methodName)}, params{std::move(params)},
id{std::move(id)}, authorization{authorization}, jsonRpc{jsonRpc} id{std::move(id)}, authorization{authorization}, jsonRpc{jsonRpc}

View File

@ -46,7 +46,7 @@ namespace aria2 {
namespace rpc { namespace rpc {
struct RpcRequest { struct RpcRequest {
enum authorization_t { enum preauthorization_t {
MUST_AUTHORIZE, MUST_AUTHORIZE,
PREAUTHORIZED PREAUTHORIZED
}; };
@ -54,7 +54,7 @@ struct RpcRequest {
std::string methodName; std::string methodName;
std::unique_ptr<List> params; std::unique_ptr<List> params;
std::unique_ptr<ValueBase> id; std::unique_ptr<ValueBase> id;
authorization_t authorization; preauthorization_t authorization;
bool jsonRpc; bool jsonRpc;
RpcRequest(); RpcRequest();
@ -65,7 +65,7 @@ struct RpcRequest {
RpcRequest(std::string methodName, RpcRequest(std::string methodName,
std::unique_ptr<List> params, std::unique_ptr<List> params,
std::unique_ptr<ValueBase> id, std::unique_ptr<ValueBase> id,
authorization_t authorization, preauthorization_t authorization,
bool jsonRpc = false); bool jsonRpc = false);
}; };

View File

@ -121,9 +121,10 @@ std::string encodeAll
RpcResponse::RpcResponse RpcResponse::RpcResponse
(int code, (int code,
RpcResponse::authorization_t authorized,
std::unique_ptr<ValueBase> param, std::unique_ptr<ValueBase> param,
std::unique_ptr<ValueBase> id) std::unique_ptr<ValueBase> id)
: code{code}, param{std::move(param)}, id{std::move(id)} : param{std::move(param)}, id{std::move(id)}, code{code}, authorized{authorized}
{} {}
std::string toXml(const RpcResponse& res, bool gzip) std::string toXml(const RpcResponse& res, bool gzip)

View File

@ -47,17 +47,35 @@ namespace aria2 {
namespace rpc { namespace rpc {
struct RpcResponse { struct RpcResponse {
enum authorization_t {
NOTAUTHORIZED,
AUTHORIZED
};
// 0 for success, non-zero for error // 0 for success, non-zero for error
int code;
std::unique_ptr<ValueBase> param; std::unique_ptr<ValueBase> param;
std::unique_ptr<ValueBase> id; std::unique_ptr<ValueBase> id;
int code;
authorization_t authorized;
RpcResponse RpcResponse
(int code, (int code,
authorization_t authorized,
std::unique_ptr<ValueBase> param, std::unique_ptr<ValueBase> param,
std::unique_ptr<ValueBase> id); std::unique_ptr<ValueBase> id);
}; };
inline
bool not_authorized(const rpc::RpcResponse& res)
{
return res.authorized != rpc::RpcResponse::AUTHORIZED;
}
template<typename InputIterator>
bool any_not_authorized(const InputIterator begin, const InputIterator end) {
return std::any_of(begin, end, not_authorized);
}
std::string toXml(const RpcResponse& response, bool gzip = false); std::string toXml(const RpcResponse& response, bool gzip = false);
// Encodes RPC response in JSON. If callback is not empty, the // Encodes RPC response in JSON. If callback is not empty, the

View File

@ -64,6 +64,11 @@ public:
virtual bool execute() CXX11_OVERRIDE; virtual bool execute() CXX11_OVERRIDE;
std::shared_ptr<WebSocketSession>& getSession()
{
return wsSession_;
}
void updateWriteCheck(); void updateWriteCheck();
}; };

View File

@ -43,6 +43,8 @@
#include "RecoverableException.h" #include "RecoverableException.h"
#include "message.h" #include "message.h"
#include "DownloadEngine.h" #include "DownloadEngine.h"
#include "DelayedCommand.h"
#include "WebSocketInteractionCommand.h"
#include "rpc_helper.h" #include "rpc_helper.h"
#include "RpcResponse.h" #include "RpcResponse.h"
#include "json.h" #include "json.h"
@ -111,8 +113,9 @@ ssize_t recvCallback(wslay_event_context_ptr wsctx,
namespace { namespace {
void addResponse(WebSocketSession* wsSession, const RpcResponse& res) void addResponse(WebSocketSession* wsSession, const RpcResponse& res)
{ {
bool notauthorized = rpc::not_authorized(res);
std::string response = toJson(res, "", false); std::string response = toJson(res, "", false);
wsSession->addTextMessage(response); wsSession->addTextMessage(response, notauthorized);
} }
} // namespace } // namespace
@ -120,8 +123,9 @@ namespace {
void addResponse(WebSocketSession* wsSession, void addResponse(WebSocketSession* wsSession,
const std::vector<RpcResponse>& results) const std::vector<RpcResponse>& results)
{ {
bool notauthorized = rpc::any_not_authorized(results.begin(), results.end());
std::string response = toJsonBatch(results, "", false); std::string response = toJsonBatch(results, "", false);
wsSession->addTextMessage(response); wsSession->addTextMessage(response, notauthorized);
} }
} // namespace } // namespace
@ -264,8 +268,35 @@ int WebSocketSession::onWriteEvent()
} }
} }
void WebSocketSession::addTextMessage(const std::string& msg) namespace {
class TextMessageCommand : public Command
{ {
private:
std::shared_ptr<WebSocketSession> session_;
const std::string msg_;
public:
TextMessageCommand(cuid_t cuid, std::shared_ptr<WebSocketSession> session,
const std::string& msg)
: Command(cuid), session_{std::move(session)}, msg_{msg}
{}
virtual bool execute() CXX11_OVERRIDE
{
session_->addTextMessage(msg_, false);
return true;
}
};
} // namespace
void WebSocketSession::addTextMessage(const std::string& msg, bool delayed)
{
if (delayed) {
auto e = getDownloadEngine();
auto cuid = command_->getCuid();
auto c = make_unique<TextMessageCommand>(cuid, command_->getSession(), msg);
e->addCommand(make_unique<DelayedCommand>(cuid, e, 1, std::move(c), false));
return;
}
// TODO Don't add text message if the size of outbound queue in // TODO Don't add text message if the size of outbound queue in
// wsctx_ exceeds certain limit. // wsctx_ exceeds certain limit.
wslay_event_msg arg = { wslay_event_msg arg = {

View File

@ -74,7 +74,7 @@ public:
int onWriteEvent(); int onWriteEvent();
// Adds text message |msg|. The message is queued and will be sent // Adds text message |msg|. The message is queued and will be sent
// in onWriteEvent(). // in onWriteEvent().
void addTextMessage(const std::string& msg); void addTextMessage(const std::string& msg, bool delayed);
// Returns true if the close frame is received. // Returns true if the close frame is received.
bool closeReceived(); bool closeReceived();
// Returns true if the close frame is sent. // Returns true if the close frame is sent.

View File

@ -78,7 +78,7 @@ void WebSocketSessionMan::addNotification
dict->put("params", std::move(params)); dict->put("params", std::move(params));
std::string msg = json::encode(dict.get()); std::string msg = json::encode(dict.get());
for(auto& session : sessions_) { for(auto& session : sessions_) {
session->addTextMessage(msg); session->addTextMessage(msg, false);
session->getCommand()->updateWriteCheck(); session->getCommand()->updateWriteCheck();
} }
} }

View File

@ -73,11 +73,12 @@ RpcResponse createJsonRpcErrorResponse(int code,
auto params = Dict::g(); auto params = Dict::g();
params->put("code", Integer::g(code)); params->put("code", Integer::g(code));
params->put("message", msg); params->put("message", msg);
return rpc::RpcResponse{code, std::move(params), std::move(id)}; return rpc::RpcResponse{code, rpc::RpcResponse::AUTHORIZED, std::move(params),
std::move(id)};
} }
RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e, RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e,
RpcRequest::authorization_t authorization) RpcRequest::preauthorization_t authorization)
{ {
auto id = jsondict->popValue("id"); auto id = jsondict->popValue("id");
if(!id) { if(!id) {

View File

@ -65,7 +65,7 @@ RpcResponse createJsonRpcErrorResponse(int code,
// Processes JSON-RPC request |jsondict| and returns the result. // Processes JSON-RPC request |jsondict| and returns the result.
RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e, RpcResponse processJsonRpcRequest(Dict* jsondict, DownloadEngine* e,
RpcRequest::authorization_t authorization); RpcRequest::preauthorization_t authorization);
} // namespace rpc } // namespace rpc

View File

@ -28,7 +28,8 @@ void RpcResponseTest::testToJson()
{ {
auto param = List::g(); auto param = List::g();
param->append(Integer::g(1)); param->append(Integer::g(1));
RpcResponse res(0, std::move(param), String::g("9")); RpcResponse res(0, RpcResponse::AUTHORIZED, std::move(param),
String::g("9"));
results.push_back(std::move(res)); results.push_back(std::move(res));
std::string s = toJson(results.back(), "", false); std::string s = toJson(results.back(), "", false);
CPPUNIT_ASSERT_EQUAL(std::string("{\"id\":\"9\"," CPPUNIT_ASSERT_EQUAL(std::string("{\"id\":\"9\","
@ -47,7 +48,7 @@ void RpcResponseTest::testToJson()
auto param = Dict::g(); auto param = Dict::g();
param->put("code", Integer::g(1)); param->put("code", Integer::g(1));
param->put("message", "HELLO ERROR"); param->put("message", "HELLO ERROR");
RpcResponse res(1, std::move(param), Null::g()); RpcResponse res(1, RpcResponse::AUTHORIZED, std::move(param), Null::g());
results.push_back(std::move(res)); results.push_back(std::move(res));
std::string s = toJson(results.back(), "", false); std::string s = toJson(results.back(), "", false);
CPPUNIT_ASSERT_EQUAL(std::string("{\"id\":null," CPPUNIT_ASSERT_EQUAL(std::string("{\"id\":null,"
@ -101,7 +102,7 @@ void RpcResponseTest::testToXml()
auto param = Dict::g(); auto param = Dict::g();
param->put("faultCode", Integer::g(1)); param->put("faultCode", Integer::g(1));
param->put("faultString", "No such method: make.hamburger"); param->put("faultString", "No such method: make.hamburger");
RpcResponse res(1, std::move(param), Null::g()); RpcResponse res(1, RpcResponse::AUTHORIZED, std::move(param), Null::g());
std::string s = toXml(res, false); std::string s = toXml(res, false);
CPPUNIT_ASSERT_EQUAL CPPUNIT_ASSERT_EQUAL
(std::string("<?xml version=\"1.0\"?>" (std::string("<?xml version=\"1.0\"?>"