diff --git a/src/HttpServer.cc b/src/HttpServer.cc index b48ecebd..4c41896d 100644 --- a/src/HttpServer.cc +++ b/src/HttpServer.cc @@ -69,6 +69,58 @@ HttpServer::HttpServer HttpServer::~HttpServer() {} +namespace { +const char* getStatusString(int status) +{ + switch(status) { + case 100: return "100 Continue"; + case 101: return "101 Switching Protocols"; + case 200: return "200 OK"; + case 201: return "201 Created"; + case 202: return "202 Accepted"; + case 203: return "203 Non-Authoritative Information"; + case 204: return "204 No Content"; + case 205: return "205 Reset Content"; + case 206: return "206 Partial Content"; + case 300: return "300 Multiple Choices"; + case 301: return "301 Moved Permanently"; + case 302: return "302 Found"; + case 303: return "303 See Other"; + case 304: return "304 Not Modified"; + case 305: return "305 Use Proxy"; + // case 306: return "306 (Unused)"; + case 307: return "307 Temporary Redirect"; + case 400: return "400 Bad Request"; + case 401: return "401 Unauthorized"; + case 402: return "402 Payment Required"; + case 403: return "403 Forbidden"; + case 404: return "404 Not Found"; + case 405: return "405 Method Not Allowed"; + case 406: return "406 Not Acceptable"; + case 407: return "407 Proxy Authentication Required"; + case 408: return "408 Request Timeout"; + case 409: return "409 Conflict"; + case 410: return "410 Gone"; + case 411: return "411 Length Required"; + case 412: return "412 Precondition Failed"; + case 413: return "413 Request Entity Too Large"; + case 414: return "414 Request-URI Too Long"; + case 415: return "415 Unsupported Media Type"; + case 416: return "416 Requested Range Not Satisfiable"; + case 417: return "417 Expectation Failed"; + // RFC 2817 defines 426 status code. + case 426: return "426 Upgrade Required"; + case 500: return "500 Internal Server Error"; + case 501: return "501 Not Implemented"; + case 502: return "502 Bad Gateway"; + case 503: return "503 Service Unavailable"; + case 504: return "504 Gateway Timeout"; + case 505: return "505 HTTP Version Not Supported"; + default: return ""; + } +} +} // namespace + SharedHandle HttpServer::receiveRequest() { if(socketRecvBuffer_->bufferEmpty()) { @@ -167,45 +219,41 @@ const std::string& HttpServer::getRequestPath() const void HttpServer::feedResponse(std::string& text, const std::string& contentType) { - feedResponse("200 OK", "", text, contentType); + feedResponse(200, "", text, contentType); } -void HttpServer::feedResponse(const std::string& status, +void HttpServer::feedResponse(int status, const std::string& headers, - std::string& text, + const std::string& text, const std::string& contentType) { std::string httpDate = Time().toHTTPDate(); std::string header= fmt("HTTP/1.1 %s\r\n" "Date: %s\r\n" - "Content-Type: %s\r\n" "Content-Length: %lu\r\n" "Expires: %s\r\n" - "Cache-Control: no-cache\r\n" - "%s%s", - status.c_str(), + "Cache-Control: no-cache\r\n", + getStatusString(status), httpDate.c_str(), - contentType.c_str(), static_cast(text.size()), - httpDate.c_str(), - supportsGZip() ? - "Content-Encoding: gzip\r\n" : "", - !supportsPersistentConnection() ? - "Connection: close\r\n" : ""); + httpDate.c_str()); + if(!contentType.empty()) { + header += "Content-Type: "; + header += contentType; + header += "\r\n"; + } if(!allowOrigin_.empty()) { header += "Access-Control-Allow-Origin: "; header += allowOrigin_; header += "\r\n"; } - if(!headers.empty()) { - header += headers; - if(headers.size() < 2 || - (headers[headers.size()-2] != '\r' && - headers[headers.size()-1] != '\n')) { - header += "\r\n"; - } + if(supportsGZip()) { + header += "Content-Encoding: gzip\r\n"; } - + if(!supportsPersistentConnection()) { + header += "Connection: close\r\n"; + } + header += headers; header += "\r\n"; A2_LOG_DEBUG(fmt("HTTP Server sends response:\n%s", header.c_str())); socketBuffer_.pushStr(header); diff --git a/src/HttpServer.h b/src/HttpServer.h index e7d7c2dc..a0017376 100644 --- a/src/HttpServer.h +++ b/src/HttpServer.h @@ -86,10 +86,14 @@ public: void feedResponse(std::string& text, const std::string& contentType); - void feedResponse(const std::string& status, - const std::string& headers, - std::string& text, - const std::string& contentType); + // Feeds HTTP response with the status code |status| (e.g., + // 200). The |headers| is zero or more lines of HTTP header field + // and each line must end with "\r\n". The |text| is the response + // body. The |contentType" is the content-type of the response body. + void feedResponse(int status, + const std::string& headers = "", + const std::string& text = "", + const std::string& contentType = ""); // Feeds "101 Switching Protocols" response. The |protocol| will // appear in Upgrade header field. The |headers| is zero or more diff --git a/src/HttpServerBodyCommand.cc b/src/HttpServerBodyCommand.cc index 48581532..d420d788 100644 --- a/src/HttpServerBodyCommand.cc +++ b/src/HttpServerBodyCommand.cc @@ -104,16 +104,16 @@ void HttpServerBodyCommand::sendJsonRpcResponse getJsonRpcContentType(!callback.empty())); } else { httpServer_->disableKeepAlive(); - std::string httpCode; + int httpCode; switch(res.code) { case -32600: - httpCode = "400 Bad Request"; + httpCode = 400; break; case -32601: - httpCode = "404 Not Found"; + httpCode = 404; break; default: - httpCode = "500 Internal Server Error"; + httpCode = 500; }; httpServer_->feedResponse(httpCode, A2STR::NIL, responseData, diff --git a/src/HttpServerCommand.cc b/src/HttpServerCommand.cc index 75abb543..48fef909 100644 --- a/src/HttpServerCommand.cc +++ b/src/HttpServerCommand.cc @@ -108,6 +108,7 @@ void HttpServerCommand::checkSocketRecvBuffer() } } +namespace { // Creates server's WebSocket accept key which will be sent in // Sec-WebSocket-Accept header field. The |clientKey| is the value // found in Sec-WebSocket-Key header field in the request. @@ -120,6 +121,23 @@ std::string createWebSocketServerKey(const std::string& clientKey) src.c_str(), src.size()); return base64::encode(&digest[0], &digest[sizeof(digest)]); } +} // namespace + +namespace { +int websocketHandshake(const SharedHandle& header) +{ + if(header->getMethod() != "GET" || + header->find("sec-websocket-key").empty()) { + return 400; + } else if(header->find("sec-websocket-version") != "13") { + return 426; + } else if(header->getRequestPath() != "/jsonrpc") { + return 404; + } else { + return 101; + } +} +} // namespace bool HttpServerCommand::execute() { @@ -140,10 +158,8 @@ bool HttpServerCommand::execute() } if(!httpServer_->authenticate()) { httpServer_->disableKeepAlive(); - std::string text; - httpServer_->feedResponse("401 Unauthorized", - "WWW-Authenticate: Basic realm=\"aria2\"", - text,"text/html"); + httpServer_->feedResponse + (401, "WWW-Authenticate: Basic realm=\"aria2\"\r\n"); Command* command = new HttpServerResponseCommand(getCuid(), httpServer_, e_, socket_); e_->addCommand(command); @@ -152,21 +168,28 @@ bool HttpServerCommand::execute() } const std::string& upgradeHd = header->find("upgrade"); const std::string& connectionHd = header->find("connection"); - if(httpServer_->getRequestPath() == "/jsonrpc" && - httpServer_->getMethod() == "GET" && - util::strieq(upgradeHd.begin(), upgradeHd.end(), "websocket") && - util::strieq(connectionHd.begin(), connectionHd.end(), "upgrade") && - header->find("sec-websocket-version") == "13" && - header->defined("sec-websocket-key")) { - std::string serverKey = - createWebSocketServerKey(header->find("sec-websocket-key")); - httpServer_->feedUpgradeResponse("websocket", - fmt("Sec-WebSocket-Accept: %s\r\n", - serverKey.c_str())); - httpServer_->getSocket()->setTcpNodelay(true); - Command* command = - new rpc::WebSocketResponseCommand(getCuid(), httpServer_, e_, - socket_); + if(util::strieq(upgradeHd.begin(), upgradeHd.end(), "websocket") && + util::strieq(connectionHd.begin(), connectionHd.end(), "upgrade")) { + int status = websocketHandshake(header); + Command* command; + if(status == 101) { + std::string serverKey = + createWebSocketServerKey(header->find("sec-websocket-key")); + httpServer_->feedUpgradeResponse("websocket", + fmt("Sec-WebSocket-Accept: %s\r\n", + serverKey.c_str())); + httpServer_->getSocket()->setTcpNodelay(true); + command = new rpc::WebSocketResponseCommand(getCuid(), httpServer_, + e_, socket_); + } else { + if(status == 426) { + httpServer_->feedResponse(426, "Sec-WebSocket-Version: 13\r\n"); + } else { + httpServer_->feedResponse(status); + } + command = new HttpServerResponseCommand(getCuid(), httpServer_, e_, + socket_); + } e_->addCommand(command); e_->setNoWait(true); return true;