jumpserver/apps/libs/ansible/modules_utils/remote_client.py

278 lines
9.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import re
import signal
import time
from functools import wraps
import paramiko
from sshtunnel import SSHTunnelForwarder
DEFAULT_RE = '.*'
SU_PROMPT_LOCALIZATIONS = [
'Password', '암호', 'パスワード', 'Adgangskode', 'Contraseña', 'Contrasenya',
'Hasło', 'Heslo', 'Jelszó', 'Lösenord', 'Mật khẩu', 'Mot de passe',
'Parola', 'Parool', 'Pasahitza', 'Passord', 'Passwort', 'Salasana',
'Sandi', 'Senha', 'Wachtwoord', 'ססמה', 'Лозинка', 'Парола', 'Пароль',
'गुप्तशब्द', 'शब्दकूट', 'సంకేతపదము', 'හස්පදය', '密码', '密碼', '口令',
]
def get_become_prompt_re():
pattern_segments = (r'(\w+\'s )?' + p for p in SU_PROMPT_LOCALIZATIONS)
prompt_pattern = "|".join(pattern_segments) + r' ?(:|) ?'
return re.compile(prompt_pattern, flags=re.IGNORECASE)
become_prompt_re = get_become_prompt_re()
def common_argument_spec():
options = dict(
login_host=dict(type='str', required=False, default='localhost'),
login_port=dict(type='int', required=False, default=22),
login_user=dict(type='str', required=False, default='root'),
login_password=dict(type='str', required=False, no_log=True),
login_secret_type=dict(type='str', required=False, default='password'),
login_private_key_path=dict(type='str', required=False, no_log=True),
gateway_args=dict(type='str', required=False, default=''),
recv_timeout=dict(type='int', required=False, default=30),
delay_time=dict(type='int', required=False, default=2),
prompt=dict(type='str', required=False, default='.*'),
answers=dict(type='str', required=False, default='.*'),
commands=dict(type='raw', required=False),
become=dict(type='bool', default=False, required=False),
become_method=dict(type='str', required=False),
become_user=dict(type='str', required=False),
become_password=dict(type='str', required=False, no_log=True),
become_private_key_path=dict(type='str', required=False, no_log=True),
old_ssh_version=dict(type='bool', default=False, required=False),
)
return options
def raise_timeout(name=''):
def decorate(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
def handler(signum, frame):
raise TimeoutError(f'{name} timed out, wait {timeout}s')
timeout = getattr(self, 'timeout', 0)
try:
if timeout > 0:
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
return func(self, *args, **kwargs)
except Exception as error:
signal.alarm(0)
raise error
return wrapper
return decorate
class OldSSHTransport(paramiko.transport.Transport):
_preferred_pubkeys = (
"ssh-ed25519",
"ecdsa-sha2-nistp256",
"ecdsa-sha2-nistp384",
"ecdsa-sha2-nistp521",
"ssh-rsa",
"rsa-sha2-256",
"rsa-sha2-512",
"ssh-dss",
)
class SSHClient:
def __init__(self, module):
self.module = module
self.gateway_server = None
self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.connect_params = self.get_connect_params()
self._channel = None
self.buffer_size = 1024
self.prompt = self.module.params['prompt']
self.timeout = self.module.params['recv_timeout']
@property
def channel(self):
if self._channel is None:
self.connect()
return self._channel
def get_connect_params(self):
p = self.module.params
params = {
'allow_agent': False,
'look_for_keys': False,
'hostname': p['login_host'],
'port': p['login_port'],
'key_filename': p['login_private_key_path'] or None
}
if p['become']:
params['username'] = p['become_user']
params['password'] = p['become_password']
params['key_filename'] = p['become_private_key_path'] or None
else:
params['username'] = p['login_user']
params['password'] = p['login_password']
params['key_filename'] = p['login_private_key_path'] or None
if p['old_ssh_version']:
params['transport_factory'] = OldSSHTransport
return params
def switch_user(self):
p = self.module.params
if not p['become']:
return
method = p['become_method']
username = p['login_user']
if method == 'sudo':
switch_cmd = 'sudo su -'
pword = p['become_password']
elif method == 'su':
switch_cmd = 'su -'
pword = p['login_password']
else:
self.module.fail_json(msg=f'Become method {method} not supported.')
return
# Expected to see a prompt, type the password, and check the username
output, error = self.execute(
[f'{switch_cmd} {username}', pword, 'whoami'],
[become_prompt_re, DEFAULT_RE, username]
)
if error:
self.module.fail_json(msg=f'Failed to become user {username}. Output: {output}')
def connect(self):
self.before_runner_start()
try:
self.client.connect(**self.connect_params)
self._channel = self.client.invoke_shell()
self._get_match_recv()
self.switch_user()
except Exception as error:
self.module.fail_json(msg=str(error))
@staticmethod
def _fit_answers(commands, answers):
if answers is None or not isinstance(answers, list):
answers = [DEFAULT_RE] * len(commands)
elif len(answers) < len(commands):
answers += [DEFAULT_RE] * (len(commands) - len(answers))
return answers
@staticmethod
def __match(expression, content):
if isinstance(expression, str):
expression = re.compile(expression, re.DOTALL | re.IGNORECASE)
elif not isinstance(expression, re.Pattern):
raise ValueError(f'{expression} should be a regular expression')
return bool(expression.search(content))
@raise_timeout('Recv message')
def _get_match_recv(self, answer_reg=DEFAULT_RE):
buffer_str = ''
prev_str = ''
check_reg = self.prompt if answer_reg == DEFAULT_RE else answer_reg
while True:
if self.channel.recv_ready():
chunk = self.channel.recv(self.buffer_size).decode('utf-8', 'replace')
buffer_str += chunk
if buffer_str and buffer_str != prev_str:
if self.__match(check_reg, buffer_str):
break
prev_str = buffer_str
time.sleep(0.01)
return buffer_str
@raise_timeout('Wait send message')
def _check_send(self):
while not self.channel.send_ready():
time.sleep(0.01)
time.sleep(self.module.params['delay_time'])
def execute(self, commands, answers=None):
combined_output = ''
error_msg = ''
try:
answers = self._fit_answers(commands, answers)
for cmd, ans_regex in zip(commands, answers):
self._check_send()
self.channel.send(cmd + '\n')
combined_output += self._get_match_recv(ans_regex) + '\n'
except Exception as e:
error_msg = str(e)
return combined_output, error_msg
def local_gateway_prepare(self):
gateway_args = self.module.params['gateway_args'] or ''
pattern = (
r"(?:sshpass -p ([^ ]+))?\s*ssh -o Port=(\d+)\s+-o StrictHostKeyChecking=no\s+"
r"([\w@]+)@([\d.]+)\s+-W %h:%p -q(?: -i (.+))?'"
)
match = re.search(pattern, gateway_args)
if not match:
return
password, port, username, remote_addr, key_path = match.groups()
password = password or None
key_path = key_path or None
server = SSHTunnelForwarder(
(remote_addr, int(port)),
ssh_username=username,
ssh_password=password,
ssh_pkey=key_path,
remote_bind_address=(
self.module.params['login_host'],
self.module.params['login_port']
)
)
server.start()
self.connect_params['hostname'] = '127.0.0.1'
self.connect_params['port'] = server.local_bind_port
self.gateway_server = server
def local_gateway_clean(self):
if self.gateway_server:
self.gateway_server.stop()
def before_runner_start(self):
self.local_gateway_prepare()
def after_runner_end(self):
self.local_gateway_clean()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
try:
self.after_runner_end()
if self.channel:
self.channel.close()
if self.client:
self.client.close()
except Exception: # noqa
pass