diff --git a/connect.py b/connect.py index 394f3897f..dee57392e 100755 --- a/connect.py +++ b/connect.py @@ -17,12 +17,10 @@ from multiprocessing import Pool from Crypto.Cipher import AES from binascii import b2a_hex, a2b_hex from ConfigParser import ConfigParser - from django.core.exceptions import ObjectDoesNotExist os.environ['DJANGO_SETTINGS_MODULE'] = 'jumpserver.settings' django.setup() - from juser.models import User from jasset.models import Asset from jlog.models import Log @@ -32,49 +30,50 @@ try: import tty except ImportError: print '\033[1;31mOnly postfix supported.\033[0m' + time.sleep(3) sys.exit() -CURRENT_DIR = os.path.abspath('.') +CURRENT_DIR = os.path.dirname(__file__) CONF = ConfigParser() CONF.read(os.path.join(CURRENT_DIR, 'jumpserver.conf')) LOG_DIR = os.path.join(CURRENT_DIR, 'logs') +# Web generate user ssh_key dir. SSH_KEY_DIR = os.path.join(CURRENT_DIR, 'keys') +# User upload the server key to this dir. SERVER_KEY_DIR = os.path.join(SSH_KEY_DIR, 'server') +# The key of decryptor. KEY = CONF.get('web', 'key') +# Login user. LOGIN_NAME = getpass.getuser() #LOGIN_NAME = os.getlogin() +USER_KEY_FILE = os.path.join(SERVER_KEY_DIR, LOGIN_NAME) + +if not os.path.isfile(USER_KEY_FILE): + USER_KEY_FILE = None -def green_print(string): - print '\033[1;32m%s\033[0m' % string +def color_print(msg, color='blue'): + """Print colorful string.""" + color_msg = {'blue': '\033[1;36m%s\033[0m', + 'green': '\033[1;32m%s\033[0m', + 'red': '\033[1;31m%s\033[0m'} + + print color_msg.get(color, 'blue') % msg -def blue_print(string): - print '\033[1;36m%s\033[0m' % string - - -def red_print(string): - print '\033[1;31m%s\033[0m' % string - - -def alert_print(string): - red_print('AlertError: %s' % string) +def color_print_exit(msg, color='red'): + """Print colorful string and exit.""" + color_print(msg, color=color) time.sleep(2) sys.exit() class ServerError(Exception): - def __init__(self, error): - self.error = error - - def __str__(self): - return self.error - - __repr__ = __str__ + pass class PyCrypt(object): - """It's used to encrypt and decrypt password.""" + """This class used to encrypt and decrypt password.""" def __init__(self, key): self.key = key @@ -83,19 +82,21 @@ class PyCrypt(object): def encrypt(self, text): cryptor = AES.new(self.key, self.mode, b'0000000000000000') length = 16 - count = len(text) - if count < length: - add = (length - count) - text += ('\0' * add) - elif count > length: - add = (length - (count % length)) - text += ('\0' * add) + try: + count = len(text) + except TypeError: + raise ServerError('Encrypt password error, TYpe error.') + add = (length - (count % length)) + text += ('\0' * add) ciphertext = cryptor.encrypt(text) return b2a_hex(ciphertext) def decrypt(self, text): cryptor = AES.new(self.key, self.mode, b'0000000000000000') - plain_text = cryptor.decrypt(a2b_hex(text)) + try: + plain_text = cryptor.decrypt(a2b_hex(text)) + except TypeError: + raise ServerError('Decrypt password error, TYpe error.') return plain_text.rstrip('\0') @@ -119,26 +120,28 @@ def set_win_size(sig, data): pass -def posix_shell(chan, username, host): - """ - Use paramiko channel connect server and logging. - """ +def get_object(model, **kwargs): + try: + the_object = model.objects.get(kwargs) + except ObjectDoesNotExist: + raise ServerError('Object get %s failed.' % str(kwargs.values())) + return the_object + + +def log_record(username, host): + """Logging user command and output.""" connect_log_dir = os.path.join(LOG_DIR, 'connect') timestamp_start = int(time.time()) today = time.strftime('%Y%m%d', time.localtime(timestamp_start)) - date_now = time.strftime('%Y%m%d%H%M%S', time.localtime(timestamp_start)) - + time_now = time.strftime('%H%M%S', time.localtime(timestamp_start)) today_connect_log_dir = os.path.join(connect_log_dir, today) - log_filename = '%s_%s_%s.log' % (username, host, date_now) + log_filename = '%s_%s_%s.log' % (username, host, time_now) log_file_path = os.path.join(today_connect_log_dir, log_filename) - - try: - user = User.objects.get(username=username) - asset = Asset.objects.get(ip=host) - except ObjectDoesNotExist: - raise ServerError('user %s or asset %s does not exist.' % (username, host)) - pid = os.getpid() + + user = get_object(User, username=username) + asset = get_object(Asset, ip=host) + if not os.path.isdir(today_connect_log_dir): try: os.makedirs(today_connect_log_dir) @@ -153,7 +156,14 @@ def posix_shell(chan, username, host): log = Log(user=user, asset=asset, log_path=log_file_path, start_time=timestamp_start, pid=pid) log.save() + return log_file, log + +def posix_shell(chan, username, host): + """ + Use paramiko channel connect server interactive. + """ + log_file, log = log_record(username, host) old_tty = termios.tcgetattr(sys.stdin) try: tty.setraw(sys.stdin.fileno()) @@ -194,6 +204,7 @@ def posix_shell(chan, username, host): def get_user_host(username): + """Get the hosts of under the user control.""" hosts_attr = {} try: user = User.objects.get(username=username) @@ -203,71 +214,50 @@ def get_user_host(username): perm_all = user.permission_set.all() for perm in perm_all: hosts_attr[perm.asset.ip] = [perm.asset.id, perm.asset.comment] - hosts = hosts_attr.keys() - hosts.sort() - return hosts_attr, hosts + return hosts_attr def get_connect_item(username, ip): cryptor = PyCrypt(KEY) - try: - asset = Asset.objects.get(ip=ip) - port = asset.port - except ObjectDoesNotExist: - raise ServerError("Host %s does not exist." % ip) + asset = Asset.objects.get(Asset, ip=ip) + port = asset.port if not asset.is_active: raise ServerError('Host %s is not active.' % ip) - try: - user = User.objects.get(username=username) - except ObjectDoesNotExist: - raise ServerError('User %s does not exist.' % username) + user = get_object(User, username=username) if not user.is_active: raise ServerError('User %s is not active.' % username) - if asset.login_type == 'L': - try: - ldap_pwd = cryptor.decrypt(user.ldap_pwd) - except TypeError: - raise ServerError('Decrypt %s ldap password error.' % username) - return 'L', username, ldap_pwd, ip, port - elif asset.login_type == 'S': - try: - ssh_key_pwd = cryptor.decrypt(user.ssh_key_pwd2) - except TypeError: - raise ServerError('Decrypt %s ssh key password error.' % username) - return 'S', username, ssh_key_pwd, ip, port - elif asset.login_type == 'P': - try: - ssh_pwd = cryptor.decrypt(user.ssh_pwd) - except TypeError: - raise ServerError('Decrypt %s ssh password error.' % username) - return 'P', username, ssh_pwd, ip, port + login_type_dict = { + 'L': user.ldap_pwd, + 'S': user.ssh_key_pwd2, + 'P': user.ssh_pwd, + } + + if asset.login_type in login_type_dict: + password = cryptor.decrypt(login_type_dict[asset.login_type]) + + return username, password, ip, port + elif asset.login_type == 'M': perms = asset.permission_set.filter(user=user) - try: + if perms: perm = perms[0] - except IndexError: + else: raise ServerError('Permission %s to %s does not exist.' % (username, ip)) if perm.role == 'SU': username_super = asset.username_super - try: - password_super = cryptor.decrypt(asset.password_super) - except TypeError: - raise ServerError('Decrypt %s map to %s password in %s error.' % (username, username_super, ip)) - return 'M', username_super, password_super, ip, port + password_super = cryptor.decrypt(asset.password_super) + return username_super, password_super, ip, port elif perm.role == 'CU': username_common = asset.username_common - try: - password_common = asset.password_common - except TypeError: - raise ServerError('Decrypt %s map to %s password in %s error.' % (username, username_common, ip)) - return 'CU', username_common, password_common, ip, port + password_common = asset.password_common + return username_common, password_common, ip, port else: raise ServerError('Perm in %s for %s map role is not in ["SU", "CU"].' % (ip, username)) @@ -276,20 +266,18 @@ def get_connect_item(username, ip): def verify_connect(username, part_ip): - ip_matched = [] - hosts_mix, hosts = get_user_host(username) - for ip in hosts: - if part_ip in ip: - ip_matched.append(ip) + hosts_attr = get_user_host(username) + hosts = hosts_attr.keys() + ip_matched = [ip for ip in hosts if part_ip in ip] if len(ip_matched) > 1: for ip in ip_matched: - print '[%s] %s -- %s' % (hosts_mix[ip][0], ip, hosts_mix[ip][1]) + print '[%s] %s -- %s' % (hosts_attr[ip][0], ip, hosts_attr[ip][1]) elif len(ip_matched) < 1: - red_print('No Permission or No host.') + color_print('No Permission or No host.', 'red') else: login_type, username, password, host, port = get_connect_item(username, ip_matched[0]) - connect(username, password, host, port, LOGIN_NAME, login_type=login_type) + connect(username, password, host, port, LOGIN_NAME) def print_prompt(): @@ -303,33 +291,26 @@ def print_prompt(): def print_user_host(username): - hosts_attr, hosts = get_user_host(username) + hosts_attr = get_user_host(username) + hosts = hosts_attr.keys() + hosts.sort() for ip in hosts: print '[%s] %s -- %s' % (hosts_attr[ip][0], ip, hosts_attr[ip][1]) -def connect(username, password, host, port, login_name, login_type='L'): +def connect(username, password, host, port, login_name): """ Connect server. """ ps1 = "PS1='[\u@%s \W]\$ '\n" % host login_msg = "clear;echo -e '\\033[32mLogin %s done. Enjoy it.\\033[0m'\n" % host - user_key_file = os.path.join(SERVER_KEY_DIR, username) - - if os.path.isfile(user_key_file): - key_filename = user_key_file - else: - key_filename = None # Make a ssh connection ssh = paramiko.SSHClient() ssh.load_system_host_keys() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) try: - if login_type == 'L': - ssh.connect(host, port=port, username=username, password=password, key_filename=key_filename, compress=True) - else: - ssh.connect(host, port=port, username=username, password=password, compress=True) + ssh.connect(host, port=port, username=username, password=password, key_filename=USER_KEY_FILE, compress=True) except paramiko.ssh_exception.AuthenticationException, paramiko.ssh_exception.SSHException: raise ServerError('Authentication Error.') except socket.error: @@ -362,25 +343,25 @@ def remote_exec_cmd(ip, port, username, password, cmd): time.sleep(3) ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(ip, port, username, password, timeout=5) + ssh.connect(ip, port, username, password, key_filename=USER_KEY_FILE, timeout=5) stdin, stdout, stderr = ssh.exec_command("bash -l -c '%s'" % cmd) out = stdout.readlines() err = stderr.readlines() - blue_print(ip + ':') + color_print('%s:', 'blue') for i in out: - green_print(" " * 4 + i.strip()) + color_print(" " * 4 + i.strip(), 'green') for j in err: - red_print(" " * 4 + j.strip()) + color_print(" " * 4 + j.strip(), 'red') ssh.close() except Exception as e: - blue_print(ip + ':') - red_print(str(e)) + color_print(ip + ':', 'blue') + color_print(str(e), 'red') def multi_remote_exec_cmd(hosts, username, cmd): pool = Pool(processes=3) for host in hosts: - login_type, username, password, ip, port = get_connect_item(username, host) + username, password, ip, port = get_connect_item(username, host) pool.apply_async(remote_exec_cmd, (ip, port, username, password, cmd)) pool.close() pool.join() @@ -388,8 +369,8 @@ def multi_remote_exec_cmd(hosts, username, cmd): def exec_cmd_servers(username): hosts = [] - green_print("Input the Host IP(s),Separated by Commas, q/Q to Quit.\n \ - You can choose in the following IP(s), Use Linux / Unix glob.") + color_print("Input the Host IP(s),Separated by Commas, q/Q to Quit.\n \ + You can choose in the following IP(s), Use Linux / Unix glob.", 'green') print_user_host(LOGIN_NAME) while True: inputs = raw_input('\033[1;32mip(s)>: \033[0m') @@ -400,11 +381,11 @@ def exec_cmd_servers(username): if fnmatch.fnmatch(host, inputs): hosts.append(host.strip()) if len(hosts) == 0: - red_print("Check again, Not matched any ip!") + color_print("Check again, Not matched any ip!", 'red') continue else: print "You matched ip: %s" % hosts - green_print("Input the Command , The command will be Execute on servers, q/Q to quit.") + color_print("Input the Command , The command will be Execute on servers, q/Q to quit.", 'green') while True: cmd = raw_input('\033[1;32mCmd(s): \033[0m') if cmd in ['q', 'Q']: @@ -440,6 +421,6 @@ if __name__ == '__main__': try: verify_connect(LOGIN_NAME, option) except ServerError, e: - red_print(e) + color_print(e, 'red') except IndexError: pass