mirror of https://github.com/huashengdun/webssh
				
				
				
			Block requests not come from trusted_downstream and public non-https requests
							parent
							
								
									db3ee2b784
								
							
						
					
					
						commit
						77b6fbfd85
					
				| 
						 | 
				
			
			@ -15,6 +15,7 @@ matrix:
 | 
			
		|||
install:
 | 
			
		||||
  - pip install -r requirements.txt
 | 
			
		||||
  - pip install pytest pytest-cov codecov flake8
 | 
			
		||||
  - if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then pip install mock; fi
 | 
			
		||||
 | 
			
		||||
script:
 | 
			
		||||
  - pytest --cov=webssh
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,13 +1,57 @@
 | 
			
		|||
import unittest
 | 
			
		||||
import paramiko
 | 
			
		||||
 | 
			
		||||
from tornado.httpclient import HTTPRequest
 | 
			
		||||
from tornado.httputil import HTTPServerRequest
 | 
			
		||||
from tornado.web import HTTPError
 | 
			
		||||
from tests.utils import read_file, make_tests_data_path
 | 
			
		||||
from webssh.handler import MixinHandler, IndexHandler, InvalidValueError
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from unittest.mock import Mock
 | 
			
		||||
except ImportError:
 | 
			
		||||
    from mock import Mock
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestMixinHandler(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_is_forbidden(self):
 | 
			
		||||
        handler = MixinHandler()
 | 
			
		||||
        request = HTTPRequest('http://example.com/')
 | 
			
		||||
        handler.request = request
 | 
			
		||||
 | 
			
		||||
        context = Mock(
 | 
			
		||||
            address=('8.8.8.8', 8888),
 | 
			
		||||
            trusted_downstream=['127.0.0.1'],
 | 
			
		||||
            _orig_protocol='http'
 | 
			
		||||
        )
 | 
			
		||||
        request.connection = Mock(context=context)
 | 
			
		||||
        self.assertTrue(handler.is_forbidden())
 | 
			
		||||
 | 
			
		||||
        context = Mock(
 | 
			
		||||
            address=('8.8.8.8', 8888),
 | 
			
		||||
            trusted_downstream=[],
 | 
			
		||||
            _orig_protocol='http'
 | 
			
		||||
        )
 | 
			
		||||
        request.connection = Mock(context=context)
 | 
			
		||||
        self.assertTrue(handler.is_forbidden())
 | 
			
		||||
 | 
			
		||||
        context = Mock(
 | 
			
		||||
            address=('192.168.1.1', 8888),
 | 
			
		||||
            trusted_downstream=[],
 | 
			
		||||
            _orig_protocol='http'
 | 
			
		||||
        )
 | 
			
		||||
        request.connection = Mock(context=context)
 | 
			
		||||
        self.assertIsNone(handler.is_forbidden())
 | 
			
		||||
 | 
			
		||||
        context = Mock(
 | 
			
		||||
            address=('8.8.8.8', 8888),
 | 
			
		||||
            trusted_downstream=[],
 | 
			
		||||
            _orig_protocol='https'
 | 
			
		||||
        )
 | 
			
		||||
        request.connection = Mock(context=context)
 | 
			
		||||
        self.assertIsNone(handler.is_forbidden())
 | 
			
		||||
 | 
			
		||||
    def test_get_real_client_addr(self):
 | 
			
		||||
        x_forwarded_for = '1.1.1.1'
 | 
			
		||||
        x_forwarded_port = 1111
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop
 | 
			
		|||
from webssh.settings import swallow_http_errors
 | 
			
		||||
from webssh.utils import (
 | 
			
		||||
    is_valid_ip_address, is_valid_port, is_valid_hostname,
 | 
			
		||||
    to_bytes, to_str, to_int, UnicodeType
 | 
			
		||||
    to_bytes, to_str, to_int, to_ip_address, UnicodeType
 | 
			
		||||
)
 | 
			
		||||
from webssh.worker import Worker, recycle_worker, workers
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -39,6 +39,28 @@ class InvalidValueError(Exception):
 | 
			
		|||
 | 
			
		||||
class MixinHandler(object):
 | 
			
		||||
 | 
			
		||||
    def prepare(self):
 | 
			
		||||
        if self.is_forbidden():
 | 
			
		||||
            raise tornado.web.HTTPError(403)
 | 
			
		||||
 | 
			
		||||
    def is_forbidden(self):
 | 
			
		||||
        """
 | 
			
		||||
        Following requests are forbidden:
 | 
			
		||||
        * requests not come from trusted_downstream (if set).
 | 
			
		||||
        * non-https requests from a public network.
 | 
			
		||||
        """
 | 
			
		||||
        context = self.request.connection.context
 | 
			
		||||
        ip = context.address[0]
 | 
			
		||||
        lst = context.trusted_downstream
 | 
			
		||||
 | 
			
		||||
        if lst and ip not in lst:
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        if context._orig_protocol == 'http':
 | 
			
		||||
            ipaddr = to_ip_address(ip)
 | 
			
		||||
            if ipaddr.is_global:
 | 
			
		||||
                return True
 | 
			
		||||
 | 
			
		||||
    def set_default_headers(self):
 | 
			
		||||
        self.set_header('Server', 'TornadoServer')
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,7 +6,7 @@ from tornado.options import options
 | 
			
		|||
from webssh.handler import IndexHandler, WsockHandler
 | 
			
		||||
from webssh.settings import (
 | 
			
		||||
    get_app_settings,  get_host_keys_settings, get_policy_setting,
 | 
			
		||||
    get_ssl_context, max_body_size, xheaders
 | 
			
		||||
    get_ssl_context, get_server_settings
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -31,12 +31,12 @@ def main():
 | 
			
		|||
    loop = tornado.ioloop.IOLoop.current()
 | 
			
		||||
    app = make_app(make_handlers(loop, options), get_app_settings(options))
 | 
			
		||||
    ssl_ctx = get_ssl_context(options)
 | 
			
		||||
    kwargs = dict(xheaders=xheaders, max_body_size=max_body_size)
 | 
			
		||||
    app.listen(options.port, options.address, **kwargs)
 | 
			
		||||
    server_settings = get_server_settings(options)
 | 
			
		||||
    app.listen(options.port, options.address, **server_settings)
 | 
			
		||||
    logging.info('Listening on {}:{}'.format(options.address, options.port))
 | 
			
		||||
    if ssl_ctx:
 | 
			
		||||
        kwargs.update(ssl_options=ssl_ctx)
 | 
			
		||||
        app.listen(options.sslPort, options.sslAddress, **kwargs)
 | 
			
		||||
        server_settings.update(ssl_options=ssl_ctx)
 | 
			
		||||
        app.listen(options.sslPort, options.sslAddress, **server_settings)
 | 
			
		||||
        logging.info('Listening on ssl {}:{}'.format(options.sslAddress,
 | 
			
		||||
                                                     options.sslPort))
 | 
			
		||||
    loop.start()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -51,6 +51,15 @@ def get_app_settings(options):
 | 
			
		|||
    return settings
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_server_settings(options):
 | 
			
		||||
    settings = dict(
 | 
			
		||||
        xheaders=xheaders,
 | 
			
		||||
        max_body_size=max_body_size,
 | 
			
		||||
        trusted_downstream=get_trusted_downstream(options)
 | 
			
		||||
    )
 | 
			
		||||
    return settings
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_host_keys_settings(options):
 | 
			
		||||
    if not options.hostFile:
 | 
			
		||||
        host_keys_filename = os.path.join(base_dir, 'known_hosts')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue