mirror of https://github.com/huashengdun/webssh
Block requests not come from trusted_downstream and public non-https requests
parent
db3ee2b784
commit
77b6fbfd85
|
@ -15,6 +15,7 @@ matrix:
|
|||
install:
|
||||
- pip install -r requirements.txt
|
||||
- pip install pytest pytest-cov codecov flake8
|
||||
- if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then pip install mock; fi
|
||||
|
||||
script:
|
||||
- pytest --cov=webssh
|
||||
|
|
|
@ -1,13 +1,57 @@
|
|||
import unittest
|
||||
import paramiko
|
||||
|
||||
from tornado.httpclient import HTTPRequest
|
||||
from tornado.httputil import HTTPServerRequest
|
||||
from tornado.web import HTTPError
|
||||
from tests.utils import read_file, make_tests_data_path
|
||||
from webssh.handler import MixinHandler, IndexHandler, InvalidValueError
|
||||
|
||||
try:
|
||||
from unittest.mock import Mock
|
||||
except ImportError:
|
||||
from mock import Mock
|
||||
|
||||
|
||||
class TestMixinHandler(unittest.TestCase):
|
||||
|
||||
def test_is_forbidden(self):
|
||||
handler = MixinHandler()
|
||||
request = HTTPRequest('http://example.com/')
|
||||
handler.request = request
|
||||
|
||||
context = Mock(
|
||||
address=('8.8.8.8', 8888),
|
||||
trusted_downstream=['127.0.0.1'],
|
||||
_orig_protocol='http'
|
||||
)
|
||||
request.connection = Mock(context=context)
|
||||
self.assertTrue(handler.is_forbidden())
|
||||
|
||||
context = Mock(
|
||||
address=('8.8.8.8', 8888),
|
||||
trusted_downstream=[],
|
||||
_orig_protocol='http'
|
||||
)
|
||||
request.connection = Mock(context=context)
|
||||
self.assertTrue(handler.is_forbidden())
|
||||
|
||||
context = Mock(
|
||||
address=('192.168.1.1', 8888),
|
||||
trusted_downstream=[],
|
||||
_orig_protocol='http'
|
||||
)
|
||||
request.connection = Mock(context=context)
|
||||
self.assertIsNone(handler.is_forbidden())
|
||||
|
||||
context = Mock(
|
||||
address=('8.8.8.8', 8888),
|
||||
trusted_downstream=[],
|
||||
_orig_protocol='https'
|
||||
)
|
||||
request.connection = Mock(context=context)
|
||||
self.assertIsNone(handler.is_forbidden())
|
||||
|
||||
def test_get_real_client_addr(self):
|
||||
x_forwarded_for = '1.1.1.1'
|
||||
x_forwarded_port = 1111
|
||||
|
|
|
@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop
|
|||
from webssh.settings import swallow_http_errors
|
||||
from webssh.utils import (
|
||||
is_valid_ip_address, is_valid_port, is_valid_hostname,
|
||||
to_bytes, to_str, to_int, UnicodeType
|
||||
to_bytes, to_str, to_int, to_ip_address, UnicodeType
|
||||
)
|
||||
from webssh.worker import Worker, recycle_worker, workers
|
||||
|
||||
|
@ -39,6 +39,28 @@ class InvalidValueError(Exception):
|
|||
|
||||
class MixinHandler(object):
|
||||
|
||||
def prepare(self):
|
||||
if self.is_forbidden():
|
||||
raise tornado.web.HTTPError(403)
|
||||
|
||||
def is_forbidden(self):
|
||||
"""
|
||||
Following requests are forbidden:
|
||||
* requests not come from trusted_downstream (if set).
|
||||
* non-https requests from a public network.
|
||||
"""
|
||||
context = self.request.connection.context
|
||||
ip = context.address[0]
|
||||
lst = context.trusted_downstream
|
||||
|
||||
if lst and ip not in lst:
|
||||
return True
|
||||
|
||||
if context._orig_protocol == 'http':
|
||||
ipaddr = to_ip_address(ip)
|
||||
if ipaddr.is_global:
|
||||
return True
|
||||
|
||||
def set_default_headers(self):
|
||||
self.set_header('Server', 'TornadoServer')
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from tornado.options import options
|
|||
from webssh.handler import IndexHandler, WsockHandler
|
||||
from webssh.settings import (
|
||||
get_app_settings, get_host_keys_settings, get_policy_setting,
|
||||
get_ssl_context, max_body_size, xheaders
|
||||
get_ssl_context, get_server_settings
|
||||
)
|
||||
|
||||
|
||||
|
@ -31,12 +31,12 @@ def main():
|
|||
loop = tornado.ioloop.IOLoop.current()
|
||||
app = make_app(make_handlers(loop, options), get_app_settings(options))
|
||||
ssl_ctx = get_ssl_context(options)
|
||||
kwargs = dict(xheaders=xheaders, max_body_size=max_body_size)
|
||||
app.listen(options.port, options.address, **kwargs)
|
||||
server_settings = get_server_settings(options)
|
||||
app.listen(options.port, options.address, **server_settings)
|
||||
logging.info('Listening on {}:{}'.format(options.address, options.port))
|
||||
if ssl_ctx:
|
||||
kwargs.update(ssl_options=ssl_ctx)
|
||||
app.listen(options.sslPort, options.sslAddress, **kwargs)
|
||||
server_settings.update(ssl_options=ssl_ctx)
|
||||
app.listen(options.sslPort, options.sslAddress, **server_settings)
|
||||
logging.info('Listening on ssl {}:{}'.format(options.sslAddress,
|
||||
options.sslPort))
|
||||
loop.start()
|
||||
|
|
|
@ -51,6 +51,15 @@ def get_app_settings(options):
|
|||
return settings
|
||||
|
||||
|
||||
def get_server_settings(options):
|
||||
settings = dict(
|
||||
xheaders=xheaders,
|
||||
max_body_size=max_body_size,
|
||||
trusted_downstream=get_trusted_downstream(options)
|
||||
)
|
||||
return settings
|
||||
|
||||
|
||||
def get_host_keys_settings(options):
|
||||
if not options.hostFile:
|
||||
host_keys_filename = os.path.join(base_dir, 'known_hosts')
|
||||
|
|
Loading…
Reference in New Issue