Updated get_trusted_downstream

pull/38/head
Sheng 2018-10-20 15:30:11 +08:00
parent f52b2f5156
commit e31e9be433
2 changed files with 17 additions and 17 deletions

View File

@ -122,21 +122,21 @@ class TestSettings(unittest.TestCase):
self.assertIsNotNone(ssl_ctx) self.assertIsNotNone(ssl_ctx)
def test_get_trusted_downstream(self): def test_get_trusted_downstream(self):
options.tdstream = '' tdstream = ''
tdstream = set() result = set()
self.assertEqual(get_trusted_downstream(options), tdstream) self.assertEqual(get_trusted_downstream(tdstream), result)
options.tdstream = '1.1.1.1, 2.2.2.2' tdstream = '1.1.1.1, 2.2.2.2'
tdstream = set(['1.1.1.1', '2.2.2.2']) result = set(['1.1.1.1', '2.2.2.2'])
self.assertEqual(get_trusted_downstream(options), tdstream) self.assertEqual(get_trusted_downstream(tdstream), result)
options.tdstream = '1.1.1.1, 2.2.2.2, 2.2.2.2' tdstream = '1.1.1.1, 2.2.2.2, 2.2.2.2'
tdstream = set(['1.1.1.1', '2.2.2.2']) result = set(['1.1.1.1', '2.2.2.2'])
self.assertEqual(get_trusted_downstream(options), tdstream) self.assertEqual(get_trusted_downstream(tdstream), result)
options.tdstream = '1.1.1.1, 2.2.2.' tdstream = '1.1.1.1, 2.2.2.'
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
get_trusted_downstream(options), tdstream get_trusted_downstream(tdstream)
def test_detect_is_open_to_public(self): def test_detect_is_open_to_public(self):
options.fbidhttp = True options.fbidhttp = True

View File

@ -58,7 +58,7 @@ def get_server_settings(options):
settings = dict( settings = dict(
xheaders=options.xheaders, xheaders=options.xheaders,
max_body_size=max_body_size, max_body_size=max_body_size,
trusted_downstream=get_trusted_downstream(options) trusted_downstream=get_trusted_downstream(options.tdstream)
) )
return settings return settings
@ -108,14 +108,14 @@ def get_ssl_context(options):
return ssl_ctx return ssl_ctx
def get_trusted_downstream(options): def get_trusted_downstream(tdstream):
tdstream = set() result = set()
for ip in options.tdstream.split(','): for ip in tdstream.split(','):
ip = ip.strip() ip = ip.strip()
if ip: if ip:
to_ip_address(ip) to_ip_address(ip)
tdstream.add(ip) result.add(ip)
return tdstream return result
def detect_is_open_to_public(options): def detect_is_open_to_public(options):