mirror of https://github.com/huashengdun/webssh
Added function for limiting connections for every client(ip)
parent
c126856daa
commit
2653a3e35a
|
@ -15,6 +15,7 @@ from webssh.settings import (
|
|||
get_app_settings, get_server_settings, max_body_size
|
||||
)
|
||||
from webssh.utils import to_str
|
||||
from webssh.worker import clients
|
||||
|
||||
try:
|
||||
from urllib.parse import urlencode
|
||||
|
@ -447,6 +448,7 @@ class OtherTestBase(AsyncHTTPTestCase):
|
|||
hostfile = ''
|
||||
syshostfile = ''
|
||||
tdstream = ''
|
||||
maxconn = 20
|
||||
body = {
|
||||
'hostname': '127.0.0.1',
|
||||
'port': '',
|
||||
|
@ -464,6 +466,7 @@ class OtherTestBase(AsyncHTTPTestCase):
|
|||
options.hostfile = self.hostfile
|
||||
options.syshostfile = self.syshostfile
|
||||
options.tdstream = self.tdstream
|
||||
options.maxconn = self.maxconn
|
||||
app = make_app(make_handlers(loop, options), get_app_settings(options))
|
||||
return app
|
||||
|
||||
|
@ -670,3 +673,46 @@ class TestAppWithPutRequest(OtherTestBase):
|
|||
url, method='PUT', body=body, headers=self.headers
|
||||
)
|
||||
self.assertIn('Method Not Allowed', ctx.exception.message)
|
||||
|
||||
|
||||
class TestAppWithTooManyConnections(OtherTestBase):
|
||||
|
||||
maxconn = 1
|
||||
|
||||
def setUp(self):
|
||||
clients.clear()
|
||||
super(TestAppWithTooManyConnections, self).setUp()
|
||||
|
||||
@tornado.testing.gen_test
|
||||
def test_app_with_too_many_connections(self):
|
||||
url = self.get_url('/')
|
||||
client = self.get_http_client()
|
||||
body = urlencode(dict(self.body, username='foo'))
|
||||
response = yield client.fetch(url, method='POST', body=body,
|
||||
headers=self.headers)
|
||||
data = json.loads(to_str(response.body))
|
||||
worker_id = data['id']
|
||||
self.assertIsNotNone(worker_id)
|
||||
self.assertIsNotNone(data['encoding'])
|
||||
self.assertIsNone(data['status'])
|
||||
|
||||
response = yield client.fetch(url, method='POST', body=body,
|
||||
headers=self.headers)
|
||||
data = json.loads(to_str(response.body))
|
||||
self.assertIsNone(data['id'])
|
||||
self.assertIsNone(data['encoding'])
|
||||
self.assertEqual(data['status'], 'Too many connections.')
|
||||
|
||||
ws_url = url.replace('http', 'ws') + 'ws?id=' + worker_id
|
||||
ws = yield tornado.websocket.websocket_connect(ws_url)
|
||||
msg = yield ws.read_message()
|
||||
self.assertIsNotNone(msg)
|
||||
|
||||
response = yield client.fetch(url, method='POST', body=body,
|
||||
headers=self.headers)
|
||||
data = json.loads(to_str(response.body))
|
||||
self.assertIsNone(data['id'])
|
||||
self.assertIsNone(data['encoding'])
|
||||
self.assertEqual(data['status'], 'Too many connections.')
|
||||
|
||||
ws.close()
|
||||
|
|
|
@ -15,7 +15,7 @@ from webssh.utils import (
|
|||
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
|
||||
to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname
|
||||
)
|
||||
from webssh.worker import Worker, recycle_worker, workers
|
||||
from webssh.worker import Worker, recycle_worker, clients
|
||||
|
||||
try:
|
||||
from concurrent.futures import Future
|
||||
|
@ -311,8 +311,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
|||
|
||||
chan = ssh.invoke_shell(term='xterm')
|
||||
chan.setblocking(0)
|
||||
worker = Worker(self.loop, ssh, chan, dst_addr)
|
||||
worker.src_addr = self.get_client_addr()
|
||||
worker = Worker(self.loop, ssh, chan, dst_addr, self.src_addr)
|
||||
worker.encoding = self.get_default_encoding(ssh)
|
||||
return worker
|
||||
|
||||
|
@ -337,6 +336,10 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
|||
# for testing purpose only
|
||||
raise ValueError('Uncaught exception')
|
||||
|
||||
self.src_addr = self.get_client_addr()
|
||||
if len(clients.get(self.src_addr[0], {})) >= options.maxconn:
|
||||
raise tornado.web.HTTPError(403, 'Too many connections.')
|
||||
|
||||
future = Future()
|
||||
t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,))
|
||||
t.setDaemon(True)
|
||||
|
@ -347,6 +350,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
|||
except (ValueError, paramiko.SSHException) as exc:
|
||||
self.result.update(status=str(exc))
|
||||
else:
|
||||
workers = clients.setdefault(worker.src_addr[0], {})
|
||||
workers[worker.id] = worker
|
||||
self.loop.call_later(DELAY, recycle_worker, worker)
|
||||
self.result.update(id=worker.id, encoding=worker.encoding)
|
||||
|
@ -363,14 +367,20 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
|
|||
def open(self):
|
||||
self.src_addr = self.get_client_addr()
|
||||
logging.info('Connected from {}:{}'.format(*self.src_addr))
|
||||
|
||||
workers = clients.get(self.src_addr[0])
|
||||
if not workers:
|
||||
self.close(reason='Websocket authentication failed.')
|
||||
return
|
||||
|
||||
try:
|
||||
worker_id = self.get_value('id')
|
||||
except (tornado.web.MissingArgumentError, InvalidValueError) as exc:
|
||||
self.close(reason=str(exc))
|
||||
else:
|
||||
worker = workers.get(worker_id)
|
||||
if worker and worker.src_addr[0] == self.src_addr[0]:
|
||||
workers.pop(worker.id)
|
||||
if worker:
|
||||
workers[worker_id] = None
|
||||
self.set_nodelay(True)
|
||||
worker.set_handler(self)
|
||||
self.worker_ref = weakref.ref(worker)
|
||||
|
|
|
@ -35,6 +35,7 @@ define('fbidhttp', type=bool, default=True,
|
|||
define('xheaders', type=bool, default=True, help='Support xheaders')
|
||||
define('xsrf', type=bool, default=True, help='CSRF protection')
|
||||
define('wpintvl', type=int, default=0, help='Websocket ping interval')
|
||||
define('maxconn', type=int, default=20, help='Maximum connections per client')
|
||||
define('version', type=bool, help='Show version information',
|
||||
callback=print_version)
|
||||
|
||||
|
|
|
@ -7,23 +7,23 @@ from tornado.util import errno_from_exception
|
|||
|
||||
|
||||
BUF_SIZE = 32 * 1024
|
||||
workers = {}
|
||||
clients = {}
|
||||
|
||||
|
||||
def recycle_worker(worker):
|
||||
if worker.handler:
|
||||
return
|
||||
logging.warning('Recycling worker {}'.format(worker.id))
|
||||
workers.pop(worker.id, None)
|
||||
worker.close(reason='worker recycled')
|
||||
|
||||
|
||||
class Worker(object):
|
||||
def __init__(self, loop, ssh, chan, dst_addr):
|
||||
def __init__(self, loop, ssh, chan, dst_addr, src_addr):
|
||||
self.loop = loop
|
||||
self.ssh = ssh
|
||||
self.chan = chan
|
||||
self.dst_addr = dst_addr
|
||||
self.src_addr = src_addr
|
||||
self.fd = chan.fileno()
|
||||
self.id = str(id(self))
|
||||
self.data_to_dst = []
|
||||
|
@ -104,3 +104,6 @@ class Worker(object):
|
|||
self.chan.close()
|
||||
self.ssh.close()
|
||||
logging.info('Connection to {}:{} lost'.format(*self.dst_addr))
|
||||
|
||||
clients[self.src_addr[0]].pop(self.id, None)
|
||||
logging.debug(clients)
|
||||
|
|
Loading…
Reference in New Issue