Support CORS

pull/58/head
Sheng 2019-01-23 21:48:03 +08:00
parent a6663c408e
commit a1c9378048
2 changed files with 56 additions and 9 deletions

View File

@ -720,12 +720,12 @@ class TestAppWithTooManyConnections(OtherTestBase):
ws.close() ws.close()
class TestAppWithCrossOriginConnect(OtherTestBase): class TestAppWithCrossOriginOperation(OtherTestBase):
origin = 'http://www.example.com' origin = 'http://www.example.com'
@tornado.testing.gen_test @tornado.testing.gen_test
def test_app_with_cross_orgin_connect(self): def test_app_with_wrong_event_origin(self):
url = self.get_url('/') url = self.get_url('/')
client = self.get_http_client() client = self.get_http_client()
body = urlencode(dict(self.body, username='foo', _origin='localhost')) body = urlencode(dict(self.body, username='foo', _origin='localhost'))
@ -734,8 +734,29 @@ class TestAppWithCrossOriginConnect(OtherTestBase):
data = json.loads(to_str(response.body)) data = json.loads(to_str(response.body))
self.assertIsNone(data['id']) self.assertIsNone(data['id'])
self.assertIsNone(data['encoding']) self.assertIsNone(data['encoding'])
self.assertIn('Cross origin frame', data['status']) self.assertEqual(
'Cross origin operation is not allowed.', data['status']
)
@tornado.testing.gen_test
def test_app_with_wrong_header_origin(self):
url = self.get_url('/')
client = self.get_http_client()
body = urlencode(dict(self.body, username='foo'))
headers = dict(self.headers, Origin='localhost')
response = yield client.fetch(url, method='POST', body=body,
headers=headers)
data = json.loads(to_str(response.body))
self.assertIsNone(data['id'])
self.assertIsNone(data['encoding'])
self.assertEqual(
'Cross origin operation is not allowed.', data['status']
)
@tornado.testing.gen_test
def test_app_with_correct_event_origin(self):
url = self.get_url('/')
client = self.get_http_client()
body = urlencode(dict(self.body, username='foo', _origin=self.origin)) body = urlencode(dict(self.body, username='foo', _origin=self.origin))
response = yield client.fetch(url, method='POST', body=body, response = yield client.fetch(url, method='POST', body=body,
headers=self.headers) headers=self.headers)
@ -743,3 +764,20 @@ class TestAppWithCrossOriginConnect(OtherTestBase):
self.assertIsNotNone(data['id']) self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding']) self.assertIsNotNone(data['encoding'])
self.assertIsNone(data['status']) self.assertIsNone(data['status'])
self.assertIsNone(response.headers.get('Access-Control-Allow-Origin'))
@tornado.testing.gen_test
def test_app_with_correct_header_origin(self):
url = self.get_url('/')
client = self.get_http_client()
body = urlencode(dict(self.body, username='foo'))
headers = dict(self.headers, Origin=self.origin)
response = yield client.fetch(url, method='POST', body=body,
headers=headers)
data = json.loads(to_str(response.body))
self.assertIsNotNone(data['id'])
self.assertIsNotNone(data['encoding'])
self.assertIsNone(data['status'])
self.assertEqual(
response.headers.get('Access-Control-Allow-Origin'), self.origin
)

View File

@ -346,6 +346,20 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
else: else:
future.set_result(worker) future.set_result(worker)
def check_origin(self):
event_origin = self.get_argument('_origin', u'')
header_origin = self.request.headers.get('Origin')
origin = event_origin or header_origin
if origin:
if not super(IndexHandler, self).check_origin(origin):
raise tornado.web.HTTPError(
403, 'Cross origin operation is not allowed.'
)
if not event_origin and self.origin_policy != 'same':
self.set_header('Access-Control-Allow-Origin', origin)
def head(self): def head(self):
pass pass
@ -362,12 +376,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
if len(clients.get(self.src_addr[0], {})) >= options.maxconn: if len(clients.get(self.src_addr[0], {})) >= options.maxconn:
raise tornado.web.HTTPError(403, 'Too many connections.') raise tornado.web.HTTPError(403, 'Too many connections.')
origin = self.get_argument('_origin', u'') self.check_origin()
if origin:
if not self.check_origin(origin):
raise tornado.web.HTTPError(
403, 'Cross origin frame operation is not allowed.'
)
future = Future() future = Future()
t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,)) t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,))