diff --git a/main.py b/main.py index 013752a..315052e 100644 --- a/main.py +++ b/main.py @@ -40,18 +40,29 @@ class AutoAddPolicy(paramiko.client.MissingHostKeyPolicy): """ lock = threading.Lock() + def is_missing_host_keys(self, client, hostname, key): + k = client._host_keys.lookup(hostname) + if k is None: + return True + host_key = k.get(key.get_name(), None) + if host_key is None: + return True + if host_key != key: + raise paramiko.BadHostKeyException(hostname, key, host_key) + def missing_host_key(self, client, hostname, key): with self.lock: - keytype = key.get_name() - logging.info( - 'Adding {} host key for {}'.format(keytype, hostname) - ) - client._host_keys.add(hostname, keytype, key) - - with open(client._host_keys_filename, 'a') as f: - f.write('{} {} {}\n'.format( - hostname, keytype, key.get_base64() - )) + if self.is_missing_host_keys(client, hostname, key): + keytype = key.get_name() + logging.info( + 'Adding {} host key for {}'.format(keytype, hostname) + ) + client._host_keys.add(hostname, keytype, key) + + with open(client._host_keys_filename, 'a') as f: + f.write('{} {} {}\n'.format( + hostname, keytype, key.get_base64() + )) paramiko.client.AutoAddPolicy = AutoAddPolicy