Move some config variables to handler.py

pull/38/head
Sheng 2018-10-18 15:07:45 +08:00
parent 383f61d4cc
commit 5d6f92e529
5 changed files with 38 additions and 30 deletions

View File

@ -3,16 +3,16 @@ import random
import threading import threading
import tornado.websocket import tornado.websocket
import tornado.gen import tornado.gen
import webssh.handler as handler
from tornado.testing import AsyncHTTPTestCase from tornado.testing import AsyncHTTPTestCase
from tornado.httpclient import HTTPError from tornado.httpclient import HTTPError
from tornado.options import options from tornado.options import options
from tests.sshserver import run_ssh_server, banner from tests.sshserver import run_ssh_server, banner
from tests.utils import encode_multipart_formdata, read_file, make_tests_data_path # noqa from tests.utils import encode_multipart_formdata, read_file, make_tests_data_path # noqa
from webssh import handler
from webssh.main import make_app, make_handlers from webssh.main import make_app, make_handlers
from webssh.settings import ( from webssh.settings import (
get_app_settings, get_server_settings, max_body_size, swallow_http_errors get_app_settings, get_server_settings, max_body_size
) )
from webssh.utils import to_str from webssh.utils import to_str
@ -23,6 +23,7 @@ except ImportError:
handler.DELAY = 0.1 handler.DELAY = 0.1
swallow_http_errors = handler.swallow_http_errors
class TestAppBasic(AsyncHTTPTestCase): class TestAppBasic(AsyncHTTPTestCase):

View File

@ -1,5 +1,6 @@
import unittest import unittest
import paramiko import paramiko
import webssh.handler
from tornado.httpclient import HTTPRequest from tornado.httpclient import HTTPRequest
from tornado.httputil import HTTPServerRequest from tornado.httputil import HTTPServerRequest
@ -16,8 +17,8 @@ class TestMixinHandler(unittest.TestCase):
def test_is_forbidden(self): def test_is_forbidden(self):
handler = MixinHandler() handler = MixinHandler()
handler.is_open_to_public = True webssh.handler.is_open_to_public = True
handler.forbid_public_http = True webssh.handler.forbid_public_http = True
request = HTTPRequest('http://example.com/') request = HTTPRequest('http://example.com/')
handler.request = request handler.request = request

View File

