Use open_to_public to store the status of the http(s) server

pull/38/head
Sheng 2018-10-20 17:26:33 +08:00
parent e31e9be433
commit 40cf1095ff
7 changed files with 51 additions and 59 deletions

View File

@ -4,7 +4,9 @@ import paramiko
from tornado.httputil import HTTPServerRequest
from tornado.options import options
from tests.utils import read_file, make_tests_data_path
from webssh.handler import MixinHandler, IndexHandler, InvalidValueError
from webssh.handler import (
MixinHandler, IndexHandler, InvalidValueError, open_to_public
)
try:
from unittest.mock import Mock
@ -16,6 +18,7 @@ class TestMixinHandler(unittest.TestCase):
def test_is_forbidden(self):
handler = MixinHandler()
open_to_public['http'] = True
options.fbidhttp = True
context = Mock(

View File

@ -10,7 +10,7 @@ from tests.utils import make_tests_data_path
from webssh.policy import load_host_keys
from webssh.settings import (
get_host_keys_settings, get_policy_setting, base_dir, print_version,
get_ssl_context, get_trusted_downstream, detect_is_open_to_public,
get_ssl_context, get_trusted_downstream
)
from webssh.utils import UnicodeType
from webssh._version import __version__
@ -137,24 +137,3 @@ class TestSettings(unittest.TestCase):
tdstream = '1.1.1.1, 2.2.2.'
with self.assertRaises(ValueError):
get_trusted_downstream(tdstream)
def test_detect_is_open_to_public(self):
options.fbidhttp = True
options.address = '127.0.0.1'
detect_is_open_to_public(options)
self.assertFalse(options.fbidhttp)
options.fbidhttp = False
options.address = '127.0.0.1'
detect_is_open_to_public(options)
self.assertFalse(options.fbidhttp)
options.fbidhttp = False
options.address = '0.0.0.0'
detect_is_open_to_public(options)
self.assertFalse(options.fbidhttp)
options.fbidhttp = True
options.address = '0.0.0.0'
detect_is_open_to_public(options)
self.assertTrue(options.fbidhttp)

View File

@ -2,8 +2,8 @@ import unittest
from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
to_int, on_public_network_interface, on_public_network_interfaces,
get_ips_by_name
to_int, on_public_network_interface, get_ips_by_name,
is_name_open_to_public
)
@ -67,10 +67,9 @@ class TestUitls(unittest.TestCase):
self.assertTrue(on_public_network_interface('2:2:2:2:2:2:2:2'))
self.assertIsNone(on_public_network_interface('127.0.0.1'))
def test_on_public_network_interfaces(self):
self.assertTrue(
on_public_network_interfaces(['0.0.0.0', '127.0.0.1'])
)
self.assertIsNone(
on_public_network_interfaces(['192.168.1.1', '127.0.0.1'])
)
def test_is_name_open_to_public(self):
self.assertTrue(is_name_open_to_public('0.0.0.0'))
self.assertTrue(is_name_open_to_public('::'))
self.assertIsNone(is_name_open_to_public('192.168.1.1'))
self.assertIsNone(is_name_open_to_public('127.0.0.1'))
self.assertIsNone(is_name_open_to_public('localhost'))

View File

@ -12,8 +12,8 @@ import tornado.web
from tornado.ioloop import IOLoop
from tornado.options import options
from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname,
to_bytes, to_str, to_int, to_ip_address, UnicodeType
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
to_int, to_ip_address, UnicodeType, is_name_open_to_public
)
from webssh.worker import Worker, recycle_worker, workers
@ -34,6 +34,17 @@ DEFAULT_PORT = 22
swallow_http_errors = True
# status of the http(s) server
open_to_public = {
'http': None,
'https': None
}
def config_open_to_public(address, server_type):
status = True if is_name_open_to_public(address) else False
open_to_public[server_type] = status
class InvalidValueError(Exception):
pass
@ -70,11 +81,12 @@ class MixinHandler(object):
)
return True
if options.fbidhttp and context._orig_protocol == 'http':
ipaddr = to_ip_address(ip)
if not ipaddr.is_private:
logging.warning('Public plain http request is forbidden.')
return True
if open_to_public['http'] and options.fbidhttp:
if context._orig_protocol == 'http':
ipaddr = to_ip_address(ip)
if not ipaddr.is_private:
logging.warning('Public plain http request is forbidden.')
return True
def set_default_headers(self):
for header in self.custom_headers.items():

View File

@ -3,10 +3,12 @@ import tornado.web
import tornado.ioloop
from tornado.options import options
from webssh.handler import IndexHandler, WsockHandler, NotFoundHandler
from webssh.handler import (
IndexHandler, WsockHandler, NotFoundHandler, config_open_to_public
)
from webssh.settings import (
get_app_settings, get_host_keys_settings, get_policy_setting,
get_ssl_context, get_server_settings, detect_is_open_to_public
get_ssl_context, get_server_settings
)
@ -27,20 +29,26 @@ def make_app(handlers, settings):
return tornado.web.Application(handlers, **settings)
def app_listen(app, port, address, server_settings):
app.listen(port, address, **server_settings)
server_type = 'https' if server_settings.get('ssl_options') else 'http'
logging.info(
'Started a {} server listening on {}:{}'.format(
server_type, address, port)
)
config_open_to_public(address, server_type)
def main():
options.parse_command_line()
loop = tornado.ioloop.IOLoop.current()
app = make_app(make_handlers(loop, options), get_app_settings(options))
ssl_ctx = get_ssl_context(options)
server_settings = get_server_settings(options)
app.listen(options.port, options.address, **server_settings)
logging.info('Listening on {}:{}'.format(options.address, options.port))
app_listen(app, options.port, options.address, server_settings)
if ssl_ctx:
server_settings.update(ssl_options=ssl_ctx)
app.listen(options.sslport, options.ssladdress, **server_settings)
logging.info('Listening on ssl {}:{}'.format(options.ssladdress,
options.sslport))
detect_is_open_to_public(options)
app_listen(app, options.sslport, options.ssladdress, server_settings)
loop.start()

View File

@ -7,9 +7,7 @@ from tornado.options import define
from webssh.policy import (
load_host_keys, get_policy_class, check_policy_setting
)
from webssh.utils import (
to_ip_address, get_ips_by_name, on_public_network_interfaces
)
from webssh.utils import to_ip_address
from webssh._version import __version__
@ -116,10 +114,3 @@ def get_trusted_downstream(tdstream):
to_ip_address(ip)
result.add(ip)
return result
def detect_is_open_to_public(options):
result = on_public_network_interfaces(get_ips_by_name(options.address))
if not result and options.fbidhttp:
options.fbidhttp = False
logging.info('Forbid public plain http: {}'.format(options.fbidhttp))

View File

@ -82,7 +82,7 @@ def on_public_network_interface(ip):
return True
def on_public_network_interfaces(ips):
for ip in ips:
def is_name_open_to_public(name):
for ip in get_ips_by_name(name):
if on_public_network_interface(ip):
return True