@ -1,21 +1,58 @@
import re
import signal
import time
import paramiko
from functools import wraps
from sshtunnel import SSHTunnelForwarder
class OldSSHTransport ( paramiko . transport . Transport ) :
_preferred_pubkeys = (
" ssh-ed25519 " ,
" ecdsa-sha2-nistp256 " ,
" ecdsa-sha2-nistp384 " ,
" ecdsa-sha2-nistp521 " ,
" ssh-rsa " ,
" rsa-sha2-256 " ,
" rsa-sha2-512 " ,
" ssh-dss " ,
)
DEFAULT_RE = ' .* '
SU_PROMPT_LOCALIZATIONS = [
' Password ' ,
' 암호 ' ,
' パスワード ' ,
' Adgangskode ' ,
' Contraseña ' ,
' Contrasenya ' ,
' Hasło ' ,
' Heslo ' ,
' Jelszó ' ,
' Lösenord ' ,
' Mật khẩu ' ,
' Mot de passe ' ,
' Parola ' ,
' Parool ' ,
' Pasahitza ' ,
' Passord ' ,
' Passwort ' ,
' Salasana ' ,
' Sandi ' ,
' Senha ' ,
' Wachtwoord ' ,
' ססמה ' ,
' Лозинка ' ,
' Парола ' ,
' Пароль ' ,
' गुप्तशब्द ' ,
' शब्दकूट ' ,
' సంకేతపదము ' ,
' හස්පදය ' ,
' 密码 ' ,
' 密碼 ' ,
' 口令 ' ,
]
def get_become_prompt_re ( ) :
b_password_string = " | " . join ( ( r ' ( \ w+ \' s )? ' + p ) for p in SU_PROMPT_LOCALIZATIONS )
b_password_string = b_password_string + ' ?(:|: ) ? '
return re . compile ( b_password_string , flags = re . IGNORECASE )
become_prompt_re = get_become_prompt_re ( )
def common_argument_spec ( ) :
@ -26,8 +63,12 @@ def common_argument_spec():
login_password = dict ( type = ' str ' , required = False , no_log = True ) ,
login_secret_type = dict ( type = ' str ' , required = False , default = ' password ' ) ,
login_private_key_path = dict ( type = ' str ' , required = False , no_log = True ) ,
first_conn_delay_time = dict ( type = ' float ' , required = False , default = 0.5 ) ,
gateway_args = dict ( type = ' str ' , required = False , default = ' ' ) ,
recv_timeout = dict ( type = ' int ' , required = False , default = 30 ) ,
delay_time = dict ( type = ' int ' , required = False , default = 2 ) ,
prompt = dict ( type = ' str ' , required = False , default = ' .* ' ) ,
answers = dict ( type = ' str ' , required = False , default = ' .* ' ) ,
commands = dict ( type = ' raw ' , required = False ) ,
become = dict ( type = ' bool ' , default = False , required = False ) ,
become_method = dict ( type = ' str ' , required = False ) ,
@ -40,19 +81,57 @@ def common_argument_spec():
return options
class SSHClient :
TIMEOUT = 20
SLEEP_INTERVAL = 2
COMPLETE_FLAG = ' complete '
def raise_timeout ( name = ' ' ) :
def decorate ( func ) :
@wraps ( func )
def wrapper ( self , * args , * * kwargs ) :
def handler ( signum , frame ) :
raise TimeoutError ( f ' { name } timed out, wait { timeout } s ' )
try :
timeout = getattr ( self , ' timeout ' , 0 )
if timeout > 0 :
signal . signal ( signal . SIGALRM , handler )
signal . alarm ( timeout )
return func ( self , * args , * * kwargs )
except Exception as error :
signal . alarm ( 0 )
raise error
return wrapper
return decorate
class OldSSHTransport ( paramiko . transport . Transport ) :
_preferred_pubkeys = (
" ssh-ed25519 " ,
" ecdsa-sha2-nistp256 " ,
" ecdsa-sha2-nistp384 " ,
" ecdsa-sha2-nistp521 " ,
" ssh-rsa " ,
" rsa-sha2-256 " ,
" rsa-sha2-512 " ,
" ssh-dss " ,
)
class SSHClient :
def __init__ ( self , module ) :
self . module = module
self . channel = None
self . is_connect = False
self . gateway_server = None
self . client = paramiko . SSHClient ( )
self . client . set_missing_host_key_policy ( paramiko . AutoAddPolicy ( ) )
self . connect_params = self . get_connect_params ( )
self . _channel = None
self . buffer_size = 1024
self . connect_params = self . get_connect_params ( )
self . prompt = self . module . params [ ' prompt ' ]
self . timeout = self . module . params [ ' recv_timeout ' ]
@property
def channel ( self ) :
if self . _channel is None :
self . connect ( )
return self . _channel
def get_connect_params ( self ) :
params = {
@ -73,22 +152,7 @@ class SSHClient:
params [ ' transport_factory ' ] = OldSSHTransport
return params
def _get_channel ( self ) :
self . channel = self . client . invoke_shell ( )
# 读取首次登陆终端返回的消息
self . channel . recv ( 2048 )
# 网络设备一般登录有延迟,等终端有返回后再执行命令
delay_time = self . module . params [ ' first_conn_delay_time ' ]
time . sleep ( delay_time )
@staticmethod
def _is_match_user ( user , content ) :
# 正常命令切割后是[命令,用户名,交互前缀]
content_list = content . split ( ) if len ( content . split ( ) ) > = 3 else None
return content_list and user in content_list
def switch_user ( self ) :
self . _get_channel ( )
if not self . module . params [ ' become ' ] :
return
method = self . module . params [ ' become_method ' ]
@ -102,22 +166,73 @@ class SSHClient:
else :
self . module . fail_json ( msg = ' Become method %s not support ' % method )
return
commands = [ f ' { switch_method } { username } ' , password ]
su_output , err_msg = self . execute ( commands )
if err_msg :
return err_msg
i_output , err_msg = self . execute (
[ f ' whoami && echo " { self . COMPLETE_FLAG } " ' ] ,
validate_output = True
__ , e_msg = self . execute (
[ f ' { switch_method } { username } ' , password , ' whoami ' ] ,
[ become_prompt_re , DEFAULT_RE , username ]
)
if e rr _msg:
return err_msg
if e_msg :
self . module . fail_json ( msg = ' Become user %s failed. ' % username )
if self . _is_match_user ( username , i_output ) :
err_msg = ' '
else :
err_msg = su_output
return err_msg
def connect ( self ) :
self . before_runner_start ( )
try :
self . client . connect ( * * self . connect_params )
self . _channel = self . client . invoke_shell ( )
self . _get_match_recv ( )
self . switch_user ( )
except Exception as error :
self . module . fail_json ( msg = str ( error ) )
@staticmethod
def _fit_answers ( commands , answers ) :
if answers is None or not isinstance ( answers , list ) :
answers = [ DEFAULT_RE ] * len ( commands )
elif len ( answers ) < len ( commands ) :
answers + = [ DEFAULT_RE ] * ( len ( commands ) - len ( answers ) )
return answers
@staticmethod
def __match ( re_ , content ) :
re_pattern = re_
if isinstance ( re_ , str ) :
re_pattern = re . compile ( re_ , re . DOTALL | re . IGNORECASE )
elif not isinstance ( re_pattern , re . Pattern ) :
raise ValueError ( f ' { re_ } should be a regular expression ' )
return bool ( re_pattern . search ( content ) )
@raise_timeout ( ' Recv message ' )
def _get_match_recv ( self , answer_reg = DEFAULT_RE ) :
last_output , output = ' ' , ' '
while True :
if self . channel . recv_ready ( ) :
recv = self . channel . recv ( self . buffer_size ) . decode ( )
output + = recv
if output and last_output != output :
fin_reg = self . prompt if answer_reg == DEFAULT_RE else answer_reg
if self . __match ( fin_reg , output ) :
break
last_output = output
time . sleep ( 0.01 )
return output
@raise_timeout ( ' Wait send message ' )
def _check_send ( self ) :
while not self . channel . send_ready ( ) :
time . sleep ( 0.01 )
time . sleep ( self . module . params [ ' delay_time ' ] )
def execute ( self , commands , answers = None ) :
all_output , error_msg = ' ' , ' '
try :
answers = self . _fit_answers ( commands , answers )
for index , command in enumerate ( commands ) :
self . _check_send ( )
self . channel . send ( command + ' \n ' )
all_output + = f ' { self . _get_match_recv ( answers [ index ] ) } \n '
except Exception as e :
error_msg = str ( e )
return all_output , error_msg
def local_gateway_prepare ( self ) :
gateway_args = self . module . params [ ' gateway_args ' ] or ' '
@ -160,48 +275,15 @@ class SSHClient:
def after_runner_end ( self ) :
self . local_gateway_clean ( )
def connect ( self ) :
try :
self . before_runner_start ( )
self . client . connect ( * * self . connect_params )
self . is_connect = True
err_msg = self . switch_user ( )
self . after_runner_end ( )
except Exception as err :
err_msg = str ( err )
return err_msg
def __enter__ ( self ) :
return self
def _get_recv ( self , size = 1024 , encoding = ' utf-8 ' ) :
output = self . channel . recv ( size ) . decode ( encoding )
return output
def execute ( self , commands , validate_output = False ) :
if not self . is_connect :
self . connect ( )
output , error_msg = ' ' , ' '
try :
for command in commands :
self . channel . send ( command + ' \n ' )
if not validate_output :
time . sleep ( self . SLEEP_INTERVAL )
output + = self . _get_recv ( )
continue
start_time = time . time ( )
while self . COMPLETE_FLAG not in output :
if time . time ( ) - start_time > self . TIMEOUT :
error_msg = output
print ( " 切换用户操作超时,跳出循环。 " )
break
time . sleep ( self . SLEEP_INTERVAL )
received_output = self . _get_recv ( ) . replace ( f ' " { self . COMPLETE_FLAG } " ' , ' ' )
output + = received_output
except Exception as e :
error_msg = str ( e )
return output , error_msg
def __del__ ( self ) :
def __exit__ ( self , exc_type , exc_val , exc_tb ) :
try :
self . channel . close ( )
self . client . close ( )
except Exception :
self . after_runner_end ( )
if self . channel :
self . channel . close ( )
if self . client :
self . client . close ( )
except Exception : # noqa
pass