diff --git a/.gitignore b/.gitignore index e182e7b..df9be38 100644 --- a/.gitignore +++ b/.gitignore @@ -59,3 +59,6 @@ target/ # temporary file *.swp + +# known_hosts file +known_hosts diff --git a/main.py b/main.py index adc2de0..97efac4 100644 --- a/main.py +++ b/main.py @@ -17,22 +17,15 @@ from tornado.util import errno_from_exception define('address', default='127.0.0.1', help='listen address') define('port', default=8888, help='listen port', type=int) define('debug', default=False, help='debug mode', type=bool) +define('policy', default='reject', + help='missing host key polilcy, reject|autoadd|warning') BUF_SIZE = 1024 DELAY = 3 -base_dir = os.path.dirname(__file__) workers = {} -def recycle(worker): - if worker.handler: - return - logging.debug('Recycling worker {}'.format(worker.id)) - workers.pop(worker.id, None) - worker.close() - - class Worker(object): def __init__(self, ssh, chan, dst_addr): self.loop = IOLoop.current() @@ -204,8 +197,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): def ssh_connect(self): ssh = paramiko.SSHClient() - ssh.load_system_host_keys() - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.load_host_keys(self.settings['host_file']) + ssh.set_missing_host_key_policy(self.settings['policy']) args = self.get_args() dst_addr = (args[0], args[1]) logging.info('Connecting to {}:{}'.format(*dst_addr)) @@ -215,6 +208,8 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): raise ValueError('Unable to connect to {}:{}'.format(*dst_addr)) except paramiko.BadAuthenticationType: raise ValueError('Authentication failed.') + except paramiko.BadHostKeyException: + raise ValueError('Bad host key.') chan = ssh.invoke_shell(term='xterm') chan.setblocking(0) worker = Worker(ssh, chan, dst_addr) @@ -278,7 +273,46 @@ class WsockHandler(MixinHandler, tornado.websocket.WebSocketHandler): worker.close() +def recycle(worker): + if worker.handler: + return + logging.debug('Recycling worker {}'.format(worker.id)) + workers.pop(worker.id, None) + worker.close() + + +def get_host_keys(path): + if os.path.exists(path) and os.path.isfile(path): + return paramiko.hostkeys.HostKeys(filename=path) + + +def create_host_file(host_file): + host_keys = get_host_keys(host_file) + if not host_keys: + host_keys = get_host_keys(os.path.expanduser("~/.ssh/known_hosts")) + host_keys.save(host_file) + + +def get_policy_class(policy): + origin_policy = policy + policy = policy.lower() + if not policy.endswith('policy'): + policy += 'policy' + + dic = {k.lower(): v for k, v in vars(paramiko.client).items()} + + try: + cls = dic[policy] + except KeyError: + raise ValueError('Unknown policy {!r}'.format(origin_policy)) + return cls + + def main(): + base_dir = os.path.dirname(__file__) + host_file = os.path.join(base_dir, 'known_hosts') + create_host_file(host_file) + settings = { 'template_path': os.path.join(base_dir, 'templates'), 'static_path': os.path.join(base_dir, 'static'), @@ -292,7 +326,11 @@ def main(): ] parse_command_line() - settings.update(debug=options.debug) + settings.update( + debug=options.debug, + host_file=host_file, + policy=get_policy_class(options.policy)() + ) app = tornado.web.Application(handlers, **settings) app.listen(options.port, options.address) logging.info('Listening on {}:{}'.format(options.address, options.port))