Added an option for configuring cross-origin websocket level

pull/58/head
Sheng 2019-01-10 22:09:32 +08:00
parent b51e823973
commit 5c8bd84b95
5 changed files with 115 additions and 3 deletions

View File

@ -5,7 +5,7 @@ from tornado.httputil import HTTPServerRequest
from tornado.options import options from tornado.options import options
from tests.utils import read_file, make_tests_data_path from tests.utils import read_file, make_tests_data_path
from webssh.handler import ( from webssh.handler import (
MixinHandler, IndexHandler, InvalidValueError, open_to_public MixinHandler, IndexHandler, WsockHandler, InvalidValueError, open_to_public
) )
try: try:
@ -202,3 +202,30 @@ class TestIndexHandler(unittest.TestCase):
with self.assertRaises(paramiko.PasswordRequiredException): with self.assertRaises(paramiko.PasswordRequiredException):
pkey = IndexHandler.get_pkey_obj(key, '', fname) pkey = IndexHandler.get_pkey_obj(key, '', fname)
class TestWsockHandler(unittest.TestCase):
def test_check_origin(self):
request = HTTPServerRequest(uri='/')
obj = Mock(spec=WsockHandler, request=request)
options.cows = 0
request.headers['Host'] = 'www.example.com:4433'
origin = 'https://www.example.com:4433'
self.assertTrue(WsockHandler.check_origin(obj, origin))
origin = 'https://www.example.com'
self.assertFalse(WsockHandler.check_origin(obj, origin))
options.cows = 1
self.assertTrue(WsockHandler.check_origin(obj, origin))
origin = 'https://blog.example.com'
self.assertTrue(WsockHandler.check_origin(obj, origin))
origin = 'https://blog.example.org'
self.assertFalse(WsockHandler.check_origin(obj, origin))
options.cows = 2
self.assertTrue(WsockHandler.check_origin(obj, origin))

View File

@ -3,7 +3,7 @@ import unittest
from webssh.utils import ( from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes, is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
to_int, on_public_network_interface, get_ips_by_name, is_ip_hostname, to_int, on_public_network_interface, get_ips_by_name, is_ip_hostname,
is_name_open_to_public is_name_open_to_public, is_same_primary_domain
) )
@ -79,3 +79,32 @@ class TestUitls(unittest.TestCase):
self.assertTrue(is_ip_hostname('127.0.0.1')) self.assertTrue(is_ip_hostname('127.0.0.1'))
self.assertFalse(is_ip_hostname('localhost')) self.assertFalse(is_ip_hostname('localhost'))
self.assertFalse(is_ip_hostname('www.google.com')) self.assertFalse(is_ip_hostname('www.google.com'))
def test_is_same_primary_domain(self):
domain1 = 'localhost'
domain2 = 'localhost'
self.assertTrue(is_same_primary_domain(domain1, domain2))
domain1 = 'localhost'
domain2 = 'test'
self.assertFalse(is_same_primary_domain(domain1, domain2))
domain1 = 'example.com'
domain2 = 'example.com'
self.assertTrue(is_same_primary_domain(domain1, domain2))
domain1 = 'www.example.com'
domain2 = 'example.com'
self.assertTrue(is_same_primary_domain(domain1, domain2))
domain1 = 'wwwexample.com'
domain2 = 'example.com'
self.assertFalse(is_same_primary_domain(domain1, domain2))
domain1 = 'www.example.com'
domain2 = 'www2.example.com'
self.assertTrue(is_same_primary_domain(domain1, domain2))
domain1 = 'xxx.www.example.com'
domain2 = 'xxx.www2.example.com'
self.assertTrue(is_same_primary_domain(domain1, domain2))

View File

@ -13,7 +13,8 @@ from tornado.ioloop import IOLoop
from tornado.options import options from tornado.options import options
from webssh.utils import ( from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str, is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname,
is_same_primary_domain
) )
from webssh.worker import Worker, recycle_worker, clients from webssh.worker import Worker, recycle_worker, clients
@ -27,6 +28,11 @@ try:
except ImportError: except ImportError:
JSONDecodeError = ValueError JSONDecodeError = ValueError
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
DELAY = 3 DELAY = 3
KEY_MAX_SIZE = 16384 KEY_MAX_SIZE = 16384
@ -364,6 +370,24 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
super(WsockHandler, self).initialize(loop) super(WsockHandler, self).initialize(loop)
self.worker_ref = None self.worker_ref = None
def check_origin(self, origin):
cows = options.cows
parsed_origin = urlparse(origin)
origin = parsed_origin.netloc
origin = origin.lower()
logging.debug('origin: {}'.format(origin))
host = self.request.headers.get('Host')
logging.debug('host: {}'.format(host))
if cows == 0:
return origin == host
elif cows == 1:
return is_same_primary_domain(origin.rsplit(':', 1)[0],
host.rsplit(':', 1)[0])
else:
return True
def open(self): def open(self):
self.src_addr = self.get_client_addr() self.src_addr = self.get_client_addr()
logging.info('Connected from {}:{}'.format(*self.src_addr)) logging.info('Connected from {}:{}'.format(*self.src_addr))

View File

@ -34,6 +34,10 @@ define('fbidhttp', type=bool, default=True,
help='Forbid public plain http incoming requests') help='Forbid public plain http incoming requests')
define('xheaders', type=bool, default=True, help='Support xheaders') define('xheaders', type=bool, default=True, help='Support xheaders')
define('xsrf', type=bool, default=True, help='CSRF protection') define('xsrf', type=bool, default=True, help='CSRF protection')
define('cows', type=int, default=0, help='Cross origin websocket, '
'0: matches host name and port number'
'1: matches primary domain only'
'?: matches nothing, allow all cross-origin websockets')
define('wpintvl', type=int, default=0, help='Websocket ping interval') define('wpintvl', type=int, default=0, help='Websocket ping interval')
define('maxconn', type=int, default=20, help='Maximum connections per client') define('maxconn', type=int, default=20, help='Maximum connections per client')
define('version', type=bool, help='Show version information', define('version', type=bool, help='Show version information',

View File

@ -96,3 +96,31 @@ def is_name_open_to_public(name):
for ip in get_ips_by_name(name): for ip in get_ips_by_name(name):
if on_public_network_interface(ip): if on_public_network_interface(ip):
return True return True
def is_same_primary_domain(domain1, domain2):
i = -1
dots = 0
l1 = len(domain1)
l2 = len(domain2)
m = 0 - min(l1, l2)
while i >= m:
c1 = domain1[i]
c2 = domain2[i]
if c1 == c2:
if c1 == '.':
dots += 1
if dots == 2:
return True
else:
return False
i -= 1
if l1 == l2:
return True
c = domain1[i] if l1 > m else domain2[i]
return c == '.'