mirror of https://github.com/huashengdun/webssh
Support redirecting http to https
parent
40cf1095ff
commit
8e4039a24a
|
@ -19,35 +19,64 @@ class TestMixinHandler(unittest.TestCase):
|
||||||
def test_is_forbidden(self):
|
def test_is_forbidden(self):
|
||||||
handler = MixinHandler()
|
handler = MixinHandler()
|
||||||
open_to_public['http'] = True
|
open_to_public['http'] = True
|
||||||
|
open_to_public['https'] = True
|
||||||
options.fbidhttp = True
|
options.fbidhttp = True
|
||||||
|
options.redirect = True
|
||||||
|
|
||||||
context = Mock(
|
context = Mock(
|
||||||
address=('8.8.8.8', 8888),
|
address=('8.8.8.8', 8888),
|
||||||
trusted_downstream=['127.0.0.1'],
|
trusted_downstream=['127.0.0.1'],
|
||||||
_orig_protocol='http'
|
_orig_protocol='http'
|
||||||
)
|
)
|
||||||
self.assertTrue(handler.is_forbidden(context))
|
self.assertTrue(handler.is_forbidden(context, ''))
|
||||||
|
|
||||||
context = Mock(
|
context = Mock(
|
||||||
address=('8.8.8.8', 8888),
|
address=('8.8.8.8', 8888),
|
||||||
trusted_downstream=[],
|
trusted_downstream=[],
|
||||||
_orig_protocol='http'
|
_orig_protocol='http'
|
||||||
)
|
)
|
||||||
self.assertTrue(handler.is_forbidden(context))
|
|
||||||
|
hostname = 'www.google.com'
|
||||||
|
self.assertEqual(handler.is_forbidden(context, hostname), False)
|
||||||
|
|
||||||
context = Mock(
|
context = Mock(
|
||||||
address=('192.168.1.1', 8888),
|
address=('192.168.1.1', 8888),
|
||||||
trusted_downstream=[],
|
trusted_downstream=[],
|
||||||
_orig_protocol='http'
|
_orig_protocol='http'
|
||||||
)
|
)
|
||||||
self.assertIsNone(handler.is_forbidden(context))
|
self.assertIsNone(handler.is_forbidden(context, ''))
|
||||||
|
|
||||||
context = Mock(
|
context = Mock(
|
||||||
address=('8.8.8.8', 8888),
|
address=('8.8.8.8', 8888),
|
||||||
trusted_downstream=[],
|
trusted_downstream=[],
|
||||||
_orig_protocol='https'
|
_orig_protocol='https'
|
||||||
)
|
)
|
||||||
self.assertIsNone(handler.is_forbidden(context))
|
self.assertIsNone(handler.is_forbidden(context, ''))
|
||||||
|
|
||||||
|
context = Mock(
|
||||||
|
address=('8.8.8.8', 8888),
|
||||||
|
trusted_downstream=[],
|
||||||
|
_orig_protocol='http'
|
||||||
|
)
|
||||||
|
hostname = '8.8.8.8'
|
||||||
|
self.assertTrue(handler.is_forbidden(context, hostname))
|
||||||
|
|
||||||
|
def test_get_redirect_url(self):
|
||||||
|
handler = MixinHandler()
|
||||||
|
hostname = 'www.example.com'
|
||||||
|
uri = '/'
|
||||||
|
port = 443
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
handler.get_redirect_url(hostname, port, uri=uri),
|
||||||
|
'https://www.example.com/'
|
||||||
|
)
|
||||||
|
|
||||||
|
port = 4433
|
||||||
|
self.assertTrue(
|
||||||
|
handler.get_redirect_url(hostname, port, uri),
|
||||||
|
'https://www.example.com:4433/'
|
||||||
|
)
|
||||||
|
|
||||||
def test_get_client_addr(self):
|
def test_get_client_addr(self):
|
||||||
handler = MixinHandler()
|
handler = MixinHandler()
|
||||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
||||||
|
|
||||||
from webssh.utils import (
|
from webssh.utils import (
|
||||||
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
|
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
|
||||||
to_int, on_public_network_interface, get_ips_by_name,
|
to_int, on_public_network_interface, get_ips_by_name, is_ip_hostname,
|
||||||
is_name_open_to_public
|
is_name_open_to_public
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -73,3 +73,9 @@ class TestUitls(unittest.TestCase):
|
||||||
self.assertIsNone(is_name_open_to_public('192.168.1.1'))
|
self.assertIsNone(is_name_open_to_public('192.168.1.1'))
|
||||||
self.assertIsNone(is_name_open_to_public('127.0.0.1'))
|
self.assertIsNone(is_name_open_to_public('127.0.0.1'))
|
||||||
self.assertIsNone(is_name_open_to_public('localhost'))
|
self.assertIsNone(is_name_open_to_public('localhost'))
|
||||||
|
|
||||||
|
def test_is_ip_hostname(self):
|
||||||
|
self.assertTrue(is_ip_hostname('[::1]'))
|
||||||
|
self.assertTrue(is_ip_hostname('127.0.0.1'))
|
||||||
|
self.assertFalse(is_ip_hostname('localhost'))
|
||||||
|
self.assertFalse(is_ip_hostname('www.google.com'))
|
||||||
|
|
|
@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop
|
||||||
from tornado.options import options
|
from tornado.options import options
|
||||||
from webssh.utils import (
|
from webssh.utils import (
|
||||||
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
|
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
|
||||||
to_int, to_ip_address, UnicodeType, is_name_open_to_public
|
to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname
|
||||||
)
|
)
|
||||||
from webssh.worker import Worker, recycle_worker, workers
|
from webssh.worker import Worker, recycle_worker, workers
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ DEFAULT_PORT = 22
|
||||||
|
|
||||||
swallow_http_errors = True
|
swallow_http_errors = True
|
||||||
|
|
||||||
# status of the http(s) server
|
# set by config_open_to_public
|
||||||
open_to_public = {
|
open_to_public = {
|
||||||
'http': None,
|
'http': None,
|
||||||
'https': None
|
'https': None
|
||||||
|
@ -56,22 +56,28 @@ class MixinHandler(object):
|
||||||
'Server': 'TornadoServer'
|
'Server': 'TornadoServer'
|
||||||
}
|
}
|
||||||
|
|
||||||
def initialize(self, loop=None):
|
html = ('<html><head><title>{code} {reason}</title></head><body>{code} '
|
||||||
conn = self.request.connection
|
'{reason}</body></html>')
|
||||||
if self.is_forbidden(conn.context):
|
|
||||||
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
|
|
||||||
conn.stream.write(to_bytes(result))
|
|
||||||
conn.close()
|
|
||||||
raise ValueError('Accesss denied')
|
|
||||||
self.loop = loop
|
|
||||||
self.context = conn.context
|
|
||||||
|
|
||||||
def is_forbidden(self, context):
|
def initialize(self, loop=None):
|
||||||
"""
|
context = self.request.connection.context
|
||||||
Following requests are forbidden:
|
result = self.is_forbidden(context, self.request.host_name)
|
||||||
* requests not come from trusted_downstream (if set).
|
self._transforms = []
|
||||||
* plain http requests from a public network.
|
if result:
|
||||||
"""
|
self.set_status(403)
|
||||||
|
self.finish(
|
||||||
|
self.html.format(code=self._status_code, reason=self._reason)
|
||||||
|
)
|
||||||
|
elif result is False:
|
||||||
|
to_url = self.get_redirect_url(
|
||||||
|
self.request.host_name, options.sslport, self.request.uri
|
||||||
|
)
|
||||||
|
self.redirect(to_url, permanent=True)
|
||||||
|
else:
|
||||||
|
self.loop = loop
|
||||||
|
self.context = context
|
||||||
|
|
||||||
|
def is_forbidden(self, context, hostname):
|
||||||
ip = context.address[0]
|
ip = context.address[0]
|
||||||
lst = context.trusted_downstream
|
lst = context.trusted_downstream
|
||||||
|
|
||||||
|
@ -81,13 +87,20 @@ class MixinHandler(object):
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if open_to_public['http'] and options.fbidhttp:
|
if open_to_public['http'] and context._orig_protocol == 'http':
|
||||||
if context._orig_protocol == 'http':
|
if not to_ip_address(ip).is_private:
|
||||||
ipaddr = to_ip_address(ip)
|
if open_to_public['https'] and options.redirect:
|
||||||
if not ipaddr.is_private:
|
if not is_ip_hostname(hostname):
|
||||||
|
# redirecting
|
||||||
|
return False
|
||||||
|
if options.fbidhttp:
|
||||||
logging.warning('Public plain http request is forbidden.')
|
logging.warning('Public plain http request is forbidden.')
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def get_redirect_url(self, hostname, port, uri):
|
||||||
|
port = '' if port == 443 else ':%s' % port
|
||||||
|
return 'https://{}{}{}'.format(hostname, port, uri)
|
||||||
|
|
||||||
def set_default_headers(self):
|
def set_default_headers(self):
|
||||||
for header in self.custom_headers.items():
|
for header in self.custom_headers.items():
|
||||||
self.set_header(*header)
|
self.set_header(*header)
|
||||||
|
|
|
@ -33,8 +33,7 @@ def app_listen(app, port, address, server_settings):
|
||||||
app.listen(port, address, **server_settings)
|
app.listen(port, address, **server_settings)
|
||||||
server_type = 'https' if server_settings.get('ssl_options') else 'http'
|
server_type = 'https' if server_settings.get('ssl_options') else 'http'
|
||||||
logging.info(
|
logging.info(
|
||||||
'Started a {} server listening on {}:{}'.format(
|
'Listening on {}:{} ({})'.format(address, port, server_type)
|
||||||
server_type, address, port)
|
|
||||||
)
|
)
|
||||||
config_open_to_public(address, server_type)
|
config_open_to_public(address, server_type)
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ def print_version(flag):
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
define('address', default='127.0.0.1', help='Listen address')
|
define('address', default='0.0.0.0', help='Listen address')
|
||||||
define('port', type=int, default=8888, help='Listen port')
|
define('port', type=int, default=8888, help='Listen port')
|
||||||
define('ssladdress', default='0.0.0.0', help='SSL listen address')
|
define('ssladdress', default='0.0.0.0', help='SSL listen address')
|
||||||
define('sslport', type=int, default=4433, help='SSL listen port')
|
define('sslport', type=int, default=4433, help='SSL listen port')
|
||||||
|
@ -29,6 +29,7 @@ define('policy', default='warning',
|
||||||
define('hostfile', default='', help='User defined host keys file')
|
define('hostfile', default='', help='User defined host keys file')
|
||||||
define('syshostfile', default='', help='System wide host keys file')
|
define('syshostfile', default='', help='System wide host keys file')
|
||||||
define('tdstream', default='', help='Trusted downstream, separated by comma')
|
define('tdstream', default='', help='Trusted downstream, separated by comma')
|
||||||
|
define('redirect', type=bool, default=True, help='Redirecting http to https')
|
||||||
define('fbidhttp', type=bool, default=True,
|
define('fbidhttp', type=bool, default=True,
|
||||||
help='Forbid public plain http incoming requests')
|
help='Forbid public plain http incoming requests')
|
||||||
define('xheaders', type=bool, default=True, help='Support xheaders')
|
define('xheaders', type=bool, default=True, help='Support xheaders')
|
||||||
|
|
|
@ -50,6 +50,16 @@ def is_valid_port(port):
|
||||||
return 0 < port < 65536
|
return 0 < port < 65536
|
||||||
|
|
||||||
|
|
||||||
|
def is_ip_hostname(hostname):
|
||||||
|
it = iter(hostname)
|
||||||
|
if next(it) == '[':
|
||||||
|
return True
|
||||||
|
for ch in it:
|
||||||
|
if ch != '.' and not ch.isdigit():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def is_valid_hostname(hostname):
|
def is_valid_hostname(hostname):
|
||||||
if hostname[-1] == '.':
|
if hostname[-1] == '.':
|
||||||
# strip exactly one dot from the right, if present
|
# strip exactly one dot from the right, if present
|
||||||
|
|
Loading…
Reference in New Issue