diff --git a/tests/test_handler.py b/tests/test_handler.py index b03f24b..dcffb34 100644 --- a/tests/test_handler.py +++ b/tests/test_handler.py @@ -19,35 +19,64 @@ class TestMixinHandler(unittest.TestCase): def test_is_forbidden(self): handler = MixinHandler() open_to_public['http'] = True + open_to_public['https'] = True options.fbidhttp = True + options.redirect = True context = Mock( address=('8.8.8.8', 8888), trusted_downstream=['127.0.0.1'], _orig_protocol='http' ) - self.assertTrue(handler.is_forbidden(context)) + self.assertTrue(handler.is_forbidden(context, '')) context = Mock( address=('8.8.8.8', 8888), trusted_downstream=[], _orig_protocol='http' ) - self.assertTrue(handler.is_forbidden(context)) + + hostname = 'www.google.com' + self.assertEqual(handler.is_forbidden(context, hostname), False) context = Mock( address=('192.168.1.1', 8888), trusted_downstream=[], _orig_protocol='http' ) - self.assertIsNone(handler.is_forbidden(context)) + self.assertIsNone(handler.is_forbidden(context, '')) context = Mock( address=('8.8.8.8', 8888), trusted_downstream=[], _orig_protocol='https' ) - self.assertIsNone(handler.is_forbidden(context)) + self.assertIsNone(handler.is_forbidden(context, '')) + + context = Mock( + address=('8.8.8.8', 8888), + trusted_downstream=[], + _orig_protocol='http' + ) + hostname = '8.8.8.8' + self.assertTrue(handler.is_forbidden(context, hostname)) + + def test_get_redirect_url(self): + handler = MixinHandler() + hostname = 'www.example.com' + uri = '/' + port = 443 + + self.assertTrue( + handler.get_redirect_url(hostname, port, uri=uri), + 'https://www.example.com/' + ) + + port = 4433 + self.assertTrue( + handler.get_redirect_url(hostname, port, uri), + 'https://www.example.com:4433/' + ) def test_get_client_addr(self): handler = MixinHandler() diff --git a/tests/test_utils.py b/tests/test_utils.py index a8c4e80..24b393c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,7 @@ import unittest from webssh.utils import ( is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes, - to_int, on_public_network_interface, get_ips_by_name, + to_int, on_public_network_interface, get_ips_by_name, is_ip_hostname, is_name_open_to_public ) @@ -73,3 +73,9 @@ class TestUitls(unittest.TestCase): self.assertIsNone(is_name_open_to_public('192.168.1.1')) self.assertIsNone(is_name_open_to_public('127.0.0.1')) self.assertIsNone(is_name_open_to_public('localhost')) + + def test_is_ip_hostname(self): + self.assertTrue(is_ip_hostname('[::1]')) + self.assertTrue(is_ip_hostname('127.0.0.1')) + self.assertFalse(is_ip_hostname('localhost')) + self.assertFalse(is_ip_hostname('www.google.com')) diff --git a/webssh/handler.py b/webssh/handler.py index 2423c71..09665f8 100644 --- a/webssh/handler.py +++ b/webssh/handler.py @@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop from tornado.options import options from webssh.utils import ( is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str, - to_int, to_ip_address, UnicodeType, is_name_open_to_public + to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname ) from webssh.worker import Worker, recycle_worker, workers @@ -34,7 +34,7 @@ DEFAULT_PORT = 22 swallow_http_errors = True -# status of the http(s) server +# set by config_open_to_public open_to_public = { 'http': None, 'https': None @@ -56,22 +56,28 @@ class MixinHandler(object): 'Server': 'TornadoServer' } + html = ('{code} {reason}{code} ' + '{reason}') + def initialize(self, loop=None): - conn = self.request.connection - if self.is_forbidden(conn.context): - result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version) - conn.stream.write(to_bytes(result)) - conn.close() - raise ValueError('Accesss denied') - self.loop = loop - self.context = conn.context - - def is_forbidden(self, context): - """ - Following requests are forbidden: - * requests not come from trusted_downstream (if set). - * plain http requests from a public network. - """ + context = self.request.connection.context + result = self.is_forbidden(context, self.request.host_name) + self._transforms = [] + if result: + self.set_status(403) + self.finish( + self.html.format(code=self._status_code, reason=self._reason) + ) + elif result is False: + to_url = self.get_redirect_url( + self.request.host_name, options.sslport, self.request.uri + ) + self.redirect(to_url, permanent=True) + else: + self.loop = loop + self.context = context + + def is_forbidden(self, context, hostname): ip = context.address[0] lst = context.trusted_downstream @@ -81,13 +87,20 @@ class MixinHandler(object): ) return True - if open_to_public['http'] and options.fbidhttp: - if context._orig_protocol == 'http': - ipaddr = to_ip_address(ip) - if not ipaddr.is_private: + if open_to_public['http'] and context._orig_protocol == 'http': + if not to_ip_address(ip).is_private: + if open_to_public['https'] and options.redirect: + if not is_ip_hostname(hostname): + # redirecting + return False + if options.fbidhttp: logging.warning('Public plain http request is forbidden.') return True + def get_redirect_url(self, hostname, port, uri): + port = '' if port == 443 else ':%s' % port + return 'https://{}{}{}'.format(hostname, port, uri) + def set_default_headers(self): for header in self.custom_headers.items(): self.set_header(*header) diff --git a/webssh/main.py b/webssh/main.py index 2cfb442..9a62156 100644 --- a/webssh/main.py +++ b/webssh/main.py @@ -33,8 +33,7 @@ def app_listen(app, port, address, server_settings): app.listen(port, address, **server_settings) server_type = 'https' if server_settings.get('ssl_options') else 'http' logging.info( - 'Started a {} server listening on {}:{}'.format( - server_type, address, port) + 'Listening on {}:{} ({})'.format(address, port, server_type) ) config_open_to_public(address, server_type) diff --git a/webssh/settings.py b/webssh/settings.py index 11bdeac..4054fba 100644 --- a/webssh/settings.py +++ b/webssh/settings.py @@ -17,7 +17,7 @@ def print_version(flag): sys.exit(0) -define('address', default='127.0.0.1', help='Listen address') +define('address', default='0.0.0.0', help='Listen address') define('port', type=int, default=8888, help='Listen port') define('ssladdress', default='0.0.0.0', help='SSL listen address') define('sslport', type=int, default=4433, help='SSL listen port') @@ -29,6 +29,7 @@ define('policy', default='warning', define('hostfile', default='', help='User defined host keys file') define('syshostfile', default='', help='System wide host keys file') define('tdstream', default='', help='Trusted downstream, separated by comma') +define('redirect', type=bool, default=True, help='Redirecting http to https') define('fbidhttp', type=bool, default=True, help='Forbid public plain http incoming requests') define('xheaders', type=bool, default=True, help='Support xheaders') diff --git a/webssh/utils.py b/webssh/utils.py index bc4799e..e71bb98 100644 --- a/webssh/utils.py +++ b/webssh/utils.py @@ -50,6 +50,16 @@ def is_valid_port(port): return 0 < port < 65536 +def is_ip_hostname(hostname): + it = iter(hostname) + if next(it) == '[': + return True + for ch in it: + if ch != '.' and not ch.isdigit(): + return False + return True + + def is_valid_hostname(hostname): if hostname[-1] == '.': # strip exactly one dot from the right, if present