mirror of https://github.com/jumpserver/jumpserver
247 lines
7.7 KiB
Python
247 lines
7.7 KiB
Python
import time
|
|
|
|
from channels_redis.core import RedisChannelLayer as _RedisChannelLayer
|
|
|
|
from common.utils.lock import DistributedLock
|
|
from common.utils.connection import get_redis_client
|
|
from common.utils import lazyproperty
|
|
from common.utils import get_logger
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
class ComputeLock(DistributedLock):
|
|
"""
|
|
需要重建缓存的时候加上该锁,避免重复计算
|
|
"""
|
|
def __init__(self, key):
|
|
name = f'compute:{key}'
|
|
super().__init__(name=name)
|
|
|
|
|
|
class CacheFieldBase:
|
|
field_type = str
|
|
|
|
def __init__(self, queryset=None, compute_func_name=None):
|
|
assert None in (queryset, compute_func_name), f'queryset and compute_func_name can only have one'
|
|
self.compute_func_name = compute_func_name
|
|
self.queryset = queryset
|
|
|
|
|
|
class CharField(CacheFieldBase):
|
|
field_type = str
|
|
|
|
|
|
class IntegerField(CacheFieldBase):
|
|
field_type = int
|
|
|
|
|
|
class CacheType(type):
|
|
def __new__(cls, name, bases, attrs: dict):
|
|
to_update = {}
|
|
field_desc_mapper = {}
|
|
|
|
for k, v in attrs.items():
|
|
if isinstance(v, CacheFieldBase):
|
|
desc = CacheValueDesc(k, v)
|
|
to_update[k] = desc
|
|
field_desc_mapper[k] = desc
|
|
|
|
attrs.update(to_update)
|
|
attrs['field_desc_mapper'] = field_desc_mapper
|
|
return type.__new__(cls, name, bases, attrs)
|
|
|
|
|
|
class Cache(metaclass=CacheType):
|
|
field_desc_mapper: dict
|
|
timeout = None
|
|
|
|
def __init__(self):
|
|
self._data = None
|
|
self.redis = get_redis_client()
|
|
|
|
def __getitem__(self, item):
|
|
return self.field_desc_mapper[item]
|
|
|
|
def __contains__(self, item):
|
|
return item in self.field_desc_mapper
|
|
|
|
def get_field(self, name):
|
|
return self.field_desc_mapper[name]
|
|
|
|
@property
|
|
def fields(self):
|
|
return self.field_desc_mapper.values()
|
|
|
|
@property
|
|
def field_names(self):
|
|
names = self.field_desc_mapper.keys()
|
|
return names
|
|
|
|
@lazyproperty
|
|
def key_suffix(self):
|
|
return self.get_key_suffix()
|
|
|
|
@property
|
|
def key_prefix(self):
|
|
clz = self.__class__
|
|
return f'cache.{clz.__module__}.{clz.__name__}'
|
|
|
|
@property
|
|
def key(self):
|
|
return f'{self.key_prefix}.{self.key_suffix}'
|
|
|
|
@property
|
|
def data(self):
|
|
if self._data is None:
|
|
data = self.load_data_from_db()
|
|
if not data:
|
|
with ComputeLock(self.key):
|
|
data = self.load_data_from_db()
|
|
if not data:
|
|
# 缓存中没有数据时,去数据库获取
|
|
self.init_all_values()
|
|
return self._data
|
|
|
|
def to_internal_value(self, data: dict):
|
|
internal_data = {}
|
|
for k, v in data.items():
|
|
field = k.decode()
|
|
if field in self:
|
|
value = self[field].to_internal_value(v.decode())
|
|
internal_data[field] = value
|
|
else:
|
|
logger.warn(f'Cache got invalid field: '
|
|
f'key={self.key} '
|
|
f'invalid_field={field} '
|
|
f'valid_fields={self.field_names}')
|
|
return internal_data
|
|
|
|
def load_data_from_db(self) -> dict:
|
|
data = self.redis.hgetall(self.key)
|
|
logger.debug(f'Get data from cache: key={self.key} data={data}')
|
|
if data:
|
|
data = self.to_internal_value(data)
|
|
self._data = data
|
|
return data
|
|
|
|
def save_data_to_db(self, data):
|
|
logger.debug(f'Set data to cache: key={self.key} data={data}')
|
|
self.redis.hset(self.key, mapping=data)
|
|
self.load_data_from_db()
|
|
|
|
def compute_values(self, *fields):
|
|
field_objs = []
|
|
for field in fields:
|
|
field_objs.append(self[field])
|
|
|
|
data = {
|
|
field_obj.field_name: field_obj.compute_value(self)
|
|
for field_obj in field_objs
|
|
}
|
|
return data
|
|
|
|
def init_all_values(self):
|
|
t_start = time.time()
|
|
logger.debug(f'Start init cache: key={self.key}')
|
|
data = self.compute_values(*self.field_names)
|
|
self.save_data_to_db(data)
|
|
logger.debug(f'End init cache: cost={time.time()-t_start} key={self.key}')
|
|
return data
|
|
|
|
def refresh(self, *fields):
|
|
if not fields:
|
|
# 没有指定 field 要刷新所有的值
|
|
self.init_all_values()
|
|
return
|
|
|
|
data = self.load_data_from_db()
|
|
if not data:
|
|
# 缓存中没有数据,设置所有的值
|
|
self.init_all_values()
|
|
return
|
|
|
|
refresh_values = self.compute_values(*fields)
|
|
self.save_data_to_db(refresh_values)
|
|
|
|
def get_key_suffix(self):
|
|
raise NotImplementedError
|
|
|
|
def reload(self):
|
|
self._data = None
|
|
|
|
def expire(self, *fields):
|
|
self._data = None
|
|
if not fields:
|
|
self.redis.delete(self.key)
|
|
else:
|
|
self.redis.hdel(self.key, *fields)
|
|
logger.debug(f'Expire cached fields: key={self.key} fields={fields}')
|
|
|
|
|
|
class CacheValueDesc:
|
|
def __init__(self, field_name, field_type: CacheFieldBase):
|
|
self.field_name = field_name
|
|
self.field_type = field_type
|
|
self._data = None
|
|
|
|
def __repr__(self):
|
|
clz = self.__class__
|
|
return f'<{clz.__name__} {self.field_name} {self.field_type}>'
|
|
|
|
def __get__(self, instance: Cache, owner):
|
|
if instance is None:
|
|
return self
|
|
if self.field_name not in instance.data:
|
|
instance.refresh(self.field_name)
|
|
# 防止边界情况没有值,报错
|
|
value = instance.data.get(self.field_name)
|
|
return value
|
|
|
|
def compute_value(self, instance: Cache):
|
|
t_start = time.time()
|
|
logger.debug(f'Start compute cache field: field={self.field_name} key={instance.key}')
|
|
if self.field_type.queryset is not None:
|
|
new_value = self.field_type.queryset.count()
|
|
else:
|
|
compute_func_name = self.field_type.compute_func_name
|
|
if not compute_func_name:
|
|
compute_func_name = f'compute_{self.field_name}'
|
|
compute_func = getattr(instance, compute_func_name, None)
|
|
assert compute_func is not None, \
|
|
f'Define `{compute_func_name}` method in {instance.__class__}'
|
|
new_value = compute_func()
|
|
|
|
new_value = self.field_type.field_type(new_value)
|
|
logger.debug(f'End compute cache field: cost={time.time()-t_start} field={self.field_name} value={new_value} key={instance.key}')
|
|
return new_value
|
|
|
|
def to_internal_value(self, value):
|
|
return self.field_type.field_type(value)
|
|
|
|
|
|
class RedisChannelLayer(_RedisChannelLayer):
|
|
async def _brpop_with_clean(self, index, channel, timeout):
|
|
cleanup_script = """
|
|
local backed_up = redis.call('ZRANGE', ARGV[2], 0, -1, 'WITHSCORES')
|
|
for i = #backed_up, 1, -2 do
|
|
redis.call('ZADD', ARGV[1], backed_up[i], backed_up[i - 1])
|
|
end
|
|
redis.call('DEL', ARGV[2])
|
|
"""
|
|
backup_queue = self._backup_channel_name(channel)
|
|
async with self.connection(index) as connection:
|
|
# 部分云厂商的 Redis 此操作会报错(不支持,比如阿里云有限制)
|
|
try:
|
|
await connection.eval(cleanup_script, keys=[], args=[channel, backup_queue])
|
|
except:
|
|
pass
|
|
result = await connection.bzpopmin(channel, timeout=timeout)
|
|
|
|
if result is not None:
|
|
_, member, timestamp = result
|
|
await connection.zadd(backup_queue, float(timestamp), member)
|
|
else:
|
|
member = None
|
|
return member
|