diff --git a/src/HttpHeader.cc b/src/HttpHeader.cc index 8513cb56..1bba0ad2 100644 --- a/src/HttpHeader.cc +++ b/src/HttpHeader.cc @@ -293,4 +293,28 @@ const std::string& HttpHeader::getRequestPath() const return requestPath_; } +bool HttpHeader::fieldContains(const std::string& name, + const std::string& value) +{ + std::pair::const_iterator, + std::multimap::const_iterator> range = + equalRange(name); + for(std::multimap::const_iterator i = range.first; + i != range.second; ++i) { + std::vector values; + util::splitIter((*i).second.begin(), (*i).second.end(), + std::back_inserter(values), + ',', + true // doStrip + ); + for(std::vector::const_iterator j = values.begin(), + eoj = values.end(); j != eoj; ++j) { + if(util::strieq((*j).first, (*j).second, value.begin(), value.end())) { + return true; + } + } + } + return false; +} + } // namespace aria2 diff --git a/src/HttpHeader.h b/src/HttpHeader.h index b337e089..8791a4c5 100644 --- a/src/HttpHeader.h +++ b/src/HttpHeader.h @@ -66,6 +66,7 @@ public: HttpHeader(); ~HttpHeader(); + // For all methods, use lowercased header field name. void put(const std::string& name, const std::string& value); bool defined(const std::string& name) const; const std::string& find(const std::string& name) const; @@ -121,6 +122,10 @@ public: // Clears table_. responseStatus_ and version_ are unchanged. void clearField(); + // Returns true if heder field |name| contains |value|. This method + // assumes the values of the header field is delimited by ','. + bool fieldContains(const std::string& name, const std::string& value); + static const std::string LOCATION; static const std::string TRANSFER_ENCODING; static const std::string CONTENT_ENCODING; diff --git a/src/HttpServerCommand.cc b/src/HttpServerCommand.cc index 026c2627..0e13094e 100644 --- a/src/HttpServerCommand.cc +++ b/src/HttpServerCommand.cc @@ -174,10 +174,8 @@ bool HttpServerCommand::execute() e_->setNoWait(true); return true; } - const std::string& upgradeHd = header->find("upgrade"); - const std::string& connectionHd = header->find("connection"); - if(util::strieq(upgradeHd.begin(), upgradeHd.end(), "websocket") && - util::strieq(connectionHd.begin(), connectionHd.end(), "upgrade")) { + if(header->fieldContains("upgrade", "websocket") && + header->fieldContains("connection", "upgrade")) { #ifdef ENABLE_WEBSOCKET int status = websocketHandshake(header); Command* command; diff --git a/test/HttpHeaderTest.cc b/test/HttpHeaderTest.cc index 690f3818..f1add642 100644 --- a/test/HttpHeaderTest.cc +++ b/test/HttpHeaderTest.cc @@ -14,6 +14,7 @@ class HttpHeaderTest:public CppUnit::TestFixture { CPPUNIT_TEST(testFindAll); CPPUNIT_TEST(testClearField); CPPUNIT_TEST(testFill); + CPPUNIT_TEST(testFieldContains); CPPUNIT_TEST_SUITE_END(); public: @@ -21,6 +22,7 @@ public: void testFindAll(); void testClearField(); void testFill(); + void testFieldContains(); }; @@ -175,4 +177,21 @@ void HttpHeaderTest::testFill() h.findAll("duplicate")[1]); } +void HttpHeaderTest::testFieldContains() +{ + HttpHeader h; + h.put("connection", "Keep-Alive, Upgrade"); + h.put("upgrade", "WebSocket"); + h.put("sec-websocket-version", "13"); + h.put("sec-websocket-version", "8, 7"); + CPPUNIT_ASSERT(h.fieldContains("connection", "upgrade")); + CPPUNIT_ASSERT(h.fieldContains("connection", "keep-alive")); + CPPUNIT_ASSERT(!h.fieldContains("connection", "close")); + CPPUNIT_ASSERT(h.fieldContains("upgrade", "websocket")); + CPPUNIT_ASSERT(!h.fieldContains("upgrade", "spdy")); + CPPUNIT_ASSERT(h.fieldContains("sec-websocket-version", "13")); + CPPUNIT_ASSERT(h.fieldContains("sec-websocket-version", "8")); + CPPUNIT_ASSERT(!h.fieldContains("sec-websocket-version", "6")); +} + } // namespace aria2