From c35f801235703f581093ce35912d6f896a33ec45 Mon Sep 17 00:00:00 2001 From: Sheng Date: Sat, 19 Jan 2019 16:46:25 +0800 Subject: [PATCH] Support custom origin configuration --- tests/test_handler.py | 19 ++++++++++++++++--- tests/test_settings.py | 31 ++++++++++++++++++++++++++++++- tests/test_utils.py | 33 ++++++++++++++++++++++++++++++++- webssh/handler.py | 23 ++++++++++++++--------- webssh/settings.py | 40 ++++++++++++++++++++++++++++++++++------ webssh/utils.py | 31 +++++++++++++++++++++++++++++++ 6 files changed, 157 insertions(+), 20 deletions(-) diff --git a/tests/test_handler.py b/tests/test_handler.py index 9d20e0e..a35b858 100644 --- a/tests/test_handler.py +++ b/tests/test_handler.py @@ -215,7 +215,7 @@ class TestWsockHandler(unittest.TestCase): request = HTTPServerRequest(uri='/') obj = Mock(spec=WsockHandler, request=request) - options.cows = 0 + obj.origin_policy = 'same' request.headers['Host'] = 'www.example.com:4433' origin = 'https://www.example.com:4433' self.assertTrue(WsockHandler.check_origin(obj, origin)) @@ -223,7 +223,7 @@ class TestWsockHandler(unittest.TestCase): origin = 'https://www.example.com' self.assertFalse(WsockHandler.check_origin(obj, origin)) - options.cows = 1 + obj.origin_policy = 'primary' self.assertTrue(WsockHandler.check_origin(obj, origin)) origin = 'https://blog.example.com' @@ -232,5 +232,18 @@ class TestWsockHandler(unittest.TestCase): origin = 'https://blog.example.org' self.assertFalse(WsockHandler.check_origin(obj, origin)) - options.cows = 2 + origin = 'https://blog.example.org' + obj.origin_policy = {'https://blog.example.org'} + self.assertTrue(WsockHandler.check_origin(obj, origin)) + + origin = 'http://blog.example.org' + obj.origin_policy = {'http://blog.example.org'} + self.assertTrue(WsockHandler.check_origin(obj, origin)) + + origin = 'http://blog.example.org' + obj.origin_policy = {'https://blog.example.org'} + self.assertFalse(WsockHandler.check_origin(obj, origin)) + + obj.origin_policy = '*' + origin = 'https://blog.example.org' self.assertTrue(WsockHandler.check_origin(obj, origin)) diff --git a/tests/test_settings.py b/tests/test_settings.py index c76acfc..bcc8f0d 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,4 +1,5 @@ import io +import random import ssl import sys import os.path @@ -10,7 +11,7 @@ from tests.utils import make_tests_data_path from webssh.policy import load_host_keys from webssh.settings import ( get_host_keys_settings, get_policy_setting, base_dir, print_version, - get_ssl_context, get_trusted_downstream + get_ssl_context, get_trusted_downstream, get_origin_setting ) from webssh.utils import UnicodeType from webssh._version import __version__ @@ -137,3 +138,31 @@ class TestSettings(unittest.TestCase): tdstream = '1.1.1.1, 2.2.2.' with self.assertRaises(ValueError): get_trusted_downstream(tdstream) + + def test_get_origin_setting(self): + options.debug = False + options.origin = '*' + with self.assertRaises(ValueError): + get_origin_setting(options) + + options.debug = True + self.assertEqual(get_origin_setting(options), '*') + + options.origin = random.choice(['Same', 'Primary']) + self.assertEqual(get_origin_setting(options), options.origin.lower()) + + options.origin = '' + with self.assertRaises(ValueError): + get_origin_setting(options) + + options.origin = ',' + with self.assertRaises(ValueError): + get_origin_setting(options) + + options.origin = 'www.example.com, https://www.example.org' + result = {'http://www.example.com', 'https://www.example.org'} + self.assertEqual(get_origin_setting(options), result) + + options.origin = 'www.example.com:80, www.example.org:443' + result = {'http://www.example.com', 'https://www.example.org'} + self.assertEqual(get_origin_setting(options), result) diff --git a/tests/test_utils.py b/tests/test_utils.py index 9e87b34..9916574 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,7 @@ import unittest from webssh.utils import ( is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes, - to_int, is_ip_hostname, is_same_primary_domain + to_int, is_ip_hostname, is_same_primary_domain, parse_origin_from_url ) @@ -90,3 +90,34 @@ class TestUitls(unittest.TestCase): domain1 = 'xxx.www.example.com' domain2 = 'xxx.www2.example.com' self.assertTrue(is_same_primary_domain(domain1, domain2)) + + def test_parse_origin_from_url(self): + url = '' + self.assertIsNone(parse_origin_from_url(url)) + + url = 'www.example.com' + self.assertEqual(parse_origin_from_url(url), 'http://www.example.com') + + url = 'http://www.example.com' + self.assertEqual(parse_origin_from_url(url), 'http://www.example.com') + + url = 'www.example.com:80' + self.assertEqual(parse_origin_from_url(url), 'http://www.example.com') + + url = 'http://www.example.com:80' + self.assertEqual(parse_origin_from_url(url), 'http://www.example.com') + + url = 'www.example.com:443' + self.assertEqual(parse_origin_from_url(url), 'https://www.example.com') + + url = 'https://www.example.com' + self.assertEqual(parse_origin_from_url(url), 'https://www.example.com') + + url = 'https://www.example.com:443' + self.assertEqual(parse_origin_from_url(url), 'https://www.example.com') + + url = 'https://www.example.com:80' + self.assertEqual(parse_origin_from_url(url), url) + + url = 'http://www.example.com:443' + self.assertEqual(parse_origin_from_url(url), url) diff --git a/webssh/handler.py b/webssh/handler.py index f6a2404..a536a82 100644 --- a/webssh/handler.py +++ b/webssh/handler.py @@ -57,6 +57,7 @@ class MixinHandler(object): def initialize(self, loop=None): self.check_request() self.loop = loop + self.origin_policy = self.settings.get('origin_policy') def check_request(self): context = self.request.connection.context @@ -364,22 +365,26 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): self.worker_ref = None def check_origin(self, origin): - cows = options.cows + if self.origin_policy == '*': + return True + parsed_origin = urlparse(origin) - origin = parsed_origin.netloc - origin = origin.lower() - logging.debug('origin: {}'.format(origin)) + netloc = parsed_origin.netloc.lower() + logging.debug('netloc: {}'.format(netloc)) host = self.request.headers.get('Host') logging.debug('host: {}'.format(host)) - if cows == 0: - return origin == host - elif cows == 1: - return is_same_primary_domain(origin.rsplit(':', 1)[0], + if netloc == host: + return True + + if self.origin_policy == 'same': + return False + elif self.origin_policy == 'primary': + return is_same_primary_domain(netloc.rsplit(':', 1)[0], host.rsplit(':', 1)[0]) else: - return True + return origin in self.origin_policy def open(self): self.src_addr = self.get_client_addr() diff --git a/webssh/settings.py b/webssh/settings.py index 74ad704..cfb4d24 100644 --- a/webssh/settings.py +++ b/webssh/settings.py @@ -7,7 +7,7 @@ from tornado.options import define from webssh.policy import ( load_host_keys, get_policy_class, check_policy_setting ) -from webssh.utils import to_ip_address +from webssh.utils import to_ip_address, parse_origin_from_url from webssh._version import __version__ @@ -34,10 +34,12 @@ define('fbidhttp', type=bool, default=True, help='Forbid public plain http incoming requests') define('xheaders', type=bool, default=True, help='Support xheaders') define('xsrf', type=bool, default=True, help='CSRF protection') -define('cows', type=int, default=0, help='Cross origin websocket, ' - '0: matches host name and port number, ' - '1: matches primary domain only, ' - '?: matches nothing, allow all cross-origin websockets') +define('origin', default='same', help='''Origin policy, +'same': same origin policy, matches host name and port number; +'primary': primary domain policy, matches primary domain only; +'': custom domains policy, matches any domain in the list +separated by comma; +'*': wildcard policy, matches any domain, allowed in debug mode only.''') define('wpintvl', type=int, default=0, help='Websocket ping interval') define('maxconn', type=int, default=20, help='Maximum connections per client') define('version', type=bool, help='Show version information', @@ -54,7 +56,8 @@ def get_app_settings(options): static_path=os.path.join(base_dir, 'webssh', 'static'), websocket_ping_interval=options.wpintvl, debug=options.debug, - xsrf_cookies=options.xsrf + xsrf_cookies=options.xsrf, + origin_policy=get_origin_setting(options) ) return settings @@ -121,3 +124,28 @@ def get_trusted_downstream(tdstream): to_ip_address(ip) result.add(ip) return result + + +def get_origin_setting(options): + if options.origin == '*': + if not options.debug: + raise ValueError( + 'Wildcard origin policy is only allowed in debug mode.' + ) + else: + return '*' + + origin = options.origin.lower() + if origin in ['same', 'primary']: + return origin + + origins = set() + for url in origin.split(','): + orig = parse_origin_from_url(url) + if orig: + origins.add(orig) + + if not origins: + raise ValueError('Empty origin list') + + return origins diff --git a/webssh/utils.py b/webssh/utils.py index d95fb74..9e73f12 100644 --- a/webssh/utils.py +++ b/webssh/utils.py @@ -6,6 +6,11 @@ try: except ImportError: UnicodeType = str +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + numeric = re.compile(r'[0-9]+$') allowed = re.compile(r'(?!-)[a-z0-9-]{1,63}(? m else domain2[i] return c == '.' + + +def parse_origin_from_url(url): + url = url.strip() + if not url: + return + + if not (url.startswith('http://') or url.startswith('https://') or + url.startswith('//')): + url = '//' + url + + parsed = urlparse(url) + port = parsed.port + scheme = parsed.scheme + + if scheme == '': + scheme = 'https' if port == 443 else 'http' + + if port == 443 and scheme == 'https': + netloc = parsed.netloc.replace(':443', '') + elif port == 80 and scheme == 'http': + netloc = parsed.netloc.replace(':80', '') + else: + netloc = parsed.netloc + + return '{}://{}'.format(scheme, netloc)