mirror of https://github.com/huashengdun/webssh
Block requests not come from trusted_downstream and public non-https requests
parent
db3ee2b784
commit
77b6fbfd85
|
@ -15,6 +15,7 @@ matrix:
|
||||||
install:
|
install:
|
||||||
- pip install -r requirements.txt
|
- pip install -r requirements.txt
|
||||||
- pip install pytest pytest-cov codecov flake8
|
- pip install pytest pytest-cov codecov flake8
|
||||||
|
- if [[ $TRAVIS_PYTHON_VERSION == '2.7' ]]; then pip install mock; fi
|
||||||
|
|
||||||
script:
|
script:
|
||||||
- pytest --cov=webssh
|
- pytest --cov=webssh
|
||||||
|
|
|
@ -1,13 +1,57 @@
|
||||||
import unittest
|
import unittest
|
||||||
import paramiko
|
import paramiko
|
||||||
|
|
||||||
|
from tornado.httpclient import HTTPRequest
|
||||||
from tornado.httputil import HTTPServerRequest
|
from tornado.httputil import HTTPServerRequest
|
||||||
|
from tornado.web import HTTPError
|
||||||
from tests.utils import read_file, make_tests_data_path
|
from tests.utils import read_file, make_tests_data_path
|
||||||
from webssh.handler import MixinHandler, IndexHandler, InvalidValueError
|
from webssh.handler import MixinHandler, IndexHandler, InvalidValueError
|
||||||
|
|
||||||
|
try:
|
||||||
|
from unittest.mock import Mock
|
||||||
|
except ImportError:
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
|
|
||||||
class TestMixinHandler(unittest.TestCase):
|
class TestMixinHandler(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_is_forbidden(self):
|
||||||
|
handler = MixinHandler()
|
||||||
|
request = HTTPRequest('http://example.com/')
|
||||||
|
handler.request = request
|
||||||
|
|
||||||
|
context = Mock(
|
||||||
|
address=('8.8.8.8', 8888),
|
||||||
|
trusted_downstream=['127.0.0.1'],
|
||||||
|
_orig_protocol='http'
|
||||||
|
)
|
||||||
|
request.connection = Mock(context=context)
|
||||||
|
self.assertTrue(handler.is_forbidden())
|
||||||
|
|
||||||
|
context = Mock(
|
||||||
|
address=('8.8.8.8', 8888),
|
||||||
|
trusted_downstream=[],
|
||||||
|
_orig_protocol='http'
|
||||||
|
)
|
||||||
|
request.connection = Mock(context=context)
|
||||||
|
self.assertTrue(handler.is_forbidden())
|
||||||
|
|
||||||
|
context = Mock(
|
||||||
|
address=('192.168.1.1', 8888),
|
||||||
|
trusted_downstream=[],
|
||||||
|
_orig_protocol='http'
|
||||||
|
)
|
||||||
|
request.connection = Mock(context=context)
|
||||||
|
self.assertIsNone(handler.is_forbidden())
|
||||||
|
|
||||||
|
context = Mock(
|
||||||
|
address=('8.8.8.8', 8888),
|
||||||
|
trusted_downstream=[],
|
||||||
|
_orig_protocol='https'
|
||||||
|
)
|
||||||
|
request.connection = Mock(context=context)
|
||||||
|
self.assertIsNone(handler.is_forbidden())
|
||||||
|
|
||||||
def test_get_real_client_addr(self):
|
def test_get_real_client_addr(self):
|
||||||
x_forwarded_for = '1.1.1.1'
|
x_forwarded_for = '1.1.1.1'
|
||||||
x_forwarded_port = 1111
|
x_forwarded_port = 1111
|
||||||
|
|
|
@ -13,7 +13,7 @@ from tornado.ioloop import IOLoop
|
||||||
from webssh.settings import swallow_http_errors
|
from webssh.settings import swallow_http_errors
|
||||||
from webssh.utils import (
|
from webssh.utils import (
|
||||||
is_valid_ip_address, is_valid_port, is_valid_hostname,
|
is_valid_ip_address, is_valid_port, is_valid_hostname,
|
||||||
to_bytes, to_str, to_int, UnicodeType
|
to_bytes, to_str, to_int, to_ip_address, UnicodeType
|
||||||
)
|
)
|
||||||
from webssh.worker import Worker, recycle_worker, workers
|
from webssh.worker import Worker, recycle_worker, workers
|
||||||
|
|
||||||
|
@ -39,6 +39,28 @@ class InvalidValueError(Exception):
|
||||||
|
|
||||||
class MixinHandler(object):
|
class MixinHandler(object):
|
||||||
|
|
||||||
|
def prepare(self):
|
||||||
|
if self.is_forbidden():
|
||||||
|
raise tornado.web.HTTPError(403)
|
||||||
|
|
||||||
|
def is_forbidden(self):
|
||||||
|
"""
|
||||||
|
Following requests are forbidden:
|
||||||
|
* requests not come from trusted_downstream (if set).
|
||||||
|
* non-https requests from a public network.
|
||||||
|
"""
|
||||||
|
context = self.request.connection.context
|
||||||
|
ip = context.address[0]
|
||||||
|
lst = context.trusted_downstream
|
||||||
|
|
||||||
|
if lst and ip not in lst:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if context._orig_protocol == 'http':
|
||||||
|
ipaddr = to_ip_address(ip)
|
||||||
|
if ipaddr.is_global:
|
||||||
|
return True
|
||||||
|
|
||||||
def set_default_headers(self):
|
def set_default_headers(self):
|
||||||
self.set_header('Server', 'TornadoServer')
|
self.set_header('Server', 'TornadoServer')
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from tornado.options import options
|
||||||
from webssh.handler import IndexHandler, WsockHandler
|
from webssh.handler import IndexHandler, WsockHandler
|
||||||
from webssh.settings import (
|
from webssh.settings import (
|
||||||
get_app_settings, get_host_keys_settings, get_policy_setting,
|
get_app_settings, get_host_keys_settings, get_policy_setting,
|
||||||
get_ssl_context, max_body_size, xheaders
|
get_ssl_context, get_server_settings
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,12 +31,12 @@ def main():
|
||||||
loop = tornado.ioloop.IOLoop.current()
|
loop = tornado.ioloop.IOLoop.current()
|
||||||
app = make_app(make_handlers(loop, options), get_app_settings(options))
|
app = make_app(make_handlers(loop, options), get_app_settings(options))
|
||||||
ssl_ctx = get_ssl_context(options)
|
ssl_ctx = get_ssl_context(options)
|
||||||
kwargs = dict(xheaders=xheaders, max_body_size=max_body_size)
|
server_settings = get_server_settings(options)
|
||||||
app.listen(options.port, options.address, **kwargs)
|
app.listen(options.port, options.address, **server_settings)
|
||||||
logging.info('Listening on {}:{}'.format(options.address, options.port))
|
logging.info('Listening on {}:{}'.format(options.address, options.port))
|
||||||
if ssl_ctx:
|
if ssl_ctx:
|
||||||
kwargs.update(ssl_options=ssl_ctx)
|
server_settings.update(ssl_options=ssl_ctx)
|
||||||
app.listen(options.sslPort, options.sslAddress, **kwargs)
|
app.listen(options.sslPort, options.sslAddress, **server_settings)
|
||||||
logging.info('Listening on ssl {}:{}'.format(options.sslAddress,
|
logging.info('Listening on ssl {}:{}'.format(options.sslAddress,
|
||||||
options.sslPort))
|
options.sslPort))
|
||||||
loop.start()
|
loop.start()
|
||||||
|
|
|
@ -51,6 +51,15 @@ def get_app_settings(options):
|
||||||
return settings
|
return settings
|
||||||
|
|
||||||
|
|
||||||
|
def get_server_settings(options):
|
||||||
|
settings = dict(
|
||||||
|
xheaders=xheaders,
|
||||||
|
max_body_size=max_body_size,
|
||||||
|
trusted_downstream=get_trusted_downstream(options)
|
||||||
|
)
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
def get_host_keys_settings(options):
|
def get_host_keys_settings(options):
|
||||||
if not options.hostFile:
|
if not options.hostFile:
|
||||||
host_keys_filename = os.path.join(base_dir, 'known_hosts')
|
host_keys_filename = os.path.join(base_dir, 'known_hosts')
|
||||||
|
|
Loading…
Reference in New Issue