mirror of https://github.com/jumpserver/jumpserver
				
				
				
			
		
			
				
	
	
		
			247 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			247 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
import time
 | 
						|
 | 
						|
from channels_redis.core import RedisChannelLayer as _RedisChannelLayer
 | 
						|
 | 
						|
from common.utils.lock import DistributedLock
 | 
						|
from common.utils.connection import get_redis_client
 | 
						|
from common.utils import lazyproperty
 | 
						|
from common.utils import get_logger
 | 
						|
 | 
						|
logger = get_logger(__file__)
 | 
						|
 | 
						|
 | 
						|
class ComputeLock(DistributedLock):
 | 
						|
    """
 | 
						|
    需要重建缓存的时候加上该锁,避免重复计算
 | 
						|
    """
 | 
						|
    def __init__(self, key):
 | 
						|
        name = f'compute:{key}'
 | 
						|
        super().__init__(name=name)
 | 
						|
 | 
						|
 | 
						|
class CacheFieldBase:
 | 
						|
    field_type = str
 | 
						|
 | 
						|
    def __init__(self, queryset=None, compute_func_name=None):
 | 
						|
        assert None in (queryset, compute_func_name), f'queryset and compute_func_name can only have one'
 | 
						|
        self.compute_func_name = compute_func_name
 | 
						|
        self.queryset = queryset
 | 
						|
 | 
						|
 | 
						|
class CharField(CacheFieldBase):
 | 
						|
    field_type = str
 | 
						|
 | 
						|
 | 
						|
class IntegerField(CacheFieldBase):
 | 
						|
    field_type = int
 | 
						|
 | 
						|
 | 
						|
class CacheType(type):
 | 
						|
    def __new__(cls, name, bases, attrs: dict):
 | 
						|
        to_update = {}
 | 
						|
        field_desc_mapper = {}
 | 
						|
 | 
						|
        for k, v in attrs.items():
 | 
						|
            if isinstance(v, CacheFieldBase):
 | 
						|
                desc = CacheValueDesc(k, v)
 | 
						|
                to_update[k] = desc
 | 
						|
                field_desc_mapper[k] = desc
 | 
						|
 | 
						|
        attrs.update(to_update)
 | 
						|
        attrs['field_desc_mapper'] = field_desc_mapper
 | 
						|
        return type.__new__(cls, name, bases, attrs)
 | 
						|
 | 
						|
 | 
						|
class Cache(metaclass=CacheType):
 | 
						|
    field_desc_mapper: dict
 | 
						|
    timeout = None
 | 
						|
 | 
						|
    def __init__(self):
 | 
						|
        self._data = None
 | 
						|
        self.redis = get_redis_client()
 | 
						|
 | 
						|
    def __getitem__(self, item):
 | 
						|
        return self.field_desc_mapper[item]
 | 
						|
 | 
						|
    def __contains__(self, item):
 | 
						|
        return item in self.field_desc_mapper
 | 
						|
 | 
						|
    def get_field(self, name):
 | 
						|
        return self.field_desc_mapper[name]
 | 
						|
 | 
						|
    @property
 | 
						|
    def fields(self):
 | 
						|
        return self.field_desc_mapper.values()
 | 
						|
 | 
						|
    @property
 | 
						|
    def field_names(self):
 | 
						|
        names = self.field_desc_mapper.keys()
 | 
						|
        return names
 | 
						|
 | 
						|
    @lazyproperty
 | 
						|
    def key_suffix(self):
 | 
						|
        return self.get_key_suffix()
 | 
						|
 | 
						|
    @property
 | 
						|
    def key_prefix(self):
 | 
						|
        clz = self.__class__
 | 
						|
        return f'cache.{clz.__module__}.{clz.__name__}'
 | 
						|
 | 
						|
    @property
 | 
						|
    def key(self):
 | 
						|
        return f'{self.key_prefix}.{self.key_suffix}'
 | 
						|
 | 
						|
    @property
 | 
						|
    def data(self):
 | 
						|
        if self._data is None:
 | 
						|
            data = self.load_data_from_db()
 | 
						|
            if not data:
 | 
						|
                with ComputeLock(self.key):
 | 
						|
                    data = self.load_data_from_db()
 | 
						|
                    if not data:
 | 
						|
                        # 缓存中没有数据时,去数据库获取
 | 
						|
                        self.init_all_values()
 | 
						|
        return self._data
 | 
						|
 | 
						|
    def to_internal_value(self, data: dict):
 | 
						|
        internal_data = {}
 | 
						|
        for k, v in data.items():
 | 
						|
            field = k.decode()
 | 
						|
            if field in self:
 | 
						|
                value = self[field].to_internal_value(v.decode())
 | 
						|
                internal_data[field] = value
 | 
						|
            else:
 | 
						|
                logger.warning(f'Cache got invalid field: '
 | 
						|
                            f'key={self.key} '
 | 
						|
                            f'invalid_field={field} '
 | 
						|
                            f'valid_fields={self.field_names}')
 | 
						|
        return internal_data
 | 
						|
 | 
						|
    def load_data_from_db(self) -> dict:
 | 
						|
        data = self.redis.hgetall(self.key)
 | 
						|
        logger.debug(f'Get data from cache: key={self.key} data={data}')
 | 
						|
        if data:
 | 
						|
            data = self.to_internal_value(data)
 | 
						|
            self._data = data
 | 
						|
        return data
 | 
						|
 | 
						|
    def save_data_to_db(self, data):
 | 
						|
        logger.debug(f'Set data to cache: key={self.key} data={data}')
 | 
						|
        self.redis.hset(self.key, mapping=data)
 | 
						|
        self.load_data_from_db()
 | 
						|
 | 
						|
    def compute_values(self, *fields):
 | 
						|
        field_objs = []
 | 
						|
        for field in fields:
 | 
						|
            field_objs.append(self[field])
 | 
						|
 | 
						|
        data = {
 | 
						|
            field_obj.field_name: field_obj.compute_value(self)
 | 
						|
            for field_obj in field_objs
 | 
						|
        }
 | 
						|
        return data
 | 
						|
 | 
						|
    def init_all_values(self):
 | 
						|
        t_start = time.time()
 | 
						|
        logger.debug(f'Start init cache: key={self.key}')
 | 
						|
        data = self.compute_values(*self.field_names)
 | 
						|
        self.save_data_to_db(data)
 | 
						|
        logger.debug(f'End init cache: cost={time.time()-t_start} key={self.key}')
 | 
						|
        return data
 | 
						|
 | 
						|
    def refresh(self, *fields):
 | 
						|
        if not fields:
 | 
						|
            # 没有指定 field 要刷新所有的值
 | 
						|
            self.init_all_values()
 | 
						|
            return
 | 
						|
 | 
						|
        data = self.load_data_from_db()
 | 
						|
        if not data:
 | 
						|
            # 缓存中没有数据,设置所有的值
 | 
						|
            self.init_all_values()
 | 
						|
            return
 | 
						|
 | 
						|
        refresh_values = self.compute_values(*fields)
 | 
						|
        self.save_data_to_db(refresh_values)
 | 
						|
 | 
						|
    def get_key_suffix(self):
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    def reload(self):
 | 
						|
        self._data = None
 | 
						|
 | 
						|
    def expire(self, *fields):
 | 
						|
        self._data = None
 | 
						|
        if not fields:
 | 
						|
            self.redis.delete(self.key)
 | 
						|
        else:
 | 
						|
            self.redis.hdel(self.key, *fields)
 | 
						|
            logger.debug(f'Expire cached fields: key={self.key} fields={fields}')
 | 
						|
 | 
						|
 | 
						|
