diff --git a/tests/test_handler.py b/tests/test_handler.py
index b03f24b..dcffb34 100644
--- a/tests/test_handler.py
+++ b/tests/test_handler.py
@@ -19,35 +19,64 @@ class TestMixinHandler(unittest.TestCase):
def test_is_forbidden(self):
handler = MixinHandler()
open_to_public['http'] = True
+ open_to_public['https'] = True
options.fbidhttp = True
+ options.redirect = True
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=['127.0.0.1'],
_orig_protocol='http'
)
- self.assertTrue(handler.is_forbidden(context))
+ self.assertTrue(handler.is_forbidden(context, ''))
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
- self.assertTrue(handler.is_forbidden(context))
+
+ hostname = 'www.google.com'
+ self.assertEqual(handler.is_forbidden(context, hostname), False)
context = Mock(
address=('192.168.1.1', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
- self.assertIsNone(handler.is_forbidden(context))
+ self.assertIsNone(handler.is_forbidden(context, ''))
context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='https'
)
- self.assertIsNone(handler.is_forbidden(context))
+ self.assertIsNone(handler.is_forbidden(context, ''))
+
+ context = Mock(
+ address=('8.8.8.8', 8888),
+ trusted_downstream=[],
+ _orig_protocol='http'
+ )
+ hostname = '8.8.8.8'
+ self.assertTrue(handler.is_forbidden(context, hostname))
+
+ def test_get_redirect_url(self):
+ handler = MixinHandler()
+ hostname = 'www.example.com'
+ uri = '/'
+ port = 443
+
+ self.assertTrue(
+ handler.get_redirect_url(hostname, port, uri=uri),
+ 'https://www.example.com/'
+ )
+
+ port = 4433
+ self.assertTrue(
+ handler.get_redirect_url(hostname, port, uri),
+ 'https://www.example.com:4433/'
+ )
def test_get_client_addr(self):
handler = MixinHandler()
diff --git a/tests/test_utils.py b/tests/test_utils.py
index a8c4e80..24b393c 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -2,7 +2,7 @@ import unittest
from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
- to_int, on_public_network_interface, get_ips_by_name,
+ to_int, on_public_network_interface, get_ips_by_name, is_ip_hostname,
is_name_open_to_public
)
@@ -73,3 +73,9 @@ class TestUitls(unittest.TestCase):
self.assertIsNone(is_name_open_to_public('192.168.1.1'))
self.assertIsNone(is_name_open_to_public('127.0.0.1'))
self.assertIsNone(is_name_open_to_public('localhost'))
+
+ def test_is_ip_hostname(self):
+ self.assertTrue(is_ip_hostname('[::1]'))
+ self.assertTrue(is_ip_hostname('127.0.0.1'))
+ self.assertFalse(is_ip_hostname('localhost'))
+ self.assertFalse(is_ip_hostname('www.google.com'))
diff --git a/webssh/handler.py b/webssh/handler.py
index 2423c71..09665f8 100644
--- a/webssh/handler.py
+++ b/webssh/handler.py
@@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop
from tornado.options import options
from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
- to_int, to_ip_address, UnicodeType, is_name_open_to_public
+ to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname
)
from webssh.worker import Worker, recycle_worker, workers
@@ -34,7 +34,7 @@ DEFAULT_PORT = 22
swallow_http_errors = True
-# status of the http(s) server
+# set by config_open_to_public
open_to_public = {
'http': None,
'https': None
@@ -56,22 +56,28 @@ class MixinHandler(object):
'Server': 'TornadoServer'
}
+ html = ('
{code} {reason}{code} '
+ '{reason}')
+
def initialize(self, loop=None):
- conn = self.request.connection
- if self.is_forbidden(conn.context):
- result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
- conn.stream.write(to_bytes(result))
- conn.close()
- raise ValueError('Accesss denied')
- self.loop = loop
- self.context = conn.context
-
- def is_forbidden(self, context):
- """
- Following requests are forbidden:
- * requests not come from trusted_downstream (if set).
- * plain http requests from a public network.
- """
+ context = self.request.connection.context
+ result = self.is_forbidden(context, self.request.host_name)
+ self._transforms = []
+ if result:
+ self.set_status(403)
+ self.finish(
+ self.html.format(code=self._status_code, reason=self._reason)
+ )
+ elif result is False:
+ to_url = self.get_redirect_url(
+ self.request.host_name, options.sslport, self.request.uri
+ )
+ self.redirect(to_url, permanent=True)
+ else:
+ self.loop = loop
+ self.context = context
+
+ def is_forbidden(self, context, hostname):
ip = context.address[0]
lst = context.trusted_downstream
@@ -81,13 +87,20 @@ class MixinHandler(object):
)
return True
- if open_to_public['http'] and options.fbidhttp:
- if context._orig_protocol == 'http':
- ipaddr = to_ip_address(ip)
- if not ipaddr.is_private:
+ if open_to_public['http'] and context._orig_protocol == 'http':
+ if not to_ip_address(ip).is_private:
+ if open_to_public['https'] and options.redirect:
+ if not is_ip_hostname(hostname):
+ # redirecting
+ return False
+ if options.fbidhttp:
logging.warning('Public plain http request is forbidden.')
return True
+ def get_redirect_url(self, hostname, port, uri):
+ port = '' if port == 443 else ':%s' % port
+ return 'https://{}{}{}'.format(hostname, port, uri)
+
def set_default_headers(self):
for header in self.custom_headers.items():
self.set_header(*header)
diff --git a/webssh/main.py b/webssh/main.py
index 2cfb442..9a62156 100644
--- a/webssh/main.py
+++ b/webssh/main.py
@@ -33,8 +33,7 @@ def app_listen(app, port, address, server_settings):
app.listen(port, address, **server_settings)
server_type = 'https' if server_settings.get('ssl_options') else 'http'
logging.info(
- 'Started a {} server listening on {}:{}'.format(
- server_type, address, port)
+ 'Listening on {}:{} ({})'.format(address, port, server_type)
)
config_open_to_public(address, server_type)
diff --git a/webssh/settings.py b/webssh/settings.py
index 11bdeac..4054fba 100644
--- a/webssh/settings.py
+++ b/webssh/settings.py
@@ -17,7 +17,7 @@ def print_version(flag):
sys.exit(0)
-define('address', default='127.0.0.1', help='Listen address')
+define('address', default='0.0.0.0', help='Listen address')
define('port', type=int, default=8888, help='Listen port')
define('ssladdress', default='0.0.0.0', help='SSL listen address')
define('sslport', type=int, default=4433, help='SSL listen port')
@@ -29,6 +29,7 @@ define('policy', default='warning',
define('hostfile', default='', help='User defined host keys file')
define('syshostfile', default='', help='System wide host keys file')
define('tdstream', default='', help='Trusted downstream, separated by comma')
+define('redirect', type=bool, default=True, help='Redirecting http to https')
define('fbidhttp', type=bool, default=True,
help='Forbid public plain http incoming requests')
define('xheaders', type=bool, default=True, help='Support xheaders')
diff --git a/webssh/utils.py b/webssh/utils.py
index bc4799e..e71bb98 100644
--- a/webssh/utils.py
+++ b/webssh/utils.py
@@ -50,6 +50,16 @@ def is_valid_port(port):
return 0 < port < 65536
+def is_ip_hostname(hostname):
+ it = iter(hostname)
+ if next(it) == '[':
+ return True
+ for ch in it:
+ if ch != '.' and not ch.isdigit():
+ return False
+ return True
+
+
def is_valid_hostname(hostname):
if hostname[-1] == '.':
# strip exactly one dot from the right, if present