mirror of https://github.com/jumpserver/jumpserver
refactor: 重构缓存框架
parent
ccb0509d85
commit
c4eacbabc6
|
@ -1,13 +1,24 @@
|
||||||
import json
|
import time
|
||||||
from django.core.cache import cache
|
|
||||||
|
from redis import Redis
|
||||||
|
|
||||||
from common.utils.lock import DistributedLock
|
from common.utils.lock import DistributedLock
|
||||||
from common.utils import lazyproperty
|
from common.utils import lazyproperty
|
||||||
from common.utils import get_logger
|
from common.utils import get_logger
|
||||||
|
from jumpserver.const import CONFIG
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
class ComputeLock(DistributedLock):
|
||||||
|
"""
|
||||||
|
需要重建缓存的时候加上该锁,避免重复计算
|
||||||
|
"""
|
||||||
|
def __init__(self, key):
|
||||||
|
name = f'compute:{key}'
|
||||||
|
super().__init__(name=name)
|
||||||
|
|
||||||
|
|
||||||
class CacheFieldBase:
|
class CacheFieldBase:
|
||||||
field_type = str
|
field_type = str
|
||||||
|
|
||||||
|
@ -25,7 +36,7 @@ class IntegerField(CacheFieldBase):
|
||||||
field_type = int
|
field_type = int
|
||||||
|
|
||||||
|
|
||||||
class CacheBase(type):
|
class CacheType(type):
|
||||||
def __new__(cls, name, bases, attrs: dict):
|
def __new__(cls, name, bases, attrs: dict):
|
||||||
to_update = {}
|
to_update = {}
|
||||||
field_desc_mapper = {}
|
field_desc_mapper = {}
|
||||||
|
@ -41,12 +52,31 @@ class CacheBase(type):
|
||||||
return type.__new__(cls, name, bases, attrs)
|
return type.__new__(cls, name, bases, attrs)
|
||||||
|
|
||||||
|
|
||||||
class Cache(metaclass=CacheBase):
|
class Cache(metaclass=CacheType):
|
||||||
field_desc_mapper: dict
|
field_desc_mapper: dict
|
||||||
timeout = None
|
timeout = None
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._data = None
|
self._data = None
|
||||||
|
self.redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.field_desc_mapper[item]
|
||||||
|
|
||||||
|
def __contains__(self, item):
|
||||||
|
return item in self.field_desc_mapper
|
||||||
|
|
||||||
|
def get_field(self, name):
|
||||||
|
return self.field_desc_mapper[name]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fields(self):
|
||||||
|
return self.field_desc_mapper.values()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def field_names(self):
|
||||||
|
names = self.field_desc_mapper.keys()
|
||||||
|
return names
|
||||||
|
|
||||||
@lazyproperty
|
@lazyproperty
|
||||||
def key_suffix(self):
|
def key_suffix(self):
|
||||||
|
@ -64,91 +94,75 @@ class Cache(metaclass=CacheBase):
|
||||||
@property
|
@property
|
||||||
def data(self):
|
def data(self):
|
||||||
if self._data is None:
|
if self._data is None:
|
||||||
data = self.get_data()
|
data = self.load_data_from_db()
|
||||||
if data is None:
|
if not data:
|
||||||
|
with ComputeLock(self.key):
|
||||||
|
data = self.load_data_from_db()
|
||||||
|
if not data:
|
||||||
# 缓存中没有数据时,去数据库获取
|
# 缓存中没有数据时,去数据库获取
|
||||||
self.compute_and_set_all_data()
|
self.init_all_values()
|
||||||
return self._data
|
return self._data
|
||||||
|
|
||||||
def get_data(self) -> dict:
|
def to_internal_value(self, data: dict):
|
||||||
data = cache.get(self.key)
|
internal_data = {}
|
||||||
|
for k, v in data.items():
|
||||||
|
field = k.decode()
|
||||||
|
if field in self:
|
||||||
|
value = self[field].to_internal_value(v.decode())
|
||||||
|
internal_data[field] = value
|
||||||
|
else:
|
||||||
|
logger.warn(f'Cache got invalid field: '
|
||||||
|
f'key={self.key} '
|
||||||
|
f'invalid_field={field} '
|
||||||
|
f'valid_fields={self.field_names}')
|
||||||
|
return internal_data
|
||||||
|
|
||||||
|
def load_data_from_db(self) -> dict:
|
||||||
|
data = self.redis.hgetall(self.key)
|
||||||
logger.debug(f'Get data from cache: key={self.key} data={data}')
|
logger.debug(f'Get data from cache: key={self.key} data={data}')
|
||||||
if data is not None:
|
if data:
|
||||||
data = json.loads(data)
|
data = self.to_internal_value(data)
|
||||||
self._data = data
|
self._data = data
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def set_data(self, data):
|
def save_data_to_db(self, data):
|
||||||
self._data = data
|
logger.info(f'Set data to cache: key={self.key} data={data}')
|
||||||
to_json = json.dumps(data)
|
self.redis.hset(self.key, mapping=data)
|
||||||
logger.info(f'Set data to cache: key={self.key} data={to_json} timeout={self.timeout}')
|
self.load_data_from_db()
|
||||||
cache.set(self.key, to_json, timeout=self.timeout)
|
|
||||||
|
|
||||||
def compute_data(self, *fields):
|
def compute_values(self, *fields):
|
||||||
field_descs = []
|
field_objs = []
|
||||||
if not fields:
|
|
||||||
field_descs = self.field_desc_mapper.values()
|
|
||||||
else:
|
|
||||||
for field in fields:
|
for field in fields:
|
||||||
assert field in self.field_desc_mapper, f'{field} is not a valid field'
|
field_objs.append(self[field])
|
||||||
field_descs.append(self.field_desc_mapper[field])
|
|
||||||
data = {
|
data = {
|
||||||
field_desc.field_name: field_desc.compute_value(self)
|
field_obj.field_name: field_obj.compute_value(self)
|
||||||
for field_desc in field_descs
|
for field_obj in field_objs
|
||||||
}
|
}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def compute_and_set_all_data(self, computed_data: dict = None):
|
def init_all_values(self):
|
||||||
"""
|
t_start = time.time()
|
||||||
TODO 怎样防止并发更新全部数据,浪费数据库资源
|
logger.info(f'Start init cache: key={self.key}')
|
||||||
"""
|
data = self.compute_values(*self.field_names)
|
||||||
uncomputed_keys = ()
|
self.save_data_to_db(data)
|
||||||
if computed_data:
|
logger.info(f'End init cache: cost={time.time()-t_start} key={self.key}')
|
||||||
computed_keys = computed_data.keys()
|
|
||||||
all_keys = self.field_desc_mapper.keys()
|
|
||||||
uncomputed_keys = all_keys - computed_keys
|
|
||||||
else:
|
|
||||||
computed_data = {}
|
|
||||||
data = self.compute_data(*uncomputed_keys)
|
|
||||||
data.update(computed_data)
|
|
||||||
self.set_data(data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
def refresh_part_data_with_lock(self, refresh_data):
|
|
||||||
with DistributedLock(name=f'{self.key}.refresh'):
|
|
||||||
data = self.get_data()
|
|
||||||
if data is not None:
|
|
||||||
data.update(refresh_data)
|
|
||||||
self.set_data(data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
def expire_fields_with_lock(self, *fields):
|
|
||||||
with DistributedLock(name=f'{self.key}.refresh'):
|
|
||||||
data = self.get_data()
|
|
||||||
if data is not None:
|
|
||||||
logger.info(f'Expire cached fields: key={self.key} fields={fields}')
|
|
||||||
for f in fields:
|
|
||||||
data.pop(f, None)
|
|
||||||
self.set_data(data)
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def refresh(self, *fields):
|
def refresh(self, *fields):
|
||||||
if not fields:
|
if not fields:
|
||||||
# 没有指定 field 要刷新所有的值
|
# 没有指定 field 要刷新所有的值
|
||||||
self.compute_and_set_all_data()
|
self.init_all_values()
|
||||||
return
|
return
|
||||||
|
|
||||||
data = self.get_data()
|
data = self.load_data_from_db()
|
||||||
if data is None:
|
if not data:
|
||||||
# 缓存中没有数据,设置所有的值
|
# 缓存中没有数据,设置所有的值
|
||||||
self.compute_and_set_all_data()
|
self.init_all_values()
|
||||||
return
|
return
|
||||||
|
|
||||||
refresh_data = self.compute_data(*fields)
|
refresh_values = self.compute_values(*fields)
|
||||||
if not self.refresh_part_data_with_lock(refresh_data):
|
self.save_data_to_db(refresh_values)
|
||||||
# 刷新部分失败,缓存中没有数据,更新所有的值
|
|
||||||
self.compute_and_set_all_data(refresh_data)
|
|
||||||
return
|
|
||||||
|
|
||||||
def get_key_suffix(self):
|
def get_key_suffix(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -157,12 +171,13 @@ class Cache(metaclass=CacheBase):
|
||||||
self._data = None
|
self._data = None
|
||||||
|
|
||||||
def expire(self, *fields):
|
def expire(self, *fields):
|
||||||
if not fields:
|
|
||||||
self._data = None
|
self._data = None
|
||||||
|
if not fields:
|
||||||
logger.info(f'Delete cached key: key={self.key}')
|
logger.info(f'Delete cached key: key={self.key}')
|
||||||
cache.delete(self.key)
|
self.redis.delete(self.key)
|
||||||
else:
|
else:
|
||||||
self.expire_fields_with_lock(*fields)
|
self.redis.hdel(self.key, *fields)
|
||||||
|
logger.info(f'Expire cached fields: key={self.key} fields={fields}')
|
||||||
|
|
||||||
|
|
||||||
class CacheValueDesc:
|
class CacheValueDesc:
|
||||||
|
@ -185,6 +200,8 @@ class CacheValueDesc:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def compute_value(self, instance: Cache):
|
def compute_value(self, instance: Cache):
|
||||||
|
t_start = time.time()
|
||||||
|
logger.info(f'Start compute cache field: field={self.field_name} key={instance.key}')
|
||||||
if self.field_type.queryset is not None:
|
if self.field_type.queryset is not None:
|
||||||
new_value = self.field_type.queryset.count()
|
new_value = self.field_type.queryset.count()
|
||||||
else:
|
else:
|
||||||
|
@ -197,5 +214,8 @@ class CacheValueDesc:
|
||||||
new_value = compute_func()
|
new_value = compute_func()
|
||||||
|
|
||||||
new_value = self.field_type.field_type(new_value)
|
new_value = self.field_type.field_type(new_value)
|
||||||
logger.info(f'Compute cache field value: key={instance.key} field={self.field_name} value={new_value}')
|
logger.info(f'End compute cache field: cost={time.time()-t_start} field={self.field_name} value={new_value} key={instance.key}')
|
||||||
return new_value
|
return new_value
|
||||||
|
|
||||||
|
def to_internal_value(self, value):
|
||||||
|
return self.field_type.field_type(value)
|
||||||
|
|
|
@ -1,39 +0,0 @@
|
||||||
from django.db.transaction import on_commit
|
|
||||||
|
|
||||||
from common.cache import *
|
|
||||||
from .utils import current_org, tmp_to_org
|
|
||||||
from .tasks import refresh_org_cache_task
|
|
||||||
from orgs.models import Organization
|
|
||||||
|
|
||||||
|
|
||||||
class OrgRelatedCache(Cache):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.current_org = Organization.get_instance(current_org.id)
|
|
||||||
|
|
||||||
def get_current_org(self):
|
|
||||||
"""
|
|
||||||
暴露给子类控制组织的回调
|
|
||||||
1. 在交互式环境下能控制组织
|
|
||||||
2. 在 celery 任务下能控制组织
|
|
||||||
"""
|
|
||||||
return self.current_org
|
|
||||||
|
|
||||||
def compute_data(self, *fields):
|
|
||||||
with tmp_to_org(self.get_current_org()):
|
|
||||||
return super().compute_data(*fields)
|
|
||||||
|
|
||||||
def refresh_async(self, *fields):
|
|
||||||
"""
|
|
||||||
在事务提交之后再发送信号,防止因事务的隔离性导致未获得最新的数据
|
|
||||||
"""
|
|
||||||
def func():
|
|
||||||
logger.info(f'CACHE: Send refresh task {self}.{fields}')
|
|
||||||
refresh_org_cache_task.delay(self, *fields)
|
|
||||||
on_commit(func)
|
|
||||||
|
|
||||||
def expire(self, *fields):
|
|
||||||
def func():
|
|
||||||
super(OrgRelatedCache, self).expire(*fields)
|
|
||||||
on_commit(func)
|
|
|
@ -1,4 +1,10 @@
|
||||||
from .cache import OrgRelatedCache, IntegerField
|
from django.db.transaction import on_commit
|
||||||
|
from orgs.models import Organization
|
||||||
|
from orgs.tasks import refresh_org_cache_task
|
||||||
|
from orgs.utils import current_org, tmp_to_org
|
||||||
|
|
||||||
|
from common.cache import Cache, IntegerField
|
||||||
|
from common.utils import get_logger
|
||||||
from users.models import UserGroup, User
|
from users.models import UserGroup, User
|
||||||
from assets.models import Node, AdminUser, SystemUser, Domain, Gateway
|
from assets.models import Node, AdminUser, SystemUser, Domain, Gateway
|
||||||
from applications.models import Application
|
from applications.models import Application
|
||||||
|
@ -6,6 +12,42 @@ from perms.models import AssetPermission, ApplicationPermission
|
||||||
from .models import OrganizationMember
|
from .models import OrganizationMember
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
class OrgRelatedCache(Cache):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.current_org = Organization.get_instance(current_org.id)
|
||||||
|
|
||||||
|
def get_current_org(self):
|
||||||
|
"""
|
||||||
|
暴露给子类控制组织的回调
|
||||||
|
1. 在交互式环境下能控制组织
|
||||||
|
2. 在 celery 任务下能控制组织
|
||||||
|
"""
|
||||||
|
return self.current_org
|
||||||
|
|
||||||
|
def compute_values(self, *fields):
|
||||||
|
with tmp_to_org(self.get_current_org()):
|
||||||
|
return super().compute_values(*fields)
|
||||||
|
|
||||||
|
def refresh_async(self, *fields):
|
||||||
|
"""
|
||||||
|
在事务提交之后再发送信号,防止因事务的隔离性导致未获得最新的数据
|
||||||
|
"""
|
||||||
|
def func():
|
||||||
|
logger.info(f'CACHE: Send refresh task {self}.{fields}')
|
||||||
|
refresh_org_cache_task.delay(self, *fields)
|
||||||
|
on_commit(func)
|
||||||
|
|
||||||
|
def expire(self, *fields):
|
||||||
|
def func():
|
||||||
|
super(OrgRelatedCache, self).expire(*fields)
|
||||||
|
on_commit(func)
|
||||||
|
|
||||||
|
|
||||||
class OrgResourceStatisticsCache(OrgRelatedCache):
|
class OrgResourceStatisticsCache(OrgRelatedCache):
|
||||||
users_amount = IntegerField()
|
users_amount = IntegerField()
|
||||||
groups_amount = IntegerField(queryset=UserGroup.objects)
|
groups_amount = IntegerField(queryset=UserGroup.objects)
|
||||||
|
|
Loading…
Reference in New Issue