mirror of https://github.com/huashengdun/webssh
Added an option for blocking public non-https requests
parent
746982b001
commit
c06bf5311a
|
@ -3,7 +3,6 @@ import paramiko
|
||||||
|
|
||||||
from tornado.httpclient import HTTPRequest
|
from tornado.httpclient import HTTPRequest
|
||||||
from tornado.httputil import HTTPServerRequest
|
from tornado.httputil import HTTPServerRequest
|
||||||
from tornado.web import HTTPError
|
|
||||||
from tests.utils import read_file, make_tests_data_path
|
from tests.utils import read_file, make_tests_data_path
|
||||||
from webssh.handler import MixinHandler, IndexHandler, InvalidValueError
|
from webssh.handler import MixinHandler, IndexHandler, InvalidValueError
|
||||||
|
|
||||||
|
@ -17,6 +16,8 @@ class TestMixinHandler(unittest.TestCase):
|
||||||
|
|
||||||
def test_is_forbidden(self):
|
def test_is_forbidden(self):
|
||||||
handler = MixinHandler()
|
handler = MixinHandler()
|
||||||
|
handler.is_open_to_public = True
|
||||||
|
handler.forbid_public_http = True
|
||||||
request = HTTPRequest('http://example.com/')
|
request = HTTPRequest('http://example.com/')
|
||||||
handler.request = request
|
handler.request = request
|
||||||
|
|
||||||
|
|
|
@ -7,10 +7,11 @@ import paramiko
|
||||||
import tornado.options as options
|
import tornado.options as options
|
||||||
|
|
||||||
from tests.utils import make_tests_data_path
|
from tests.utils import make_tests_data_path
|
||||||
|
from webssh import settings
|
||||||
from webssh.policy import load_host_keys
|
from webssh.policy import load_host_keys
|
||||||
from webssh.settings import (
|
from webssh.settings import (
|
||||||
get_host_keys_settings, get_policy_setting, base_dir, print_version,
|
get_host_keys_settings, get_policy_setting, base_dir, print_version,
|
||||||
get_ssl_context, get_trusted_downstream
|
get_ssl_context, get_trusted_downstream, detect_is_open_to_public,
|
||||||
)
|
)
|
||||||
from webssh.utils import UnicodeType
|
from webssh.utils import UnicodeType
|
||||||
from webssh._version import __version__
|
from webssh._version import __version__
|
||||||
|
@ -137,3 +138,29 @@ class TestSettings(unittest.TestCase):
|
||||||
options.tdstream = '1.1.1.1, 2.2.2.'
|
options.tdstream = '1.1.1.1, 2.2.2.'
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
get_trusted_downstream(options), tdstream
|
get_trusted_downstream(options), tdstream
|
||||||
|
|
||||||
|
def test_detect_is_open_to_public(self):
|
||||||
|
options.fbidhttp = True
|
||||||
|
options.address = 'localhost'
|
||||||
|
detect_is_open_to_public(options)
|
||||||
|
self.assertFalse(settings.is_open_to_public)
|
||||||
|
|
||||||
|
options.address = '127.0.0.1'
|
||||||
|
detect_is_open_to_public(options)
|
||||||
|
self.assertFalse(settings.is_open_to_public)
|
||||||
|
|
||||||
|
options.address = '192.168.1.1'
|
||||||
|
detect_is_open_to_public(options)
|
||||||
|
self.assertFalse(settings.is_open_to_public)
|
||||||
|
|
||||||
|
options.address = ''
|
||||||
|
detect_is_open_to_public(options)
|
||||||
|
self.assertTrue(settings.is_open_to_public)
|
||||||
|
|
||||||
|
options.address = '0.0.0.0'
|
||||||
|
detect_is_open_to_public(options)
|
||||||
|
self.assertTrue(settings.is_open_to_public)
|
||||||
|
|
||||||
|
options.address = '::'
|
||||||
|
detect_is_open_to_public(options)
|
||||||
|
self.assertTrue(settings.is_open_to_public)
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from webssh.utils import (
|
from webssh.utils import (
|
||||||
is_valid_ip_address, is_valid_port, is_valid_hostname,
|
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
|
||||||
to_str, to_bytes, to_int
|
to_int, on_public_network_interface, on_public_network_interfaces,
|
||||||
|
get_ips_by_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,3 +52,25 @@ class TestUitls(unittest.TestCase):
|
||||||
self.assertFalse(is_valid_hostname('https://www.google.com'))
|
self.assertFalse(is_valid_hostname('https://www.google.com'))
|
||||||
self.assertFalse(is_valid_hostname('127.0.0.1'))
|
self.assertFalse(is_valid_hostname('127.0.0.1'))
|
||||||
self.assertFalse(is_valid_hostname('::1'))
|
self.assertFalse(is_valid_hostname('::1'))
|
||||||
|
|
||||||
|
def test_get_ips_by_name(self):
|
||||||
|
self.assertTrue(get_ips_by_name(''), {'0.0.0.0', '::'})
|
||||||
|
self.assertTrue(get_ips_by_name('localhost'), {'127.0.0.1'})
|
||||||
|
self.assertTrue(get_ips_by_name('192.68.1.1'), {'192.168.1.1'})
|
||||||
|
self.assertTrue(get_ips_by_name('2.2.2.2'), {'2.2.2.2'})
|
||||||
|
|
||||||
|
def test_on_public_network_interface(self):
|
||||||
|
self.assertTrue(on_public_network_interface('0.0.0.0'))
|
||||||
|
self.assertTrue(on_public_network_interface('::'))
|
||||||
|
self.assertTrue(on_public_network_interface('0:0:0:0:0:0:0:0'))
|
||||||
|
self.assertTrue(on_public_network_interface('2.2.2.2'))
|
||||||
|
self.assertTrue(on_public_network_interface('2:2:2:2:2:2:2:2'))
|
||||||
|
self.assertIsNone(on_public_network_interface('127.0.0.1'))
|
||||||
|
|
||||||
|
def test_on_public_network_interfaces(self):
|
||||||
|
self.assertTrue(
|
||||||
|
on_public_network_interfaces(['0.0.0.0', '127.0.0.1'])
|
||||||
|
)
|
||||||
|
self.assertIsNone(
|
||||||
|
on_public_network_interfaces(['192.168.1.1', '127.0.0.1'])
|
||||||
|
)
|
||||||
|
|
|
@ -10,7 +10,8 @@ import paramiko
|
||||||
import tornado.web
|
import tornado.web
|
||||||
|
|
||||||
from tornado.ioloop import IOLoop
|
from tornado.ioloop import IOLoop
|
||||||
from webssh.settings import swallow_http_errors
|
from tornado.options import options
|
||||||
|
from webssh import settings
|
||||||
from webssh.utils import (
|
from webssh.utils import (
|
||||||
is_valid_ip_address, is_valid_port, is_valid_hostname,
|
is_valid_ip_address, is_valid_port, is_valid_hostname,
|
||||||
to_bytes, to_str, to_int, to_ip_address, UnicodeType
|
to_bytes, to_str, to_int, to_ip_address, UnicodeType
|
||||||
|
@ -39,11 +40,20 @@ class InvalidValueError(Exception):
|
||||||
|
|
||||||
class MixinHandler(object):
|
class MixinHandler(object):
|
||||||
|
|
||||||
|
is_open_to_public = None
|
||||||
|
forbid_public_http = None
|
||||||
|
|
||||||
custom_headers = {
|
custom_headers = {
|
||||||
'Server': 'TornadoServer'
|
'Server': 'TornadoServer'
|
||||||
}
|
}
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
|
if self.is_open_to_public is None:
|
||||||
|
MixinHandler.is_open_to_public = settings.is_open_to_public
|
||||||
|
|
||||||
|
if self.forbid_public_http is None:
|
||||||
|
MixinHandler.forbid_public_http = options.fbidhttp
|
||||||
|
|
||||||
if self.is_forbidden():
|
if self.is_forbidden():
|
||||||
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
|
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
|
||||||
self.request.connection.stream.write(to_bytes(result))
|
self.request.connection.stream.write(to_bytes(result))
|
||||||
|
@ -66,6 +76,7 @@ class MixinHandler(object):
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if self.is_open_to_public and self.forbid_public_http:
|
||||||
if context._orig_protocol == 'http':
|
if context._orig_protocol == 'http':
|
||||||
ipaddr = to_ip_address(ip)
|
ipaddr = to_ip_address(ip)
|
||||||
if not ipaddr.is_private:
|
if not ipaddr.is_private:
|
||||||
|
@ -127,7 +138,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
super(IndexHandler, self).initialize()
|
super(IndexHandler, self).initialize()
|
||||||
|
|
||||||
def write_error(self, status_code, **kwargs):
|
def write_error(self, status_code, **kwargs):
|
||||||
if self.request.method != 'POST' or not swallow_http_errors:
|
if self.request.method != 'POST' or not settings.swallow_http_errors:
|
||||||
super(IndexHandler, self).write_error(status_code, **kwargs)
|
super(IndexHandler, self).write_error(status_code, **kwargs)
|
||||||
else:
|
else:
|
||||||
exc_info = kwargs.get('exc_info')
|
exc_info = kwargs.get('exc_info')
|
||||||
|
|
|
@ -6,7 +6,7 @@ from tornado.options import options
|
||||||
from webssh.handler import IndexHandler, WsockHandler, NotFoundHandler
|
from webssh.handler import IndexHandler, WsockHandler, NotFoundHandler
|
||||||
from webssh.settings import (
|
from webssh.settings import (
|
||||||
get_app_settings, get_host_keys_settings, get_policy_setting,
|
get_app_settings, get_host_keys_settings, get_policy_setting,
|
||||||
get_ssl_context, get_server_settings
|
get_ssl_context, get_server_settings, detect_is_open_to_public
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@ def main():
|
||||||
app.listen(options.sslport, options.ssladdress, **server_settings)
|
app.listen(options.sslport, options.ssladdress, **server_settings)
|
||||||
logging.info('Listening on ssl {}:{}'.format(options.ssladdress,
|
logging.info('Listening on ssl {}:{}'.format(options.ssladdress,
|
||||||
options.sslport))
|
options.sslport))
|
||||||
|
detect_is_open_to_public(options)
|
||||||
loop.start()
|
loop.start()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,9 @@ from tornado.options import define
|
||||||
from webssh.policy import (
|
from webssh.policy import (
|
||||||
load_host_keys, get_policy_class, check_policy_setting
|
load_host_keys, get_policy_class, check_policy_setting
|
||||||
)
|
)
|
||||||
from webssh.utils import to_ip_address
|
from webssh.utils import (
|
||||||
|
to_ip_address, get_ips_by_name, on_public_network_interfaces
|
||||||
|
)
|
||||||
from webssh._version import __version__
|
from webssh._version import __version__
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,6 +31,7 @@ define('policy', default='warning',
|
||||||
define('hostfile', default='', help='User defined host keys file')
|
define('hostfile', default='', help='User defined host keys file')
|
||||||
define('syshostfile', default='', help='System wide host keys file')
|
define('syshostfile', default='', help='System wide host keys file')
|
||||||
define('tdstream', default='', help='trusted downstream, separated by comma')
|
define('tdstream', default='', help='trusted downstream, separated by comma')
|
||||||
|
define('fbidhttp', type=bool, default=True, help='forbid public http request')
|
||||||
define('wpintvl', type=int, default=0, help='Websocket ping interval')
|
define('wpintvl', type=int, default=0, help='Websocket ping interval')
|
||||||
define('version', type=bool, help='Show version information',
|
define('version', type=bool, help='Show version information',
|
||||||
callback=print_version)
|
callback=print_version)
|
||||||
|
@ -38,6 +41,7 @@ base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
max_body_size = 1 * 1024 * 1024
|
max_body_size = 1 * 1024 * 1024
|
||||||
swallow_http_errors = True
|
swallow_http_errors = True
|
||||||
xheaders = True
|
xheaders = True
|
||||||
|
is_open_to_public = False
|
||||||
|
|
||||||
|
|
||||||
def get_app_settings(options):
|
def get_app_settings(options):
|
||||||
|
@ -113,3 +117,12 @@ def get_trusted_downstream(options):
|
||||||
to_ip_address(ip)
|
to_ip_address(ip)
|
||||||
tdstream.add(ip)
|
tdstream.add(ip)
|
||||||
return tdstream
|
return tdstream
|
||||||
|
|
||||||
|
|
||||||
|
def detect_is_open_to_public(options):
|
||||||
|
global is_open_to_public
|
||||||
|
if on_public_network_interfaces(get_ips_by_name(options.address)):
|
||||||
|
is_open_to_public = True
|
||||||
|
logging.info('Forbid public http: {}'.format(options.fbidhttp))
|
||||||
|
else:
|
||||||
|
is_open_to_public = False
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import re
|
import re
|
||||||
|
import socket
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from types import UnicodeType
|
from types import UnicodeType
|
||||||
|
@ -10,6 +11,9 @@ except ImportError:
|
||||||
numeric = re.compile(r'[0-9]+$')
|
numeric = re.compile(r'[0-9]+$')
|
||||||
allowed = re.compile(r'(?!-)[a-z0-9-]{1,63}(?<!-)$', re.IGNORECASE)
|
allowed = re.compile(r'(?!-)[a-z0-9-]{1,63}(?<!-)$', re.IGNORECASE)
|
||||||
|
|
||||||
|
default_public_ipv4addr = ipaddress.ip_address(u'0.0.0.0')
|
||||||
|
default_public_ipv6addr = ipaddress.ip_address(u'::')
|
||||||
|
|
||||||
|
|
||||||
def to_str(bstr, encoding='utf-8'):
|
def to_str(bstr, encoding='utf-8'):
|
||||||
if isinstance(bstr, bytes):
|
if isinstance(bstr, bytes):
|
||||||
|
@ -60,3 +64,25 @@ def is_valid_hostname(hostname):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return all(allowed.match(label) for label in labels)
|
return all(allowed.match(label) for label in labels)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ips_by_name(name):
|
||||||
|
if name == '':
|
||||||
|
return {'0.0.0.0', '::'}
|
||||||
|
ret = socket.getaddrinfo(name, 0, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
||||||
|
return {t[4][0] for t in ret}
|
||||||
|
|
||||||
|
|
||||||
|
def on_public_network_interface(ip):
|
||||||
|
ipaddr = to_ip_address(ip)
|
||||||
|
if ipaddr == default_public_ipv4addr or ipaddr == default_public_ipv6addr:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not ipaddr.is_private:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def on_public_network_interfaces(ips):
|
||||||
|
for ip in ips:
|
||||||
|
if on_public_network_interface(ip):
|
||||||
|
return True
|
||||||
|
|
Loading…
Reference in New Issue