Added function for limiting connections for every client(ip)

pull/58/head
Sheng 2018-12-29 16:16:06 +08:00
parent c126856daa
commit 2653a3e35a
4 changed files with 68 additions and 8 deletions

View File

@ -15,6 +15,7 @@ from webssh.settings import (
get_app_settings, get_server_settings, max_body_size get_app_settings, get_server_settings, max_body_size
) )
from webssh.utils import to_str from webssh.utils import to_str
from webssh.worker import clients
try: try:
from urllib.parse import urlencode from urllib.parse import urlencode
@ -447,6 +448,7 @@ class OtherTestBase(AsyncHTTPTestCase):
hostfile = '' hostfile = ''
syshostfile = '' syshostfile = ''
tdstream = '' tdstream = ''
maxconn = 20
body = { body = {
'hostname': '127.0.0.1', 'hostname': '127.0.0.1',
'port': '', 'port': '',
@ -464,6 +466,7 @@ class OtherTestBase(AsyncHTTPTestCase):
options.hostfile = self.hostfile options.hostfile = self.hostfile
options.syshostfile = self.syshostfile options.syshostfile = self.syshostfile
options.tdstream = self.tdstream options.tdstream = self.tdstream
options.maxconn = self.maxconn
app = make_app(make_handlers(loop, options), get_app_settings(options)) app = make_app(make_handlers(loop, options), get_app_settings(options))
return app return app
@ -670,3 +673,46 @@ class TestAppWithPutRequest(OtherTestBase):
url, method='PUT', body=body, headers=self.headers url, method='PUT', body=body, headers=self.headers
) )
self.assertIn('Method Not Allowed', ctx.exception.message) self.assertIn('Method Not Allowed', ctx.exception.message)
class TestAppWithTooManyConnections(OtherTestBase):
maxconn = 1
def setUp(self):
clients.clear()
super(TestAppWithTooManyConnections, self).setUp()
@tornado.testing.gen_test
def test_app_with_too_many_connections(self):
url = self.get_url('/')
client = self.get_http_client()
body = urlencode(dict(self.body, username='foo'))
response = yield client.fetch(url, method='POST', body=body,
headers=self.headers)
data = json.loads(to_str(response.body))
worker_id = data['id']
self.assertIsNotNone(worker_id)
self.assertIsNotNone(data['encoding'])
self.assertIsNone(data['status'])
response = yield client.fetch(url, method='POST', body=body,
headers=self.headers)
data = json.loads(to_str(response.body))
self.assertIsNone(data['id'])
self.assertIsNone(data['encoding'])
self.assertEqual(data['status'], 'Too many connections.')
ws_url = url.replace('http', 'ws') + 'ws?id=' + worker_id
ws = yield tornado.websocket.websocket_connect(ws_url)
msg = yield ws.read_message()
self.assertIsNotNone(msg)
response = yield client.fetch(url, method='POST', body=body,
headers=self.headers)
data = json.loads(to_str(response.body))
self.assertIsNone(data['id'])
self.assertIsNone(data['encoding'])
self.assertEqual(data['status'], 'Too many connections.')
ws.close()

View File

