diff --git a/tests/test_handler.py b/tests/test_handler.py index 4770695..9d20e0e 100644 --- a/tests/test_handler.py +++ b/tests/test_handler.py @@ -19,57 +19,53 @@ class TestMixinHandler(unittest.TestCase): def test_is_forbidden(self): mhandler = MixinHandler() - handler.https_server_enabled = True + handler.redirecting = True options.fbidhttp = True - options.redirect = True context = Mock( address=('8.8.8.8', 8888), trusted_downstream=['127.0.0.1'], _orig_protocol='http' ) - self.assertTrue(mhandler.is_forbidden(context, '')) + hostname = '4.4.4.4' + self.assertTrue(mhandler.is_forbidden(context, hostname)) context = Mock( address=('8.8.8.8', 8888), trusted_downstream=[], _orig_protocol='http' ) - hostname = 'www.google.com' self.assertEqual(mhandler.is_forbidden(context, hostname), False) - handler.https_server_enabled = False - self.assertTrue(mhandler.is_forbidden(context, hostname)) - - options.redirect = False - self.assertTrue(mhandler.is_forbidden(context, hostname)) - - context = Mock( - address=('192.168.1.1', 8888), - trusted_downstream=[], - _orig_protocol='http' - ) - self.assertIsNone(mhandler.is_forbidden(context, '')) - context = Mock( address=('8.8.8.8', 8888), trusted_downstream=[], - _orig_protocol='https' + _orig_protocol='http' ) - self.assertIsNone(mhandler.is_forbidden(context, '')) + hostname = '4.4.4.4' + self.assertTrue(mhandler.is_forbidden(context, hostname)) context = Mock( - address=('8.8.8.8', 8888), + address=('192.168.1.1', 8888), trusted_downstream=[], _orig_protocol='http' ) - hostname = '8.8.8.8' - self.assertTrue(mhandler.is_forbidden(context, hostname)) + hostname = 'www.google.com' + self.assertIsNone(mhandler.is_forbidden(context, hostname)) options.fbidhttp = False self.assertIsNone(mhandler.is_forbidden(context, hostname)) + hostname = '4.4.4.4' + self.assertIsNone(mhandler.is_forbidden(context, hostname)) + + handler.redirecting = False + self.assertIsNone(mhandler.is_forbidden(context, hostname)) + + context._orig_protocol = 'https' + self.assertIsNone(mhandler.is_forbidden(context, hostname)) + def test_get_redirect_url(self): mhandler = MixinHandler() hostname = 'www.example.com' diff --git a/tests/test_main.py b/tests/test_main.py index d0f72a4..6ed89fc 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -11,12 +11,12 @@ class TestMain(unittest.TestCase): app = Application() app.listen = lambda x, y, **kwargs: 1 - handler.https_server_enabled = False + handler.redirecting = None server_settings = dict() app_listen(app, 80, '127.0.0.1', server_settings) - self.assertFalse(handler.https_server_enabled) + self.assertFalse(handler.redirecting) - handler.https_server_enabled = False + handler.redirecting = None server_settings = dict(ssl_options='enabled') app_listen(app, 80, '127.0.0.1', server_settings) - self.assertTrue(handler.https_server_enabled) + self.assertTrue(handler.redirecting) diff --git a/webssh/handler.py b/webssh/handler.py index 9a0c850..f6a2404 100644 --- a/webssh/handler.py +++ b/webssh/handler.py @@ -38,7 +38,7 @@ KEY_MAX_SIZE = 16384 DEFAULT_PORT = 22 swallow_http_errors = True -https_server_enabled = False +redirecting = None class InvalidValueError(Exception): @@ -78,6 +78,7 @@ class MixinHandler(object): def is_forbidden(self, context, hostname): ip = context.address[0] lst = context.trusted_downstream + ip_address = None if lst and ip not in lst: logging.warning( @@ -85,15 +86,19 @@ class MixinHandler(object): ) return True - if context._orig_protocol == 'http' and \ - not to_ip_address(ip).is_private: - if options.redirect and https_server_enabled: - if not is_ip_hostname(hostname): + if context._orig_protocol == 'http': + if redirecting and not is_ip_hostname(hostname): + ip_address = to_ip_address(ip) + if not ip_address.is_private: # redirecting return False + if options.fbidhttp: - logging.warning('Public plain http request is forbidden.') - return True + if ip_address is None: + ip_address = to_ip_address(ip) + if not ip_address.is_private: + logging.warning('Public plain http request is forbidden.') + return True def get_redirect_url(self, hostname, port, uri): port = '' if port == 443 else ':%s' % port diff --git a/webssh/main.py b/webssh/main.py index 0bfac0e..8370ff4 100644 --- a/webssh/main.py +++ b/webssh/main.py @@ -34,7 +34,7 @@ def app_listen(app, port, address, server_settings): server_type = 'http' else: server_type = 'https' - handler.https_server_enabled = True + handler.redirecting = True if options.redirect else False logging.info( 'Listening on {}:{} ({})'.format(address, port, server_type) )