@ -1,4 +1,5 @@
import io import io
import random
import ssl import ssl
import sys import sys
import os.path import os.path
@ -7,7 +8,7 @@ import paramiko
import tornado.options as options import tornado.options as options
from tests.utils import make_tests_data_path from tests.utils import make_tests_data_path
from webssh import settings from webssh import handler
from webssh.policy import load_host_keys from webssh.policy import load_host_keys
from webssh.settings import ( from webssh.settings import (
get_host_keys_settings, get_policy_setting, base_dir, print_version, get_host_keys_settings, get_policy_setting, base_dir, print_version,
@ -140,27 +141,39 @@ class TestSettings(unittest.TestCase):
get_trusted_downstream(options), tdstream get_trusted_downstream(options), tdstream
def test_detect_is_open_to_public(self): def test_detect_is_open_to_public(self):
options.fbidhttp = True options.fbidhttp = random.choice([True, False])
options.address = 'localhost' options.address = 'localhost'
detect_is_open_to_public(options) detect_is_open_to_public(options)
self.assertFalse(settings.is_open_to_public) self.assertFalse(handler.is_open_to_public)
self.assertEqual(handler.forbid_public_http, options.fbidhttp)
options.fbidhttp = random.choice([True, False])
options.fbidhttp = False
options.address = '127.0.0.1' options.address = '127.0.0.1'
detect_is_open_to_public(options) detect_is_open_to_public(options)
self.assertFalse(settings.is_open_to_public) self.assertFalse(handler.is_open_to_public)
self.assertEqual(handler.forbid_public_http, options.fbidhttp)
options.fbidhttp = random.choice([True, False])
options.address = '192.168.1.1' options.address = '192.168.1.1'
detect_is_open_to_public(options) detect_is_open_to_public(options)
self.assertFalse(settings.is_open_to_public) self.assertFalse(handler.is_open_to_public)
self.assertEqual(handler.forbid_public_http, options.fbidhttp)
options.fbidhttp = random.choice([True, False])
options.address = '' options.address = ''
detect_is_open_to_public(options) detect_is_open_to_public(options)
self.assertTrue(settings.is_open_to_public) self.assertTrue(handler.is_open_to_public)
self.assertEqual(handler.forbid_public_http, options.fbidhttp)
options.fbidhttp = random.choice([True, False])
options.address = '0.0.0.0' options.address = '0.0.0.0'
detect_is_open_to_public(options) detect_is_open_to_public(options)
self.assertTrue(settings.is_open_to_public) self.assertTrue(handler.is_open_to_public)
self.assertEqual(handler.forbid_public_http, options.fbidhttp)
options.fbidhttp = random.choice([True, False])
options.address = '::' options.address = '::'
detect_is_open_to_public(options) detect_is_open_to_public(options)
self.assertTrue(settings.is_open_to_public) self.assertTrue(handler.is_open_to_public)
self.assertEqual(handler.forbid_public_http, options.fbidhttp)

View File

@ -10,8 +10,6 @@ import paramiko
import tornado.web import tornado.web
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from tornado.options import options
from webssh import settings
from webssh.utils import ( from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, is_valid_ip_address, is_valid_port, is_valid_hostname,
to_bytes, to_str, to_int, to_ip_address, UnicodeType to_bytes, to_str, to_int, to_ip_address, UnicodeType
@ -33,6 +31,10 @@ DELAY = 3
KEY_MAX_SIZE = 16384 KEY_MAX_SIZE = 16384
DEFAULT_PORT = 22 DEFAULT_PORT = 22
swallow_http_errors = True
is_open_to_public = None
forbid_public_http = None
class InvalidValueError(Exception): class InvalidValueError(Exception):
pass pass
@ -40,20 +42,11 @@ class InvalidValueError(Exception):
class MixinHandler(object): class MixinHandler(object):
is_open_to_public = None
forbid_public_http = None
custom_headers = { custom_headers = {
'Server': 'TornadoServer' 'Server': 'TornadoServer'
} }
def initialize(self): def initialize(self):
if self.is_open_to_public is None:
MixinHandler.is_open_to_public = settings.is_open_to_public
if self.forbid_public_http is None:
MixinHandler.forbid_public_http = options.fbidhttp
if self.is_forbidden(): if self.is_forbidden():
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version) result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
self.request.connection.stream.write(to_bytes(result)) self.request.connection.stream.write(to_bytes(result))
@ -76,7 +69,7 @@ class MixinHandler(object):
) )
return True return True
if self.is_open_to_public and self.forbid_public_http: if is_open_to_public and forbid_public_http:
if context._orig_protocol == 'http': if context._orig_protocol == 'http':
ipaddr = to_ip_address(ip) ipaddr = to_ip_address(ip)
if not ipaddr.is_private: if not ipaddr.is_private:
@ -138,7 +131,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
super(IndexHandler, self).initialize() super(IndexHandler, self).initialize()
def write_error(self, status_code, **kwargs): def write_error(self, status_code, **kwargs):
if self.request.method != 'POST' or not settings.swallow_http_errors: if self.request.method != 'POST' or not swallow_http_errors:
super(IndexHandler, self).write_error(status_code, **kwargs) super(IndexHandler, self).write_error(status_code, **kwargs)
else: else:
exc_info = kwargs.get('exc_info') exc_info = kwargs.get('exc_info')

View File

@ -4,6 +4,7 @@ import ssl
import sys import sys
from tornado.options import define from tornado.options import define
from webssh import handler
from webssh.policy import ( from webssh.policy import (
load_host_keys, get_policy_class, check_policy_setting load_host_keys, get_policy_class, check_policy_setting
) )
@ -39,9 +40,7 @@ define('version', type=bool, help='Show version information',
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
max_body_size = 1 * 1024 * 1024 max_body_size = 1 * 1024 * 1024
swallow_http_errors = True
xheaders = True xheaders = True
is_open_to_public = False
def get_app_settings(options): def get_app_settings(options):
@ -120,9 +119,10 @@ def get_trusted_downstream(options):
def detect_is_open_to_public(options): def detect_is_open_to_public(options):
global is_open_to_public handler.forbid_public_http = options.fbidhttp
if on_public_network_interfaces(get_ips_by_name(options.address)): if on_public_network_interfaces(get_ips_by_name(options.address)):
is_open_to_public = True handler.is_open_to_public = True
logging.info('Forbid public http: {}'.format(options.fbidhttp)) logging.info('Forbid public http: {}'.format(options.fbidhttp))
else: else:
is_open_to_public = False handler.is_open_to_public = False