diff --git a/tests/test_handler.py b/tests/test_handler.py index e7117b4..4770695 100644 --- a/tests/test_handler.py +++ b/tests/test_handler.py @@ -4,8 +4,9 @@ import paramiko from tornado.httputil import HTTPServerRequest from tornado.options import options from tests.utils import read_file, make_tests_data_path +from webssh import handler from webssh.handler import ( - MixinHandler, IndexHandler, WsockHandler, InvalidValueError, open_to_public + MixinHandler, IndexHandler, WsockHandler, InvalidValueError ) try: @@ -17,9 +18,8 @@ except ImportError: class TestMixinHandler(unittest.TestCase): def test_is_forbidden(self): - handler = MixinHandler() - open_to_public['http'] = True - open_to_public['https'] = True + mhandler = MixinHandler() + handler.https_server_enabled = True options.fbidhttp = True options.redirect = True @@ -28,7 +28,7 @@ class TestMixinHandler(unittest.TestCase): trusted_downstream=['127.0.0.1'], _orig_protocol='http' ) - self.assertTrue(handler.is_forbidden(context, '')) + self.assertTrue(mhandler.is_forbidden(context, '')) context = Mock( address=('8.8.8.8', 8888), @@ -37,21 +37,27 @@ class TestMixinHandler(unittest.TestCase): ) hostname = 'www.google.com' - self.assertEqual(handler.is_forbidden(context, hostname), False) + self.assertEqual(mhandler.is_forbidden(context, hostname), False) + + handler.https_server_enabled = False + self.assertTrue(mhandler.is_forbidden(context, hostname)) + + options.redirect = False + self.assertTrue(mhandler.is_forbidden(context, hostname)) context = Mock( address=('192.168.1.1', 8888), trusted_downstream=[], _orig_protocol='http' ) - self.assertIsNone(handler.is_forbidden(context, '')) + self.assertIsNone(mhandler.is_forbidden(context, '')) context = Mock( address=('8.8.8.8', 8888), trusted_downstream=[], _orig_protocol='https' ) - self.assertIsNone(handler.is_forbidden(context, '')) + self.assertIsNone(mhandler.is_forbidden(context, '')) context = Mock( address=('8.8.8.8', 8888), @@ -59,43 +65,46 @@ class TestMixinHandler(unittest.TestCase): _orig_protocol='http' ) hostname = '8.8.8.8' - self.assertTrue(handler.is_forbidden(context, hostname)) + self.assertTrue(mhandler.is_forbidden(context, hostname)) + + options.fbidhttp = False + self.assertIsNone(mhandler.is_forbidden(context, hostname)) def test_get_redirect_url(self): - handler = MixinHandler() + mhandler = MixinHandler() hostname = 'www.example.com' uri = '/' port = 443 self.assertEqual( - handler.get_redirect_url(hostname, port, uri=uri), + mhandler.get_redirect_url(hostname, port, uri=uri), 'https://www.example.com/' ) port = 4433 self.assertEqual( - handler.get_redirect_url(hostname, port, uri), + mhandler.get_redirect_url(hostname, port, uri), 'https://www.example.com:4433/' ) def test_get_client_addr(self): - handler = MixinHandler() + mhandler = MixinHandler() client_addr = ('8.8.8.8', 8888) context_addr = ('127.0.0.1', 1234) options.xheaders = True - handler.context = Mock(address=context_addr) - handler.get_real_client_addr = lambda: None - self.assertEqual(handler.get_client_addr(), context_addr) + mhandler.context = Mock(address=context_addr) + mhandler.get_real_client_addr = lambda: None + self.assertEqual(mhandler.get_client_addr(), context_addr) - handler.context = Mock(address=context_addr) - handler.get_real_client_addr = lambda: client_addr - self.assertEqual(handler.get_client_addr(), client_addr) + mhandler.context = Mock(address=context_addr) + mhandler.get_real_client_addr = lambda: client_addr + self.assertEqual(mhandler.get_client_addr(), client_addr) options.xheaders = False - handler.context = Mock(address=context_addr) - handler.get_real_client_addr = lambda: client_addr - self.assertEqual(handler.get_client_addr(), context_addr) + mhandler.context = Mock(address=context_addr) + mhandler.get_real_client_addr = lambda: client_addr + self.assertEqual(mhandler.get_client_addr(), context_addr) def test_get_real_client_addr(self): x_forwarded_for = '1.1.1.1' @@ -104,36 +113,36 @@ class TestMixinHandler(unittest.TestCase): x_real_port = 2222 fake_port = 65535 - handler = MixinHandler() - handler.request = HTTPServerRequest(uri='/') - handler.request.remote_ip = x_forwarded_for + mhandler = MixinHandler() + mhandler.request = HTTPServerRequest(uri='/') + mhandler.request.remote_ip = x_forwarded_for - self.assertIsNone(handler.get_real_client_addr()) + self.assertIsNone(mhandler.get_real_client_addr()) - handler.request.headers.add('X-Forwarded-For', x_forwarded_for) - self.assertEqual(handler.get_real_client_addr(), + mhandler.request.headers.add('X-Forwarded-For', x_forwarded_for) + self.assertEqual(mhandler.get_real_client_addr(), (x_forwarded_for, fake_port)) - handler.request.headers.add('X-Forwarded-Port', fake_port + 1) - self.assertEqual(handler.get_real_client_addr(), + mhandler.request.headers.add('X-Forwarded-Port', fake_port + 1) + self.assertEqual(mhandler.get_real_client_addr(), (x_forwarded_for, fake_port)) - handler.request.headers['X-Forwarded-Port'] = x_forwarded_port - self.assertEqual(handler.get_real_client_addr(), + mhandler.request.headers['X-Forwarded-Port'] = x_forwarded_port + self.assertEqual(mhandler.get_real_client_addr(), (x_forwarded_for, x_forwarded_port)) - handler.request.remote_ip = x_real_ip + mhandler.request.remote_ip = x_real_ip - handler.request.headers.add('X-Real-Ip', x_real_ip) - self.assertEqual(handler.get_real_client_addr(), + mhandler.request.headers.add('X-Real-Ip', x_real_ip) + self.assertEqual(mhandler.get_real_client_addr(), (x_real_ip, fake_port)) - handler.request.headers.add('X-Real-Port', fake_port + 1) - self.assertEqual(handler.get_real_client_addr(), + mhandler.request.headers.add('X-Real-Port', fake_port + 1) + self.assertEqual(mhandler.get_real_client_addr(), (x_real_ip, fake_port)) - handler.request.headers['X-Real-Port'] = x_real_port - self.assertEqual(handler.get_real_client_addr(), + mhandler.request.headers['X-Real-Port'] = x_real_port + self.assertEqual(mhandler.get_real_client_addr(), (x_real_ip, x_real_port)) diff --git a/tests/test_main.py b/tests/test_main.py index 0dedbbd..d0f72a4 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,7 +1,7 @@ import unittest from tornado.web import Application -from webssh.handler import open_to_public +from webssh import handler from webssh.main import app_listen @@ -10,36 +10,13 @@ class TestMain(unittest.TestCase): def test_app_listen(self): app = Application() app.listen = lambda x, y, **kwargs: 1 - open_to_public['https'] = None - open_to_public['http'] = None - server_settings = dict(ssl_options=False) + handler.https_server_enabled = False + server_settings = dict() app_listen(app, 80, '127.0.0.1', server_settings) - self.assertEqual(open_to_public['http'], False) - self.assertIsNone(open_to_public['https']) - open_to_public['http'] = None + self.assertFalse(handler.https_server_enabled) - server_settings = dict(ssl_options=False) - app_listen(app, 80, '0.0.0.0', server_settings) - self.assertEqual(open_to_public['http'], True) - self.assertIsNone(open_to_public['https']) - open_to_public['http'] = None - - server_settings = dict(ssl_options=True) - app_listen(app, 443, '127.0.0.1', server_settings) - self.assertEqual(open_to_public['https'], False) - self.assertIsNone(open_to_public['http']) - open_to_public['https'] = None - - server_settings = dict(ssl_options=True) - app_listen(app, 443, '0.0.0.0', server_settings) - self.assertEqual(open_to_public['https'], True) - self.assertIsNone(open_to_public['http']) - open_to_public['https'] = None - - server_settings = dict(ssl_options=False) - app_listen(app, 80, '0.0.0.0', server_settings) - server_settings = dict(ssl_options=True) - app_listen(app, 443, '0.0.0.0', server_settings) - self.assertEqual(open_to_public['https'], True) - self.assertEqual(open_to_public['http'], True) + handler.https_server_enabled = False + server_settings = dict(ssl_options='enabled') + app_listen(app, 80, '127.0.0.1', server_settings) + self.assertTrue(handler.https_server_enabled) diff --git a/tests/test_utils.py b/tests/test_utils.py index f9007b9..9e87b34 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,8 +2,7 @@ import unittest from webssh.utils import ( is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes, - to_int, on_public_network_interface, get_ips_by_name, is_ip_hostname, - is_name_open_to_public, is_same_primary_domain + to_int, is_ip_hostname, is_same_primary_domain ) @@ -53,27 +52,6 @@ class TestUitls(unittest.TestCase): self.assertFalse(is_valid_hostname('127.0.0.1')) self.assertFalse(is_valid_hostname('::1')) - def test_get_ips_by_name(self): - self.assertTrue(get_ips_by_name(''), {'0.0.0.0', '::'}) - self.assertTrue(get_ips_by_name('localhost'), {'127.0.0.1'}) - self.assertTrue(get_ips_by_name('192.68.1.1'), {'192.168.1.1'}) - self.assertTrue(get_ips_by_name('2.2.2.2'), {'2.2.2.2'}) - - def test_on_public_network_interface(self): - self.assertTrue(on_public_network_interface('0.0.0.0')) - self.assertTrue(on_public_network_interface('::')) - self.assertTrue(on_public_network_interface('0:0:0:0:0:0:0:0')) - self.assertTrue(on_public_network_interface('2.2.2.2')) - self.assertTrue(on_public_network_interface('2:2:2:2:2:2:2:2')) - self.assertIsNone(on_public_network_interface('127.0.0.1')) - - def test_is_name_open_to_public(self): - self.assertTrue(is_name_open_to_public('0.0.0.0')) - self.assertTrue(is_name_open_to_public('::')) - self.assertIsNone(is_name_open_to_public('192.168.1.1')) - self.assertIsNone(is_name_open_to_public('127.0.0.1')) - self.assertIsNone(is_name_open_to_public('localhost')) - def test_is_ip_hostname(self): self.assertTrue(is_ip_hostname('[::1]')) self.assertTrue(is_ip_hostname('127.0.0.1')) diff --git a/webssh/handler.py b/webssh/handler.py index 6b4603d..9a0c850 100644 --- a/webssh/handler.py +++ b/webssh/handler.py @@ -13,8 +13,7 @@ from tornado.ioloop import IOLoop from tornado.options import options from webssh.utils import ( is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str, - to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname, - is_same_primary_domain + to_int, to_ip_address, UnicodeType, is_ip_hostname, is_same_primary_domain ) from webssh.worker import Worker, recycle_worker, clients @@ -39,18 +38,7 @@ KEY_MAX_SIZE = 16384 DEFAULT_PORT = 22 swallow_http_errors = True - -# set by config_open_to_public -open_to_public = { - 'http': None, - 'https': None -} - - -def config_open_to_public(address, server_type): - status = True if is_name_open_to_public(address) else False - logging.debug('{} server open to public: {}'.format(server_type, status)) - open_to_public[server_type] = status +https_server_enabled = False class InvalidValueError(Exception): @@ -97,15 +85,15 @@ class MixinHandler(object): ) return True - if open_to_public['http'] and context._orig_protocol == 'http': - if not to_ip_address(ip).is_private: - if open_to_public['https'] and options.redirect: - if not is_ip_hostname(hostname): - # redirecting - return False - if options.fbidhttp: - logging.warning('Public plain http request is forbidden.') - return True + if context._orig_protocol == 'http' and \ + not to_ip_address(ip).is_private: + if options.redirect and https_server_enabled: + if not is_ip_hostname(hostname): + # redirecting + return False + if options.fbidhttp: + logging.warning('Public plain http request is forbidden.') + return True def get_redirect_url(self, hostname, port, uri): port = '' if port == 443 else ':%s' % port diff --git a/webssh/main.py b/webssh/main.py index 9a62156..0bfac0e 100644 --- a/webssh/main.py +++ b/webssh/main.py @@ -3,9 +3,8 @@ import tornado.web import tornado.ioloop from tornado.options import options -from webssh.handler import ( - IndexHandler, WsockHandler, NotFoundHandler, config_open_to_public -) +from webssh import handler +from webssh.handler import IndexHandler, WsockHandler, NotFoundHandler from webssh.settings import ( get_app_settings, get_host_keys_settings, get_policy_setting, get_ssl_context, get_server_settings @@ -31,11 +30,14 @@ def make_app(handlers, settings): def app_listen(app, port, address, server_settings): app.listen(port, address, **server_settings) - server_type = 'https' if server_settings.get('ssl_options') else 'http' + if not server_settings.get('ssl_options'): + server_type = 'http' + else: + server_type = 'https' + handler.https_server_enabled = True logging.info( 'Listening on {}:{} ({})'.format(address, port, server_type) ) - config_open_to_public(address, server_type) def main(): diff --git a/webssh/utils.py b/webssh/utils.py index 59f0735..d95fb74 100644 --- a/webssh/utils.py +++ b/webssh/utils.py @@ -1,6 +1,5 @@ import ipaddress import re -import socket try: from types import UnicodeType @@ -11,9 +10,6 @@ except ImportError: numeric = re.compile(r'[0-9]+$') allowed = re.compile(r'(?!-)[a-z0-9-]{1,63}(?