优化代码

pull/6/head
ibuler 2014-12-31 22:58:37 +08:00
parent 9ba7ad147a
commit 73da287c6c
1 changed files with 100 additions and 119 deletions

View File

@ -17,12 +17,10 @@ from multiprocessing import Pool
from Crypto.Cipher import AES
from binascii import b2a_hex, a2b_hex
from ConfigParser import ConfigParser
from django.core.exceptions import ObjectDoesNotExist
os.environ['DJANGO_SETTINGS_MODULE'] = 'jumpserver.settings'
django.setup()
from juser.models import User
from jasset.models import Asset
from jlog.models import Log
@ -32,49 +30,50 @@ try:
import tty
except ImportError:
print '\033[1;31mOnly postfix supported.\033[0m'
time.sleep(3)
sys.exit()
CURRENT_DIR = os.path.abspath('.')
CURRENT_DIR = os.path.dirname(__file__)
CONF = ConfigParser()
CONF.read(os.path.join(CURRENT_DIR, 'jumpserver.conf'))
LOG_DIR = os.path.join(CURRENT_DIR, 'logs')
# Web generate user ssh_key dir.
SSH_KEY_DIR = os.path.join(CURRENT_DIR, 'keys')
# User upload the server key to this dir.
SERVER_KEY_DIR = os.path.join(SSH_KEY_DIR, 'server')
# The key of decryptor.
KEY = CONF.get('web', 'key')
# Login user.
LOGIN_NAME = getpass.getuser()
#LOGIN_NAME = os.getlogin()
USER_KEY_FILE = os.path.join(SERVER_KEY_DIR, LOGIN_NAME)
if not os.path.isfile(USER_KEY_FILE):
USER_KEY_FILE = None
def green_print(string):
print '\033[1;32m%s\033[0m' % string
def color_print(msg, color='blue'):
"""Print colorful string."""
color_msg = {'blue': '\033[1;36m%s\033[0m',
'green': '\033[1;32m%s\033[0m',
'red': '\033[1;31m%s\033[0m'}
print color_msg.get(color, 'blue') % msg
def blue_print(string):
print '\033[1;36m%s\033[0m' % string
def red_print(string):
print '\033[1;31m%s\033[0m' % string
def alert_print(string):
red_print('AlertError: %s' % string)
def color_print_exit(msg, color='red'):
"""Print colorful string and exit."""
color_print(msg, color=color)
time.sleep(2)
sys.exit()
class ServerError(Exception):
def __init__(self, error):
self.error = error
def __str__(self):
return self.error
__repr__ = __str__
pass
class PyCrypt(object):
"""It's used to encrypt and decrypt password."""
"""This class used to encrypt and decrypt password."""
def __init__(self, key):
self.key = key
@ -83,19 +82,21 @@ class PyCrypt(object):
def encrypt(self, text):
cryptor = AES.new(self.key, self.mode, b'0000000000000000')
length = 16
count = len(text)
if count < length:
add = (length - count)
text += ('\0' * add)
elif count > length:
add = (length - (count % length))
text += ('\0' * add)
try:
count = len(text)
except TypeError:
raise ServerError('Encrypt password error, TYpe error.')
add = (length - (count % length))
text += ('\0' * add)
ciphertext = cryptor.encrypt(text)
return b2a_hex(ciphertext)
def decrypt(self, text):
cryptor = AES.new(self.key, self.mode, b'0000000000000000')
plain_text = cryptor.decrypt(a2b_hex(text))
try:
plain_text = cryptor.decrypt(a2b_hex(text))
except TypeError:
raise ServerError('Decrypt password error, TYpe error.')
return plain_text.rstrip('\0')
@ -119,26 +120,28 @@ def set_win_size(sig, data):
pass
def posix_shell(chan, username, host):
"""
Use paramiko channel connect server and logging.
"""
def get_object(model, **kwargs):
try:
the_object = model.objects.get(kwargs)
except ObjectDoesNotExist:
raise ServerError('Object get %s failed.' % str(kwargs.values()))
return the_object
def log_record(username, host):
"""Logging user command and output."""
connect_log_dir = os.path.join(LOG_DIR, 'connect')
timestamp_start = int(time.time())
today = time.strftime('%Y%m%d', time.localtime(timestamp_start))
date_now = time.strftime('%Y%m%d%H%M%S', time.localtime(timestamp_start))
time_now = time.strftime('%H%M%S', time.localtime(timestamp_start))
today_connect_log_dir = os.path.join(connect_log_dir, today)
log_filename = '%s_%s_%s.log' % (username, host, date_now)
log_filename = '%s_%s_%s.log' % (username, host, time_now)
log_file_path = os.path.join(today_connect_log_dir, log_filename)
try:
user = User.objects.get(username=username)
asset = Asset.objects.get(ip=host)
except ObjectDoesNotExist:
raise ServerError('user %s or asset %s does not exist.' % (username, host))
pid = os.getpid()
user = get_object(User, username=username)
asset = get_object(Asset, ip=host)
if not os.path.isdir(today_connect_log_dir):
try:
os.makedirs(today_connect_log_dir)
@ -153,7 +156,14 @@ def posix_shell(chan, username, host):
log = Log(user=user, asset=asset, log_path=log_file_path, start_time=timestamp_start, pid=pid)
log.save()
return log_file, log
def posix_shell(chan, username, host):
"""
Use paramiko channel connect server interactive.
"""
log_file, log = log_record(username, host)
old_tty = termios.tcgetattr(sys.stdin)
try:
tty.setraw(sys.stdin.fileno())
@ -194,6 +204,7 @@ def posix_shell(chan, username, host):
def get_user_host(username):
"""Get the hosts of under the user control."""
hosts_attr = {}
try:
user = User.objects.get(username=username)
@ -203,71 +214,50 @@ def get_user_host(username):
perm_all = user.permission_set.all()
for perm in perm_all:
hosts_attr[perm.asset.ip] = [perm.asset.id, perm.asset.comment]
hosts = hosts_attr.keys()
hosts.sort()
return hosts_attr, hosts
return hosts_attr
def get_connect_item(username, ip):
cryptor = PyCrypt(KEY)
try:
asset = Asset.objects.get(ip=ip)
port = asset.port
except ObjectDoesNotExist:
raise ServerError("Host %s does not exist." % ip)
asset = Asset.objects.get(Asset, ip=ip)
port = asset.port
if not asset.is_active:
raise ServerError('Host %s is not active.' % ip)
try:
user = User.objects.get(username=username)
except ObjectDoesNotExist:
raise ServerError('User %s does not exist.' % username)
user = get_object(User, username=username)
if not user.is_active:
raise ServerError('User %s is not active.' % username)
if asset.login_type == 'L':
try:
ldap_pwd = cryptor.decrypt(user.ldap_pwd)
except TypeError:
raise ServerError('Decrypt %s ldap password error.' % username)
return 'L', username, ldap_pwd, ip, port
elif asset.login_type == 'S':
try:
ssh_key_pwd = cryptor.decrypt(user.ssh_key_pwd2)
except TypeError:
raise ServerError('Decrypt %s ssh key password error.' % username)
return 'S', username, ssh_key_pwd, ip, port
elif asset.login_type == 'P':
try:
ssh_pwd = cryptor.decrypt(user.ssh_pwd)
except TypeError:
raise ServerError('Decrypt %s ssh password error.' % username)
return 'P', username, ssh_pwd, ip, port
login_type_dict = {
'L': user.ldap_pwd,
'S': user.ssh_key_pwd2,
'P': user.ssh_pwd,
}
if asset.login_type in login_type_dict:
password = cryptor.decrypt(login_type_dict[asset.login_type])
return username, password, ip, port
elif asset.login_type == 'M':
perms = asset.permission_set.filter(user=user)
try:
if perms:
perm = perms[0]
except IndexError:
else:
raise ServerError('Permission %s to %s does not exist.' % (username, ip))
if perm.role == 'SU':
username_super = asset.username_super
try:
password_super = cryptor.decrypt(asset.password_super)
except TypeError:
raise ServerError('Decrypt %s map to %s password in %s error.' % (username, username_super, ip))
return 'M', username_super, password_super, ip, port
password_super = cryptor.decrypt(asset.password_super)
return username_super, password_super, ip, port
elif perm.role == 'CU':
username_common = asset.username_common
try:
password_common = asset.password_common
except TypeError:
raise ServerError('Decrypt %s map to %s password in %s error.' % (username, username_common, ip))
return 'CU', username_common, password_common, ip, port
password_common = asset.password_common
return username_common, password_common, ip, port
else:
raise ServerError('Perm in %s for %s map role is not in ["SU", "CU"].' % (ip, username))
@ -276,20 +266,18 @@ def get_connect_item(username, ip):
def verify_connect(username, part_ip):
ip_matched = []
hosts_mix, hosts = get_user_host(username)
for ip in hosts:
if part_ip in ip:
ip_matched.append(ip)
hosts_attr = get_user_host(username)
hosts = hosts_attr.keys()
ip_matched = [ip for ip in hosts if part_ip in ip]
if len(ip_matched) > 1:
for ip in ip_matched:
print '[%s] %s -- %s' % (hosts_mix[ip][0], ip, hosts_mix[ip][1])
print '[%s] %s -- %s' % (hosts_attr[ip][0], ip, hosts_attr[ip][1])
elif len(ip_matched) < 1:
red_print('No Permission or No host.')
color_print('No Permission or No host.', 'red')
else:
login_type, username, password, host, port = get_connect_item(username, ip_matched[0])
connect(username, password, host, port, LOGIN_NAME, login_type=login_type)
connect(username, password, host, port, LOGIN_NAME)
def print_prompt():
@ -303,33 +291,26 @@ def print_prompt():
def print_user_host(username):
hosts_attr, hosts = get_user_host(username)
hosts_attr = get_user_host(username)
hosts = hosts_attr.keys()
hosts.sort()
for ip in hosts:
print '[%s] %s -- %s' % (hosts_attr[ip][0], ip, hosts_attr[ip][1])
def connect(username, password, host, port, login_name, login_type='L'):
def connect(username, password, host, port, login_name):
"""
Connect server.
"""
ps1 = "PS1='[\u@%s \W]\$ '\n" % host
login_msg = "clear;echo -e '\\033[32mLogin %s done. Enjoy it.\\033[0m'\n" % host
user_key_file = os.path.join(SERVER_KEY_DIR, username)
if os.path.isfile(user_key_file):
key_filename = user_key_file
else:
key_filename = None
# Make a ssh connection
ssh = paramiko.SSHClient()
ssh.load_system_host_keys()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
if login_type == 'L':
ssh.connect(host, port=port, username=username, password=password, key_filename=key_filename, compress=True)
else:
ssh.connect(host, port=port, username=username, password=password, compress=True)
ssh.connect(host, port=port, username=username, password=password, key_filename=USER_KEY_FILE, compress=True)
except paramiko.ssh_exception.AuthenticationException, paramiko.ssh_exception.SSHException:
raise ServerError('Authentication Error.')
except socket.error:
@ -362,25 +343,25 @@ def remote_exec_cmd(ip, port, username, password, cmd):
time.sleep(3)
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(ip, port, username, password, timeout=5)
ssh.connect(ip, port, username, password, key_filename=USER_KEY_FILE, timeout=5)
stdin, stdout, stderr = ssh.exec_command("bash -l -c '%s'" % cmd)
out = stdout.readlines()
err = stderr.readlines()
blue_print(ip + ':')
color_print('%s:', 'blue')
for i in out:
green_print(" " * 4 + i.strip())
color_print(" " * 4 + i.strip(), 'green')
for j in err:
red_print(" " * 4 + j.strip())
color_print(" " * 4 + j.strip(), 'red')
ssh.close()
except Exception as e:
blue_print(ip + ':')
red_print(str(e))
color_print(ip + ':', 'blue')
color_print(str(e), 'red')
def multi_remote_exec_cmd(hosts, username, cmd):
pool = Pool(processes=3)
for host in hosts:
login_type, username, password, ip, port = get_connect_item(username, host)
username, password, ip, port = get_connect_item(username, host)
pool.apply_async(remote_exec_cmd, (ip, port, username, password, cmd))
pool.close()
pool.join()
@ -388,8 +369,8 @@ def multi_remote_exec_cmd(hosts, username, cmd):
def exec_cmd_servers(username):
hosts = []
green_print("Input the Host IP(s),Separated by Commas, q/Q to Quit.\n \
You can choose in the following IP(s), Use Linux / Unix glob.")
color_print("Input the Host IP(s),Separated by Commas, q/Q to Quit.\n \
You can choose in the following IP(s), Use Linux / Unix glob.", 'green')
print_user_host(LOGIN_NAME)
while True:
inputs = raw_input('\033[1;32mip(s)>: \033[0m')
@ -400,11 +381,11 @@ def exec_cmd_servers(username):
if fnmatch.fnmatch(host, inputs):
hosts.append(host.strip())
if len(hosts) == 0:
red_print("Check again, Not matched any ip!")
color_print("Check again, Not matched any ip!", 'red')
continue
else:
print "You matched ip: %s" % hosts
green_print("Input the Command , The command will be Execute on servers, q/Q to quit.")
color_print("Input the Command , The command will be Execute on servers, q/Q to quit.", 'green')
while True:
cmd = raw_input('\033[1;32mCmd(s): \033[0m')
if cmd in ['q', 'Q']:
@ -440,6 +421,6 @@ if __name__ == '__main__':
try:
verify_connect(LOGIN_NAME, option)
except ServerError, e:
red_print(e)
color_print(e, 'red')
except IndexError:
pass