mirror of https://github.com/huashengdun/webssh
Support custom origin configuration
parent
8a8d741230
commit
c35f801235
|
@ -215,7 +215,7 @@ class TestWsockHandler(unittest.TestCase):
|
|||
request = HTTPServerRequest(uri='/')
|
||||
obj = Mock(spec=WsockHandler, request=request)
|
||||
|
||||
options.cows = 0
|
||||
obj.origin_policy = 'same'
|
||||
request.headers['Host'] = 'www.example.com:4433'
|
||||
origin = 'https://www.example.com:4433'
|
||||
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||
|
@ -223,7 +223,7 @@ class TestWsockHandler(unittest.TestCase):
|
|||
origin = 'https://www.example.com'
|
||||
self.assertFalse(WsockHandler.check_origin(obj, origin))
|
||||
|
||||
options.cows = 1
|
||||
obj.origin_policy = 'primary'
|
||||
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||
|
||||
origin = 'https://blog.example.com'
|
||||
|
@ -232,5 +232,18 @@ class TestWsockHandler(unittest.TestCase):
|
|||
origin = 'https://blog.example.org'
|
||||
self.assertFalse(WsockHandler.check_origin(obj, origin))
|
||||
|
||||
options.cows = 2
|
||||
origin = 'https://blog.example.org'
|
||||
obj.origin_policy = {'https://blog.example.org'}
|
||||
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||
|
||||
origin = 'http://blog.example.org'
|
||||
obj.origin_policy = {'http://blog.example.org'}
|
||||
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||
|
||||
origin = 'http://blog.example.org'
|
||||
obj.origin_policy = {'https://blog.example.org'}
|
||||
self.assertFalse(WsockHandler.check_origin(obj, origin))
|
||||
|
||||
obj.origin_policy = '*'
|
||||
origin = 'https://blog.example.org'
|
||||
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import io
|
||||
import random
|
||||
import ssl
|
||||
import sys
|
||||
import os.path
|
||||
|
@ -10,7 +11,7 @@ from tests.utils import make_tests_data_path
|
|||
from webssh.policy import load_host_keys
|
||||
from webssh.settings import (
|
||||
get_host_keys_settings, get_policy_setting, base_dir, print_version,
|
||||
get_ssl_context, get_trusted_downstream
|
||||
get_ssl_context, get_trusted_downstream, get_origin_setting
|
||||
)
|
||||
from webssh.utils import UnicodeType
|
||||
from webssh._version import __version__
|
||||
|
@ -137,3 +138,31 @@ class TestSettings(unittest.TestCase):
|
|||
tdstream = '1.1.1.1, 2.2.2.'
|
||||
with self.assertRaises(ValueError):
|
||||
get_trusted_downstream(tdstream)
|
||||
|
||||
def test_get_origin_setting(self):
|
||||
options.debug = False
|
||||
options.origin = '*'
|
||||
with self.assertRaises(ValueError):
|
||||
get_origin_setting(options)
|
||||
|
||||
options.debug = True
|
||||
self.assertEqual(get_origin_setting(options), '*')
|
||||
|
||||
options.origin = random.choice(['Same', 'Primary'])
|
||||
self.assertEqual(get_origin_setting(options), options.origin.lower())
|
||||
|
||||
options.origin = ''
|
||||
with self.assertRaises(ValueError):
|
||||
get_origin_setting(options)
|
||||
|
||||
options.origin = ','
|
||||
with self.assertRaises(ValueError):
|
||||
get_origin_setting(options)
|
||||
|
||||
options.origin = 'www.example.com, https://www.example.org'
|
||||
result = {'http://www.example.com', 'https://www.example.org'}
|
||||
self.assertEqual(get_origin_setting(options), result)
|
||||
|
||||
options.origin = 'www.example.com:80, www.example.org:443'
|
||||
result = {'http://www.example.com', 'https://www.example.org'}
|
||||
self.assertEqual(get_origin_setting(options), result)
|
||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
|||
|
||||
from webssh.utils import (
|
||||
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
|
||||
to_int, is_ip_hostname, is_same_primary_domain
|
||||
to_int, is_ip_hostname, is_same_primary_domain, parse_origin_from_url
|
||||
)
|
||||
|
||||
|
||||
|
@ -90,3 +90,34 @@ class TestUitls(unittest.TestCase):
|
|||
domain1 = 'xxx.www.example.com'
|
||||
domain2 = 'xxx.www2.example.com'
|
||||
self.assertTrue(is_same_primary_domain(domain1, domain2))
|
||||
|
||||
def test_parse_origin_from_url(self):
|
||||
url = ''
|
||||
self.assertIsNone(parse_origin_from_url(url))
|
||||
|
||||
url = 'www.example.com'
|
||||
self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
|
||||
|
||||
url = 'http://www.example.com'
|
||||
self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
|
||||
|
||||
url = 'www.example.com:80'
|
||||
self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
|
||||
|
||||
url = 'http://www.example.com:80'
|
||||
self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
|
||||
|
||||
url = 'www.example.com:443'
|
||||
self.assertEqual(parse_origin_from_url(url), 'https://www.example.com')
|
||||
|
||||
url = 'https://www.example.com'
|
||||
self.assertEqual(parse_origin_from_url(url), 'https://www.example.com')
|
||||
|
||||
url = 'https://www.example.com:443'
|
||||
self.assertEqual(parse_origin_from_url(url), 'https://www.example.com')
|
||||
|
||||
url = 'https://www.example.com:80'
|
||||
self.assertEqual(parse_origin_from_url(url), url)
|
||||
|
||||
url = 'http://www.example.com:443'
|
||||
self.assertEqual(parse_origin_from_url(url), url)
|
||||
|
|
|
@ -57,6 +57,7 @@ class MixinHandler(object):
|
|||
def initialize(self, loop=None):
|
||||
self.check_request()
|
||||
self.loop = loop
|
||||
self.origin_policy = self.settings.get('origin_policy')
|
||||
|
||||
def check_request(self):
|
||||
context = self.request.connection.context
|
||||
|
@ -364,22 +365,26 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
|
|||
self.worker_ref = None
|
||||
|
||||
def check_origin(self, origin):
|
||||
cows = options.cows
|
||||
if self.origin_policy == '*':
|
||||
return True
|
||||
|
||||
parsed_origin = urlparse(origin)
|
||||
origin = parsed_origin.netloc
|
||||
origin = origin.lower()
|
||||
logging.debug('origin: {}'.format(origin))
|
||||
netloc = parsed_origin.netloc.lower()
|
||||
logging.debug('netloc: {}'.format(netloc))
|
||||
|
||||
host = self.request.headers.get('Host')
|
||||
logging.debug('host: {}'.format(host))
|
||||
|
||||
if cows == 0:
|
||||
return origin == host
|
||||
elif cows == 1:
|
||||
return is_same_primary_domain(origin.rsplit(':', 1)[0],
|
||||
if netloc == host:
|
||||
return True
|
||||
|
||||
if self.origin_policy == 'same':
|
||||
return False
|
||||
elif self.origin_policy == 'primary':
|
||||
return is_same_primary_domain(netloc.rsplit(':', 1)[0],
|
||||
host.rsplit(':', 1)[0])
|
||||
else:
|
||||
return True
|
||||
return origin in self.origin_policy
|
||||
|
||||
def open(self):
|
||||
self.src_addr = self.get_client_addr()
|
||||
|
|
|
@ -7,7 +7,7 @@ from tornado.options import define
|
|||
from webssh.policy import (
|
||||
load_host_keys, get_policy_class, check_policy_setting
|
||||
)
|
||||
from webssh.utils import to_ip_address
|
||||
from webssh.utils import to_ip_address, parse_origin_from_url
|
||||
from webssh._version import __version__
|
||||
|
||||
|
||||
|
@ -34,10 +34,12 @@ define('fbidhttp', type=bool, default=True,
|
|||
help='Forbid public plain http incoming requests')
|
||||
define('xheaders', type=bool, default=True, help='Support xheaders')
|
||||
define('xsrf', type=bool, default=True, help='CSRF protection')
|
||||
define('cows', type=int, default=0, help='Cross origin websocket, '
|
||||
'0: matches host name and port number, '
|
||||
'1: matches primary domain only, '
|
||||
'?: matches nothing, allow all cross-origin websockets')
|
||||
define('origin', default='same', help='''Origin policy,
|
||||
'same': same origin policy, matches host name and port number;
|
||||
'primary': primary domain policy, matches primary domain only;
|
||||
'<domains>': custom domains policy, matches any domain in the <domains> list
|
||||
separated by comma;
|
||||
'*': wildcard policy, matches any domain, allowed in debug mode only.''')
|
||||
define('wpintvl', type=int, default=0, help='Websocket ping interval')
|
||||
define('maxconn', type=int, default=20, help='Maximum connections per client')
|
||||
define('version', type=bool, help='Show version information',
|
||||
|
@ -54,7 +56,8 @@ def get_app_settings(options):
|
|||
static_path=os.path.join(base_dir, 'webssh', 'static'),
|
||||
websocket_ping_interval=options.wpintvl,
|
||||
debug=options.debug,
|
||||
xsrf_cookies=options.xsrf
|
||||
xsrf_cookies=options.xsrf,
|
||||
origin_policy=get_origin_setting(options)
|
||||
)
|
||||
return settings
|
||||
|
||||
|
@ -121,3 +124,28 @@ def get_trusted_downstream(tdstream):
|
|||
to_ip_address(ip)
|
||||
result.add(ip)
|
||||
return result
|
||||
|
||||
|
||||
def get_origin_setting(options):
|
||||
if options.origin == '*':
|
||||
if not options.debug:
|
||||
raise ValueError(
|
||||
'Wildcard origin policy is only allowed in debug mode.'
|
||||
)
|
||||
else:
|
||||
return '*'
|
||||
|
||||
origin = options.origin.lower()
|
||||
if origin in ['same', 'primary']:
|
||||
return origin
|
||||
|
||||
origins = set()
|
||||
for url in origin.split(','):
|
||||
orig = parse_origin_from_url(url)
|
||||
if orig:
|
||||
origins.add(orig)
|
||||
|
||||
if not origins:
|
||||
raise ValueError('Empty origin list')
|
||||
|
||||
return origins
|
||||
|
|
|
@ -6,6 +6,11 @@ try:
|
|||
except ImportError:
|
||||
UnicodeType = str
|
||||
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
except ImportError:
|
||||
from urlparse import urlparse
|
||||
|
||||
|
||||
numeric = re.compile(r'[0-9]+$')
|
||||
allowed = re.compile(r'(?!-)[a-z0-9-]{1,63}(?<!-)$', re.IGNORECASE)
|
||||
|
@ -101,3 +106,29 @@ def is_same_primary_domain(domain1, domain2):
|
|||
|
||||
c = domain1[i] if l1 > m else domain2[i]
|
||||
return c == '.'
|
||||
|
||||
|
||||
def parse_origin_from_url(url):
|
||||
url = url.strip()
|
||||
if not url:
|
||||
return
|
||||
|
||||
if not (url.startswith('http://') or url.startswith('https://') or
|
||||
url.startswith('//')):
|
||||
url = '//' + url
|
||||
|
||||
parsed = urlparse(url)
|
||||
port = parsed.port
|
||||
scheme = parsed.scheme
|
||||
|
||||
if scheme == '':
|
||||
scheme = 'https' if port == 443 else 'http'
|
||||
|
||||
if port == 443 and scheme == 'https':
|
||||
netloc = parsed.netloc.replace(':443', '')
|
||||
elif port == 80 and scheme == 'http':
|
||||
netloc = parsed.netloc.replace(':80', '')
|
||||
else:
|
||||
netloc = parsed.netloc
|
||||
|
||||
return '{}://{}'.format(scheme, netloc)
|
||||
|
|
Loading…
Reference in New Issue