import io import logging import socket import threading import traceback import weakref import paramiko import tornado.web from tornado.ioloop import IOLoop from worker import Worker, recycle_worker, workers try: from concurrent.futures import Future except ImportError: from tornado.concurrent import Future DELAY = 3 class MixinHandler(object): def get_client_addr(self): ip = self.request.headers.get('X-Real-Ip') port = self.request.headers.get('X-Real-Port') addr = None if ip and port: addr = (ip, int(port)) elif ip or port: logging.warn('Wrong nginx configuration.') return addr class IndexHandler(MixinHandler, tornado.web.RequestHandler): def initialize(self, loop, policy, host_keys_settings): self.loop = loop self.policy = policy self.host_keys_settings = host_keys_settings def get_privatekey(self): try: data = self.request.files.get('privatekey')[0]['body'] except TypeError: return return data.decode('utf-8') def get_specific_pkey(self, pkeycls, privatekey, password): logging.info('Trying {}'.format(pkeycls.__name__)) try: pkey = pkeycls.from_private_key(io.StringIO(privatekey), password=password) except paramiko.PasswordRequiredException: raise ValueError('Need password to decrypt the private key.') except paramiko.SSHException: pass else: return pkey def get_pkey(self, privatekey, password): password = password.encode('utf-8') if password else None pkey = self.get_specific_pkey(paramiko.RSAKey, privatekey, password)\ or self.get_specific_pkey(paramiko.DSSKey, privatekey, password)\ or self.get_specific_pkey(paramiko.ECDSAKey, privatekey, password)\ or self.get_specific_pkey(paramiko.Ed25519Key, privatekey, password) if not pkey: raise ValueError('Not a valid private key file or ' 'wrong password for decrypting the private key.') return pkey def get_port(self): value = self.get_value('port') try: port = int(value) except ValueError: port = 0 if 0 < port < 65536: return port raise ValueError('Invalid port {}'.format(value)) def get_value(self, name): value = self.get_argument(name) if not value: raise ValueError('Empty {}'.format(name)) return value def get_args(self): hostname = self.get_value('hostname') port = self.get_port() username = self.get_value('username') password = self.get_argument('password') privatekey = self.get_privatekey() pkey = self.get_pkey(privatekey, password) if privatekey else None args = (hostname, port, username, password, pkey) logging.debug(args) return args def get_client_addr(self): return super(IndexHandler, self).get_client_addr() or self.request.\ connection.stream.socket.getpeername() def ssh_connect(self): ssh = paramiko.SSHClient() ssh._system_host_keys = self.host_keys_settings['system_host_keys'] ssh._host_keys = self.host_keys_settings['host_keys'] ssh._host_keys_filename = self.host_keys_settings['host_keys_filename'] ssh.set_missing_host_key_policy(self.policy) args = self.get_args() dst_addr = (args[0], args[1]) logging.info('Connecting to {}:{}'.format(*dst_addr)) try: ssh.connect(*args, timeout=6) except socket.error: raise ValueError('Unable to connect to {}:{}'.format(*dst_addr)) except paramiko.BadAuthenticationType: raise ValueError('Authentication failed.') except paramiko.BadHostKeyException: raise ValueError('Bad host key.') chan = ssh.invoke_shell(term='xterm') chan.setblocking(0) worker = Worker(self.loop, ssh, chan, dst_addr) worker.src_addr = self.get_client_addr() return worker def ssh_connect_wrapped(self, future): try: worker = self.ssh_connect() except Exception as exc: logging.error(traceback.format_exc()) future.set_exception(exc) else: future.set_result(worker) def get(self): self.render('index.html') @tornado.gen.coroutine def post(self): worker_id = None status = None future = Future() t = threading.Thread(target=self.ssh_connect_wrapped, args=(future,)) t.setDaemon(True) t.start() try: worker = yield future except Exception as exc: status = str(exc) else: worker_id = worker.id workers[worker_id] = worker self.loop.call_later(DELAY, recycle_worker, worker) self.write(dict(id=worker_id, status=status)) class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): def initialize(self, loop): self.loop = loop self.worker_ref = None def get_client_addr(self): return super(WsockHandler, self).get_client_addr() or self.stream.\ socket.getpeername() def open(self): self.src_addr = self.get_client_addr() logging.info('Connected from {}:{}'.format(*self.src_addr)) worker = workers.get(self.get_argument('id')) if worker and worker.src_addr[0] == self.src_addr[0]: workers.pop(worker.id) self.set_nodelay(True) worker.set_handler(self) self.worker_ref = weakref.ref(worker) self.loop.add_handler(worker.fd, worker, IOLoop.READ) else: self.close() def on_message(self, message): logging.debug('{!r} from {}:{}'.format(message, *self.src_addr)) worker = self.worker_ref() worker.data_to_dst.append(message) worker.on_write() def on_close(self): logging.info('Disconnected from {}:{}'.format(*self.src_addr)) worker = self.worker_ref() if self.worker_ref else None if worker: worker.close()