Made WebSocket handshake more strict.

Refactored HttpServer as well.
pull/14/head
Tatsuhiro Tsujikawa 2012-03-25 22:10:36 +09:00
parent 8499a47d21
commit 06b6bef860
4 changed files with 123 additions and 48 deletions

View File

@ -69,6 +69,58 @@ HttpServer::HttpServer
HttpServer::~HttpServer() {}
namespace {
const char* getStatusString(int status)
{
switch(status) {
case 100: return "100 Continue";
case 101: return "101 Switching Protocols";
case 200: return "200 OK";
case 201: return "201 Created";
case 202: return "202 Accepted";
case 203: return "203 Non-Authoritative Information";
case 204: return "204 No Content";
case 205: return "205 Reset Content";
case 206: return "206 Partial Content";
case 300: return "300 Multiple Choices";
case 301: return "301 Moved Permanently";
case 302: return "302 Found";
case 303: return "303 See Other";
case 304: return "304 Not Modified";
case 305: return "305 Use Proxy";
// case 306: return "306 (Unused)";
case 307: return "307 Temporary Redirect";
case 400: return "400 Bad Request";
case 401: return "401 Unauthorized";
case 402: return "402 Payment Required";
case 403: return "403 Forbidden";
case 404: return "404 Not Found";
case 405: return "405 Method Not Allowed";
case 406: return "406 Not Acceptable";
case 407: return "407 Proxy Authentication Required";
case 408: return "408 Request Timeout";
case 409: return "409 Conflict";
case 410: return "410 Gone";
case 411: return "411 Length Required";
case 412: return "412 Precondition Failed";
case 413: return "413 Request Entity Too Large";
case 414: return "414 Request-URI Too Long";
case 415: return "415 Unsupported Media Type";
case 416: return "416 Requested Range Not Satisfiable";
case 417: return "417 Expectation Failed";
// RFC 2817 defines 426 status code.
case 426: return "426 Upgrade Required";
case 500: return "500 Internal Server Error";
case 501: return "501 Not Implemented";
case 502: return "502 Bad Gateway";
case 503: return "503 Service Unavailable";
case 504: return "504 Gateway Timeout";
case 505: return "505 HTTP Version Not Supported";
default: return "";
}
}
} // namespace
SharedHandle<HttpHeader> HttpServer::receiveRequest()
{
if(socketRecvBuffer_->bufferEmpty()) {
@ -167,45 +219,41 @@ const std::string& HttpServer::getRequestPath() const
void HttpServer::feedResponse(std::string& text, const std::string& contentType)
{
feedResponse("200 OK", "", text, contentType);
feedResponse(200, "", text, contentType);
}
void HttpServer::feedResponse(const std::string& status,
void HttpServer::feedResponse(int status,
const std::string& headers,
std::string& text,
const std::string& text,
const std::string& contentType)
{
std::string httpDate = Time().toHTTPDate();
std::string header= fmt("HTTP/1.1 %s\r\n"
"Date: %s\r\n"
"Content-Type: %s\r\n"
"Content-Length: %lu\r\n"
"Expires: %s\r\n"
"Cache-Control: no-cache\r\n"
"%s%s",
status.c_str(),
"Cache-Control: no-cache\r\n",
getStatusString(status),
httpDate.c_str(),
contentType.c_str(),
static_cast<unsigned long>(text.size()),
httpDate.c_str(),
supportsGZip() ?
"Content-Encoding: gzip\r\n" : "",
!supportsPersistentConnection() ?
"Connection: close\r\n" : "");
httpDate.c_str());
if(!contentType.empty()) {
header += "Content-Type: ";
header += contentType;
header += "\r\n";
}
if(!allowOrigin_.empty()) {
header += "Access-Control-Allow-Origin: ";
header += allowOrigin_;
header += "\r\n";
}
if(!headers.empty()) {
header += headers;
if(headers.size() < 2 ||
(headers[headers.size()-2] != '\r' &&
headers[headers.size()-1] != '\n')) {
header += "\r\n";
}
if(supportsGZip()) {
header += "Content-Encoding: gzip\r\n";
}
if(!supportsPersistentConnection()) {
header += "Connection: close\r\n";
}
header += headers;
header += "\r\n";
A2_LOG_DEBUG(fmt("HTTP Server sends response:\n%s", header.c_str()));
socketBuffer_.pushStr(header);

View File

@ -86,10 +86,14 @@ public:
void feedResponse(std::string& text, const std::string& contentType);
void feedResponse(const std::string& status,
const std::string& headers,
std::string& text,
const std::string& contentType);
// Feeds HTTP response with the status code |status| (e.g.,
// 200). The |headers| is zero or more lines of HTTP header field
// and each line must end with "\r\n". The |text| is the response
// body. The |contentType" is the content-type of the response body.
void feedResponse(int status,
const std::string& headers = "",
const std::string& text = "",
const std::string& contentType = "");
// Feeds "101 Switching Protocols" response. The |protocol| will
// appear in Upgrade header field. The |headers| is zero or more

View File

@ -104,16 +104,16 @@ void HttpServerBodyCommand::sendJsonRpcResponse
getJsonRpcContentType(!callback.empty()));
} else {
httpServer_->disableKeepAlive();
std::string httpCode;
int httpCode;
switch(res.code) {
case -32600:
httpCode = "400 Bad Request";
httpCode = 400;
break;
case -32601:
httpCode = "404 Not Found";
httpCode = 404;
break;
default:
httpCode = "500 Internal Server Error";
httpCode = 500;
};
httpServer_->feedResponse(httpCode, A2STR::NIL,
responseData,

View File

@ -108,6 +108,7 @@ void HttpServerCommand::checkSocketRecvBuffer()
}
}
namespace {
// Creates server's WebSocket accept key which will be sent in
// Sec-WebSocket-Accept header field. The |clientKey| is the value
// found in Sec-WebSocket-Key header field in the request.
@ -120,6 +121,23 @@ std::string createWebSocketServerKey(const std::string& clientKey)
src.c_str(), src.size());
return base64::encode(&digest[0], &digest[sizeof(digest)]);
}
} // namespace
namespace {
int websocketHandshake(const SharedHandle<HttpHeader>& header)
{
if(header->getMethod() != "GET" ||
header->find("sec-websocket-key").empty()) {
return 400;
} else if(header->find("sec-websocket-version") != "13") {
return 426;
} else if(header->getRequestPath() != "/jsonrpc") {
return 404;
} else {
return 101;
}
}
} // namespace
bool HttpServerCommand::execute()
{
@ -140,10 +158,8 @@ bool HttpServerCommand::execute()
}
if(!httpServer_->authenticate()) {
httpServer_->disableKeepAlive();
std::string text;
httpServer_->feedResponse("401 Unauthorized",
"WWW-Authenticate: Basic realm=\"aria2\"",
text,"text/html");
httpServer_->feedResponse
(401, "WWW-Authenticate: Basic realm=\"aria2\"\r\n");
Command* command =
new HttpServerResponseCommand(getCuid(), httpServer_, e_, socket_);
e_->addCommand(command);
@ -152,21 +168,28 @@ bool HttpServerCommand::execute()
}
const std::string& upgradeHd = header->find("upgrade");
const std::string& connectionHd = header->find("connection");
if(httpServer_->getRequestPath() == "/jsonrpc" &&
httpServer_->getMethod() == "GET" &&
util::strieq(upgradeHd.begin(), upgradeHd.end(), "websocket") &&
util::strieq(connectionHd.begin(), connectionHd.end(), "upgrade") &&
header->find("sec-websocket-version") == "13" &&
header->defined("sec-websocket-key")) {
std::string serverKey =
createWebSocketServerKey(header->find("sec-websocket-key"));
httpServer_->feedUpgradeResponse("websocket",
fmt("Sec-WebSocket-Accept: %s\r\n",
serverKey.c_str()));
httpServer_->getSocket()->setTcpNodelay(true);
Command* command =
new rpc::WebSocketResponseCommand(getCuid(), httpServer_, e_,
socket_);
if(util::strieq(upgradeHd.begin(), upgradeHd.end(), "websocket") &&
util::strieq(connectionHd.begin(), connectionHd.end(), "upgrade")) {
int status = websocketHandshake(header);
Command* command;
if(status == 101) {
std::string serverKey =
createWebSocketServerKey(header->find("sec-websocket-key"));
httpServer_->feedUpgradeResponse("websocket",
fmt("Sec-WebSocket-Accept: %s\r\n",
serverKey.c_str()));
httpServer_->getSocket()->setTcpNodelay(true);
command = new rpc::WebSocketResponseCommand(getCuid(), httpServer_,
e_, socket_);
} else {
if(status == 426) {
httpServer_->feedResponse(426, "Sec-WebSocket-Version: 13\r\n");
} else {
httpServer_->feedResponse(status);
}
command = new HttpServerResponseCommand(getCuid(), httpServer_, e_,
socket_);
}
e_->addCommand(command);
e_->setNoWait(true);
return true;