diff --git a/tests/test_handler.py b/tests/test_handler.py index f0fa759..ac157d8 100644 --- a/tests/test_handler.py +++ b/tests/test_handler.py @@ -1,9 +1,9 @@ import unittest import paramiko -import webssh.handler from tornado.httpclient import HTTPRequest from tornado.httputil import HTTPServerRequest +from tornado.options import options from tests.utils import read_file, make_tests_data_path from webssh.handler import MixinHandler, IndexHandler, InvalidValueError @@ -17,10 +17,9 @@ class TestMixinHandler(unittest.TestCase): def test_is_forbidden(self): handler = MixinHandler() - webssh.handler.is_open_to_public = True - webssh.handler.forbid_public_http = True request = HTTPRequest('http://example.com/') handler.request = request + options.fbidhttp = True context = Mock( address=('8.8.8.8', 8888), diff --git a/tests/test_settings.py b/tests/test_settings.py index 3fac353..100a032 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,5 +1,4 @@ import io -import random import ssl import sys import os.path @@ -8,7 +7,6 @@ import paramiko import tornado.options as options from tests.utils import make_tests_data_path -from webssh import handler from webssh.policy import load_host_keys from webssh.settings import ( get_host_keys_settings, get_policy_setting, base_dir, print_version, @@ -141,39 +139,22 @@ class TestSettings(unittest.TestCase): get_trusted_downstream(options), tdstream def test_detect_is_open_to_public(self): - options.fbidhttp = random.choice([True, False]) - options.address = 'localhost' + options.fbidhttp = True + options.address = '127.0.0.1' detect_is_open_to_public(options) - self.assertFalse(handler.is_open_to_public) - self.assertEqual(handler.forbid_public_http, options.fbidhttp) + self.assertFalse(options.fbidhttp) - options.fbidhttp = random.choice([True, False]) options.fbidhttp = False options.address = '127.0.0.1' detect_is_open_to_public(options) - self.assertFalse(handler.is_open_to_public) - self.assertEqual(handler.forbid_public_http, options.fbidhttp) - - options.fbidhttp = random.choice([True, False]) - options.address = '192.168.1.1' - detect_is_open_to_public(options) - self.assertFalse(handler.is_open_to_public) - self.assertEqual(handler.forbid_public_http, options.fbidhttp) + self.assertFalse(options.fbidhttp) - options.fbidhttp = random.choice([True, False]) - options.address = '' - detect_is_open_to_public(options) - self.assertTrue(handler.is_open_to_public) - self.assertEqual(handler.forbid_public_http, options.fbidhttp) - - options.fbidhttp = random.choice([True, False]) + options.fbidhttp = False options.address = '0.0.0.0' detect_is_open_to_public(options) - self.assertTrue(handler.is_open_to_public) - self.assertEqual(handler.forbid_public_http, options.fbidhttp) + self.assertFalse(options.fbidhttp) - options.fbidhttp = random.choice([True, False]) - options.address = '::' + options.fbidhttp = True + options.address = '0.0.0.0' detect_is_open_to_public(options) - self.assertTrue(handler.is_open_to_public) - self.assertEqual(handler.forbid_public_http, options.fbidhttp) + self.assertTrue(options.fbidhttp) diff --git a/webssh/handler.py b/webssh/handler.py index 129847c..c25c36f 100644 --- a/webssh/handler.py +++ b/webssh/handler.py @@ -10,6 +10,7 @@ import paramiko import tornado.web from tornado.ioloop import IOLoop +from tornado.options import options from webssh.utils import ( is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str, to_int, to_ip_address, UnicodeType @@ -32,8 +33,6 @@ KEY_MAX_SIZE = 16384 DEFAULT_PORT = 22 swallow_http_errors = True -is_open_to_public = None -forbid_public_http = None class InvalidValueError(Exception): @@ -69,12 +68,11 @@ class MixinHandler(object): ) return True - if is_open_to_public and forbid_public_http: - if context._orig_protocol == 'http': - ipaddr = to_ip_address(ip) - if not ipaddr.is_private: - logging.warning('Public non-https request is forbidden.') - return True + if options.fbidhttp and context._orig_protocol == 'http': + ipaddr = to_ip_address(ip) + if not ipaddr.is_private: + logging.warning('Public non-https request is forbidden.') + return True def set_default_headers(self): for header in self.custom_headers.items(): diff --git a/webssh/settings.py b/webssh/settings.py index 923d78d..733b143 100644 --- a/webssh/settings.py +++ b/webssh/settings.py @@ -4,7 +4,6 @@ import ssl import sys from tornado.options import define -from webssh import handler from webssh.policy import ( load_host_keys, get_policy_class, check_policy_setting ) @@ -119,10 +118,7 @@ def get_trusted_downstream(options): def detect_is_open_to_public(options): - handler.forbid_public_http = options.fbidhttp - - if on_public_network_interfaces(get_ips_by_name(options.address)): - handler.is_open_to_public = True - logging.info('Forbid public http: {}'.format(options.fbidhttp)) - else: - handler.is_open_to_public = False + result = on_public_network_interfaces(get_ips_by_name(options.address)) + if not result and options.fbidhttp: + options.fbidhttp = False + logging.info('Forbid public http: {}'.format(options.fbidhttp))