jumpserver/apps/common/cache.py

247 lines
7.7 KiB
Python
Raw Normal View History

2021-03-09 05:57:58 +00:00
import time
from channels_redis.core import RedisChannelLayer as _RedisChannelLayer
from common.utils.lock import DistributedLock
from common.utils.connection import get_redis_client
from common.utils import lazyproperty
from common.utils import get_logger
logger = get_logger(__file__)
2021-03-09 05:57:58 +00:00
class ComputeLock(DistributedLock):
"""
需要重建缓存的时候加上该锁避免重复计算
"""
def __init__(self, key):
name = f'compute:{key}'
super().__init__(name=name)
class CacheFieldBase:
field_type = str
def __init__(self, queryset=None, compute_func_name=None):
assert None in (queryset, compute_func_name), f'queryset and compute_func_name can only have one'
self.compute_func_name = compute_func_name
self.queryset = queryset
class CharField(CacheFieldBase):
field_type = str
class IntegerField(CacheFieldBase):
field_type = int
2021-03-09 05:57:58 +00:00
class CacheType(type):
def __new__(cls, name, bases, attrs: dict):
to_update = {}
field_desc_mapper = {}
for k, v in attrs.items():
if isinstance(v, CacheFieldBase):
desc = CacheValueDesc(k, v)
to_update[k] = desc
field_desc_mapper[k] = desc
attrs.update(to_update)
attrs['field_desc_mapper'] = field_desc_mapper
return type.__new__(cls, name, bases, attrs)
2021-03-09 05:57:58 +00:00
class Cache(metaclass=CacheType):
field_desc_mapper: dict
timeout = None
def __init__(self):
self._data = None
self.redis = get_redis_client()
2021-03-09 05:57:58 +00:00
def __getitem__(self, item):
return self.field_desc_mapper[item]
def __contains__(self, item):
return item in self.field_desc_mapper
def get_field(self, name):
return self.field_desc_mapper[name]
@property
def fields(self):
return self.field_desc_mapper.values()
@property
def field_names(self):
names = self.field_desc_mapper.keys()
return names
@lazyproperty
def key_suffix(self):
return self.get_key_suffix()
@property
def key_prefix(self):
clz = self.__class__
return f'cache.{clz.__module__}.{clz.__name__}'
@property
def key(self):
return f'{self.key_prefix}.{self.key_suffix}'
@property
def data(self):
if self._data is None:
2021-03-09 05:57:58 +00:00
data = self.load_data_from_db()
if not data:
with ComputeLock(self.key):
data = self.load_data_from_db()
if not data:
# 缓存中没有数据时,去数据库获取
self.init_all_values()
return self._data
2021-03-09 05:57:58 +00:00
def to_internal_value(self, data: dict):
internal_data = {}
for k, v in data.items():
field = k.decode()
if field in self:
value = self[field].to_internal_value(v.decode())
internal_data[field] = value
else:
logger.warn(f'Cache got invalid field: '
f'key={self.key} '
f'invalid_field={field} '
f'valid_fields={self.field_names}')
return internal_data
def load_data_from_db(self) -> dict:
data = self.redis.hgetall(self.key)
logger.debug(f'Get data from cache: key={self.key} data={data}')
2021-03-09 05:57:58 +00:00
if data:
data = self.to_internal_value(data)
self._data = data
return data
2021-03-09 05:57:58 +00:00
def save_data_to_db(self, data):
logger.debug(f'Set data to cache: key={self.key} data={data}')
2021-03-09 05:57:58 +00:00
self.redis.hset(self.key, mapping=data)
self.load_data_from_db()
def compute_values(self, *fields):
field_objs = []
for field in fields:
field_objs.append(self[field])
data = {
2021-03-09 05:57:58 +00:00
field_obj.field_name: field_obj.compute_value(self)
for field_obj in field_objs
}
return data
2021-03-09 05:57:58 +00:00
def init_all_values(self):
t_start = time.time()
logger.debug(f'Start init cache: key={self.key}')
2021-03-09 05:57:58 +00:00
data = self.compute_values(*self.field_names)
self.save_data_to_db(data)
logger.debug(f'End init cache: cost={time.time()-t_start} key={self.key}')
return data
def refresh(self, *fields):
if not fields:
# 没有指定 field 要刷新所有的值
2021-03-09 05:57:58 +00:00
self.init_all_values()
return
2021-03-09 05:57:58 +00:00
data = self.load_data_from_db()
if not data:
# 缓存中没有数据,设置所有的值
2021-03-09 05:57:58 +00:00
self.init_all_values()
return
2021-03-09 05:57:58 +00:00
refresh_values = self.compute_values(*fields)
self.save_data_to_db(refresh_values)
def get_key_suffix(self):
raise NotImplementedError
def reload(self):
self._data = None
def expire(self, *fields):
2021-03-09 05:57:58 +00:00
self._data = None
if not fields:
2021-03-09 05:57:58 +00:00
self.redis.delete(self.key)
else:
2021-03-09 05:57:58 +00:00
self.redis.hdel(self.key, *fields)
logger.debug(f'Expire cached fields: key={self.key} fields={fields}')
class CacheValueDesc:
def __init__(self, field_name, field_type: CacheFieldBase):
self.field_name = field_name
self.field_type = field_type
self._data = None
def __repr__(self):
clz = self.__class__
return f'<{clz.__name__} {self.field_name} {self.field_type}>'
def __get__(self, instance: Cache, owner):
if instance is None:
return self
if self.field_name not in instance.data:
instance.refresh(self.field_name)
# 防止边界情况没有值,报错
value = instance.data.get(self.field_name)
return value
def compute_value(self, instance: Cache):
2021-03-09 05:57:58 +00:00
t_start = time.time()
logger.debug(f'Start compute cache field: field={self.field_name} key={instance.key}')
if self.field_type.queryset is not None:
new_value = self.field_type.queryset.count()
else:
compute_func_name = self.field_type.compute_func_name
if not compute_func_name:
compute_func_name = f'compute_{self.field_name}'
compute_func = getattr(instance, compute_func_name, None)
assert compute_func is not None, \
f'Define `{compute_func_name}` method in {instance.__class__}'
new_value = compute_func()
new_value = self.field_type.field_type(new_value)
logger.debug(f'End compute cache field: cost={time.time()-t_start} field={self.field_name} value={new_value} key={instance.key}')
return new_value
2021-03-09 05:57:58 +00:00
def to_internal_value(self, value):
return self.field_type.field_type(value)
class RedisChannelLayer(_RedisChannelLayer):
async def _brpop_with_clean(self, index, channel, timeout):
cleanup_script = """
local backed_up = redis.call('ZRANGE', ARGV[2], 0, -1, 'WITHSCORES')
for i = #backed_up, 1, -2 do
redis.call('ZADD', ARGV[1], backed_up[i], backed_up[i - 1])
end
redis.call('DEL', ARGV[2])
"""
backup_queue = self._backup_channel_name(channel)
async with self.connection(index) as connection:
# 部分云厂商的 Redis 此操作会报错(不支持,比如阿里云有限制)
try:
await connection.eval(cleanup_script, keys=[], args=[channel, backup_queue])
except:
pass
result = await connection.bzpopmin(channel, timeout=timeout)
if result is not None:
_, member, timestamp = result
await connection.zadd(backup_queue, float(timestamp), member)
else:
member = None
return member