mirror of https://github.com/huashengdun/webssh
Added an option for configuring cross-origin websocket level
parent
b51e823973
commit
5c8bd84b95
|
@ -5,7 +5,7 @@ from tornado.httputil import HTTPServerRequest
|
||||||
from tornado.options import options
|
from tornado.options import options
|
||||||
from tests.utils import read_file, make_tests_data_path
|
from tests.utils import read_file, make_tests_data_path
|
||||||
from webssh.handler import (
|
from webssh.handler import (
|
||||||
MixinHandler, IndexHandler, InvalidValueError, open_to_public
|
MixinHandler, IndexHandler, WsockHandler, InvalidValueError, open_to_public
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -202,3 +202,30 @@ class TestIndexHandler(unittest.TestCase):
|
||||||
|
|
||||||
with self.assertRaises(paramiko.PasswordRequiredException):
|
with self.assertRaises(paramiko.PasswordRequiredException):
|
||||||
pkey = IndexHandler.get_pkey_obj(key, '', fname)
|
pkey = IndexHandler.get_pkey_obj(key, '', fname)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWsockHandler(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_check_origin(self):
|
||||||
|
request = HTTPServerRequest(uri='/')
|
||||||
|
obj = Mock(spec=WsockHandler, request=request)
|
||||||
|
|
||||||
|
options.cows = 0
|
||||||
|
request.headers['Host'] = 'www.example.com:4433'
|
||||||
|
origin = 'https://www.example.com:4433'
|
||||||
|
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||||
|
|
||||||
|
origin = 'https://www.example.com'
|
||||||
|
self.assertFalse(WsockHandler.check_origin(obj, origin))
|
||||||
|
|
||||||
|
options.cows = 1
|
||||||
|
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||||
|
|
||||||
|
origin = 'https://blog.example.com'
|
||||||
|
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||||
|
|
||||||
|
origin = 'https://blog.example.org'
|
||||||
|
self.assertFalse(WsockHandler.check_origin(obj, origin))
|
||||||
|
|
||||||
|
options.cows = 2
|
||||||
|
self.assertTrue(WsockHandler.check_origin(obj, origin))
|
||||||
|
|
|
@ -3,7 +3,7 @@ import unittest
|
||||||
from webssh.utils import (
|
from webssh.utils import (
|
||||||
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
|
is_valid_ip_address, is_valid_port, is_valid_hostname, to_str, to_bytes,
|
||||||
to_int, on_public_network_interface, get_ips_by_name, is_ip_hostname,
|
to_int, on_public_network_interface, get_ips_by_name, is_ip_hostname,
|
||||||
is_name_open_to_public
|
is_name_open_to_public, is_same_primary_domain
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,3 +79,32 @@ class TestUitls(unittest.TestCase):
|
||||||
self.assertTrue(is_ip_hostname('127.0.0.1'))
|
self.assertTrue(is_ip_hostname('127.0.0.1'))
|
||||||
self.assertFalse(is_ip_hostname('localhost'))
|
self.assertFalse(is_ip_hostname('localhost'))
|
||||||
self.assertFalse(is_ip_hostname('www.google.com'))
|
self.assertFalse(is_ip_hostname('www.google.com'))
|
||||||
|
|
||||||
|
def test_is_same_primary_domain(self):
|
||||||
|
domain1 = 'localhost'
|
||||||
|
domain2 = 'localhost'
|
||||||
|
self.assertTrue(is_same_primary_domain(domain1, domain2))
|
||||||
|
|
||||||
|
domain1 = 'localhost'
|
||||||
|
domain2 = 'test'
|
||||||
|
self.assertFalse(is_same_primary_domain(domain1, domain2))
|
||||||
|
|
||||||
|
domain1 = 'example.com'
|
||||||
|
domain2 = 'example.com'
|
||||||
|
self.assertTrue(is_same_primary_domain(domain1, domain2))
|
||||||
|
|
||||||
|
domain1 = 'www.example.com'
|
||||||
|
domain2 = 'example.com'
|
||||||
|
self.assertTrue(is_same_primary_domain(domain1, domain2))
|
||||||
|
|
||||||
|
domain1 = 'wwwexample.com'
|
||||||
|
domain2 = 'example.com'
|
||||||
|
self.assertFalse(is_same_primary_domain(domain1, domain2))
|
||||||
|
|
||||||
|
domain1 = 'www.example.com'
|
||||||
|
domain2 = 'www2.example.com'
|
||||||
|
self.assertTrue(is_same_primary_domain(domain1, domain2))
|
||||||
|
|
||||||
|
domain1 = 'xxx.www.example.com'
|
||||||
|
domain2 = 'xxx.www2.example.com'
|
||||||
|
self.assertTrue(is_same_primary_domain(domain1, domain2))
|
||||||
|
|
|
@ -13,7 +13,8 @@ from tornado.ioloop import IOLoop
|
||||||
from tornado.options import options
|
from tornado.options import options
|
||||||
from webssh.utils import (
|
from webssh.utils import (
|
||||||
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
|
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
|
||||||
to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname
|
to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname,
|
||||||
|
is_same_primary_domain
|
||||||
)
|
)
|
||||||
from webssh.worker import Worker, recycle_worker, clients
|
from webssh.worker import Worker, recycle_worker, clients
|
||||||
|
|
||||||
|
@ -27,6 +28,11 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
JSONDecodeError = ValueError
|
JSONDecodeError = ValueError
|
||||||
|
|
||||||
|
try:
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
except ImportError:
|
||||||
|
from urlparse import urlparse
|
||||||
|
|
||||||
|
|
||||||
DELAY = 3
|
DELAY = 3
|
||||||
KEY_MAX_SIZE = 16384
|
KEY_MAX_SIZE = 16384
|
||||||
|
@ -364,6 +370,24 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
|
||||||
super(WsockHandler, self).initialize(loop)
|
super(WsockHandler, self).initialize(loop)
|
||||||
self.worker_ref = None
|
self.worker_ref = None
|
||||||
|
|
||||||
|
def check_origin(self, origin):
|
||||||
|
cows = options.cows
|
||||||
|
parsed_origin = urlparse(origin)
|
||||||
|
origin = parsed_origin.netloc
|
||||||
|
origin = origin.lower()
|
||||||
|
logging.debug('origin: {}'.format(origin))
|
||||||
|
|
||||||
|
host = self.request.headers.get('Host')
|
||||||
|
logging.debug('host: {}'.format(host))
|
||||||
|
|
||||||
|
if cows == 0:
|
||||||
|
return origin == host
|
||||||
|
elif cows == 1:
|
||||||
|
return is_same_primary_domain(origin.rsplit(':', 1)[0],
|
||||||
|
host.rsplit(':', 1)[0])
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
self.src_addr = self.get_client_addr()
|
self.src_addr = self.get_client_addr()
|
||||||
logging.info('Connected from {}:{}'.format(*self.src_addr))
|
logging.info('Connected from {}:{}'.format(*self.src_addr))
|
||||||
|
|
|
@ -34,6 +34,10 @@ define('fbidhttp', type=bool, default=True,
|
||||||
help='Forbid public plain http incoming requests')
|
help='Forbid public plain http incoming requests')
|
||||||
define('xheaders', type=bool, default=True, help='Support xheaders')
|
define('xheaders', type=bool, default=True, help='Support xheaders')
|
||||||
define('xsrf', type=bool, default=True, help='CSRF protection')
|
define('xsrf', type=bool, default=True, help='CSRF protection')
|
||||||
|
define('cows', type=int, default=0, help='Cross origin websocket, '
|
||||||
|
'0: matches host name and port number'
|
||||||
|
'1: matches primary domain only'
|
||||||
|
'?: matches nothing, allow all cross-origin websockets')
|
||||||
define('wpintvl', type=int, default=0, help='Websocket ping interval')
|
define('wpintvl', type=int, default=0, help='Websocket ping interval')
|
||||||
define('maxconn', type=int, default=20, help='Maximum connections per client')
|
define('maxconn', type=int, default=20, help='Maximum connections per client')
|
||||||
define('version', type=bool, help='Show version information',
|
define('version', type=bool, help='Show version information',
|
||||||
|
|
|
@ -96,3 +96,31 @@ def is_name_open_to_public(name):
|
||||||
for ip in get_ips_by_name(name):
|
for ip in get_ips_by_name(name):
|
||||||
if on_public_network_interface(ip):
|
if on_public_network_interface(ip):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_same_primary_domain(domain1, domain2):
|
||||||
|
i = -1
|
||||||
|
dots = 0
|
||||||
|
l1 = len(domain1)
|
||||||
|
l2 = len(domain2)
|
||||||
|
m = 0 - min(l1, l2)
|
||||||
|
|
||||||
|
while i >= m:
|
||||||
|
c1 = domain1[i]
|
||||||
|
c2 = domain2[i]
|
||||||
|
|
||||||
|
if c1 == c2:
|
||||||
|
if c1 == '.':
|
||||||
|
dots += 1
|
||||||
|
if dots == 2:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
i -= 1
|
||||||
|
|
||||||
|
if l1 == l2:
|
||||||
|
return True
|
||||||
|
|
||||||
|
c = domain1[i] if l1 > m else domain2[i]
|
||||||
|
return c == '.'
|
||||||
|
|
Loading…
Reference in New Issue