mirror of https://github.com/jumpserver/jumpserver
				
				
				
			
		
			
				
	
	
		
			436 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			436 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
# -*- coding: utf-8 -*-
 | 
						|
#
 | 
						|
import datetime
 | 
						|
import ipaddress
 | 
						|
import logging
 | 
						|
import os
 | 
						|
import platform
 | 
						|
import re
 | 
						|
import socket
 | 
						|
import time
 | 
						|
import uuid
 | 
						|
from collections import OrderedDict
 | 
						|
from functools import wraps
 | 
						|
from itertools import chain
 | 
						|
 | 
						|
import html2text
 | 
						|
import psutil
 | 
						|
from django.conf import settings
 | 
						|
from django.templatetags.static import static
 | 
						|
 | 
						|
from common.permissions import ServiceAccountSignaturePermission
 | 
						|
 | 
						|
UUID_PATTERN = re.compile(r'\w{8}(-\w{4}){3}-\w{12}')
 | 
						|
ipip_db = None
 | 
						|
 | 
						|
 | 
						|
def combine_seq(s1, s2, callback=None):
 | 
						|
    for s in (s1, s2):
 | 
						|
        if not hasattr(s, '__iter__'):
 | 
						|
            return []
 | 
						|
 | 
						|
    seq = chain(s1, s2)
 | 
						|
    if callback:
 | 
						|
        seq = map(callback, seq)
 | 
						|
    return seq
 | 
						|
 | 
						|
 | 
						|
def get_logger(name=''):
 | 
						|
    if '/' in name:
 | 
						|
        name = os.path.basename(name).replace('.py', '')
 | 
						|
    return logging.getLogger('jumpserver.%s' % name)
 | 
						|
 | 
						|
 | 
						|
def get_syslogger(name=''):
 | 
						|
    return logging.getLogger('syslog.%s' % name)
 | 
						|
 | 
						|
 | 
						|
def timesince(dt, since='', default="just now"):
 | 
						|
    """
 | 
						|
    Returns string representing "time since" e.g.
 | 
						|
    3 days, 5 hours.
 | 
						|
    """
 | 
						|
 | 
						|
    if not since:
 | 
						|
        since = datetime.datetime.utcnow()
 | 
						|
 | 
						|
    if since is None:
 | 
						|
        return default
 | 
						|
 | 
						|
    diff = since - dt
 | 
						|
 | 
						|
    periods = (
 | 
						|
        (diff.days / 365, "year", "years"),
 | 
						|
        (diff.days / 30, "month", "months"),
 | 
						|
        (diff.days / 7, "week", "weeks"),
 | 
						|
        (diff.days, "day", "days"),
 | 
						|
        (diff.seconds / 3600, "hour", "hours"),
 | 
						|
        (diff.seconds / 60, "minute", "minutes"),
 | 
						|
        (diff.seconds, "second", "seconds"),
 | 
						|
    )
 | 
						|
 | 
						|
    for period, singular, plural in periods:
 | 
						|
        if period:
 | 
						|
            return "%d %s" % (period, singular if period == 1 else plural)
 | 
						|
    return default
 | 
						|
 | 
						|
 | 
						|
def setattr_bulk(seq, key, value):
 | 
						|
    def set_attr(obj):
 | 
						|
        setattr(obj, key, value)
 | 
						|
        return obj
 | 
						|
 | 
						|
    return map(set_attr, seq)
 | 
						|
 | 
						|
 | 
						|
def set_or_append_attr_bulk(seq, key, value):
 | 
						|
    for obj in seq:
 | 
						|
        ori = getattr(obj, key, None)
 | 
						|
        if ori:
 | 
						|
            value += " " + ori
 | 
						|
        setattr(obj, key, value)
 | 
						|
 | 
						|
 | 
						|
