mirror of https://github.com/huashengdun/webssh
Added an option for configuring cross-origin websocket level
parent
b51e823973
commit
5c8bd84b95
|
@ -5,7 +5,7 @@ from tornado.httputil import HTTPServerRequest
|
|||
from tornado.options import options
|
||||
from tests.utils import read_file, make_tests_data_path
|
||||
from webssh.handler import (
|
||||
MixinHandler, IndexHandler, InvalidValueError, open_to_public
|
||||
MixinHandler, IndexHandler, WsockHandler, InvalidValueError, open_to_public
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -202,3 +202,30 @@ class TestIndexHandler(unittest.TestCase):
|
|||
|
||||
with self.assertRaises(paramiko.PasswordRequiredException):
|
||||
pkey = IndexHandler.get_pkey_obj(key, '', fname)
|
||||
|
||||
|
||||
class TestWsockHandler(unittest.TestCase):
|
||||
|
||||
def test_check_origin(self):
|
||||
request = HTTPServerRequest(uri='/')
|
||||
obj = Mock(spec=WsockHandler, request=request)
|
||||
|
||||
options.cows = 0
|
||||
request.headers['Host'] = 'www.example.com:4433'
|
||||
origin = 'https://www.example.com:4433'
|
||||
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||
|
||||
origin = 'https://www.example.com'
|
||||
self.assertFalse(WsockHandler.check_origin(obj, origin))
|
||||
|
||||
options.cows = 1
|
||||
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||
|
||||
origin = 'https://blog.example.com'
|
||||
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||
|
||||
origin = 'https://blog.example.org'
|
||||
self.assertFalse(WsockHandler.check_origin(obj, origin))
|
||||
|
||||
options.cows = 2
|
||||
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||
|
|
|
@ -3,7 +3,7 @@ import unittest
|
|||
from webssh.utils import (
|
||||
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
|
||||
to_int, on_public_network_interface, get_ips_by_name, is_ip_hostname,
|
||||
is_name_open_to_public
|
||||
is_name_open_to_public, is_same_primary_domain
|
||||
)
|
||||
|
||||
|
||||
|
@ -79,3 +79,32 @@ class TestUitls(unittest.TestCase):
|
|||
self.assertTrue(is_ip_hostname('127.0.0.1'))
|
||||
self.assertFalse(is_ip_hostname('localhost'))
|
||||
self.assertFalse(is_ip_hostname('www.google.com'))
|
||||
|
||||
def test_is_same_primary_domain(self):
|
||||
domain1 = 'localhost'
|
||||
domain2 = 'localhost'
|
||||
self.assertTrue(is_same_primary_domain(domain1, domain2))
|
||||
|
||||
domain1 = 'localhost'
|
||||
domain2 = 'test'
|
||||
self.assertFalse(is_same_primary_domain(domain1, domain2))
|
||||
|
||||
domain1 = 'example.com'
|
||||
domain2 = 'example.com'
|
||||
self.assertTrue(is_same_primary_domain(domain1, domain2))
|
||||
|
||||
domain1 = 'www.example.com'
|
||||
domain2 = 'example.com'
|
||||
self.assertTrue(is_same_primary_domain(domain1, domain2))
|
||||
|
||||
domain1 = 'wwwexample.com'
|
||||
domain2 = 'example.com'
|
||||
self.assertFalse(is_same_primary_domain(domain1, domain2))
|
||||
|
||||
domain1 = 'www.example.com'
|
||||
domain2 = 'www2.example.com'
|
||||
self.assertTrue(is_same_primary_domain(domain1, domain2))
|
||||
|
||||
domain1 = 'xxx.www.example.com'
|
||||
domain2 = 'xxx.www2.example.com'
|
||||
self.assertTrue(is_same_primary_domain(domain1, domain2))
|
||||
|
|
|
@ -13,7 +13,8 @@ from tornado.ioloop import IOLoop
|
|||
from tornado.options import options
|
||||
from webssh.utils import (
|
||||
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
|
||||
to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname
|
||||
to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname,
|
||||
is_same_primary_domain
|
||||
)
|
||||
from webssh.worker import Worker, recycle_worker, clients
|
||||
|
||||
|
@ -27,6 +28,11 @@ try:
|
|||
except ImportError:
|
||||
JSONDecodeError = ValueError
|
||||
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
except ImportError:
|
||||
from urlparse import urlparse
|
||||
|
||||
|
||||
DELAY = 3
|
||||
KEY_MAX_SIZE = 16384
|
||||
|
@ -364,6 +370,24 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
|
|||
super(WsockHandler, self).initialize(loop)
|
||||
self.worker_ref = None
|
||||
|
||||
def check_origin(self, origin):
|
||||
cows = options.cows
|
||||
parsed_origin = urlparse(origin)
|
||||
origin = parsed_origin.netloc
|
||||
origin = origin.lower()
|
||||
logging.debug('origin: {}'.format(origin))
|
||||
|
||||
host = self.request.headers.get('Host')
|
||||
logging.debug('host: {}'.format(host))
|
||||
|
||||
if cows == 0:
|
||||
return origin == host
|
||||
elif cows == 1:
|
||||
return is_same_primary_domain(origin.rsplit(':', 1)[0],
|
||||
host.rsplit(':', 1)[0])
|
||||
else:
|
||||
return True
|
||||
|
||||
def open(self):
|
||||
self.src_addr = self.get_client_addr()
|
||||
logging.info('Connected from {}:{}'.format(*self.src_addr))
|
||||
|
|
|
@ -34,6 +34,10 @@ define('fbidhttp', type=bool, default=True,
|
|||
help='Forbid public plain http incoming requests')
|
||||
define('xheaders', type=bool, default=True, help='Support xheaders')
|
||||
define('xsrf', type=bool, default=True, help='CSRF protection')
|
||||
define('cows', type=int, default=0, help='Cross origin websocket, '
|
||||
'0: matches host name and port number'
|
||||
'1: matches primary domain only'
|
||||
'?: matches nothing, allow all cross-origin websockets')
|
||||
define('wpintvl', type=int, default=0, help='Websocket ping interval')
|
||||
define('maxconn', type=int, default=20, help='Maximum connections per client')
|
||||
define('version', type=bool, help='Show version information',
|
||||
|
|
|
@ -96,3 +96,31 @@ def is_name_open_to_public(name):
|
|||
for ip in get_ips_by_name(name):
|
||||
if on_public_network_interface(ip):
|
||||
return True
|
||||
|
||||
|
||||
def is_same_primary_domain(domain1, domain2):
|
||||
i = -1
|
||||
dots = 0
|
||||
l1 = len(domain1)
|
||||
l2 = len(domain2)
|
||||
m = 0 - min(l1, l2)
|
||||
|
||||
while i >= m:
|
||||
c1 = domain1[i]
|
||||
c2 = domain2[i]
|
||||
|
||||
if c1 == c2:
|
||||
if c1 == '.':
|
||||
dots += 1
|
||||
if dots == 2:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
i -= 1
|
||||
|
||||
if l1 == l2:
|
||||
return True
|
||||
|
||||
c = domain1[i] if l1 > m else domain2[i]
|
||||
return c == '.'
|
||||
|
|
Loading…
Reference in New Issue