mirror of https://github.com/huashengdun/webssh
Support redirecting http to https
parent
40cf1095ff
commit
8e4039a24a
|
@ -19,35 +19,64 @@ class TestMixinHandler(unittest.TestCase):
|
|||
def test_is_forbidden(self):
|
||||
handler = MixinHandler()
|
||||
open_to_public['http'] = True
|
||||
open_to_public['https'] = True
|
||||
options.fbidhttp = True
|
||||
options.redirect = True
|
||||
|
||||
context = Mock(
|
||||
address=('8.8.8.8', 8888),
|
||||
trusted_downstream=['127.0.0.1'],
|
||||
_orig_protocol='http'
|
||||
)
|
||||
self.assertTrue(handler.is_forbidden(context))
|
||||
self.assertTrue(handler.is_forbidden(context, ''))
|
||||
|
||||
context = Mock(
|
||||
address=('8.8.8.8', 8888),
|
||||
trusted_downstream=[],
|
||||
_orig_protocol='http'
|
||||
)
|
||||
self.assertTrue(handler.is_forbidden(context))
|
||||
|
||||
hostname = 'www.google.com'
|
||||
self.assertEqual(handler.is_forbidden(context, hostname), False)
|
||||
|
||||
context = Mock(
|
||||
address=('192.168.1.1', 8888),
|
||||
trusted_downstream=[],
|
||||
_orig_protocol='http'
|
||||
)
|
||||
self.assertIsNone(handler.is_forbidden(context))
|
||||
self.assertIsNone(handler.is_forbidden(context, ''))
|
||||
|
||||
context = Mock(
|
||||
address=('8.8.8.8', 8888),
|
||||
trusted_downstream=[],
|
||||
_orig_protocol='https'
|
||||
)
|
||||
self.assertIsNone(handler.is_forbidden(context))
|
||||
self.assertIsNone(handler.is_forbidden(context, ''))
|
||||
|
||||
context = Mock(
|
||||
address=('8.8.8.8', 8888),
|
||||
trusted_downstream=[],
|
||||
_orig_protocol='http'
|
||||
)
|
||||
hostname = '8.8.8.8'
|
||||
self.assertTrue(handler.is_forbidden(context, hostname))
|
||||
|
||||
def test_get_redirect_url(self):
|
||||
handler = MixinHandler()
|
||||
hostname = 'www.example.com'
|
||||
uri = '/'
|
||||
port = 443
|
||||
|
||||
self.assertTrue(
|
||||
handler.get_redirect_url(hostname, port, uri=uri),
|
||||
'https://www.example.com/'
|
||||
)
|
||||
|
||||
port = 4433
|
||||
self.assertTrue(
|
||||
handler.get_redirect_url(hostname, port, uri),
|
||||
'https://www.example.com:4433/'
|
||||
)
|
||||
|
||||
def test_get_client_addr(self):
|
||||
handler = MixinHandler()
|
||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
|||
|
||||
from webssh.utils import (
|
||||
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
|
||||
to_int, on_public_network_interface, get_ips_by_name,
|
||||
to_int, on_public_network_interface, get_ips_by_name, is_ip_hostname,
|
||||
is_name_open_to_public
|
||||
)
|
||||
|
||||
|
@ -73,3 +73,9 @@ class TestUitls(unittest.TestCase):
|
|||
self.assertIsNone(is_name_open_to_public('192.168.1.1'))
|
||||
self.assertIsNone(is_name_open_to_public('127.0.0.1'))
|
||||
self.assertIsNone(is_name_open_to_public('localhost'))
|
||||
|
||||
def test_is_ip_hostname(self):
|
||||
self.assertTrue(is_ip_hostname('[::1]'))
|
||||
self.assertTrue(is_ip_hostname('127.0.0.1'))
|
||||
self.assertFalse(is_ip_hostname('localhost'))
|
||||
self.assertFalse(is_ip_hostname('www.google.com'))
|
||||
|
|
|
@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop
|
|||
from tornado.options import options
|
||||
from webssh.utils import (
|
||||
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
|
||||
to_int, to_ip_address, UnicodeType, is_name_open_to_public
|
||||
to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname
|
||||
)
|
||||
from webssh.worker import Worker, recycle_worker, workers
|
||||
|
||||
|
@ -34,7 +34,7 @@ DEFAULT_PORT = 22
|
|||
|
||||
swallow_http_errors = True
|
||||
|
||||
# status of the http(s) server
|
||||
# set by config_open_to_public
|
||||
open_to_public = {
|
||||
'http': None,
|
||||
'https': None
|
||||
|
@ -56,22 +56,28 @@ class MixinHandler(object):
|
|||
'Server': 'TornadoServer'
|
||||
}
|
||||
|
||||
def initialize(self, loop=None):
|
||||
conn = self.request.connection
|
||||
if self.is_forbidden(conn.context):
|
||||
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
|
||||
conn.stream.write(to_bytes(result))
|
||||
conn.close()
|
||||
raise ValueError('Accesss denied')
|
||||
self.loop = loop
|
||||
self.context = conn.context
|
||||
html = ('<html><head><title>{code} {reason}</title></head><body>{code} '
|
||||
'{reason}</body></html>')
|
||||
|
||||
def is_forbidden(self, context):
|
||||
"""
|
||||
Following requests are forbidden:
|
||||
* requests not come from trusted_downstream (if set).
|
||||
* plain http requests from a public network.
|
||||
"""
|
||||
def initialize(self, loop=None):
|
||||
context = self.request.connection.context
|
||||
result = self.is_forbidden(context, self.request.host_name)
|
||||
self._transforms = []
|
||||
if result:
|
||||
self.set_status(403)
|
||||
self.finish(
|
||||
self.html.format(code=self._status_code, reason=self._reason)
|
||||
)
|
||||
elif result is False:
|
||||
to_url = self.get_redirect_url(
|
||||
self.request.host_name, options.sslport, self.request.uri
|
||||
)
|
||||
self.redirect(to_url, permanent=True)
|
||||
else:
|
||||
self.loop = loop
|
||||
self.context = context
|
||||
|
||||
def is_forbidden(self, context, hostname):
|
||||
ip = context.address[0]
|
||||
lst = context.trusted_downstream
|
||||
|
||||
|
@ -81,13 +87,20 @@ class MixinHandler(object):
|
|||
)
|
||||
return True
|
||||
|
||||
if open_to_public['http'] and options.fbidhttp:
|
||||
if context._orig_protocol == 'http':
|
||||
ipaddr = to_ip_address(ip)
|
||||
if not ipaddr.is_private:
|
||||
if open_to_public['http'] and context._orig_protocol == 'http':
|
||||
if not to_ip_address(ip).is_private:
|
||||
if open_to_public['https'] and options.redirect:
|
||||
if not is_ip_hostname(hostname):
|
||||
# redirecting
|
||||
return False
|
||||
if options.fbidhttp:
|
||||
logging.warning('Public plain http request is forbidden.')
|
||||
return True
|
||||
|
||||
def get_redirect_url(self, hostname, port, uri):
|
||||
port = '' if port == 443 else ':%s' % port
|
||||
return 'https://{}{}{}'.format(hostname, port, uri)
|
||||
|
||||
def set_default_headers(self):
|
||||
for header in self.custom_headers.items():
|
||||
self.set_header(*header)
|
||||
|
|
|
@ -33,8 +33,7 @@ def app_listen(app, port, address, server_settings):
|
|||
app.listen(port, address, **server_settings)
|
||||
server_type = 'https' if server_settings.get('ssl_options') else 'http'
|
||||
logging.info(
|
||||
'Started a {} server listening on {}:{}'.format(
|
||||
server_type, address, port)
|
||||
'Listening on {}:{} ({})'.format(address, port, server_type)
|
||||
)
|
||||
config_open_to_public(address, server_type)
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ def print_version(flag):
|
|||
sys.exit(0)
|
||||
|
||||
|
||||
define('address', default='127.0.0.1', help='Listen address')
|
||||
define('address', default='0.0.0.0', help='Listen address')
|
||||
define('port', type=int, default=8888, help='Listen port')
|
||||
define('ssladdress', default='0.0.0.0', help='SSL listen address')
|
||||
define('sslport', type=int, default=4433, help='SSL listen port')
|
||||
|
@ -29,6 +29,7 @@ define('policy', default='warning',
|
|||
define('hostfile', default='', help='User defined host keys file')
|
||||
define('syshostfile', default='', help='System wide host keys file')
|
||||
define('tdstream', default='', help='Trusted downstream, separated by comma')
|
||||
define('redirect', type=bool, default=True, help='Redirecting http to https')
|
||||
define('fbidhttp', type=bool, default=True,
|
||||
help='Forbid public plain http incoming requests')
|
||||
define('xheaders', type=bool, default=True, help='Support xheaders')
|
||||
|
|
|
@ -50,6 +50,16 @@ def is_valid_port(port):
|
|||
return 0 < port < 65536
|
||||
|
||||
|
||||
def is_ip_hostname(hostname):
|
||||
it = iter(hostname)
|
||||
if next(it) == '[':
|
||||
return True
|
||||
for ch in it:
|
||||
if ch != '.' and not ch.isdigit():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_hostname(hostname):
|
||||
if hostname[-1] == '.':
|
||||
# strip exactly one dot from the right, if present
|
||||
|
|
Loading…
Reference in New Issue