def capacity_convert(size, expect='auto', rate=1000):
 | 
						|
    """
 | 
						|
    :param size: '100MB', '1G'
 | 
						|
    :param expect: 'K, M, G, T
 | 
						|
    :param rate: Default 1000, may be 1024
 | 
						|
    :return:
 | 
						|
    """
 | 
						|
    rate_mapping = (
 | 
						|
        ('K', rate),
 | 
						|
        ('KB', rate),
 | 
						|
        ('M', rate ** 2),
 | 
						|
        ('MB', rate ** 2),
 | 
						|
        ('G', rate ** 3),
 | 
						|
        ('GB', rate ** 3),
 | 
						|
        ('T', rate ** 4),
 | 
						|
        ('TB', rate ** 4),
 | 
						|
    )
 | 
						|
 | 
						|
    rate_mapping = OrderedDict(rate_mapping)
 | 
						|
 | 
						|
    std_size = 0  # To KB
 | 
						|
    for unit in rate_mapping:
 | 
						|
        if size.endswith(unit):
 | 
						|
            try:
 | 
						|
                std_size = float(size.strip(unit).strip()) * rate_mapping[unit]
 | 
						|
            except ValueError:
 | 
						|
                pass
 | 
						|
 | 
						|
    if expect == 'auto':
 | 
						|
        for unit, rate_ in rate_mapping.items():
 | 
						|
            if rate > std_size / rate_ >= 1 or unit == "T":
 | 
						|
                expect = unit
 | 
						|
                break
 | 
						|
 | 
						|
    if expect not in rate_mapping:
 | 
						|
        expect = 'K'
 | 
						|
 | 
						|
    expect_size = std_size / rate_mapping[expect]
 | 
						|
    return expect_size, expect
 | 
						|
 | 
						|
 | 
						|
def sum_capacity(cap_list):
 | 
						|
    total = 0
 | 
						|
    for cap in cap_list:
 | 
						|
        size, _ = capacity_convert(cap, expect='K')
 | 
						|
        total += size
 | 
						|
    total = '{} K'.format(total)
 | 
						|
    return capacity_convert(total, expect='auto')
 | 
						|
 | 
						|
 | 
						|
def get_short_uuid_str():
 | 
						|
    return str(uuid.uuid4()).split('-')[-1]
 | 
						|
 | 
						|
 | 
						|
def is_uuid(seq):
 | 
						|
    if isinstance(seq, uuid.UUID):
 | 
						|
        return True
 | 
						|
    elif isinstance(seq, str) and UUID_PATTERN.match(seq):
 | 
						|
        return True
 | 
						|
    elif isinstance(seq, (list, tuple)):
 | 
						|
        return all([is_uuid(x) for x in seq])
 | 
						|
    return False
 | 
						|
 | 
						|
 | 
						|
def get_request_ip(request):
 | 
						|
    x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR', '').split(',')
 | 
						|
    if x_forwarded_for and x_forwarded_for[0]:
 | 
						|
        login_ip = x_forwarded_for[0].split(":")[0]
 | 
						|
        return login_ip
 | 
						|
 | 
						|
    login_ip = request.META.get('REMOTE_ADDR', '')
 | 
						|
    return login_ip
 | 
						|
 | 
						|
 | 
						|
def get_request_ip_or_data(request):
 | 
						|
    ip = ''
 | 
						|
 | 
						|
    if hasattr(request, 'data') and isinstance(request.data, dict) and request.data.get('remote_addr', ''):
 | 
						|
        permission = ServiceAccountSignaturePermission()
 | 
						|
        if permission.has_permission(request, None):
 | 
						|
            ip = request.data.get('remote_addr', '')
 | 
						|
    ip = ip or get_request_ip(request)
 | 
						|
    return ip
 | 
						|
 | 
						|
 | 
						|
def get_request_user_agent(request):
 | 
						|
    user_agent = request.META.get('HTTP_USER_AGENT', '')
 | 
						|
    return user_agent
 | 
						|
 | 
						|
 | 
						|
def validate_ip(ip):
 | 
						|
    try:
 | 
						|
        ipaddress.ip_address(ip)
 | 
						|
        return True
 | 
						|
    except ValueError:
 | 
						|
        pass
 | 
						|
    return False
 | 
						|
 | 
						|
 | 
						|
def with_cache(func):
 | 
						|
    cache = {}
 | 
						|
    key = "_{}.{}".format(func.__module__, func.__name__)
 | 
						|
 | 
						|
    @wraps(func)
 | 
						|
    def wrapper(*args, **kwargs):
 | 
						|
        cached = cache.get(key)
 | 
						|
        if cached:
 | 
						|
            return cached
 | 
						|
        res = func(*args, **kwargs)
 | 
						|
        cache[key] = res
 | 
						|
        return res
 | 
						|
 | 
						|
    return wrapper
 | 
						|
 | 
						|
 | 
						|
logger = get_logger(__name__)
 | 
						|
 | 
						|
 | 
						|
