From 77b6fbfd8573b298de0a16dbbc175bffe151492f Mon Sep 17 00:00:00 2001
From: Sheng <webmaster0115@gmail.com>
Date: Mon, 15 Oct 2018 20:13:11 +0800
Subject: [PATCH] Block requests not come from trusted_downstream and public
 non-https requests

---
 .travis.yml           |  1 +
 tests/test_handler.py | 44 +++++++++++++++++++++++++++++++++++++++++++
 webssh/handler.py     | 24 ++++++++++++++++++++++-
 webssh/main.py        | 10 +++++-----
 webssh/settings.py    |  9 +++++++++
 5 files changed, 82 insertions(+), 6 deletions(-)

diff --git a/.travis.yml b/.travis.yml
index 4af2747..296c162 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -15,6 +15,7 @@ matrix:
 install:
   - pip install -r requirements.txt
   - pip install pytest pytest-cov codecov flake8
+  - if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then pip install mock; fi
 
 script:
   - pytest --cov=webssh
diff --git a/tests/test_handler.py b/tests/test_handler.py
index 51486e1..fde909f 100644
--- a/tests/test_handler.py
+++ b/tests/test_handler.py
@@ -1,13 +1,57 @@
 import unittest
 import paramiko
 
+from tornado.httpclient import HTTPRequest
 from tornado.httputil import HTTPServerRequest
+from tornado.web import HTTPError
 from tests.utils import read_file, make_tests_data_path
 from webssh.handler import MixinHandler, IndexHandler, InvalidValueError
 
+try:
+    from unittest.mock import Mock
+except ImportError:
+    from mock import Mock
+
 
 class TestMixinHandler(unittest.TestCase):
 
+    def test_is_forbidden(self):
+        handler = MixinHandler()
+        request = HTTPRequest('http://example.com/')
+        handler.request = request
+
+        context = Mock(
+            address=('8.8.8.8', 8888),
+            trusted_downstream=['127.0.0.1'],
+            _orig_protocol='http'
+        )
+        request.connection = Mock(context=context)
+        self.assertTrue(handler.is_forbidden())
+
+        context = Mock(
+            address=('8.8.8.8', 8888),
+            trusted_downstream=[],
+            _orig_protocol='http'
+        )
+        request.connection = Mock(context=context)
+        self.assertTrue(handler.is_forbidden())
+
+        context = Mock(
+            address=('192.168.1.1', 8888),
+            trusted_downstream=[],
+            _orig_protocol='http'
+        )
+        request.connection = Mock(context=context)
+        self.assertIsNone(handler.is_forbidden())
+
+        context = Mock(
+            address=('8.8.8.8', 8888),
+            trusted_downstream=[],
+            _orig_protocol='https'
+        )
+        request.connection = Mock(context=context)
+        self.assertIsNone(handler.is_forbidden())
+
     def test_get_real_client_addr(self):
         x_forwarded_for = '1.1.1.1'
         x_forwarded_port = 1111
diff --git a/webssh/handler.py b/webssh/handler.py
index 8563325..33b7ec6 100644
--- a/webssh/handler.py
+++ b/webssh/handler.py
@@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop
 from webssh.settings import swallow_http_errors
 from webssh.utils import (
     is_valid_ip_address, is_valid_port, is_valid_hostname,
-    to_bytes, to_str, to_int, UnicodeType
+    to_bytes, to_str, to_int, to_ip_address, UnicodeType
 )
 from webssh.worker import Worker, recycle_worker, workers
 
@@ -39,6 +39,28 @@ class InvalidValueError(Exception):
 
 class MixinHandler(object):
 
+    def prepare(self):
+        if self.is_forbidden():
+            raise tornado.web.HTTPError(403)
+
+    def is_forbidden(self):
+        """
+        Following requests are forbidden:
+        * requests not come from trusted_downstream (if set).
+        * non-https requests from a public network.
+        """
+        context = self.request.connection.context
+        ip = context.address[0]
+        lst = context.trusted_downstream
+
+        if lst and ip not in lst:
+            return True
+
+        if context._orig_protocol == 'http':
+            ipaddr = to_ip_address(ip)
+            if ipaddr.is_global:
+                return True
+
     def set_default_headers(self):
         self.set_header('Server', 'TornadoServer')
 
diff --git a/webssh/main.py b/webssh/main.py
index 64b2fdb..c914279 100644
--- a/webssh/main.py
+++ b/webssh/main.py
@@ -6,7 +6,7 @@ from tornado.options import options
 from webssh.handler import IndexHandler, WsockHandler
 from webssh.settings import (
     get_app_settings,  get_host_keys_settings, get_policy_setting,
-    get_ssl_context, max_body_size, xheaders
+    get_ssl_context, get_server_settings
 )
 
 
@@ -31,12 +31,12 @@ def main():
     loop = tornado.ioloop.IOLoop.current()
     app = make_app(make_handlers(loop, options), get_app_settings(options))
     ssl_ctx = get_ssl_context(options)
-    kwargs = dict(xheaders=xheaders, max_body_size=max_body_size)
-    app.listen(options.port, options.address, **kwargs)
+    server_settings = get_server_settings(options)
+    app.listen(options.port, options.address, **server_settings)
     logging.info('Listening on {}:{}'.format(options.address, options.port))
     if ssl_ctx:
-        kwargs.update(ssl_options=ssl_ctx)
-        app.listen(options.sslPort, options.sslAddress, **kwargs)
+        server_settings.update(ssl_options=ssl_ctx)
+        app.listen(options.sslPort, options.sslAddress, **server_settings)
         logging.info('Listening on ssl {}:{}'.format(options.sslAddress,
                                                      options.sslPort))
     loop.start()
diff --git a/webssh/settings.py b/webssh/settings.py
index 4a299ac..f2f38ef 100644
--- a/webssh/settings.py
+++ b/webssh/settings.py
@@ -51,6 +51,15 @@ def get_app_settings(options):
     return settings
 
 
+def get_server_settings(options):
+    settings = dict(
+        xheaders=xheaders,
+        max_body_size=max_body_size,
+        trusted_downstream=get_trusted_downstream(options)
+    )
+    return settings
+
+
 def get_host_keys_settings(options):
     if not options.hostFile:
         host_keys_filename = os.path.join(base_dir, 'known_hosts')