Support redirecting http to https

pull/38/head
Sheng 2018-10-21 14:07:44 +08:00
parent 40cf1095ff
commit 8e4039a24a
6 changed files with 87 additions and 29 deletions

View File

@ -19,35 +19,64 @@ class TestMixinHandler(unittest.TestCase):
def test_is_forbidden(self):
handler = MixinHandler()
open_to_public['http'] = True
open_to_public['https'] = True
options.fbidhttp = True
options.redirect = True
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=['127.0.0.1'],
_orig_protocol='http'
)
self.assertTrue(handler.is_forbidden(context))
self.assertTrue(handler.is_forbidden(context, ''))
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
self.assertTrue(handler.is_forbidden(context))
hostname = 'www.google.com'
self.assertEqual(handler.is_forbidden(context, hostname), False)
context = Mock(
address=('192.168.1.1', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
self.assertIsNone(handler.is_forbidden(context))
self.assertIsNone(handler.is_forbidden(context, ''))
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='https'
)
self.assertIsNone(handler.is_forbidden(context))
self.assertIsNone(handler.is_forbidden(context, ''))
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
hostname = '8.8.8.8'
self.assertTrue(handler.is_forbidden(context, hostname))
def test_get_redirect_url(self):
handler = MixinHandler()
hostname = 'www.example.com'
uri = '/'
port = 443
self.assertTrue(
handler.get_redirect_url(hostname, port, uri=uri),
'https://www.example.com/'
)
port = 4433
self.assertTrue(
handler.get_redirect_url(hostname, port, uri),
'https://www.example.com:4433/'
)
def test_get_client_addr(self):
handler = MixinHandler()

View File

@ -2,7 +2,7 @@ import unittest
from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
to_int, on_public_network_interface, get_ips_by_name,
to_int, on_public_network_interface, get_ips_by_name, is_ip_hostname,
is_name_open_to_public
)
@ -73,3 +73,9 @@ class TestUitls(unittest.TestCase):
self.assertIsNone(is_name_open_to_public('192.168.1.1'))
self.assertIsNone(is_name_open_to_public('127.0.0.1'))
self.assertIsNone(is_name_open_to_public('localhost'))
def test_is_ip_hostname(self):
self.assertTrue(is_ip_hostname('[::1]'))
self.assertTrue(is_ip_hostname('127.0.0.1'))
self.assertFalse(is_ip_hostname('localhost'))
self.assertFalse(is_ip_hostname('www.google.com'))

View File

@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop
from tornado.options import options
from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
to_int, to_ip_address, UnicodeType, is_name_open_to_public
to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname
)
from webssh.worker import Worker, recycle_worker, workers
@ -34,7 +34,7 @@ DEFAULT_PORT = 22
swallow_http_errors = True
# status of the http(s) server
# set by config_open_to_public
open_to_public = {
'http': None,
'https': None
@ -56,22 +56,28 @@ class MixinHandler(object):
'Server': 'TornadoServer'
}
def initialize(self, loop=None):
conn = self.request.connection
if self.is_forbidden(conn.context):
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
conn.stream.write(to_bytes(result))
conn.close()
raise ValueError('Accesss denied')
self.loop = loop
self.context = conn.context
html = ('<html><head><title>{code} {reason}</title></head><body>{code} '
'{reason}</body></html>')
def is_forbidden(self, context):
"""
Following requests are forbidden:
* requests not come from trusted_downstream (if set).
* plain http requests from a public network.
"""
def initialize(self, loop=None):
context = self.request.connection.context
result = self.is_forbidden(context, self.request.host_name)
self._transforms = []
if result:
self.set_status(403)
self.finish(
self.html.format(code=self._status_code, reason=self._reason)
)
elif result is False:
to_url = self.get_redirect_url(
self.request.host_name, options.sslport, self.request.uri
)
self.redirect(to_url, permanent=True)
else:
self.loop = loop
self.context = context
def is_forbidden(self, context, hostname):
ip = context.address[0]
lst = context.trusted_downstream
@ -81,13 +87,20 @@ class MixinHandler(object):
)
return True
if open_to_public['http'] and options.fbidhttp:
if context._orig_protocol == 'http':
ipaddr = to_ip_address(ip)
if not ipaddr.is_private:
if open_to_public['http'] and context._orig_protocol == 'http':
if not to_ip_address(ip).is_private:
if open_to_public['https'] and options.redirect:
if not is_ip_hostname(hostname):
# redirecting
return False
if options.fbidhttp:
logging.warning('Public plain http request is forbidden.')
return True
def get_redirect_url(self, hostname, port, uri):
port = '' if port == 443 else ':%s' % port
return 'https://{}{}{}'.format(hostname, port, uri)
def set_default_headers(self):
for header in self.custom_headers.items():
self.set_header(*header)

View File

@ -33,8 +33,7 @@ def app_listen(app, port, address, server_settings):
app.listen(port, address, **server_settings)
server_type = 'https' if server_settings.get('ssl_options') else 'http'
logging.info(
'Started a {} server listening on {}:{}'.format(
server_type, address, port)
'Listening on {}:{} ({})'.format(address, port, server_type)
)
config_open_to_public(address, server_type)

View File

@ -17,7 +17,7 @@ def print_version(flag):
sys.exit(0)
define('address', default='127.0.0.1', help='Listen address')
define('address', default='0.0.0.0', help='Listen address')
define('port', type=int, default=8888, help='Listen port')
define('ssladdress', default='0.0.0.0', help='SSL listen address')
define('sslport', type=int, default=4433, help='SSL listen port')
@ -29,6 +29,7 @@ define('policy', default='warning',
define('hostfile', default='', help='User defined host keys file')
define('syshostfile', default='', help='System wide host keys file')
define('tdstream', default='', help='Trusted downstream, separated by comma')
define('redirect', type=bool, default=True, help='Redirecting http to https')
define('fbidhttp', type=bool, default=True,
help='Forbid public plain http incoming requests')
define('xheaders', type=bool, default=True, help='Support xheaders')

View File

@ -50,6 +50,16 @@ def is_valid_port(port):
return 0 < port < 65536
def is_ip_hostname(hostname):
it = iter(hostname)
if next(it) == '[':
return True
for ch in it:
if ch != '.' and not ch.isdigit():
return False
return True
def is_valid_hostname(hostname):
if hostname[-1] == '.':
# strip exactly one dot from the right, if present