def timeit(func):
 | 
						|
    def wrapper(*args, **kwargs):
 | 
						|
        name = func
 | 
						|
        for attr in ('__qualname__', '__name__'):
 | 
						|
            if hasattr(func, attr):
 | 
						|
                name = getattr(func, attr)
 | 
						|
                break
 | 
						|
 | 
						|
        logger.debug("Start call: {}".format(name))
 | 
						|
        now = time.time()
 | 
						|
        result = func(*args, **kwargs)
 | 
						|
        using = (time.time() - now) * 1000
 | 
						|
        msg = "Ends  call: {}, using: {:.1f}ms".format(name, using)
 | 
						|
        logger.debug(msg)
 | 
						|
        return result
 | 
						|
 | 
						|
    return wrapper
 | 
						|
 | 
						|
 | 
						|
def group_obj_by_count(objs, count=50):
 | 
						|
    objs_grouped = [
 | 
						|
        objs[i:i + count] for i in range(0, len(objs), count)
 | 
						|
    ]
 | 
						|
    return objs_grouped
 | 
						|
 | 
						|
 | 
						|
def dict_get_any(d, keys):
 | 
						|
    for key in keys:
 | 
						|
        value = d.get(key)
 | 
						|
        if value:
 | 
						|
            return value
 | 
						|
    return None
 | 
						|
 | 
						|
 | 
						|
class lazyproperty:
 | 
						|
    def __init__(self, func):
 | 
						|
        self.func = func
 | 
						|
 | 
						|
    def __get__(self, instance, cls):
 | 
						|
        if instance is None:
 | 
						|
            return self
 | 
						|
        else:
 | 
						|
            value = self.func(instance)
 | 
						|
            setattr(instance, self.func.__name__, value)
 | 
						|
            return value
 | 
						|
 | 
						|
 | 
						|
def get_disk_usage(path):
 | 
						|
    return psutil.disk_usage(path=path).percent
 | 
						|
 | 
						|
 | 
						|
def get_cpu_load():
 | 
						|
    cpu_load_1, cpu_load_5, cpu_load_15 = psutil.getloadavg()
 | 
						|
    cpu_count = psutil.cpu_count()
 | 
						|
    single_cpu_load_1 = cpu_load_1 / cpu_count
 | 
						|
    single_cpu_load_1 = '%.2f' % single_cpu_load_1
 | 
						|
    return float(single_cpu_load_1)
 | 
						|
 | 
						|
 | 
						|
def get_docker_mem_usage_if_limit():
 | 
						|
    try:
 | 
						|
        with open('/sys/fs/cgroup/memory/memory.limit_in_bytes') as f:
 | 
						|
            limit_in_bytes = int(f.readline())
 | 
						|
            total = psutil.virtual_memory().total
 | 
						|
            if limit_in_bytes >= total:
 | 
						|
                raise ValueError('Not limit')
 | 
						|
 | 
						|
        with open('/sys/fs/cgroup/memory/memory.usage_in_bytes') as f:
 | 
						|
            usage_in_bytes = int(f.readline())
 | 
						|
 | 
						|
        with open('/sys/fs/cgroup/memory/memory.stat') as f:
 | 
						|
            inactive_file = 0
 | 
						|
            for line in f:
 | 
						|
                if line.startswith('total_inactive_file'):
 | 
						|
                    name, inactive_file = line.split()
 | 
						|
                    break
 | 
						|
 | 
						|
                if line.startswith('inactive_file'):
 | 
						|
                    name, inactive_file = line.split()
 | 
						|
                    continue
 | 
						|
 | 
						|
            inactive_file = int(inactive_file)
 | 
						|
        return ((usage_in_bytes - inactive_file) / limit_in_bytes) * 100
 | 
						|
 | 
						|
    except Exception:
 | 
						|
        return None
 | 
						|
 | 
						|
 | 
						|
def get_memory_usage():
 | 
						|
    usage = get_docker_mem_usage_if_limit()
 | 
						|
    if usage is not None:
 | 
						|
        return usage
 | 
						|
    return psutil.virtual_memory().percent
 | 
						|
 | 
						|
 | 
						|
class Time:
 | 
						|
    def __init__(self):
 | 
						|
        self._timestamps = []
 | 
						|
        self._msgs = []
 | 
						|
 | 
						|
    def begin(self):
 | 
						|
        self._timestamps.append(time.time())
 | 
						|
 | 
						|
    def time(self, msg):
 | 
						|
        self._timestamps.append(time.time())
 | 
						|
        self._msgs.append(msg)
 | 
						|
 | 
						|
    def print(self):
 | 
						|
        last, *timestamps = self._timestamps
 | 
						|
        for timestamp, msg in zip(timestamps, self._msgs):
 | 
						|
            logger.debug(f'TIME_IT: {msg} {timestamp - last}')
 | 
						|
            last = timestamp
 | 
						|
 | 
						|
 | 
						|
