diff --git a/webssh/handler.py b/webssh/handler.py index 3a4a68f..6efcb60 100644 --- a/webssh/handler.py +++ b/webssh/handler.py @@ -348,7 +348,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): chan = ssh.invoke_shell(term='xterm') chan.setblocking(0) - worker = Worker(self.loop, ssh, chan, dst_addr, self.src_addr) + worker = Worker(self.loop, ssh, chan, dst_addr) worker.encoding = self.get_default_encoding(ssh) return worker @@ -378,8 +378,9 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): # for testing purpose only raise ValueError('Uncaught exception') - self.src_addr = self.get_client_addr() - if len(clients.get(self.src_addr[0], {})) >= options.maxconn: + ip, port = self.get_client_addr() + workers = clients.get(ip, {}) + if workers and len(workers) >= options.maxconn: raise tornado.web.HTTPError(403, 'Too many live connections.') self.check_origin() @@ -397,7 +398,9 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): logging.error(traceback.format_exc()) self.result.update(status=str(exc)) else: - workers = clients.setdefault(worker.src_addr[0], {}) + if not workers: + clients[ip] = workers + worker.src_addr = (ip, port) workers[worker.id] = worker self.loop.call_later(DELAY, recycle_worker, worker) self.result.update(id=worker.id, encoding=worker.encoding) diff --git a/webssh/worker.py b/webssh/worker.py index d90cec7..8e7e278 100644 --- a/webssh/worker.py +++ b/webssh/worker.py @@ -7,7 +7,22 @@ from tornado.util import errno_from_exception BUF_SIZE = 32 * 1024 -clients = {} +clients = {} # {ip: {id: worker}} + + +def clear_worker(worker, clients): + ip = worker.src_addr[0] + workers = clients.get(ip) + if workers: + try: + workers.pop(worker.id) + except KeyError: + pass + else: + if not workers: + clients.pop(ip) + if not clients: + clients.clear() def recycle_worker(worker): @@ -18,12 +33,11 @@ def recycle_worker(worker): class Worker(object): - def __init__(self, loop, ssh, chan, dst_addr, src_addr): + def __init__(self, loop, ssh, chan, dst_addr): self.loop = loop self.ssh = ssh self.chan = chan self.dst_addr = dst_addr - self.src_addr = src_addr self.fd = chan.fileno() self.id = str(id(self)) self.data_to_dst = [] @@ -110,5 +124,5 @@ class Worker(object): self.ssh.close() logging.info('Connection to {}:{} lost'.format(*self.dst_addr)) - clients[self.src_addr[0]].pop(self.id, None) + clear_worker(self, clients) logging.debug(clients)