Support custom origin configuration

pull/58/head
Sheng 2019-01-19 16:46:25 +08:00
parent 8a8d741230
commit c35f801235
6 changed files with 157 additions and 20 deletions

View File

@ -215,7 +215,7 @@ class TestWsockHandler(unittest.TestCase):
request = HTTPServerRequest(uri='/') request = HTTPServerRequest(uri='/')
obj = Mock(spec=WsockHandler, request=request) obj = Mock(spec=WsockHandler, request=request)
options.cows = 0 obj.origin_policy = 'same'
request.headers['Host'] = 'www.example.com:4433' request.headers['Host'] = 'www.example.com:4433'
origin = 'https://www.example.com:4433' origin = 'https://www.example.com:4433'
self.assertTrue(WsockHandler.check_origin(obj, origin)) self.assertTrue(WsockHandler.check_origin(obj, origin))
@ -223,7 +223,7 @@ class TestWsockHandler(unittest.TestCase):
origin = 'https://www.example.com' origin = 'https://www.example.com'
self.assertFalse(WsockHandler.check_origin(obj, origin)) self.assertFalse(WsockHandler.check_origin(obj, origin))
options.cows = 1 obj.origin_policy = 'primary'
self.assertTrue(WsockHandler.check_origin(obj, origin)) self.assertTrue(WsockHandler.check_origin(obj, origin))
origin = 'https://blog.example.com' origin = 'https://blog.example.com'
@ -232,5 +232,18 @@ class TestWsockHandler(unittest.TestCase):
origin = 'https://blog.example.org' origin = 'https://blog.example.org'
self.assertFalse(WsockHandler.check_origin(obj, origin)) self.assertFalse(WsockHandler.check_origin(obj, origin))
options.cows = 2 origin = 'https://blog.example.org'
obj.origin_policy = {'https://blog.example.org'}
self.assertTrue(WsockHandler.check_origin(obj, origin))
origin = 'http://blog.example.org'
obj.origin_policy = {'http://blog.example.org'}
self.assertTrue(WsockHandler.check_origin(obj, origin))
origin = 'http://blog.example.org'
obj.origin_policy = {'https://blog.example.org'}
self.assertFalse(WsockHandler.check_origin(obj, origin))
obj.origin_policy = '*'
origin = 'https://blog.example.org'
self.assertTrue(WsockHandler.check_origin(obj, origin)) self.assertTrue(WsockHandler.check_origin(obj, origin))

View File

@ -1,4 +1,5 @@
import io import io
import random
import ssl import ssl
import sys import sys
import os.path import os.path
@ -10,7 +11,7 @@ from tests.utils import make_tests_data_path
from webssh.policy import load_host_keys from webssh.policy import load_host_keys
from webssh.settings import ( from webssh.settings import (
get_host_keys_settings, get_policy_setting, base_dir, print_version, get_host_keys_settings, get_policy_setting, base_dir, print_version,
get_ssl_context, get_trusted_downstream get_ssl_context, get_trusted_downstream, get_origin_setting
) )
from webssh.utils import UnicodeType from webssh.utils import UnicodeType
from webssh._version import __version__ from webssh._version import __version__
@ -137,3 +138,31 @@ class TestSettings(unittest.TestCase):
tdstream = '1.1.1.1, 2.2.2.' tdstream = '1.1.1.1, 2.2.2.'
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
get_trusted_downstream(tdstream) get_trusted_downstream(tdstream)
def test_get_origin_setting(self):
options.debug = False
options.origin = '*'
with self.assertRaises(ValueError):
get_origin_setting(options)
options.debug = True
self.assertEqual(get_origin_setting(options), '*')
options.origin = random.choice(['Same', 'Primary'])
self.assertEqual(get_origin_setting(options), options.origin.lower())
options.origin = ''
with self.assertRaises(ValueError):
get_origin_setting(options)
options.origin = ','
with self.assertRaises(ValueError):
get_origin_setting(options)
options.origin = 'www.example.com, https://www.example.org'
result = {'http://www.example.com', 'https://www.example.org'}
self.assertEqual(get_origin_setting(options), result)
options.origin = 'www.example.com:80, www.example.org:443'
result = {'http://www.example.com', 'https://www.example.org'}
self.assertEqual(get_origin_setting(options), result)

View File

@ -2,7 +2,7 @@ import unittest
from webssh.utils import ( from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes, is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
to_int, is_ip_hostname, is_same_primary_domain to_int, is_ip_hostname, is_same_primary_domain, parse_origin_from_url
) )
@ -90,3 +90,34 @@ class TestUitls(unittest.TestCase):
domain1 = 'xxx.www.example.com' domain1 = 'xxx.www.example.com'
domain2 = 'xxx.www2.example.com' domain2 = 'xxx.www2.example.com'
self.assertTrue(is_same_primary_domain(domain1, domain2)) self.assertTrue(is_same_primary_domain(domain1, domain2))
def test_parse_origin_from_url(self):
url = ''
self.assertIsNone(parse_origin_from_url(url))
url = 'www.example.com'
self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
url = 'http://www.example.com'
self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
url = 'www.example.com:80'
self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
url = 'http://www.example.com:80'
self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
url = 'www.example.com:443'
self.assertEqual(parse_origin_from_url(url), 'https://www.example.com')
url = 'https://www.example.com'
self.assertEqual(parse_origin_from_url(url), 'https://www.example.com')
url = 'https://www.example.com:443'
self.assertEqual(parse_origin_from_url(url), 'https://www.example.com')
url = 'https://www.example.com:80'
self.assertEqual(parse_origin_from_url(url), url)
url = 'http://www.example.com:443'
self.assertEqual(parse_origin_from_url(url), url)