class CacheValueDesc:
 | 
						|
    def __init__(self, field_name, field_type: CacheFieldBase):
 | 
						|
        self.field_name = field_name
 | 
						|
        self.field_type = field_type
 | 
						|
        self._data = None
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        clz = self.__class__
 | 
						|
        return f'<{clz.__name__} {self.field_name} {self.field_type}>'
 | 
						|
 | 
						|
    def __get__(self, instance: Cache, owner):
 | 
						|
        if instance is None:
 | 
						|
            return self
 | 
						|
        if self.field_name not in instance.data:
 | 
						|
            instance.refresh(self.field_name)
 | 
						|
        # 防止边界情况没有值,报错
 | 
						|
        value = instance.data.get(self.field_name)
 | 
						|
        return value
 | 
						|
 | 
						|
    def compute_value(self, instance: Cache):
 | 
						|
        t_start = time.time()
 | 
						|
        logger.debug(f'Start compute cache field: field={self.field_name} key={instance.key}')
 | 
						|
        if self.field_type.queryset is not None:
 | 
						|
            new_value = self.field_type.queryset.count()
 | 
						|
        else:
 | 
						|
            compute_func_name = self.field_type.compute_func_name
 | 
						|
            if not compute_func_name:
 | 
						|
                compute_func_name = f'compute_{self.field_name}'
 | 
						|
            compute_func = getattr(instance, compute_func_name, None)
 | 
						|
            assert compute_func is not None, \
 | 
						|
                f'Define `{compute_func_name}` method in {instance.__class__}'
 | 
						|
            new_value = compute_func()
 | 
						|
 | 
						|
        new_value = self.field_type.field_type(new_value)
 | 
						|
        logger.debug(f'End compute cache field: cost={time.time()-t_start} field={self.field_name} value={new_value} key={instance.key}')
 | 
						|
        return new_value
 | 
						|
 | 
						|
    def to_internal_value(self, value):
 | 
						|
        return self.field_type.field_type(value)
 | 
						|
 | 
						|
 | 
						|
class RedisChannelLayer(_RedisChannelLayer):
 | 
						|
    async def _brpop_with_clean(self, index, channel, timeout):
 | 
						|
        cleanup_script = """
 | 
						|
            local backed_up = redis.call('ZRANGE', ARGV[2], 0, -1, 'WITHSCORES')
 | 
						|
            for i = #backed_up, 1, -2 do
 | 
						|
                redis.call('ZADD', ARGV[1], backed_up[i], backed_up[i - 1])
 | 
						|
            end
 | 
						|
            redis.call('DEL', ARGV[2])
 | 
						|
        """
 | 
						|
        backup_queue = self._backup_channel_name(channel)
 | 
						|
        async with self.connection(index) as connection:
 | 
						|
            # 部分云厂商的 Redis 此操作会报错(不支持,比如阿里云有限制)
 | 
						|
            try:
 | 
						|
                await connection.eval(cleanup_script, keys=[], args=[channel, backup_queue])
 | 
						|
            except:
 | 
						|
                pass
 | 
						|
            result = await connection.bzpopmin(channel, timeout=timeout)
 | 
						|
 | 
						|
            if result is not None:
 | 
						|
                _, member, timestamp = result
 | 
						|
                await connection.zadd(backup_queue, float(timestamp), member)
 | 
						|
            else:
 | 
						|
                member = None
 | 
						|
            return member
 |