Added a command line option xheaders

pull/38/head
Sheng 6 years ago
parent a8a444d7ed
commit 68468585ee

@ -1,7 +1,6 @@
import unittest import unittest
import paramiko import paramiko
from tornado.httpclient import HTTPRequest
from tornado.httputil import HTTPServerRequest from tornado.httputil import HTTPServerRequest
from tornado.options import options from tornado.options import options
from tests.utils import read_file, make_tests_data_path from tests.utils import read_file, make_tests_data_path
@ -17,42 +16,55 @@ class TestMixinHandler(unittest.TestCase):
def test_is_forbidden(self): def test_is_forbidden(self):
handler = MixinHandler() handler = MixinHandler()
request = HTTPRequest('http://example.com/')
handler.request = request
options.fbidhttp = True options.fbidhttp = True
context = Mock( handler.context = Mock(
address=('8.8.8.8', 8888), address=('8.8.8.8', 8888),
trusted_downstream=['127.0.0.1'], trusted_downstream=['127.0.0.1'],
_orig_protocol='http' _orig_protocol='http'
) )
request.connection = Mock(context=context)
self.assertTrue(handler.is_forbidden()) self.assertTrue(handler.is_forbidden())
context = Mock( handler.context = Mock(
address=('8.8.8.8', 8888), address=('8.8.8.8', 8888),
trusted_downstream=[], trusted_downstream=[],
_orig_protocol='http' _orig_protocol='http'
) )
request.connection = Mock(context=context)
self.assertTrue(handler.is_forbidden()) self.assertTrue(handler.is_forbidden())
context = Mock( handler.context = Mock(
address=('192.168.1.1', 8888), address=('192.168.1.1', 8888),
trusted_downstream=[], trusted_downstream=[],
_orig_protocol='http' _orig_protocol='http'
) )
request.connection = Mock(context=context)
self.assertIsNone(handler.is_forbidden()) self.assertIsNone(handler.is_forbidden())
context = Mock( handler.context = Mock(
address=('8.8.8.8', 8888), address=('8.8.8.8', 8888),
trusted_downstream=[], trusted_downstream=[],
_orig_protocol='https' _orig_protocol='https'
) )
request.connection = Mock(context=context)
self.assertIsNone(handler.is_forbidden()) self.assertIsNone(handler.is_forbidden())
def test_get_client_addr(self):
handler = MixinHandler()
client_addr = ('8.8.8.8', 8888)
context_addr = ('127.0.0.1', 1234)
options.xheaders = True
handler.context = Mock(address=context_addr)
handler.get_real_client_addr = lambda: None
self.assertEqual(handler.get_client_addr(), context_addr)
handler.context = Mock(address=context_addr)
handler.get_real_client_addr = lambda: client_addr
self.assertEqual(handler.get_client_addr(), client_addr)
options.xheaders = False
handler.context = Mock(address=context_addr)
handler.get_real_client_addr = lambda: client_addr
self.assertEqual(handler.get_client_addr(), context_addr)
def test_get_real_client_addr(self): def test_get_real_client_addr(self):
x_forwarded_for = '1.1.1.1' x_forwarded_for = '1.1.1.1'
x_forwarded_port = 1111 x_forwarded_port = 1111

@ -46,19 +46,21 @@ class MixinHandler(object):
} }
def initialize(self): def initialize(self):
conn = self.request.connection
self.context = conn.context
if self.is_forbidden(): if self.is_forbidden():
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version) result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
self.request.connection.stream.write(to_bytes(result)) conn.stream.write(to_bytes(result))
self.request.connection.close() conn.close()
raise ValueError('Accesss denied') raise ValueError('Accesss denied')
def is_forbidden(self): def is_forbidden(self):
""" """
Following requests are forbidden: Following requests are forbidden:
* requests not come from trusted_downstream (if set). * requests not come from trusted_downstream (if set).
* non-https requests from a public network. * plain http requests from a public network.
""" """
context = self.request.connection.context context = self.context
ip = context.address[0] ip = context.address[0]
lst = context.trusted_downstream lst = context.trusted_downstream
@ -71,7 +73,7 @@ class MixinHandler(object):
if options.fbidhttp and context._orig_protocol == 'http': if options.fbidhttp and context._orig_protocol == 'http':
ipaddr = to_ip_address(ip) ipaddr = to_ip_address(ip)
if not ipaddr.is_private: if not ipaddr.is_private:
logging.warning('Public non-https request is forbidden.') logging.warning('Public plain http request is forbidden.')
return True return True
def set_default_headers(self): def set_default_headers(self):
@ -85,8 +87,10 @@ class MixinHandler(object):
return value return value
def get_client_addr(self): def get_client_addr(self):
return self.get_real_client_addr() or self.request.connection.context.\ if options.xheaders:
address return self.get_real_client_addr() or self.context.address
else:
return self.context.address
def get_real_client_addr(self): def get_real_client_addr(self):
ip = self.request.remote_ip ip = self.request.remote_ip

@ -30,8 +30,10 @@ define('policy', default='warning',
help='Missing host key policy, reject|autoadd|warning') help='Missing host key policy, reject|autoadd|warning')
define('hostfile', default='', help='User defined host keys file') define('hostfile', default='', help='User defined host keys file')
define('syshostfile', default='', help='System wide host keys file') define('syshostfile', default='', help='System wide host keys file')
define('tdstream', default='', help='trusted downstream, separated by comma') define('tdstream', default='', help='Trusted downstream, separated by comma')
define('fbidhttp', type=bool, default=True, help='forbid public http request') define('fbidhttp', type=bool, default=True,
help='Forbid public plain http incoming requests')
define('xheaders', type=bool, default=True, help='Support xheaders')
define('wpintvl', type=int, default=0, help='Websocket ping interval') define('wpintvl', type=int, default=0, help='Websocket ping interval')
define('version', type=bool, help='Show version information', define('version', type=bool, help='Show version information',
callback=print_version) callback=print_version)
@ -39,7 +41,6 @@ define('version', type=bool, help='Show version information',
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
max_body_size = 1 * 1024 * 1024 max_body_size = 1 * 1024 * 1024
xheaders = True
def get_app_settings(options): def get_app_settings(options):
@ -55,7 +56,7 @@ def get_app_settings(options):
def get_server_settings(options): def get_server_settings(options):
settings = dict( settings = dict(
xheaders=xheaders, xheaders=options.xheaders,
max_body_size=max_body_size, max_body_size=max_body_size,
trusted_downstream=get_trusted_downstream(options) trusted_downstream=get_trusted_downstream(options)
) )
@ -121,4 +122,4 @@ def detect_is_open_to_public(options):
result = on_public_network_interfaces(get_ips_by_name(options.address)) result = on_public_network_interfaces(get_ips_by_name(options.address))
if not result and options.fbidhttp: if not result and options.fbidhttp:
options.fbidhttp = False options.fbidhttp = False
logging.info('Forbid public http: {}'.format(options.fbidhttp)) logging.info('Forbid public plain http: {}'.format(options.fbidhttp))

Loading…
Cancel
Save