From 311fcfedc926b5c82cf7a1d651b9390e6181decc Mon Sep 17 00:00:00 2001 From: Sheng Date: Tue, 6 Mar 2018 09:34:55 +0800 Subject: [PATCH] Added a MixinHandler --- main.py | 57 ++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/main.py b/main.py index 537918f..578453f 100644 --- a/main.py +++ b/main.py @@ -71,12 +71,12 @@ class Worker(object): if errno_from_exception(e) in _ERRNO_CONNRESET: self.close() else: - logging.debug('"{}" from {}'.format(data, self.dst_addr)) + logging.debug('"{}" from {}:{}'.format(data, *self.dst_addr)) if not data: self.close() return - logging.debug('"{}" to {}'.format(data, self.handler.src_addr)) + logging.debug('"{}" to {}:{}'.format(data, *self.handler.src_addr)) try: self.handler.write_message(data) except tornado.websocket.WebSocketClosedError: @@ -88,7 +88,7 @@ class Worker(object): return data = ''.join(self.data_to_dst) - logging.debug('"{}" to {}'.format(data, self.dst_addr)) + logging.debug('"{}" to {}:{}'.format(data, *self.dst_addr)) try: sent = self.chan.send(data) @@ -114,10 +114,28 @@ class Worker(object): self.handler.close() self.chan.close() self.ssh.close() - logging.info('Connection to {} lost'.format(self.dst_addr)) + logging.info('Connection to {}:{} lost'.format(*self.dst_addr)) -class IndexHandler(tornado.web.RequestHandler): +class MixinHandler(object): + + def get_addr(self): + ip = self.request.headers.get('X-Real-Ip') + port = self.request.headers.get('X-Real-Port') + + if ip and port: + addr = (ip, int(port)) + elif not ip and not port: + if not getattr(self, 'stream', None): + self.stream = self.request.connection.stream + addr = self.stream.socket.getpeername() + else: + raise ValueError('Wrong nginx configuration.') + + return addr + + +class IndexHandler(MixinHandler, tornado.web.RequestHandler): def get_privatekey(self): try: data = self.request.files.get('privatekey')[0]['body'] @@ -184,12 +202,12 @@ class IndexHandler(tornado.web.RequestHandler): ssh.load_system_host_keys() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) args = self.get_args() - dst_addr = '{}:{}'.format(*args[:2]) - logging.info('Connecting to {}'.format(dst_addr)) + dst_addr = (args[0], args[1]) + logging.info('Connecting to {}:{}'.format(*dst_addr)) try: ssh.connect(*args, timeout=6) except socket.error: - raise ValueError('Unable to connect to {}'.format(dst_addr)) + raise ValueError('Unable to connect to {}:{}'.format(*dst_addr)) except paramiko.BadAuthenticationType: raise ValueError('Authentication failed.') chan = ssh.invoke_shell(term='xterm') @@ -211,47 +229,44 @@ class IndexHandler(tornado.web.RequestHandler): logging.error(traceback.format_exc()) status = str(e) else: + worker.src_addr = self.get_addr() worker_id = worker.id workers[worker_id] = worker self.write(dict(id=worker_id, status=status)) -class WsockHandler(tornado.websocket.WebSocketHandler): +class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): def __init__(self, *args, **kwargs): self.loop = IOLoop.current() self.worker_ref = None super(self.__class__, self).__init__(*args, **kwargs) - def get_addr(self): - ip = self.request.headers.get_list('X-Real-Ip') - port = self.request.headers.get_list('X-Real-Port') - addr = ':'.join(ip + port) - if not addr: - addr = '{}:{}'.format(*self.stream.socket.getpeername()) - return addr - def open(self): self.src_addr = self.get_addr() - logging.info('Connected from {}'.format(self.src_addr)) + logging.info('Connected from {}:{}'.format(*self.src_addr)) worker = workers.pop(self.get_argument('id'), None) if not worker: - self.close(reason='Invalid worker id') + self.close(reason='Invalid worker id.') return + if self.src_addr[0] != worker.src_addr[0]: + self.close(reason='Invalid client addr.') + return + self.set_nodelay(True) worker.set_handler(self) self.worker_ref = weakref.ref(worker) self.loop.add_handler(worker.fd, worker, IOLoop.READ) def on_message(self, message): - logging.debug('"{}" from {}'.format(message, self.src_addr)) + logging.debug('"{}" from {}:{}'.format(message, *self.src_addr)) worker = self.worker_ref() worker.data_to_dst.append(message) worker.on_write() def on_close(self): - logging.info('Disconnected from {}'.format(self.src_addr)) + logging.info('Disconnected from {}:{}'.format(*self.src_addr)) worker = self.worker_ref() if self.worker_ref else None if worker: worker.close()