jumpserver/apps/common/utils.py

413 lines
11 KiB
Python
Raw Normal View History

2016-09-03 11:05:50 +00:00
# -*- coding: utf-8 -*-
#
import re
2018-04-02 05:19:31 +00:00
import sys
2017-03-23 16:27:33 +00:00
from collections import OrderedDict
from six import string_types
2016-12-25 05:15:28 +00:00
import base64
2016-11-06 16:39:26 +00:00
import os
2016-09-13 17:08:26 +00:00
from itertools import chain
2016-09-15 03:19:36 +00:00
import logging
2016-10-26 11:10:14 +00:00
import datetime
2016-12-25 05:15:28 +00:00
import time
import hashlib
from email.utils import formatdate
import calendar
import threading
2017-12-13 09:21:08 +00:00
from io import StringIO
2017-12-10 16:29:25 +00:00
import uuid
2018-06-01 08:22:52 +00:00
from functools import wraps
2016-09-03 11:05:50 +00:00
2016-11-06 16:39:26 +00:00
import paramiko
2016-11-09 15:49:10 +00:00
import sshpubkeys
2016-11-01 09:21:16 +00:00
from itsdangerous import TimedJSONWebSignatureSerializer, JSONWebSignatureSerializer, \
2016-10-14 16:49:59 +00:00
BadSignature, SignatureExpired
2016-09-03 11:05:50 +00:00
from django.shortcuts import reverse as dj_reverse
from django.conf import settings
2016-09-10 13:08:10 +00:00
from django.utils import timezone
2016-09-03 11:05:50 +00:00
2016-11-06 16:39:26 +00:00
UUID_PATTERN = re.compile(r'[0-9a-zA-Z\-]{36}')
2016-09-03 11:05:50 +00:00
2016-10-14 16:49:59 +00:00
def reverse(view_name, urlconf=None, args=None, kwargs=None,
current_app=None, external=False):
url = dj_reverse(view_name, urlconf=urlconf, args=args,
kwargs=kwargs, current_app=current_app)
2016-09-03 11:05:50 +00:00
if external:
url = settings.SITE_URL.strip('/') + url
return url
def get_object_or_none(model, **kwargs):
try:
obj = model.objects.get(**kwargs)
2017-03-01 07:30:19 +00:00
except model.DoesNotExist:
return None
2016-09-03 11:05:50 +00:00
return obj
2016-09-07 16:40:59 +00:00
2017-12-24 10:53:07 +00:00
class Singleton(type):
def __init__(cls, *args, **kwargs):
cls.__instance = None
super().__init__(*args, **kwargs)
def __call__(cls, *args, **kwargs):
if cls.__instance is None:
cls.__instance = super().__call__(*args, **kwargs)
return cls.__instance
else:
return cls.__instance
class Signer(metaclass=Singleton):
2017-03-23 16:27:33 +00:00
"""用来加密,解密,和基于时间戳的方式验证token"""
2017-12-24 10:53:07 +00:00
def __init__(self, secret_key=None):
2016-11-01 09:21:16 +00:00
self.secret_key = secret_key
2016-09-07 16:40:59 +00:00
2016-11-01 09:21:16 +00:00
def sign(self, value):
2017-05-24 12:12:50 +00:00
if isinstance(value, bytes):
value = value.decode("utf-8")
2016-11-01 09:21:16 +00:00
s = JSONWebSignatureSerializer(self.secret_key)
return s.dumps(value)
2016-09-10 13:08:10 +00:00
2016-11-01 09:21:16 +00:00
def unsign(self, value):
2018-04-19 08:35:38 +00:00
if value is None:
return value
2016-11-01 09:21:16 +00:00
s = JSONWebSignatureSerializer(self.secret_key)
2016-11-01 11:31:35 +00:00
try:
return s.loads(value)
except BadSignature:
2017-05-22 11:51:54 +00:00
return {}
2016-10-14 16:49:59 +00:00
2016-11-01 09:21:16 +00:00
def sign_t(self, value, expires_in=3600):
s = TimedJSONWebSignatureSerializer(self.secret_key, expires_in=expires_in)
return str(s.dumps(value), encoding="utf8")
2016-10-14 16:49:59 +00:00
2016-11-01 09:21:16 +00:00
def unsign_t(self, value):
s = TimedJSONWebSignatureSerializer(self.secret_key)
2016-11-01 11:31:35 +00:00
try:
return s.loads(value)
except (BadSignature, SignatureExpired):
2017-05-22 11:51:54 +00:00
return {}
2016-10-14 16:49:59 +00:00
2016-09-10 13:08:10 +00:00
def date_expired_default():
try:
2018-01-12 07:43:26 +00:00
years = int(settings.DEFAULT_EXPIRED_YEARS)
2016-09-10 13:08:10 +00:00
except TypeError:
years = 70
2016-10-14 16:49:59 +00:00
return timezone.now() + timezone.timedelta(days=365*years)
2016-09-10 13:08:10 +00:00
2016-09-13 17:08:26 +00:00
def combine_seq(s1, s2, callback=None):
for s in (s1, s2):
if not hasattr(s, '__iter__'):
return []
seq = chain(s1, s2)
if callback:
seq = map(callback, seq)
return seq
2016-09-15 03:19:36 +00:00
def get_logger(name=None):
return logging.getLogger('jumpserver.%s' % name)
2016-09-18 16:07:52 +00:00
2016-10-26 11:10:14 +00:00
def timesince(dt, since='', default="just now"):
"""
Returns string representing "time since" e.g.
3 days, 5 hours.
"""
if since is '':
since = datetime.datetime.utcnow()
if since is None:
return default
diff = since - dt
periods = (
(diff.days / 365, "year", "years"),
(diff.days / 30, "month", "months"),
(diff.days / 7, "week", "weeks"),
(diff.days, "day", "days"),
(diff.seconds / 3600, "hour", "hours"),
(diff.seconds / 60, "minute", "minutes"),
(diff.seconds, "second", "seconds"),
)
for period, singular, plural in periods:
if period:
return "%d %s" % (period, singular if period == 1 else plural)
return default
2016-11-01 09:21:16 +00:00
2017-12-13 09:21:08 +00:00
def ssh_key_string_to_obj(text, password=None):
2016-11-06 16:39:26 +00:00
key = None
try:
2017-12-13 09:21:08 +00:00
key = paramiko.RSAKey.from_private_key(StringIO(text), password=password)
2016-11-06 16:39:26 +00:00
except paramiko.SSHException:
pass
try:
2017-12-13 09:21:08 +00:00
key = paramiko.DSSKey.from_private_key(StringIO(text), password=password)
2016-11-06 16:39:26 +00:00
except paramiko.SSHException:
pass
return key
2017-12-21 03:31:13 +00:00
def ssh_pubkey_gen(private_key=None, username='jumpserver', hostname='localhost', password=None):
2017-05-24 12:12:50 +00:00
if isinstance(private_key, bytes):
private_key = private_key.decode("utf-8")
2016-11-06 16:39:26 +00:00
if isinstance(private_key, string_types):
2017-12-21 03:31:13 +00:00
private_key = ssh_key_string_to_obj(private_key, password=password)
2016-11-06 16:39:26 +00:00
if not isinstance(private_key, (paramiko.RSAKey, paramiko.DSSKey)):
raise IOError('Invalid private key')
public_key = "%(key_type)s %(key_content)s %(username)s@%(hostname)s" % {
'key_type': private_key.get_name(),
'key_content': private_key.get_base64(),
'username': username,
'hostname': hostname,
}
return public_key
def ssh_key_gen(length=2048, type='rsa', password=None, username='jumpserver', hostname=None):
"""Generate user ssh private and public key
Use paramiko RSAKey generate it.
:return private key str and public key str
"""
if hostname is None:
hostname = os.uname()[1]
2017-04-07 11:11:27 +00:00
f = StringIO()
2016-11-06 16:39:26 +00:00
try:
if type == 'rsa':
private_key_obj = paramiko.RSAKey.generate(length)
elif type == 'dsa':
private_key_obj = paramiko.DSSKey.generate(length)
else:
raise IOError('SSH private key must be `rsa` or `dsa`')
private_key_obj.write_private_key(f, password=password)
private_key = f.getvalue()
public_key = ssh_pubkey_gen(private_key_obj, username=username, hostname=hostname)
return private_key, public_key
except IOError:
raise IOError('These is error when generate ssh key.')
2017-12-21 03:31:13 +00:00
def validate_ssh_private_key(text, password=None):
2017-05-24 12:12:50 +00:00
if isinstance(text, bytes):
2017-07-19 23:55:24 +00:00
try:
text = text.decode("utf-8")
except UnicodeDecodeError:
return False
2017-12-21 03:31:13 +00:00
key = ssh_key_string_to_obj(text, password=password)
2016-11-06 16:39:26 +00:00
if key is None:
return False
else:
return True
2016-11-09 15:49:10 +00:00
def validate_ssh_public_key(text):
ssh = sshpubkeys.SSHKey(text)
try:
ssh.parse()
2017-03-31 15:46:00 +00:00
except (sshpubkeys.InvalidKeyException, UnicodeDecodeError):
2016-11-09 15:49:10 +00:00
return False
except NotImplementedError as e:
return False
return True
2016-11-10 08:59:50 +00:00
def setattr_bulk(seq, key, value):
def set_attr(obj):
setattr(obj, key, value)
return obj
return map(set_attr, seq)
2018-04-07 16:16:37 +00:00
def set_or_append_attr_bulk(seq, key, value):
for obj in seq:
ori = getattr(obj, key, None)
if ori:
value += " " + ori
setattr(obj, key, value)
2016-12-25 05:15:28 +00:00
def content_md5(data):
"""计算data的MD5值经过Base64编码并返回str类型。
返回值可以直接作为HTTP Content-Type头部的值
"""
2017-10-31 03:34:20 +00:00
if isinstance(data, str):
data = hashlib.md5(data.encode('utf-8'))
2018-01-25 08:38:40 +00:00
value = base64.b64encode(data.hexdigest().encode('utf-8'))
2017-10-31 03:34:20 +00:00
return value.decode('utf-8')
2016-12-25 05:15:28 +00:00
2018-01-25 08:38:40 +00:00
2016-12-25 05:15:28 +00:00
_STRPTIME_LOCK = threading.Lock()
_GMT_FORMAT = "%a, %d %b %Y %H:%M:%S GMT"
_ISO8601_FORMAT = "%Y-%m-%dT%H:%M:%S.000Z"
def to_unixtime(time_string, format_string):
2017-05-15 15:39:54 +00:00
time_string = time_string.decode("ascii")
2016-12-25 05:15:28 +00:00
with _STRPTIME_LOCK:
return int(calendar.timegm(time.strptime(time_string, format_string)))
def http_date(timeval=None):
"""返回符合HTTP标准的GMT时间字符串用strftime的格式表示就是"%a, %d %b %Y %H:%M:%S GMT"
但不能使用strftime因为strftime的结果是和locale相关的
"""
return formatdate(timeval, usegmt=True)
def http_to_unixtime(time_string):
"""把HTTP Date格式的字符串转换为UNIX时间自1970年1月1日UTC零点的秒数
HTTP Date形如 `Sat, 05 Dec 2015 11:10:29 GMT`
"""
return to_unixtime(time_string, _GMT_FORMAT)
def iso8601_to_unixtime(time_string):
"""把ISO8601时间字符串形如2012-02-24T06:07:48.000Z转换为UNIX时间精确到秒。"""
return to_unixtime(time_string, _ISO8601_FORMAT)
def make_signature(access_key_secret, date=None):
2017-05-15 15:39:54 +00:00
if isinstance(date, bytes):
date = bytes.decode(date)
2016-12-25 05:15:28 +00:00
if isinstance(date, int):
date_gmt = http_date(date)
elif date is None:
date_gmt = http_date(int(time.time()))
else:
date_gmt = date
data = str(access_key_secret) + "\n" + date_gmt
return content_md5(data)
def encrypt_password(password, salt=None):
2017-03-09 06:55:33 +00:00
from passlib.hash import sha512_crypt
if password:
return sha512_crypt.using(rounds=5000).hash(password, salt=salt)
2017-03-09 06:55:33 +00:00
return None
2017-03-15 16:19:47 +00:00
def capacity_convert(size, expect='auto', rate=1000):
"""
2017-03-23 16:27:33 +00:00
:param size: '100MB', '1G'
2017-03-15 16:19:47 +00:00
:param expect: 'K, M, G, T
2017-03-23 16:27:33 +00:00
:param rate: Default 1000, may be 1024
2017-03-15 16:19:47 +00:00
:return:
"""
rate_mapping = (
('K', rate),
('KB', rate),
('M', rate**2),
('MB', rate**2),
('G', rate**3),
('GB', rate**3),
('T', rate**4),
('TB', rate**4),
)
rate_mapping = OrderedDict(rate_mapping)
std_size = 0 # To KB
for unit in rate_mapping:
if size.endswith(unit):
try:
std_size = float(size.strip(unit).strip()) * rate_mapping[unit]
except ValueError:
pass
if expect == 'auto':
for unit, rate_ in rate_mapping.items():
if rate > std_size/rate_ > 1:
expect = unit
break
2018-01-02 06:29:37 +00:00
if expect not in rate_mapping:
expect = 'K'
2017-03-15 16:19:47 +00:00
expect_size = std_size / rate_mapping[expect]
return expect_size, expect
def sum_capacity(cap_list):
total = 0
for cap in cap_list:
size, _ = capacity_convert(cap, expect='K')
total += size
total = '{} K'.format(total)
return capacity_convert(total, expect='auto')
2017-12-10 16:29:25 +00:00
def get_short_uuid_str():
return str(uuid.uuid4()).split('-')[-1]
2018-04-10 12:45:01 +00:00
def is_uuid(seq):
if isinstance(seq, str):
if UUID_PATTERN.match(seq):
return True
else:
return False
else:
2018-04-10 12:45:01 +00:00
for s in seq:
if not is_uuid(s):
return False
return True
2017-12-24 10:53:07 +00:00
def get_signer():
signer = Signer(settings.SECRET_KEY)
return signer
2018-03-23 11:46:46 +00:00
2018-04-02 05:19:31 +00:00
class TeeObj:
origin_stdout = sys.stdout
def __init__(self, file_obj):
self.file_obj = file_obj
def write(self, msg):
self.origin_stdout.write(msg)
self.file_obj.write(msg.replace('*', ''))
def flush(self):
self.origin_stdout.flush()
self.file_obj.flush()
def close(self):
self.file_obj.close()
2018-04-10 12:45:01 +00:00
2018-06-01 08:22:52 +00:00
def with_cache(func):
cache = {}
key = "_{}.{}".format(func.__module__, func.__name__)
@wraps(func)
def wrapper(*args, **kwargs):
cached = cache.get(key)
if cached:
return cached
res = func(*args, **kwargs)
cache[key] = res
return res
return wrapper