diff --git a/tests/test_handler.py b/tests/test_handler.py index 1f09c07..e7117b4 100644 --- a/tests/test_handler.py +++ b/tests/test_handler.py @@ -5,7 +5,7 @@ from tornado.httputil import HTTPServerRequest from tornado.options import options from tests.utils import read_file, make_tests_data_path from webssh.handler import ( - MixinHandler, IndexHandler, InvalidValueError, open_to_public + MixinHandler, IndexHandler, WsockHandler, InvalidValueError, open_to_public ) try: @@ -202,3 +202,30 @@ class TestIndexHandler(unittest.TestCase): with self.assertRaises(paramiko.PasswordRequiredException): pkey = IndexHandler.get_pkey_obj(key, '', fname) + + +class TestWsockHandler(unittest.TestCase): + + def test_check_origin(self): + request = HTTPServerRequest(uri='/') + obj = Mock(spec=WsockHandler, request=request) + + options.cows = 0 + request.headers['Host'] = 'www.example.com:4433' + origin = 'https://www.example.com:4433' + self.assertTrue(WsockHandler.check_origin(obj, origin)) + + origin = 'https://www.example.com' + self.assertFalse(WsockHandler.check_origin(obj, origin)) + + options.cows = 1 + self.assertTrue(WsockHandler.check_origin(obj, origin)) + + origin = 'https://blog.example.com' + self.assertTrue(WsockHandler.check_origin(obj, origin)) + + origin = 'https://blog.example.org' + self.assertFalse(WsockHandler.check_origin(obj, origin)) + + options.cows = 2 + self.assertTrue(WsockHandler.check_origin(obj, origin)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 24b393c..24d5ec8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,7 @@ import unittest from webssh.utils import ( is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes, to_int, on_public_network_interface, get_ips_by_name, is_ip_hostname, - is_name_open_to_public + is_name_open_to_public, is_same_primary_domain ) @@ -79,3 +79,32 @@ class TestUitls(unittest.TestCase): self.assertTrue(is_ip_hostname('127.0.0.1')) self.assertFalse(is_ip_hostname('localhost')) self.assertFalse(is_ip_hostname('www.google.com')) + + def test_is_same_primary_domain(self): + domain1 = 'localhost' + domain2 = 'localhost' + self.assertTrue(is_same_primary_domain(domain1, domain2)) + + domain1 = 'localhost' + domain2 = 'test' + self.assertFalse(is_same_primary_domain(domain1, domain2)) + + domain1 = 'example.com' + domain2 = 'example.com' + self.assertTrue(is_same_primary_domain(domain1, domain2)) + + domain1 = 'www.example.com' + domain2 = 'example.com' + self.assertTrue(is_same_primary_domain(domain1, domain2)) + + domain1 = 'wwwexample.com' + domain2 = 'example.com' + self.assertFalse(is_same_primary_domain(domain1, domain2)) + + domain1 = 'www.example.com' + domain2 = 'www2.example.com' + self.assertTrue(is_same_primary_domain(domain1, domain2)) + + domain1 = 'xxx.www.example.com' + domain2 = 'xxx.www2.example.com' + self.assertTrue(is_same_primary_domain(domain1, domain2)) diff --git a/webssh/handler.py b/webssh/handler.py index 3fad100..6b4603d 100644 --- a/webssh/handler.py +++ b/webssh/handler.py @@ -13,7 +13,8 @@ from tornado.ioloop import IOLoop from tornado.options import options from webssh.utils import ( is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str, - to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname + to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname, + is_same_primary_domain ) from webssh.worker import Worker, recycle_worker, clients @@ -27,6 +28,11 @@ try: except ImportError: JSONDecodeError = ValueError +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + DELAY = 3 KEY_MAX_SIZE = 16384 @@ -364,6 +370,24 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): super(WsockHandler, self).initialize(loop) self.worker_ref = None + def check_origin(self, origin): + cows = options.cows + parsed_origin = urlparse(origin) + origin = parsed_origin.netloc + origin = origin.lower() + logging.debug('origin: {}'.format(origin)) + + host = self.request.headers.get('Host') + logging.debug('host: {}'.format(host)) + + if cows == 0: + return origin == host + elif cows == 1: + return is_same_primary_domain(origin.rsplit(':', 1)[0], + host.rsplit(':', 1)[0]) + else: + return True + def open(self): self.src_addr = self.get_client_addr() logging.info('Connected from {}:{}'.format(*self.src_addr)) diff --git a/webssh/settings.py b/webssh/settings.py index 158e13e..2bd4a64 100644 --- a/webssh/settings.py +++ b/webssh/settings.py @@ -34,6 +34,10 @@ define('fbidhttp', type=bool, default=True, help='Forbid public plain http incoming requests') define('xheaders', type=bool, default=True, help='Support xheaders') define('xsrf', type=bool, default=True, help='CSRF protection') +define('cows', type=int, default=0, help='Cross origin websocket, ' + '0: matches host name and port number' + '1: matches primary domain only' + '?: matches nothing, allow all cross-origin websockets') define('wpintvl', type=int, default=0, help='Websocket ping interval') define('maxconn', type=int, default=20, help='Maximum connections per client') define('version', type=bool, help='Show version information', diff --git a/webssh/utils.py b/webssh/utils.py index e71bb98..6e89551 100644 --- a/webssh/utils.py +++ b/webssh/utils.py @@ -96,3 +96,31 @@ def is_name_open_to_public(name): for ip in get_ips_by_name(name): if on_public_network_interface(ip): return True + + +def is_same_primary_domain(domain1, domain2): + i = -1 + dots = 0 + l1 = len(domain1) + l2 = len(domain2) + m = 0 - min(l1, l2) + + while i >= m: + c1 = domain1[i] + c2 = domain2[i] + + if c1 == c2: + if c1 == '.': + dots += 1 + if dots == 2: + return True + else: + return False + + i -= 1 + + if l1 == l2: + return True + + c = domain1[i] if l1 > m else domain2[i] + return c == '.'