diff --git a/tests/test_handler.py b/tests/test_handler.py index ac157d8..bd8b023 100644 --- a/tests/test_handler.py +++ b/tests/test_handler.py @@ -1,7 +1,6 @@ import unittest import paramiko -from tornado.httpclient import HTTPRequest from tornado.httputil import HTTPServerRequest from tornado.options import options from tests.utils import read_file, make_tests_data_path @@ -17,42 +16,55 @@ class TestMixinHandler(unittest.TestCase): def test_is_forbidden(self): handler = MixinHandler() - request = HTTPRequest('http://example.com/') - handler.request = request options.fbidhttp = True - context = Mock( + handler.context = Mock( address=('8.8.8.8', 8888), trusted_downstream=['127.0.0.1'], _orig_protocol='http' ) - request.connection = Mock(context=context) self.assertTrue(handler.is_forbidden()) - context = Mock( + handler.context = Mock( address=('8.8.8.8', 8888), trusted_downstream=[], _orig_protocol='http' ) - request.connection = Mock(context=context) self.assertTrue(handler.is_forbidden()) - context = Mock( + handler.context = Mock( address=('192.168.1.1', 8888), trusted_downstream=[], _orig_protocol='http' ) - request.connection = Mock(context=context) self.assertIsNone(handler.is_forbidden()) - context = Mock( + handler.context = Mock( address=('8.8.8.8', 8888), trusted_downstream=[], _orig_protocol='https' ) - request.connection = Mock(context=context) self.assertIsNone(handler.is_forbidden()) + def test_get_client_addr(self): + handler = MixinHandler() + client_addr = ('8.8.8.8', 8888) + context_addr = ('127.0.0.1', 1234) + options.xheaders = True + + handler.context = Mock(address=context_addr) + handler.get_real_client_addr = lambda: None + self.assertEqual(handler.get_client_addr(), context_addr) + + handler.context = Mock(address=context_addr) + handler.get_real_client_addr = lambda: client_addr + self.assertEqual(handler.get_client_addr(), client_addr) + + options.xheaders = False + handler.context = Mock(address=context_addr) + handler.get_real_client_addr = lambda: client_addr + self.assertEqual(handler.get_client_addr(), context_addr) + def test_get_real_client_addr(self): x_forwarded_for = '1.1.1.1' x_forwarded_port = 1111 diff --git a/webssh/handler.py b/webssh/handler.py index c25c36f..55618ef 100644 --- a/webssh/handler.py +++ b/webssh/handler.py @@ -46,19 +46,21 @@ class MixinHandler(object): } def initialize(self): + conn = self.request.connection + self.context = conn.context if self.is_forbidden(): result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version) - self.request.connection.stream.write(to_bytes(result)) - self.request.connection.close() + conn.stream.write(to_bytes(result)) + conn.close() raise ValueError('Accesss denied') def is_forbidden(self): """ Following requests are forbidden: * requests not come from trusted_downstream (if set). - * non-https requests from a public network. + * plain http requests from a public network. """ - context = self.request.connection.context + context = self.context ip = context.address[0] lst = context.trusted_downstream @@ -71,7 +73,7 @@ class MixinHandler(object): if options.fbidhttp and context._orig_protocol == 'http': ipaddr = to_ip_address(ip) if not ipaddr.is_private: - logging.warning('Public non-https request is forbidden.') + logging.warning('Public plain http request is forbidden.') return True def set_default_headers(self): @@ -85,8 +87,10 @@ class MixinHandler(object): return value def get_client_addr(self): - return self.get_real_client_addr() or self.request.connection.context.\ - address + if options.xheaders: + return self.get_real_client_addr() or self.context.address + else: + return self.context.address def get_real_client_addr(self): ip = self.request.remote_ip diff --git a/webssh/settings.py b/webssh/settings.py index 733b143..1743287 100644 --- a/webssh/settings.py +++ b/webssh/settings.py @@ -30,8 +30,10 @@ define('policy', default='warning', help='Missing host key policy, reject|autoadd|warning') define('hostfile', default='', help='User defined host keys file') define('syshostfile', default='', help='System wide host keys file') -define('tdstream', default='', help='trusted downstream, separated by comma') -define('fbidhttp', type=bool, default=True, help='forbid public http request') +define('tdstream', default='', help='Trusted downstream, separated by comma') +define('fbidhttp', type=bool, default=True, + help='Forbid public plain http incoming requests') +define('xheaders', type=bool, default=True, help='Support xheaders') define('wpintvl', type=int, default=0, help='Websocket ping interval') define('version', type=bool, help='Show version information', callback=print_version) @@ -39,7 +41,6 @@ define('version', type=bool, help='Show version information', base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) max_body_size = 1 * 1024 * 1024 -xheaders = True def get_app_settings(options): @@ -55,7 +56,7 @@ def get_app_settings(options): def get_server_settings(options): settings = dict( - xheaders=xheaders, + xheaders=options.xheaders, max_body_size=max_body_size, trusted_downstream=get_trusted_downstream(options) ) @@ -121,4 +122,4 @@ def detect_is_open_to_public(options): result = on_public_network_interfaces(get_ips_by_name(options.address)) if not result and options.fbidhttp: options.fbidhttp = False - logging.info('Forbid public http: {}'.format(options.fbidhttp)) + logging.info('Forbid public plain http: {}'.format(options.fbidhttp))