import logging import tornado.websocket from tornado.ioloop import IOLoop from tornado.iostream import _ERRNO_CONNRESET from tornado.util import errno_from_exception BUF_SIZE = 1024 workers = {} def recycle_worker(worker): if worker.handler: return logging.debug('Recycling worker {}'.format(worker.id)) workers.pop(worker.id, None) worker.close() class Worker(object): def __init__(self, loop, ssh, chan, dst_addr): self.loop = loop self.ssh = ssh self.chan = chan self.dst_addr = dst_addr self.fd = chan.fileno() self.id = str(id(self)) self.data_to_dst = [] self.handler = None self.mode = IOLoop.READ def __call__(self, fd, events): if events & IOLoop.READ: self.on_read() if events & IOLoop.WRITE: self.on_write() if events & IOLoop.ERROR: self.close() def set_handler(self, handler): if not self.handler: self.handler = handler def update_handler(self, mode): if self.mode != mode: self.loop.update_handler(self.fd, mode) self.mode = mode def on_read(self): logging.debug('worker {} on read'.format(self.id)) try: data = self.chan.recv(BUF_SIZE) except (OSError, IOError) as e: logging.error(e) if errno_from_exception(e) in _ERRNO_CONNRESET: self.close() else: logging.debug('{!r} from {}:{}'.format(data, *self.dst_addr)) if not data: self.close() return logging.debug('{!r} to {}:{}'.format(data, *self.handler.src_addr)) try: self.handler.write_message(data) except tornado.websocket.WebSocketClosedError: self.close() def on_write(self): logging.debug('worker {} on write'.format(self.id)) if not self.data_to_dst: return data = ''.join(self.data_to_dst) logging.debug('{!r} to {}:{}'.format(data, *self.dst_addr)) try: sent = self.chan.send(data) except (OSError, IOError) as e: logging.error(e) if errno_from_exception(e) in _ERRNO_CONNRESET: self.close() else: self.update_handler(IOLoop.WRITE) else: self.data_to_dst = [] data = data[sent:] if data: self.data_to_dst.append(data) self.update_handler(IOLoop.WRITE) else: self.update_handler(IOLoop.READ) def close(self): logging.debug('Closing worker {}'.format(self.id)) if self.handler: self.loop.remove_handler(self.fd) self.handler.close() self.chan.close() self.ssh.close() logging.info('Connection to {}:{} lost'.format(*self.dst_addr))