def bulk_get(d, keys, default=None):
 | 
						|
    values = []
 | 
						|
    for key in keys:
 | 
						|
        values.append(d.get(key, default))
 | 
						|
    return values
 | 
						|
 | 
						|
 | 
						|
def unique(objects, key=None):
 | 
						|
    seen = OrderedDict()
 | 
						|
 | 
						|
    if key is None:
 | 
						|
        key = lambda item: item
 | 
						|
 | 
						|
    for obj in objects:
 | 
						|
        v = key(obj)
 | 
						|
        if v not in seen:
 | 
						|
            seen[v] = obj
 | 
						|
    return list(seen.values())
 | 
						|
 | 
						|
 | 
						|
def get_file_by_arch(dir, filename):
 | 
						|
    platform_name = platform.system()
 | 
						|
    arch = platform.machine()
 | 
						|
 | 
						|
    file_path = os.path.join(
 | 
						|
        settings.BASE_DIR, dir, platform_name, arch, filename
 | 
						|
    )
 | 
						|
    return file_path
 | 
						|
 | 
						|
 | 
						|
def pretty_string(data, max_length=128, ellipsis_str='...'):
 | 
						|
    """
 | 
						|
    params:
 | 
						|
       data: abcdefgh
 | 
						|
       max_length: 7
 | 
						|
       ellipsis_str: ...
 | 
						|
   return:
 | 
						|
       ab...gh
 | 
						|
    """
 | 
						|
    data = str(data)
 | 
						|
    if len(data) < max_length:
 | 
						|
        return data
 | 
						|
    remain_length = max_length - len(ellipsis_str)
 | 
						|
    half = remain_length // 2
 | 
						|
    if half <= 1:
 | 
						|
        return data[:max_length]
 | 
						|
    start = data[:half]
 | 
						|
    end = data[-half:]
 | 
						|
    data = f'{start}{ellipsis_str}{end}'
 | 
						|
    return data
 | 
						|
 | 
						|
 | 
						|
def group_by_count(it, count):
 | 
						|
    return [it[i:i + count] for i in range(0, len(it), count)]
 | 
						|
 | 
						|
 | 
						|
def test_ip_connectivity(host, port, timeout=0.5):
 | 
						|
    """
 | 
						|
    timeout: seconds
 | 
						|
    """
 | 
						|
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 | 
						|
    sock.settimeout(timeout)
 | 
						|
    result = sock.connect_ex((host, int(port)))
 | 
						|
    sock.close()
 | 
						|
    if result == 0:
 | 
						|
        connectivity = True
 | 
						|
    else:
 | 
						|
        connectivity = False
 | 
						|
    return connectivity
 | 
						|
 | 
						|
 | 
						|
def static_or_direct(logo_path):
 | 
						|
    if logo_path.startswith('img/'):
 | 
						|
        return static(logo_path)
 | 
						|
    else:
 | 
						|
        return logo_path
 | 
						|
 | 
						|
 | 
						|
def make_dirs(name, mode=0o755, exist_ok=False):
 | 
						|
    """ 默认权限设置为 0o755 """
 | 
						|
    return os.makedirs(name, mode=mode, exist_ok=exist_ok)
 | 
						|
 | 
						|
 | 
						|
def distinct(seq, key=None):
 | 
						|
    if key is None:
 | 
						|
        # 如果未提供关键字参数,则默认使用元素本身作为比较键
 | 
						|
        key = lambda x: x
 | 
						|
    seen = set()
 | 
						|
    result = []
 | 
						|
    for item in seq:
 | 
						|
        k = key(item)
 | 
						|
        if k not in seen:
 | 
						|
            seen.add(k)
 | 
						|
            result.append(item)
 | 
						|
    return result
 | 
						|
 | 
						|
 | 
						|
def is_macos():
 | 
						|
    return platform.system() == 'Darwin'
 | 
						|
 | 
						|
 | 
						|
def convert_html_to_markdown(html_str):
 | 
						|
    h = html2text.HTML2Text()
 | 
						|
    h.body_width = 0
 | 
						|
    h.ignore_links = True
 | 
						|
 | 
						|
    markdown = h.handle(html_str)
 | 
						|
    markdown = markdown.replace('\n\n', '\n')
 | 
						|
    markdown = markdown.replace('\n ', '\n')
 | 
						|
    return markdown
 |