mirror of https://github.com/huashengdun/webssh
Added a command line option xheaders
parent
a8a444d7ed
commit
68468585ee
|
@ -1,7 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
import paramiko
|
import paramiko
|
||||||
|
|
||||||
from tornado.httpclient import HTTPRequest
|
|
||||||
from tornado.httputil import HTTPServerRequest
|
from tornado.httputil import HTTPServerRequest
|
||||||
from tornado.options import options
|
from tornado.options import options
|
||||||
from tests.utils import read_file, make_tests_data_path
|
from tests.utils import read_file, make_tests_data_path
|
||||||
|
@ -17,42 +16,55 @@ class TestMixinHandler(unittest.TestCase):
|
||||||
|
|
||||||
def test_is_forbidden(self):
|
def test_is_forbidden(self):
|
||||||
handler = MixinHandler()
|
handler = MixinHandler()
|
||||||
request = HTTPRequest('http://example.com/')
|
|
||||||
handler.request = request
|
|
||||||
options.fbidhttp = True
|
options.fbidhttp = True
|
||||||
|
|
||||||
context = Mock(
|
handler.context = Mock(
|
||||||
address=('8.8.8.8', 8888),
|
address=('8.8.8.8', 8888),
|
||||||
trusted_downstream=['127.0.0.1'],
|
trusted_downstream=['127.0.0.1'],
|
||||||
_orig_protocol='http'
|
_orig_protocol='http'
|
||||||
)
|
)
|
||||||
request.connection = Mock(context=context)
|
|
||||||
self.assertTrue(handler.is_forbidden())
|
self.assertTrue(handler.is_forbidden())
|
||||||
|
|
||||||
context = Mock(
|
handler.context = Mock(
|
||||||
address=('8.8.8.8', 8888),
|
address=('8.8.8.8', 8888),
|
||||||
trusted_downstream=[],
|
trusted_downstream=[],
|
||||||
_orig_protocol='http'
|
_orig_protocol='http'
|
||||||
)
|
)
|
||||||
request.connection = Mock(context=context)
|
|
||||||
self.assertTrue(handler.is_forbidden())
|
self.assertTrue(handler.is_forbidden())
|
||||||
|
|
||||||
context = Mock(
|
handler.context = Mock(
|
||||||
address=('192.168.1.1', 8888),
|
address=('192.168.1.1', 8888),
|
||||||
trusted_downstream=[],
|
trusted_downstream=[],
|
||||||
_orig_protocol='http'
|
_orig_protocol='http'
|
||||||
)
|
)
|
||||||
request.connection = Mock(context=context)
|
|
||||||
self.assertIsNone(handler.is_forbidden())
|
self.assertIsNone(handler.is_forbidden())
|
||||||
|
|
||||||
context = Mock(
|
handler.context = Mock(
|
||||||
address=('8.8.8.8', 8888),
|
address=('8.8.8.8', 8888),
|
||||||
trusted_downstream=[],
|
trusted_downstream=[],
|
||||||
_orig_protocol='https'
|
_orig_protocol='https'
|
||||||
)
|
)
|
||||||
request.connection = Mock(context=context)
|
|
||||||
self.assertIsNone(handler.is_forbidden())
|
self.assertIsNone(handler.is_forbidden())
|
||||||
|
|
||||||
|
def test_get_client_addr(self):
|
||||||
|
handler = MixinHandler()
|
||||||
|
client_addr = ('8.8.8.8', 8888)
|
||||||
|
context_addr = ('127.0.0.1', 1234)
|
||||||
|
options.xheaders = True
|
||||||
|
|
||||||
|
handler.context = Mock(address=context_addr)
|
||||||
|
handler.get_real_client_addr = lambda: None
|
||||||
|
self.assertEqual(handler.get_client_addr(), context_addr)
|
||||||
|
|
||||||
|
handler.context = Mock(address=context_addr)
|
||||||
|
handler.get_real_client_addr = lambda: client_addr
|
||||||
|
self.assertEqual(handler.get_client_addr(), client_addr)
|
||||||
|
|
||||||
|
options.xheaders = False
|
||||||
|
handler.context = Mock(address=context_addr)
|
||||||
|
handler.get_real_client_addr = lambda: client_addr
|
||||||
|
self.assertEqual(handler.get_client_addr(), context_addr)
|
||||||
|
|
||||||
def test_get_real_client_addr(self):
|
def test_get_real_client_addr(self):
|
||||||
x_forwarded_for = '1.1.1.1'
|
x_forwarded_for = '1.1.1.1'
|
||||||
x_forwarded_port = 1111
|
x_forwarded_port = 1111
|
||||||
|
|
|
@ -46,19 +46,21 @@ class MixinHandler(object):
|
||||||
}
|
}
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
|
conn = self.request.connection
|
||||||
|
self.context = conn.context
|
||||||
if self.is_forbidden():
|
if self.is_forbidden():
|
||||||
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
|
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
|
||||||
self.request.connection.stream.write(to_bytes(result))
|
conn.stream.write(to_bytes(result))
|
||||||
self.request.connection.close()
|
conn.close()
|
||||||
raise ValueError('Accesss denied')
|
raise ValueError('Accesss denied')
|
||||||
|
|
||||||
def is_forbidden(self):
|
def is_forbidden(self):
|
||||||
"""
|
"""
|
||||||
Following requests are forbidden:
|
Following requests are forbidden:
|
||||||
* requests not come from trusted_downstream (if set).
|
* requests not come from trusted_downstream (if set).
|
||||||
* non-https requests from a public network.
|
* plain http requests from a public network.
|
||||||
"""
|
"""
|
||||||
context = self.request.connection.context
|
context = self.context
|
||||||
ip = context.address[0]
|
ip = context.address[0]
|
||||||
lst = context.trusted_downstream
|
lst = context.trusted_downstream
|
||||||
|
|
||||||
|
@ -71,7 +73,7 @@ class MixinHandler(object):
|
||||||
if options.fbidhttp and context._orig_protocol == 'http':
|
if options.fbidhttp and context._orig_protocol == 'http':
|
||||||
ipaddr = to_ip_address(ip)
|
ipaddr = to_ip_address(ip)
|
||||||
if not ipaddr.is_private:
|
if not ipaddr.is_private:
|
||||||
logging.warning('Public non-https request is forbidden.')
|
logging.warning('Public plain http request is forbidden.')
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def set_default_headers(self):
|
def set_default_headers(self):
|
||||||
|
@ -85,8 +87,10 @@ class MixinHandler(object):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def get_client_addr(self):
|
def get_client_addr(self):
|
||||||
return self.get_real_client_addr() or self.request.connection.context.\
|
if options.xheaders:
|
||||||
address
|
return self.get_real_client_addr() or self.context.address
|
||||||
|
else:
|
||||||
|
return self.context.address
|
||||||
|
|
||||||
def get_real_client_addr(self):
|
def get_real_client_addr(self):
|
||||||
ip = self.request.remote_ip
|
ip = self.request.remote_ip
|
||||||
|
|
|
@ -30,8 +30,10 @@ define('policy', default='warning',
|
||||||
help='Missing host key policy, reject|autoadd|warning')
|
help='Missing host key policy, reject|autoadd|warning')
|
||||||
define('hostfile', default='', help='User defined host keys file')
|
define('hostfile', default='', help='User defined host keys file')
|
||||||
define('syshostfile', default='', help='System wide host keys file')
|
define('syshostfile', default='', help='System wide host keys file')
|
||||||
define('tdstream', default='', help='trusted downstream, separated by comma')
|
define('tdstream', default='', help='Trusted downstream, separated by comma')
|
||||||
define('fbidhttp', type=bool, default=True, help='forbid public http request')
|
define('fbidhttp', type=bool, default=True,
|
||||||
|
help='Forbid public plain http incoming requests')
|
||||||
|
define('xheaders', type=bool, default=True, help='Support xheaders')
|
||||||
define('wpintvl', type=int, default=0, help='Websocket ping interval')
|
define('wpintvl', type=int, default=0, help='Websocket ping interval')
|
||||||
define('version', type=bool, help='Show version information',
|
define('version', type=bool, help='Show version information',
|
||||||
callback=print_version)
|
callback=print_version)
|
||||||
|
@ -39,7 +41,6 @@ define('version', type=bool, help='Show version information',
|
||||||
|
|
||||||
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
max_body_size = 1 * 1024 * 1024
|
max_body_size = 1 * 1024 * 1024
|
||||||
xheaders = True
|
|
||||||
|
|
||||||
|
|
||||||
def get_app_settings(options):
|
def get_app_settings(options):
|
||||||
|
@ -55,7 +56,7 @@ def get_app_settings(options):
|
||||||
|
|
||||||
def get_server_settings(options):
|
def get_server_settings(options):
|
||||||
settings = dict(
|
settings = dict(
|
||||||
xheaders=xheaders,
|
xheaders=options.xheaders,
|
||||||
max_body_size=max_body_size,
|
max_body_size=max_body_size,
|
||||||
trusted_downstream=get_trusted_downstream(options)
|
trusted_downstream=get_trusted_downstream(options)
|
||||||
)
|
)
|
||||||
|
@ -121,4 +122,4 @@ def detect_is_open_to_public(options):
|
||||||
result = on_public_network_interfaces(get_ips_by_name(options.address))
|
result = on_public_network_interfaces(get_ips_by_name(options.address))
|
||||||
if not result and options.fbidhttp:
|
if not result and options.fbidhttp:
|
||||||
options.fbidhttp = False
|
options.fbidhttp = False
|
||||||
logging.info('Forbid public http: {}'.format(options.fbidhttp))
|
logging.info('Forbid public plain http: {}'.format(options.fbidhttp))
|
||||||
|
|
Loading…
Reference in New Issue