mirror of https://github.com/huashengdun/webssh
Added function for limiting connections for every client(ip)
parent
c126856daa
commit
2653a3e35a
|
@ -15,6 +15,7 @@ from webssh.settings import (
|
||||||
get_app_settings, get_server_settings, max_body_size
|
get_app_settings, get_server_settings, max_body_size
|
||||||
)
|
)
|
||||||
from webssh.utils import to_str
|
from webssh.utils import to_str
|
||||||
|
from webssh.worker import clients
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
@ -447,6 +448,7 @@ class OtherTestBase(AsyncHTTPTestCase):
|
||||||
hostfile = ''
|
hostfile = ''
|
||||||
syshostfile = ''
|
syshostfile = ''
|
||||||
tdstream = ''
|
tdstream = ''
|
||||||
|
maxconn = 20
|
||||||
body = {
|
body = {
|
||||||
'hostname': '127.0.0.1',
|
'hostname': '127.0.0.1',
|
||||||
'port': '',
|
'port': '',
|
||||||
|
@ -464,6 +466,7 @@ class OtherTestBase(AsyncHTTPTestCase):
|
||||||
options.hostfile = self.hostfile
|
options.hostfile = self.hostfile
|
||||||
options.syshostfile = self.syshostfile
|
options.syshostfile = self.syshostfile
|
||||||
options.tdstream = self.tdstream
|
options.tdstream = self.tdstream
|
||||||
|
options.maxconn = self.maxconn
|
||||||
app = make_app(make_handlers(loop, options), get_app_settings(options))
|
app = make_app(make_handlers(loop, options), get_app_settings(options))
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
@ -670,3 +673,46 @@ class TestAppWithPutRequest(OtherTestBase):
|
||||||
url, method='PUT', body=body, headers=self.headers
|
url, method='PUT', body=body, headers=self.headers
|
||||||
)
|
)
|
||||||
self.assertIn('Method Not Allowed', ctx.exception.message)
|
self.assertIn('Method Not Allowed', ctx.exception.message)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAppWithTooManyConnections(OtherTestBase):
|
||||||
|
|
||||||
|
maxconn = 1
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
clients.clear()
|
||||||
|
super(TestAppWithTooManyConnections, self).setUp()
|
||||||
|
|
||||||
|
@tornado.testing.gen_test
|
||||||
|
def test_app_with_too_many_connections(self):
|
||||||
|
url = self.get_url('/')
|
||||||
|
client = self.get_http_client()
|
||||||
|
body = urlencode(dict(self.body, username='foo'))
|
||||||
|
response = yield client.fetch(url, method='POST', body=body,
|
||||||
|
headers=self.headers)
|
||||||
|
data = json.loads(to_str(response.body))
|
||||||
|
worker_id = data['id']
|
||||||
|
self.assertIsNotNone(worker_id)
|
||||||
|
self.assertIsNotNone(data['encoding'])
|
||||||
|
self.assertIsNone(data['status'])
|
||||||
|
|
||||||
|
response = yield client.fetch(url, method='POST', body=body,
|
||||||
|
headers=self.headers)
|
||||||
|
data = json.loads(to_str(response.body))
|
||||||
|
self.assertIsNone(data['id'])
|
||||||
|
self.assertIsNone(data['encoding'])
|
||||||
|
self.assertEqual(data['status'], 'Too many connections.')
|
||||||
|
|
||||||
|
ws_url = url.replace('http', 'ws') + 'ws?id=' + worker_id
|
||||||
|
ws = yield tornado.websocket.websocket_connect(ws_url)
|
||||||
|
msg = yield ws.read_message()
|
||||||
|
self.assertIsNotNone(msg)
|
||||||
|
|
||||||
|
response = yield client.fetch(url, method='POST', body=body,
|
||||||
|
headers=self.headers)
|
||||||
|
data = json.loads(to_str(response.body))
|
||||||
|
self.assertIsNone(data['id'])
|
||||||
|
self.assertIsNone(data['encoding'])
|
||||||
|
self.assertEqual(data['status'], 'Too many connections.')
|
||||||
|
|
||||||
|
ws.close()
|
||||||
|
|
|
@ -15,7 +15,7 @@ from webssh.utils import (
|
||||||
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
|
is_valid_ip_address, is_valid_port, is_valid_hostname, to_bytes, to_str,
|
||||||
to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname
|
to_int, to_ip_address, UnicodeType, is_name_open_to_public, is_ip_hostname
|
||||||
)
|
)
|
||||||
from webssh.worker import Worker, recycle_worker, workers
|
from webssh.worker import Worker, recycle_worker, clients
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
|
@ -311,8 +311,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
|
|
||||||
chan = ssh.invoke_shell(term='xterm')
|
chan = ssh.invoke_shell(term='xterm')
|
||||||
chan.setblocking(0)
|
chan.setblocking(0)
|
||||||
worker = Worker(self.loop, ssh, chan, dst_addr)
|
worker = Worker(self.loop, ssh, chan, dst_addr, self.src_addr)
|
||||||
worker.src_addr = self.get_client_addr()
|
|
||||||
worker.encoding = self.get_default_encoding(ssh)
|
worker.encoding = self.get_default_encoding(ssh)
|
||||||
return worker
|
return worker
|
||||||
|
|
||||||
|
@ -337,6 +336,10 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
# for testing purpose only
|
# for testing purpose only
|
||||||
raise ValueError('Uncaught exception')
|
raise ValueError('Uncaught exception')
|
||||||
|
|
||||||
|
self.src_addr = self.get_client_addr()
|
||||||
|
if len(clients.get(self.src_addr[0], {})) >= options.maxconn:
|
||||||
|
raise tornado.web.HTTPError(403, 'Too many connections.')
|
||||||
|
|
||||||
future = Future()
|
future = Future()
|
||||||
t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,))
|
t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,))
|
||||||
t.setDaemon(True)
|
t.setDaemon(True)
|
||||||
|
@ -347,6 +350,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
|
||||||
except (ValueError, paramiko.SSHException) as exc:
|
except (ValueError, paramiko.SSHException) as exc:
|
||||||
self.result.update(status=str(exc))
|
self.result.update(status=str(exc))
|
||||||
else:
|
else:
|
||||||
|
workers = clients.setdefault(worker.src_addr[0], {})
|
||||||
workers[worker.id] = worker
|
workers[worker.id] = worker
|
||||||
self.loop.call_later(DELAY, recycle_worker, worker)
|
self.loop.call_later(DELAY, recycle_worker, worker)
|
||||||
self.result.update(id=worker.id, encoding=worker.encoding)
|
self.result.update(id=worker.id, encoding=worker.encoding)
|
||||||
|
@ -363,14 +367,20 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler):
|
||||||
def open(self):
|
def open(self):
|
||||||
self.src_addr = self.get_client_addr()
|
self.src_addr = self.get_client_addr()
|
||||||
logging.info('Connected from {}:{}'.format(*self.src_addr))
|
logging.info('Connected from {}:{}'.format(*self.src_addr))
|
||||||
|
|
||||||
|
workers = clients.get(self.src_addr[0])
|
||||||
|
if not workers:
|
||||||
|
self.close(reason='Websocket authentication failed.')
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
worker_id = self.get_value('id')
|
worker_id = self.get_value('id')
|
||||||
except (tornado.web.MissingArgumentError, InvalidValueError) as exc:
|
except (tornado.web.MissingArgumentError, InvalidValueError) as exc:
|
||||||
self.close(reason=str(exc))
|
self.close(reason=str(exc))
|
||||||
else:
|
else:
|
||||||
worker = workers.get(worker_id)
|
worker = workers.get(worker_id)
|
||||||
if worker and worker.src_addr[0] == self.src_addr[0]:
|
if worker:
|
||||||
workers.pop(worker.id)
|
workers[worker_id] = None
|
||||||
self.set_nodelay(True)
|
self.set_nodelay(True)
|
||||||
worker.set_handler(self)
|
worker.set_handler(self)
|
||||||
self.worker_ref = weakref.ref(worker)
|
self.worker_ref = weakref.ref(worker)
|
||||||
|
|
|
@ -35,6 +35,7 @@ define('fbidhttp', type=bool, default=True,
|
||||||
define('xheaders', type=bool, default=True, help='Support xheaders')
|
define('xheaders', type=bool, default=True, help='Support xheaders')
|
||||||
define('xsrf', type=bool, default=True, help='CSRF protection')
|
define('xsrf', type=bool, default=True, help='CSRF protection')
|
||||||
define('wpintvl', type=int, default=0, help='Websocket ping interval')
|
define('wpintvl', type=int, default=0, help='Websocket ping interval')
|
||||||
|
define('maxconn', type=int, default=20, help='Maximum connections per client')
|
||||||
define('version', type=bool, help='Show version information',
|
define('version', type=bool, help='Show version information',
|
||||||
callback=print_version)
|
callback=print_version)
|
||||||
|
|
||||||
|
|
|
@ -7,23 +7,23 @@ from tornado.util import errno_from_exception
|
||||||
|
|
||||||
|
|
||||||
BUF_SIZE = 32 * 1024
|
BUF_SIZE = 32 * 1024
|
||||||
workers = {}
|
clients = {}
|
||||||
|
|
||||||
|
|
||||||
def recycle_worker(worker):
|
def recycle_worker(worker):
|
||||||
if worker.handler:
|
if worker.handler:
|
||||||
return
|
return
|
||||||
logging.warning('Recycling worker {}'.format(worker.id))
|
logging.warning('Recycling worker {}'.format(worker.id))
|
||||||
workers.pop(worker.id, None)
|
|
||||||
worker.close(reason='worker recycled')
|
worker.close(reason='worker recycled')
|
||||||
|
|
||||||
|
|
||||||
class Worker(object):
|
class Worker(object):
|
||||||
def __init__(self, loop, ssh, chan, dst_addr):
|
def __init__(self, loop, ssh, chan, dst_addr, src_addr):
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.ssh = ssh
|
self.ssh = ssh
|
||||||
self.chan = chan
|
self.chan = chan
|
||||||
self.dst_addr = dst_addr
|
self.dst_addr = dst_addr
|
||||||
|
self.src_addr = src_addr
|
||||||
self.fd = chan.fileno()
|
self.fd = chan.fileno()
|
||||||
self.id = str(id(self))
|
self.id = str(id(self))
|
||||||
self.data_to_dst = []
|
self.data_to_dst = []
|
||||||
|
@ -104,3 +104,6 @@ class Worker(object):
|
||||||
self.chan.close()
|
self.chan.close()
|
||||||
self.ssh.close()
|
self.ssh.close()
|
||||||
logging.info('Connection to {}:{} lost'.format(*self.dst_addr))
|
logging.info('Connection to {}:{} lost'.format(*self.dst_addr))
|
||||||
|
|
||||||
|
clients[self.src_addr[0]].pop(self.id, None)
|
||||||
|
logging.debug(clients)
|
||||||
|
|
Loading…
Reference in New Issue