mirror of https://github.com/huashengdun/webssh
Support custom origin configuration
parent
8a8d741230
commit
c35f801235
|
@ -215,7 +215,7 @@ class TestWsockHandler(unittest.TestCase):
|
||||||
request = HTTPServerRequest(uri='/')
|
request = HTTPServerRequest(uri='/')
|
||||||
obj = Mock(spec=WsockHandler, request=request)
|
obj = Mock(spec=WsockHandler, request=request)
|
||||||
|
|
||||||
options.cows = 0
|
obj.origin_policy = 'same'
|
||||||
request.headers['Host'] = 'www.example.com:4433'
|
request.headers['Host'] = 'www.example.com:4433'
|
||||||
origin = 'https://www.example.com:4433'
|
origin = 'https://www.example.com:4433'
|
||||||
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||||
|
@ -223,7 +223,7 @@ class TestWsockHandler(unittest.TestCase):
|
||||||
origin = 'https://www.example.com'
|
origin = 'https://www.example.com'
|
||||||
self.assertFalse(WsockHandler.check_origin(obj, origin))
|
self.assertFalse(WsockHandler.check_origin(obj, origin))
|
||||||
|
|
||||||
options.cows = 1
|
obj.origin_policy = 'primary'
|
||||||
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||||
|
|
||||||
origin = 'https://blog.example.com'
|
origin = 'https://blog.example.com'
|
||||||
|
@ -232,5 +232,18 @@ class TestWsockHandler(unittest.TestCase):
|
||||||
origin = 'https://blog.example.org'
|
origin = 'https://blog.example.org'
|
||||||
self.assertFalse(WsockHandler.check_origin(obj, origin))
|
self.assertFalse(WsockHandler.check_origin(obj, origin))
|
||||||
|
|
||||||
options.cows = 2
|
origin = 'https://blog.example.org'
|
||||||
|
obj.origin_policy = {'https://blog.example.org'}
|
||||||
|
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||||
|
|
||||||
|
origin = 'http://blog.example.org'
|
||||||
|
obj.origin_policy = {'http://blog.example.org'}
|
||||||
|
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||||
|
|
||||||
|
origin = 'http://blog.example.org'
|
||||||
|
obj.origin_policy = {'https://blog.example.org'}
|
||||||
|
self.assertFalse(WsockHandler.check_origin(obj, origin))
|
||||||
|
|
||||||
|
obj.origin_policy = '*'
|
||||||
|
origin = 'https://blog.example.org'
|
||||||
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import io
|
import io
|
||||||
|
import random
|
||||||
import ssl
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
import os.path
|
import os.path
|
||||||
|
@ -10,7 +11,7 @@ from tests.utils import make_tests_data_path
|
||||||
from webssh.policy import load_host_keys
|
from webssh.policy import load_host_keys
|
||||||
from webssh.settings import (
|
from webssh.settings import (
|
||||||
get_host_keys_settings, get_policy_setting, base_dir, print_version,
|
get_host_keys_settings, get_policy_setting, base_dir, print_version,
|
||||||
get_ssl_context, get_trusted_downstream
|
get_ssl_context, get_trusted_downstream, get_origin_setting
|
||||||
)
|
)
|
||||||
from webssh.utils import UnicodeType
|
from webssh.utils import UnicodeType
|
||||||
from webssh._version import __version__
|
from webssh._version import __version__
|
||||||
|
@ -137,3 +138,31 @@ class TestSettings(unittest.TestCase):
|
||||||
tdstream = '1.1.1.1, 2.2.2.'
|
tdstream = '1.1.1.1, 2.2.2.'
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
get_trusted_downstream(tdstream)
|
get_trusted_downstream(tdstream)
|
||||||
|
|
||||||
|
def test_get_origin_setting(self):
|
||||||
|
options.debug = False
|
||||||
|
options.origin = '*'
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
get_origin_setting(options)
|
||||||
|
|
||||||
|
options.debug = True
|
||||||
|
self.assertEqual(get_origin_setting(options), '*')
|
||||||
|
|
||||||
|
options.origin = random.choice(['Same', 'Primary'])
|
||||||
|
self.assertEqual(get_origin_setting(options), options.origin.lower())
|
||||||
|
|
||||||
|
options.origin = ''
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
get_origin_setting(options)
|
||||||
|
|
||||||
|
options.origin = ','
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
get_origin_setting(options)
|
||||||
|
|
||||||
|
options.origin = 'www.example.com, https://www.example.org'
|
||||||
|
result = {'http://www.example.com', 'https://www.example.org'}
|
||||||
|
self.assertEqual(get_origin_setting(options), result)
|
||||||
|
|
||||||
|
options.origin = 'www.example.com:80, www.example.org:443'
|
||||||
|
result = {'http://www.example.com', 'https://www.example.org'}
|
||||||
|
self.assertEqual(get_origin_setting(options), result)
|
||||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
||||||
|
|
||||||
from webssh.utils import (
|
from webssh.utils import (
|
||||||
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
|
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
|
||||||
to_int, is_ip_hostname, is_same_primary_domain
|
to_int, is_ip_hostname, is_same_primary_domain, parse_origin_from_url
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,3 +90,34 @@ class TestUitls(unittest.TestCase):
|
||||||
domain1 = 'xxx.www.example.com'
|
domain1 = 'xxx.www.example.com'
|
||||||
domain2 = 'xxx.www2.example.com'
|
domain2 = 'xxx.www2.example.com'
|
||||||
self.assertTrue(is_same_primary_domain(domain1, domain2))
|
self.assertTrue(is_same_primary_domain(domain1, domain2))
|
||||||
|
|
||||||
|
def test_parse_origin_from_url(self):
|
||||||
|
url = ''
|
||||||
|
self.assertIsNone(parse_origin_from_url(url))
|
||||||
|
|
||||||
|
url = 'www.example.com'
|
||||||
|
self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
|
||||||
|
|
||||||
|
url = 'http://www.example.com'
|
||||||
|
self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
|
||||||
|
|
||||||
|
url = 'www.example.com:80'
|
||||||
|
self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
|
||||||
|
|
||||||
|
url = 'http://www.example.com:80'
|
||||||
|
self.assertEqual(parse_origin_from_url(url), 'http://www.example.com')
|
||||||
|
|
||||||
|
url = 'www.example.com:443'
|
||||||
|
self.assertEqual(parse_origin_from_url(url), 'https://www.example.com')
|
||||||
|
|
||||||
|
url = 'https://www.example.com'
|
||||||
|
self.assertEqual(parse_origin_from_url(url), 'https://www.example.com')
|
||||||
|
|
||||||
|
url = 'https://www.example.com:443'
|
||||||
|
self.assertEqual(parse_origin_from_url(url), 'https://www.example.com')
|
||||||
|
|
||||||
|
url = 'https://www.example.com:80'
|
||||||
|
self.assertEqual(parse_origin_from_url(url), url)
|
||||||
|
|
||||||
|
url = 'http://www.example.com:443'
|
||||||
|
self.assertEqual(parse_origin_from_url(url), url)
|
||||||
|
|
|
@ -57,6 +57,7 @@ class MixinHandler(object):
|
||||||
def initialize(self, loop=None):
|
def initialize(self, loop=None):
|
||||||
self.check_request()
|
self.check_request()
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
|
self.origin_policy = self.settings.get('origin_policy')
|
||||||
|
|
||||||
def check_request(self):
|
def check_request(self):
|
||||||
context = self.request.connection.context
|
context = self.request.connection.context
|
||||||
|
@ -364,22 +365,26 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
|
||||||
self.worker_ref = None
|
self.worker_ref = None
|
||||||
|
|
||||||
def check_origin(self, origin):
|
def check_origin(self, origin):
|
||||||
cows = options.cows
|
if self.origin_policy == '*':
|
||||||
|
return True
|
||||||
|
|
||||||
parsed_origin = urlparse(origin)
|
parsed_origin = urlparse(origin)
|
||||||
origin = parsed_origin.netloc
|
netloc = parsed_origin.netloc.lower()
|
||||||
origin = origin.lower()
|
logging.debug('netloc: {}'.format(netloc))
|
||||||
logging.debug('origin: {}'.format(origin))
|
|
||||||
|
|
||||||
host = self.request.headers.get('Host')
|
host = self.request.headers.get('Host')
|
||||||
logging.debug('host: {}'.format(host))
|
logging.debug('host: {}'.format(host))
|
||||||
|
|
||||||
if cows == 0:
|
if netloc == host:
|
||||||
return origin == host
|
return True
|
||||||
elif cows == 1:
|
|
||||||
return is_same_primary_domain(origin.rsplit(':', 1)[0],
|
if self.origin_policy == 'same':
|
||||||
|
return False
|
||||||
|
elif self.origin_policy == 'primary':
|
||||||
|
return is_same_primary_domain(netloc.rsplit(':', 1)[0],
|
||||||
host.rsplit(':', 1)[0])
|
host.rsplit(':', 1)[0])
|
||||||
else:
|
else:
|
||||||
return True
|
return origin in self.origin_policy
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
self.src_addr = self.get_client_addr()
|
self.src_addr = self.get_client_addr()
|
||||||
|
|
|
@ -7,7 +7,7 @@ from tornado.options import define
|
||||||
from webssh.policy import (
|
from webssh.policy import (
|
||||||
load_host_keys, get_policy_class, check_policy_setting
|
load_host_keys, get_policy_class, check_policy_setting
|
||||||
)
|
)
|
||||||
from webssh.utils import to_ip_address
|
from webssh.utils import to_ip_address, parse_origin_from_url
|
||||||
from webssh._version import __version__
|
from webssh._version import __version__
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,10 +34,12 @@ define('fbidhttp', type=bool, default=True,
|
||||||
help='Forbid public plain http incoming requests')
|
help='Forbid public plain http incoming requests')
|
||||||
define('xheaders', type=bool, default=True, help='Support xheaders')
|
define('xheaders', type=bool, default=True, help='Support xheaders')
|
||||||
define('xsrf', type=bool, default=True, help='CSRF protection')
|
define('xsrf', type=bool, default=True, help='CSRF protection')
|
||||||
define('cows', type=int, default=0, help='Cross origin websocket, '
|
define('origin', default='same', help='''Origin policy,
|
||||||
'0: matches host name and port number, '
|
'same': same origin policy, matches host name and port number;
|
||||||
'1: matches primary domain only, '
|
'primary': primary domain policy, matches primary domain only;
|
||||||
'?: matches nothing, allow all cross-origin websockets')
|
'<domains>': custom domains policy, matches any domain in the <domains> list
|
||||||
|
separated by comma;
|
||||||
|
'*': wildcard policy, matches any domain, allowed in debug mode only.''')
|
||||||
define('wpintvl', type=int, default=0, help='Websocket ping interval')
|
define('wpintvl', type=int, default=0, help='Websocket ping interval')
|
||||||
define('maxconn', type=int, default=20, help='Maximum connections per client')
|
define('maxconn', type=int, default=20, help='Maximum connections per client')
|
||||||
define('version', type=bool, help='Show version information',
|
define('version', type=bool, help='Show version information',
|
||||||
|
@ -54,7 +56,8 @@ def get_app_settings(options):
|
||||||
static_path=os.path.join(base_dir, 'webssh', 'static'),
|
static_path=os.path.join(base_dir, 'webssh', 'static'),
|
||||||
websocket_ping_interval=options.wpintvl,
|
websocket_ping_interval=options.wpintvl,
|
||||||
debug=options.debug,
|
debug=options.debug,
|
||||||
xsrf_cookies=options.xsrf
|
xsrf_cookies=options.xsrf,
|
||||||
|
origin_policy=get_origin_setting(options)
|
||||||
)
|
)
|
||||||
return settings
|
return settings
|
||||||
|
|
||||||
|
@ -121,3 +124,28 @@ def get_trusted_downstream(tdstream):
|
||||||
to_ip_address(ip)
|
to_ip_address(ip)
|
||||||
result.add(ip)
|
result.add(ip)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_origin_setting(options):
|
||||||
|
if options.origin == '*':
|
||||||
|
if not options.debug:
|
||||||
|
raise ValueError(
|
||||||
|
'Wildcard origin policy is only allowed in debug mode.'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return '*'
|
||||||
|
|
||||||
|
origin = options.origin.lower()
|
||||||
|
if origin in ['same', 'primary']:
|
||||||
|
return origin
|
||||||
|
|
||||||
|
origins = set()
|
||||||
|
for url in origin.split(','):
|
||||||
|
orig = parse_origin_from_url(url)
|
||||||
|
if orig:
|
||||||
|
origins.add(orig)
|
||||||
|
|
||||||
|
if not origins:
|
||||||
|
raise ValueError('Empty origin list')
|
||||||
|
|
||||||
|
return origins
|
||||||
|
|
|
@ -6,6 +6,11 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
UnicodeType = str
|
UnicodeType = str
|
||||||
|
|
||||||
|
try:
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
except ImportError:
|
||||||
|
from urlparse import urlparse
|
||||||
|
|
||||||
|
|
||||||
numeric = re.compile(r'[0-9]+$')
|
numeric = re.compile(r'[0-9]+$')
|
||||||
allowed = re.compile(r'(?!-)[a-z0-9-]{1,63}(?<!-)$', re.IGNORECASE)
|
allowed = re.compile(r'(?!-)[a-z0-9-]{1,63}(?<!-)$', re.IGNORECASE)
|
||||||
|
@ -101,3 +106,29 @@ def is_same_primary_domain(domain1, domain2):
|
||||||
|
|
||||||
c = domain1[i] if l1 > m else domain2[i]
|
c = domain1[i] if l1 > m else domain2[i]
|
||||||
return c == '.'
|
return c == '.'
|
||||||
|
|
||||||
|
|
||||||
|
def parse_origin_from_url(url):
|
||||||
|
url = url.strip()
|
||||||
|
if not url:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not (url.startswith('http://') or url.startswith('https://') or
|
||||||
|
url.startswith('//')):
|
||||||
|
url = '//' + url
|
||||||
|
|
||||||
|
parsed = urlparse(url)
|
||||||
|
port = parsed.port
|
||||||
|
scheme = parsed.scheme
|
||||||
|
|
||||||
|
if scheme == '':
|
||||||
|
scheme = 'https' if port == 443 else 'http'
|
||||||
|
|
||||||
|
if port == 443 and scheme == 'https':
|
||||||
|
netloc = parsed.netloc.replace(':443', '')
|
||||||
|
elif port == 80 and scheme == 'http':
|
||||||
|
netloc = parsed.netloc.replace(':80', '')
|
||||||
|
else:
|
||||||
|
netloc = parsed.netloc
|
||||||
|
|
||||||
|
return '{}://{}'.format(scheme, netloc)
|
||||||
|
|
Loading…
Reference in New Issue