@ -15,7 +15,7 @@ from webssh.utils import (
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str, is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname
) )
from webssh.worker import Worker, recycle_worker, workers from webssh.worker import Worker, recycle_worker, clients
try: try:
from concurrent.futures import Future from concurrent.futures import Future
@ -311,8 +311,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
chan = ssh.invoke_shell(term='xterm') chan = ssh.invoke_shell(term='xterm')
chan.setblocking(0) chan.setblocking(0)
worker = Worker(self.loop, ssh, chan, dst_addr) worker = Worker(self.loop, ssh, chan, dst_addr, self.src_addr)
worker.src_addr = self.get_client_addr()
worker.encoding = self.get_default_encoding(ssh) worker.encoding = self.get_default_encoding(ssh)
return worker return worker
@ -337,6 +336,10 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
# for testing purpose only # for testing purpose only
raise ValueError('Uncaught exception') raise ValueError('Uncaught exception')
self.src_addr = self.get_client_addr()
if len(clients.get(self.src_addr[0], {})) >= options.maxconn:
raise tornado.web.HTTPError(403, 'Too many connections.')
future = Future() future = Future()
t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,)) t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,))
t.setDaemon(True) t.setDaemon(True)
@ -347,6 +350,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
except (ValueError, paramiko.SSHException) as exc: except (ValueError, paramiko.SSHException) as exc:
self.result.update(status=str(exc)) self.result.update(status=str(exc))
else: else:
workers = clients.setdefault(worker.src_addr[0], {})
workers[worker.id] = worker workers[worker.id] = worker
self.loop.call_later(DELAY, recycle_worker, worker) self.loop.call_later(DELAY, recycle_worker, worker)
self.result.update(id=worker.id, encoding=worker.encoding) self.result.update(id=worker.id, encoding=worker.encoding)
@ -363,14 +367,20 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
def open(self): def open(self):
self.src_addr = self.get_client_addr() self.src_addr = self.get_client_addr()
logging.info('Connected from {}:{}'.format(*self.src_addr)) logging.info('Connected from {}:{}'.format(*self.src_addr))
workers = clients.get(self.src_addr[0])
if not workers:
self.close(reason='Websocket authentication failed.')
return
try: try:
worker_id = self.get_value('id') worker_id = self.get_value('id')
except (tornado.web.MissingArgumentError, InvalidValueError) as exc: except (tornado.web.MissingArgumentError, InvalidValueError) as exc:
self.close(reason=str(exc)) self.close(reason=str(exc))
else: else:
worker = workers.get(worker_id) worker = workers.get(worker_id)
if worker and worker.src_addr[0] == self.src_addr[0]: if worker:
workers.pop(worker.id) workers[worker_id] = None
self.set_nodelay(True) self.set_nodelay(True)
worker.set_handler(self) worker.set_handler(self)
self.worker_ref = weakref.ref(worker) self.worker_ref = weakref.ref(worker)

View File

@ -35,6 +35,7 @@ define('fbidhttp', type=bool, default=True,
define('xheaders', type=bool, default=True, help='Support xheaders') define('xheaders', type=bool, default=True, help='Support xheaders')
define('xsrf', type=bool, default=True, help='CSRF protection') define('xsrf', type=bool, default=True, help='CSRF protection')
define('wpintvl', type=int, default=0, help='Websocket ping interval') define('wpintvl', type=int, default=0, help='Websocket ping interval')
define('maxconn', type=int, default=20, help='Maximum connections per client')
define('version', type=bool, help='Show version information', define('version', type=bool, help='Show version information',
callback=print_version) callback=print_version)

View File

@ -7,23 +7,23 @@ from tornado.util import errno_from_exception
BUF_SIZE = 32 * 1024 BUF_SIZE = 32 * 1024
workers = {} clients = {}
def recycle_worker(worker): def recycle_worker(worker):
if worker.handler: if worker.handler:
return return
logging.warning('Recycling worker {}'.format(worker.id)) logging.warning('Recycling worker {}'.format(worker.id))
workers.pop(worker.id, None)
worker.close(reason='worker recycled') worker.close(reason='worker recycled')
class Worker(object): class Worker(object):
def __init__(self, loop, ssh, chan, dst_addr): def __init__(self, loop, ssh, chan, dst_addr, src_addr):
self.loop = loop self.loop = loop
self.ssh = ssh self.ssh = ssh
self.chan = chan self.chan = chan
self.dst_addr = dst_addr self.dst_addr = dst_addr
self.src_addr = src_addr
self.fd = chan.fileno() self.fd = chan.fileno()
self.id = str(id(self)) self.id = str(id(self))
self.data_to_dst = [] self.data_to_dst = []
@ -104,3 +104,6 @@ class Worker(object):
self.chan.close() self.chan.close()
self.ssh.close() self.ssh.close()
logging.info('Connection to {}:{} lost'.format(*self.dst_addr)) logging.info('Connection to {}:{} lost'.format(*self.dst_addr))
clients[self.src_addr[0]].pop(self.id, None)
logging.debug(clients)