Redefined AutoAddPolicy class for thread-safety

pull/12/head
Sheng 2018-04-09 22:08:45 +08:00
parent 0f8771f077
commit 838d453336
1 changed files with 24 additions and 15 deletions

39
main.py
View File

@ -8,7 +8,6 @@ import uuid
import weakref import weakref
import paramiko import paramiko
import tornado.gen import tornado.gen
import tornado.ioloop
import tornado.web import tornado.web
import tornado.websocket import tornado.websocket
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
@ -27,7 +26,6 @@ define('port', default=8888, help='listen port', type=int)
define('debug', default=False, help='debug mode', type=bool) define('debug', default=False, help='debug mode', type=bool)
define('policy', default='warning', define('policy', default='warning',
help='missing host key policy, reject|autoadd|warning') help='missing host key policy, reject|autoadd|warning')
define('period', default=10, help='seconds for periodic callback', type=int)
BUF_SIZE = 1024 BUF_SIZE = 1024
@ -35,6 +33,28 @@ DELAY = 3
workers = {} workers = {}
class AutoAddPolicy(paramiko.client.MissingHostKeyPolicy):
"""
thread-safe AutoAddPolicy
"""
lock = threading.Lock()
def missing_host_key(self, client, hostname, key):
with self.lock:
keytype = key.get_name()
logging.info(
'Adding {} host key for {}'.format(keytype, hostname)
)
client._host_keys.add(hostname, keytype, key)
with open(client._host_keys_filename, 'a') as f:
f.write('{} {} {}\n'.format(
hostname, keytype, key.get_base64()
))
paramiko.client.AutoAddPolicy = AutoAddPolicy
class Worker(object): class Worker(object):
def __init__(self, loop, ssh, chan, dst_addr): def __init__(self, loop, ssh, chan, dst_addr):
self.loop = loop self.loop = loop
@ -209,6 +229,7 @@ class IndexHandler(MixinHandler, tornado.web.RequestHandler):
ssh = paramiko.SSHClient() ssh = paramiko.SSHClient()
ssh._system_host_keys = self.settings['system_host_keys'] ssh._system_host_keys = self.settings['system_host_keys']
ssh._host_keys = self.settings['host_keys'] ssh._host_keys = self.settings['host_keys']
ssh._host_keys_filename = self.settings['host_keys_filename']
ssh.set_missing_host_key_policy(self.settings['policy']) ssh.set_missing_host_key_policy(self.settings['policy'])
args = self.get_args() args = self.get_args()
@ -314,14 +335,6 @@ def get_host_keys(path):
return paramiko.hostkeys.HostKeys() return paramiko.hostkeys.HostKeys()
def save_host_keys(host_keys, filename):
length = len(host_keys)
if length != host_keys._last_len:
logging.info('Updating {}'.format(filename))
host_keys.save(filename)
host_keys._last_len = length
def get_policy_class(policy): def get_policy_class(policy):
origin_policy = policy origin_policy = policy
policy = policy.lower() policy = policy.lower()
@ -347,11 +360,6 @@ def get_application_settings():
if policy_class is paramiko.client.AutoAddPolicy: if policy_class is paramiko.client.AutoAddPolicy:
host_keys.save(filename) # for permission test host_keys.save(filename) # for permission test
host_keys._last_len = len(host_keys)
tornado.ioloop.PeriodicCallback(
lambda: save_host_keys(host_keys, filename),
options.period * 1000 # milliseconds
).start()
elif policy_class is paramiko.client.RejectPolicy: elif policy_class is paramiko.client.RejectPolicy:
if not host_keys and not system_host_keys: if not host_keys and not system_host_keys:
raise ValueError('Empty known_hosts with reject policy?') raise ValueError('Empty known_hosts with reject policy?')
@ -362,6 +370,7 @@ def get_application_settings():
cookie_secret=uuid.uuid4().hex, cookie_secret=uuid.uuid4().hex,
xsrf_cookies=True, xsrf_cookies=True,
host_keys=host_keys, host_keys=host_keys,
host_keys_filename=filename,
system_host_keys=system_host_keys, system_host_keys=system_host_keys,
policy=policy_class(), policy=policy_class(),
debug=options.debug debug=options.debug