Added a command line option xheaders

pull/38/head
Sheng 6 years ago
parent a8a444d7ed
commit 68468585ee

@ -1,7 +1,6 @@
import unittest
import paramiko
from tornado.httpclient import HTTPRequest
from tornado.httputil import HTTPServerRequest
from tornado.options import options
from tests.utils import read_file, make_tests_data_path
@ -17,42 +16,55 @@ class TestMixinHandler(unittest.TestCase):
def test_is_forbidden(self):
handler = MixinHandler()
request = HTTPRequest('http://example.com/')
handler.request = request
options.fbidhttp = True
context = Mock(
handler.context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=['127.0.0.1'],
_orig_protocol='http'
)
request.connection = Mock(context=context)
self.assertTrue(handler.is_forbidden())
context = Mock(
handler.context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
request.connection = Mock(context=context)
self.assertTrue(handler.is_forbidden())
context = Mock(
handler.context = Mock(
address=('192.168.1.1', 8888),
trusted_downstream=[],
_orig_protocol='http'
)
request.connection = Mock(context=context)
self.assertIsNone(handler.is_forbidden())
context = Mock(
handler.context = Mock(
address=('8.8.8.8', 8888),
trusted_downstream=[],
_orig_protocol='https'
)
request.connection = Mock(context=context)
self.assertIsNone(handler.is_forbidden())
def test_get_client_addr(self):
handler = MixinHandler()
client_addr = ('8.8.8.8', 8888)
context_addr = ('127.0.0.1', 1234)
options.xheaders = True
handler.context = Mock(address=context_addr)
handler.get_real_client_addr = lambda: None
self.assertEqual(handler.get_client_addr(), context_addr)
handler.context = Mock(address=context_addr)
handler.get_real_client_addr = lambda: client_addr
self.assertEqual(handler.get_client_addr(), client_addr)
options.xheaders = False
handler.context = Mock(address=context_addr)
handler.get_real_client_addr = lambda: client_addr
self.assertEqual(handler.get_client_addr(), context_addr)
def test_get_real_client_addr(self):
x_forwarded_for = '1.1.1.1'
x_forwarded_port = 1111

@ -46,19 +46,21 @@ class MixinHandler(object):
}
def initialize(self):
conn = self.request.connection
self.context = conn.context
if self.is_forbidden():
result = '{} 403 Forbidden\r\n\r\n'.format(self.request.version)
self.request.connection.stream.write(to_bytes(result))
self.request.connection.close()
conn.stream.write(to_bytes(result))
conn.close()
raise ValueError('Accesss denied')
def is_forbidden(self):
"""
Following requests are forbidden:
* requests not come from trusted_downstream (if set).
* non-https requests from a public network.
* plain http requests from a public network.
"""
context = self.request.connection.context
context = self.context
ip = context.address[0]
lst = context.trusted_downstream
@ -71,7 +73,7 @@ class MixinHandler(object):
if options.fbidhttp and context._orig_protocol == 'http':
ipaddr = to_ip_address(ip)
if not ipaddr.is_private:
logging.warning('Public non-https request is forbidden.')
logging.warning('Public plain http request is forbidden.')
return True
def set_default_headers(self):
@ -85,8 +87,10 @@ class MixinHandler(object):
return value
def get_client_addr(self):
return self.get_real_client_addr() or self.request.connection.context.\
address
if options.xheaders:
return self.get_real_client_addr() or self.context.address
else:
return self.context.address
def get_real_client_addr(self):
ip = self.request.remote_ip

@ -30,8 +30,10 @@ define('policy', default='warning',
help='Missing host key policy, reject|autoadd|warning')
define('hostfile', default='', help='User defined host keys file')
define('syshostfile', default='', help='System wide host keys file')
define('tdstream', default='', help='trusted downstream, separated by comma')
define('fbidhttp', type=bool, default=True, help='forbid public http request')
define('tdstream', default='', help='Trusted downstream, separated by comma')
define('fbidhttp', type=bool, default=True,
help='Forbid public plain http incoming requests')
define('xheaders', type=bool, default=True, help='Support xheaders')
define('wpintvl', type=int, default=0, help='Websocket ping interval')
define('version', type=bool, help='Show version information',
callback=print_version)
@ -39,7 +41,6 @@ define('version', type=bool, help='Show version information',
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
max_body_size = 1 * 1024 * 1024
xheaders = True
def get_app_settings(options):
@ -55,7 +56,7 @@ def get_app_settings(options):
def get_server_settings(options):
settings = dict(
xheaders=xheaders,
xheaders=options.xheaders,
max_body_size=max_body_size,
trusted_downstream=get_trusted_downstream(options)
)
@ -121,4 +122,4 @@ def detect_is_open_to_public(options):
result = on_public_network_interfaces(get_ips_by_name(options.address))
if not result and options.fbidhttp:
options.fbidhttp = False
logging.info('Forbid public http: {}'.format(options.fbidhttp))
logging.info('Forbid public plain http: {}'.format(options.fbidhttp))

Loading…
Cancel
Save