From c06bf5311a1e450d4060cf500a63b4f524994f85 Mon Sep 17 00:00:00 2001 From: Sheng Date: Wed, 17 Oct 2018 22:39:53 +0800 Subject: [PATCH] Added an option for blocking public non-https requests --- tests/test_handler.py | 3 ++- tests/test_settings.py | 29 ++++++++++++++++++++++++++++- tests/test_utils.py | 27 +++++++++++++++++++++++++-- webssh/handler.py | 25 ++++++++++++++++++------- webssh/main.py | 3 ++- webssh/settings.py | 15 ++++++++++++++- webssh/utils.py | 26 ++++++++++++++++++++++++++ 7 files changed, 115 insertions(+), 13 deletions(-) diff --git a/tests/test_handler.py b/tests/test_handler.py index fde909f..9e28717 100644 --- a/tests/test_handler.py +++ b/tests/test_handler.py @@ -3,7 +3,6 @@ import paramiko from tornado.httpclient import HTTPRequest from tornado.httputil import HTTPServerRequest -from tornado.web import HTTPError from tests.utils import read_file, make_tests_data_path from webssh.handler import MixinHandler, IndexHandler, InvalidValueError @@ -17,6 +16,8 @@ class TestMixinHandler(unittest.TestCase): def test_is_forbidden(self): handler = MixinHandler() + handler.is_open_to_public = True + handler.forbid_public_http = True request = HTTPRequest('http://example.com/') handler.request = request diff --git a/tests/test_settings.py b/tests/test_settings.py index 0303d0f..8f7b21e 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -7,10 +7,11 @@ import paramiko import tornado.options as options from tests.utils import make_tests_data_path +from webssh import settings from webssh.policy import load_host_keys from webssh.settings import ( get_host_keys_settings, get_policy_setting, base_dir, print_version, - get_ssl_context, get_trusted_downstream + get_ssl_context, get_trusted_downstream, detect_is_open_to_public, ) from webssh.utils import UnicodeType from webssh._version import __version__ @@ -137,3 +138,29 @@ class TestSettings(unittest.TestCase): options.tdstream = '1.1.1.1, 2.2.2.' with self.assertRaises(ValueError): get_trusted_downstream(options), tdstream + + def test_detect_is_open_to_public(self): + options.fbidhttp = True + options.address = 'localhost' + detect_is_open_to_public(options) + self.assertFalse(settings.is_open_to_public) + + options.address = '127.0.0.1' + detect_is_open_to_public(options) + self.assertFalse(settings.is_open_to_public) + + options.address = '192.168.1.1' + detect_is_open_to_public(options) + self.assertFalse(settings.is_open_to_public) + + options.address = '' + detect_is_open_to_public(options) + self.assertTrue(settings.is_open_to_public) + + options.address = '0.0.0.0' + detect_is_open_to_public(options) + self.assertTrue(settings.is_open_to_public) + + options.address = '::' + detect_is_open_to_public(options) + self.assertTrue(settings.is_open_to_public) diff --git a/tests/test_utils.py b/tests/test_utils.py index f94546e..9eec2b5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,9 @@ import unittest from webssh.utils import ( - is_valid_ip_address, is_valid_port, is_valid_hostname, - to_str, to_bytes, to_int + is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes, + to_int, on_public_network_interface, on_public_network_interfaces, + get_ips_by_name ) @@ -51,3 +52,25 @@ class TestUitls(unittest.TestCase): self.assertFalse(is_valid_hostname('https://www.google.com')) self.assertFalse(is_valid_hostname('127.0.0.1')) self.assertFalse(is_valid_hostname('::1')) + + def test_get_ips_by_name(self): + self.assertTrue(get_ips_by_name(''), {'0.0.0.0', '::'}) + self.assertTrue(get_ips_by_name('localhost'), {'127.0.0.1'}) + self.assertTrue(get_ips_by_name('192.68.1.1'), {'192.168.1.1'}) + self.assertTrue(get_ips_by_name('2.2.2.2'), {'2.2.2.2'}) + + def test_on_public_network_interface(self): + self.assertTrue(on_public_network_interface('0.0.0.0')) + self.assertTrue(on_public_network_interface('::')) + self.assertTrue(on_public_network_interface('0:0:0:0:0:0:0:0')) + self.assertTrue(on_public_network_interface('2.2.2.2')) + self.assertTrue(on_public_network_interface('2:2:2:2:2:2:2:2')) + self.assertIsNone(on_public_network_interface('127.0.0.1')) + + def test_on_public_network_interfaces(self): + self.assertTrue( + on_public_network_interfaces(['0.0.0.0', '127.0.0.1']) + ) + self.assertIsNone( + on_public_network_interfaces(['192.168.1.1', '127.0.0.1']) + ) diff --git a/webssh/handler.py b/webssh/handler.py index 6f43074..83751d2 100644 --- a/webssh/handler.py +++ b/webssh/handler.py @@ -10,7 +10,8 @@ import paramiko import tornado.web from tornado.ioloop import IOLoop -from webssh.settings import swallow_http_errors +from tornado.options import options +from webssh import settings from webssh.utils import ( is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str, to_int, to_ip_address, UnicodeType @@ -39,11 +40,20 @@ class InvalidValueError(Exception): class MixinHandler(object): + is_open_to_public = None + forbid_public_http = None + custom_headers = { 'Server': 'TornadoServer' } def initialize(self): + if self.is_open_to_public is None: + MixinHandler.is_open_to_public = settings.is_open_to_public + + if self.forbid_public_http is None: + MixinHandler.forbid_public_http = options.fbidhttp + if self.is_forbidden(): result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version) self.request.connection.stream.write(to_bytes(result)) @@ -66,11 +76,12 @@ class MixinHandler(object): ) return True - if context._orig_protocol == 'http': - ipaddr = to_ip_address(ip) - if not ipaddr.is_private: - logging.warning('Public non-https request is forbidden.') - return True + if self.is_open_to_public and self.forbid_public_http: + if context._orig_protocol == 'http': + ipaddr = to_ip_address(ip) + if not ipaddr.is_private: + logging.warning('Public non-https request is forbidden.') + return True def set_default_headers(self): for header in self.custom_headers.items(): @@ -127,7 +138,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): super(IndexHandler, self).initialize() def write_error(self, status_code, **kwargs): - if self.request.method != 'POST' or not swallow_http_errors: + if self.request.method != 'POST' or not settings.swallow_http_errors: super(IndexHandler, self).write_error(status_code, **kwargs) else: exc_info = kwargs.get('exc_info') diff --git a/webssh/main.py b/webssh/main.py index 849fe47..50abf14 100644 --- a/webssh/main.py +++ b/webssh/main.py @@ -6,7 +6,7 @@ from tornado.options import options from webssh.handler import IndexHandler, WsockHandler, NotFoundHandler from webssh.settings import ( get_app_settings, get_host_keys_settings, get_policy_setting, - get_ssl_context, get_server_settings + get_ssl_context, get_server_settings, detect_is_open_to_public ) @@ -40,6 +40,7 @@ def main(): app.listen(options.sslport, options.ssladdress, **server_settings) logging.info('Listening on ssl {}:{}'.format(options.ssladdress, options.sslport)) + detect_is_open_to_public(options) loop.start() diff --git a/webssh/settings.py b/webssh/settings.py index 10db49e..ef3dd87 100644 --- a/webssh/settings.py +++ b/webssh/settings.py @@ -7,7 +7,9 @@ from tornado.options import define from webssh.policy import ( load_host_keys, get_policy_class, check_policy_setting ) -from webssh.utils import to_ip_address +from webssh.utils import ( + to_ip_address, get_ips_by_name, on_public_network_interfaces +) from webssh._version import __version__ @@ -29,6 +31,7 @@ define('policy', default='warning', define('hostfile', default='', help='User defined host keys file') define('syshostfile', default='', help='System wide host keys file') define('tdstream', default='', help='trusted downstream, separated by comma') +define('fbidhttp', type=bool, default=True, help='forbid public http request') define('wpintvl', type=int, default=0, help='Websocket ping interval') define('version', type=bool, help='Show version information', callback=print_version) @@ -38,6 +41,7 @@ base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) max_body_size = 1 * 1024 * 1024 swallow_http_errors = True xheaders = True +is_open_to_public = False def get_app_settings(options): @@ -113,3 +117,12 @@ def get_trusted_downstream(options): to_ip_address(ip) tdstream.add(ip) return tdstream + + +def detect_is_open_to_public(options): + global is_open_to_public + if on_public_network_interfaces(get_ips_by_name(options.address)): + is_open_to_public = True + logging.info('Forbid public http: {}'.format(options.fbidhttp)) + else: + is_open_to_public = False diff --git a/webssh/utils.py b/webssh/utils.py index 5c832c6..48d6d97 100644 --- a/webssh/utils.py +++ b/webssh/utils.py @@ -1,5 +1,6 @@ import ipaddress import re +import socket try: from types import UnicodeType @@ -10,6 +11,9 @@ except ImportError: numeric = re.compile(r'[0-9]+$') allowed = re.compile(r'(?!-)[a-z0-9-]{1,63}(?