Made WebSocket handshake more strict.

Refactored HttpServer as well.
pull/14/head
Tatsuhiro Tsujikawa 2012-03-25 22:10:36 +09:00
parent 8499a47d21
commit 06b6bef860
4 changed files with 123 additions and 48 deletions

View File

@ -69,6 +69,58 @@ HttpServer::HttpServer
HttpServer::~HttpServer() {} HttpServer::~HttpServer() {}
namespace {
const char* getStatusString(int status)
{
switch(status) {
case 100: return "100 Continue";
case 101: return "101 Switching Protocols";
case 200: return "200 OK";
case 201: return "201 Created";
case 202: return "202 Accepted";
case 203: return "203 Non-Authoritative Information";
case 204: return "204 No Content";
case 205: return "205 Reset Content";
case 206: return "206 Partial Content";
case 300: return "300 Multiple Choices";
case 301: return "301 Moved Permanently";
case 302: return "302 Found";
case 303: return "303 See Other";
case 304: return "304 Not Modified";
case 305: return "305 Use Proxy";
// case 306: return "306 (Unused)";
case 307: return "307 Temporary Redirect";
case 400: return "400 Bad Request";
case 401: return "401 Unauthorized";
case 402: return "402 Payment Required";
case 403: return "403 Forbidden";
case 404: return "404 Not Found";
case 405: return "405 Method Not Allowed";
case 406: return "406 Not Acceptable";
case 407: return "407 Proxy Authentication Required";
case 408: return "408 Request Timeout";
case 409: return "409 Conflict";
case 410: return "410 Gone";
case 411: return "411 Length Required";
case 412: return "412 Precondition Failed";
case 413: return "413 Request Entity Too Large";
case 414: return "414 Request-URI Too Long";
case 415: return "415 Unsupported Media Type";
case 416: return "416 Requested Range Not Satisfiable";
case 417: return "417 Expectation Failed";
// RFC 2817 defines 426 status code.
case 426: return "426 Upgrade Required";
case 500: return "500 Internal Server Error";
case 501: return "501 Not Implemented";
case 502: return "502 Bad Gateway";
case 503: return "503 Service Unavailable";
case 504: return "504 Gateway Timeout";
case 505: return "505 HTTP Version Not Supported";
default: return "";
}
}
} // namespace
SharedHandle<HttpHeader> HttpServer::receiveRequest() SharedHandle<HttpHeader> HttpServer::receiveRequest()
{ {
if(socketRecvBuffer_->bufferEmpty()) { if(socketRecvBuffer_->bufferEmpty()) {
@ -167,45 +219,41 @@ const std::string& HttpServer::getRequestPath() const
void HttpServer::feedResponse(std::string& text, const std::string& contentType) void HttpServer::feedResponse(std::string& text, const std::string& contentType)
{ {
feedResponse("200 OK", "", text, contentType); feedResponse(200, "", text, contentType);
} }
void HttpServer::feedResponse(const std::string& status, void HttpServer::feedResponse(int status,
const std::string& headers, const std::string& headers,
std::string& text, const std::string& text,
const std::string& contentType) const std::string& contentType)
{ {
std::string httpDate = Time().toHTTPDate(); std::string httpDate = Time().toHTTPDate();
std::string header= fmt("HTTP/1.1 %s\r\n" std::string header= fmt("HTTP/1.1 %s\r\n"
"Date: %s\r\n" "Date: %s\r\n"
"Content-Type: %s\r\n"
"Content-Length: %lu\r\n" "Content-Length: %lu\r\n"
"Expires: %s\r\n" "Expires: %s\r\n"
"Cache-Control: no-cache\r\n" "Cache-Control: no-cache\r\n",
"%s%s", getStatusString(status),
status.c_str(),
httpDate.c_str(), httpDate.c_str(),
contentType.c_str(),
static_cast<unsigned long>(text.size()), static_cast<unsigned long>(text.size()),
httpDate.c_str(), httpDate.c_str());
supportsGZip() ? if(!contentType.empty()) {
"Content-Encoding: gzip\r\n" : "", header += "Content-Type: ";
!supportsPersistentConnection() ? header += contentType;
"Connection: close\r\n" : ""); header += "\r\n";
}
if(!allowOrigin_.empty()) { if(!allowOrigin_.empty()) {
header += "Access-Control-Allow-Origin: "; header += "Access-Control-Allow-Origin: ";
header += allowOrigin_; header += allowOrigin_;
header += "\r\n"; header += "\r\n";
} }
if(!headers.empty()) { if(supportsGZip()) {
header += headers; header += "Content-Encoding: gzip\r\n";
if(headers.size() < 2 ||
(headers[headers.size()-2] != '\r' &&
headers[headers.size()-1] != '\n')) {
header += "\r\n";
}
} }
if(!supportsPersistentConnection()) {
header += "Connection: close\r\n";
}
header += headers;
header += "\r\n"; header += "\r\n";
A2_LOG_DEBUG(fmt("HTTP Server sends response:\n%s", header.c_str())); A2_LOG_DEBUG(fmt("HTTP Server sends response:\n%s", header.c_str()));
socketBuffer_.pushStr(header); socketBuffer_.pushStr(header);

View File

@ -86,10 +86,14 @@ public:
void feedResponse(std::string& text, const std::string& contentType); void feedResponse(std::string& text, const std::string& contentType);
void feedResponse(const std::string& status, // Feeds HTTP response with the status code |status| (e.g.,
const std::string& headers, // 200). The |headers| is zero or more lines of HTTP header field
std::string& text, // and each line must end with "\r\n". The |text| is the response
const std::string& contentType); // body. The |contentType" is the content-type of the response body.
void feedResponse(int status,
const std::string& headers = "",
const std::string& text = "",
const std::string& contentType = "");
// Feeds "101 Switching Protocols" response. The |protocol| will // Feeds "101 Switching Protocols" response. The |protocol| will
// appear in Upgrade header field. The |headers| is zero or more // appear in Upgrade header field. The |headers| is zero or more

View File

@ -104,16 +104,16 @@ void HttpServerBodyCommand::sendJsonRpcResponse
getJsonRpcContentType(!callback.empty())); getJsonRpcContentType(!callback.empty()));
} else { } else {
httpServer_->disableKeepAlive(); httpServer_->disableKeepAlive();
std::string httpCode; int httpCode;
switch(res.code) { switch(res.code) {
case -32600: case -32600:
httpCode = "400 Bad Request"; httpCode = 400;
break; break;
case -32601: case -32601:
httpCode = "404 Not Found"; httpCode = 404;
break; break;
default: default:
httpCode = "500 Internal Server Error"; httpCode = 500;
}; };
httpServer_->feedResponse(httpCode, A2STR::NIL, httpServer_->feedResponse(httpCode, A2STR::NIL,
responseData, responseData,

View File

@ -108,6 +108,7 @@ void HttpServerCommand::checkSocketRecvBuffer()
} }
} }
namespace {
// Creates server's WebSocket accept key which will be sent in // Creates server's WebSocket accept key which will be sent in
// Sec-WebSocket-Accept header field. The |clientKey| is the value // Sec-WebSocket-Accept header field. The |clientKey| is the value
// found in Sec-WebSocket-Key header field in the request. // found in Sec-WebSocket-Key header field in the request.
@ -120,6 +121,23 @@ std::string createWebSocketServerKey(const std::string& clientKey)
src.c_str(), src.size()); src.c_str(), src.size());
return base64::encode(&digest[0], &digest[sizeof(digest)]); return base64::encode(&digest[0], &digest[sizeof(digest)]);
} }
} // namespace
namespace {
int websocketHandshake(const SharedHandle<HttpHeader>& header)
{
if(header->getMethod() != "GET" ||
header->find("sec-websocket-key").empty()) {
return 400;
} else if(header->find("sec-websocket-version") != "13") {
return 426;
} else if(header->getRequestPath() != "/jsonrpc") {
return 404;
} else {
return 101;
}
}
} // namespace
bool HttpServerCommand::execute() bool HttpServerCommand::execute()
{ {
@ -140,10 +158,8 @@ bool HttpServerCommand::execute()
} }
if(!httpServer_->authenticate()) { if(!httpServer_->authenticate()) {
httpServer_->disableKeepAlive(); httpServer_->disableKeepAlive();
std::string text; httpServer_->feedResponse
httpServer_->feedResponse("401 Unauthorized", (401, "WWW-Authenticate: Basic realm=\"aria2\"\r\n");
"WWW-Authenticate: Basic realm=\"aria2\"",
text,"text/html");
Command* command = Command* command =
new HttpServerResponseCommand(getCuid(), httpServer_, e_, socket_); new HttpServerResponseCommand(getCuid(), httpServer_, e_, socket_);
e_->addCommand(command); e_->addCommand(command);
@ -152,21 +168,28 @@ bool HttpServerCommand::execute()
} }
const std::string& upgradeHd = header->find("upgrade"); const std::string& upgradeHd = header->find("upgrade");
const std::string& connectionHd = header->find("connection"); const std::string& connectionHd = header->find("connection");
if(httpServer_->getRequestPath() == "/jsonrpc" && if(util::strieq(upgradeHd.begin(), upgradeHd.end(), "websocket") &&
httpServer_->getMethod() == "GET" && util::strieq(connectionHd.begin(), connectionHd.end(), "upgrade")) {
util::strieq(upgradeHd.begin(), upgradeHd.end(), "websocket") && int status = websocketHandshake(header);
util::strieq(connectionHd.begin(), connectionHd.end(), "upgrade") && Command* command;
header->find("sec-websocket-version") == "13" && if(status == 101) {
header->defined("sec-websocket-key")) { std::string serverKey =
std::string serverKey = createWebSocketServerKey(header->find("sec-websocket-key"));
createWebSocketServerKey(header->find("sec-websocket-key")); httpServer_->feedUpgradeResponse("websocket",
httpServer_->feedUpgradeResponse("websocket", fmt("Sec-WebSocket-Accept: %s\r\n",
fmt("Sec-WebSocket-Accept: %s\r\n", serverKey.c_str()));
serverKey.c_str())); httpServer_->getSocket()->setTcpNodelay(true);
httpServer_->getSocket()->setTcpNodelay(true); command = new rpc::WebSocketResponseCommand(getCuid(), httpServer_,
Command* command = e_, socket_);
new rpc::WebSocketResponseCommand(getCuid(), httpServer_, e_, } else {
socket_); if(status == 426) {
httpServer_->feedResponse(426, "Sec-WebSocket-Version: 13\r\n");
} else {
httpServer_->feedResponse(status);
}
command = new HttpServerResponseCommand(getCuid(), httpServer_, e_,
socket_);
}
e_->addCommand(command); e_->addCommand(command);
e_->setNoWait(true); e_->setNoWait(true);
return true; return true;