View File

@ -57,6 +57,7 @@ class MixinHandler(object):
def initialize(self, loop=None): def initialize(self, loop=None):
self.check_request() self.check_request()
self.loop = loop self.loop = loop
self.origin_policy = self.settings.get('origin_policy')
def check_request(self): def check_request(self):
context = self.request.connection.context context = self.request.connection.context
@ -364,22 +365,26 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
self.worker_ref = None self.worker_ref = None
def check_origin(self, origin): def check_origin(self, origin):
cows = options.cows if self.origin_policy == '*':
return True
parsed_origin = urlparse(origin) parsed_origin = urlparse(origin)
origin = parsed_origin.netloc netloc = parsed_origin.netloc.lower()
origin = origin.lower() logging.debug('netloc: {}'.format(netloc))
logging.debug('origin: {}'.format(origin))
host = self.request.headers.get('Host') host = self.request.headers.get('Host')
logging.debug('host: {}'.format(host)) logging.debug('host: {}'.format(host))
if cows == 0: if netloc == host:
return origin == host return True
elif cows == 1:
return is_same_primary_domain(origin.rsplit(':', 1)[0], if self.origin_policy == 'same':
return False
elif self.origin_policy == 'primary':
return is_same_primary_domain(netloc.rsplit(':', 1)[0],
host.rsplit(':', 1)[0]) host.rsplit(':', 1)[0])
else: else:
return True return origin in self.origin_policy
def open(self): def open(self):
self.src_addr = self.get_client_addr() self.src_addr = self.get_client_addr()

View File

@ -7,7 +7,7 @@ from tornado.options import define
from webssh.policy import ( from webssh.policy import (
load_host_keys, get_policy_class, check_policy_setting load_host_keys, get_policy_class, check_policy_setting
) )
from webssh.utils import to_ip_address from webssh.utils import to_ip_address, parse_origin_from_url
from webssh._version import __version__ from webssh._version import __version__
@ -34,10 +34,12 @@ define('fbidhttp', type=bool, default=True,
help='Forbid public plain http incoming requests') help='Forbid public plain http incoming requests')
define('xheaders', type=bool, default=True, help='Support xheaders') define('xheaders', type=bool, default=True, help='Support xheaders')
define('xsrf', type=bool, default=True, help='CSRF protection') define('xsrf', type=bool, default=True, help='CSRF protection')
define('cows', type=int, default=0, help='Cross origin websocket, ' define('origin', default='same', help='''Origin policy,
'0: matches host name and port number, ' 'same': same origin policy, matches host name and port number;
'1: matches primary domain only, ' 'primary': primary domain policy, matches primary domain only;
'?: matches nothing, allow all cross-origin websockets') '<domains>': custom domains policy, matches any domain in the <domains> list
separated by comma;
'*': wildcard policy, matches any domain, allowed in debug mode only.''')
define('wpintvl', type=int, default=0, help='Websocket ping interval') define('wpintvl', type=int, default=0, help='Websocket ping interval')
define('maxconn', type=int, default=20, help='Maximum connections per client') define('maxconn', type=int, default=20, help='Maximum connections per client')
define('version', type=bool, help='Show version information', define('version', type=bool, help='Show version information',
@ -54,7 +56,8 @@ def get_app_settings(options):
static_path=os.path.join(base_dir, 'webssh', 'static'), static_path=os.path.join(base_dir, 'webssh', 'static'),
websocket_ping_interval=options.wpintvl, websocket_ping_interval=options.wpintvl,
debug=options.debug, debug=options.debug,
xsrf_cookies=options.xsrf xsrf_cookies=options.xsrf,
origin_policy=get_origin_setting(options)
) )
return settings return settings
@ -121,3 +124,28 @@ def get_trusted_downstream(tdstream):
to_ip_address(ip) to_ip_address(ip)
result.add(ip) result.add(ip)
return result return result
def get_origin_setting(options):
if options.origin == '*':
if not options.debug:
raise ValueError(
'Wildcard origin policy is only allowed in debug mode.'
)
else:
return '*'
origin = options.origin.lower()
if origin in ['same', 'primary']:
return origin
origins = set()
for url in origin.split(','):
orig = parse_origin_from_url(url)
if orig:
origins.add(orig)
if not origins:
raise ValueError('Empty origin list')
return origins

View File

@ -6,6 +6,11 @@ try:
except ImportError: except ImportError:
UnicodeType = str UnicodeType = str
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
numeric = re.compile(r'[0-9]+$') numeric = re.compile(r'[0-9]+$')
allowed = re.compile(r'(?!-)[a-z0-9-]{1,63}(?<!-)$', re.IGNORECASE) allowed = re.compile(r'(?!-)[a-z0-9-]{1,63}(?<!-)$', re.IGNORECASE)
@ -101,3 +106,29 @@ def is_same_primary_domain(domain1, domain2):
c = domain1[i] if l1 > m else domain2[i] c = domain1[i] if l1 > m else domain2[i]
return c == '.' return c == '.'
def parse_origin_from_url(url):
url = url.strip()
if not url:
return
if not (url.startswith('http://') or url.startswith('https://') or
url.startswith('//')):
url = '//' + url
parsed = urlparse(url)
port = parsed.port
scheme = parsed.scheme
if scheme == '':
scheme = 'https' if port == 443 else 'http'
if port == 443 and scheme == 'https':
netloc = parsed.netloc.replace(':443', '')
elif port == 80 and scheme == 'http':
netloc = parsed.netloc.replace(':80', '')
else:
netloc = parsed.netloc
return '{}://{}'.format(scheme, netloc)