From 221bd81583a4078480687acba6d532e038306f40 Mon Sep 17 00:00:00 2001 From: Sheng Date: Thu, 15 Mar 2018 01:09:51 +0800 Subject: [PATCH] Refactored code --- main.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index e69c425..f0f5564 100644 --- a/main.py +++ b/main.py @@ -197,7 +197,10 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler): def ssh_connect(self): ssh = paramiko.SSHClient() - ssh.load_host_keys(self.settings['host_file']) + if isinstance(self.settings['policy'], paramiko.client.AutoAddPolicy): + ssh.load_host_keys(self.settings['host_file']) + else: + ssh._host_keys = self.settings.get('host_keys') ssh.set_missing_host_key_policy(self.settings['policy']) args = self.get_args() dst_addr = (args[0], args[1]) @@ -284,13 +287,15 @@ def recycle(worker): def get_host_keys(path): if os.path.exists(path) and os.path.isfile(path): return paramiko.hostkeys.HostKeys(filename=path) + return paramiko.hostkeys.HostKeys() def create_host_file(host_file): host_keys = get_host_keys(host_file) if not host_keys: - host_keys = get_host_keys(os.path.expanduser("~/.ssh/known_hosts")) + host_keys = get_host_keys(os.path.expanduser('~/.ssh/known_hosts')) host_keys.save(host_file) + return host_keys def get_policy_class(policy): @@ -311,7 +316,7 @@ def get_policy_class(policy): def main(): base_dir = os.path.dirname(__file__) host_file = os.path.join(base_dir, 'known_hosts') - create_host_file(host_file) + host_keys = create_host_file(host_file) settings = { 'template_path': os.path.join(base_dir, 'templates'), @@ -329,6 +334,7 @@ def main(): settings.update( debug=options.debug, host_file=host_file, + host_keys=host_keys, policy=get_policy_class(options.policy)() ) app = tornado.web.Application(handlers, **settings)