refactor: 重构缓存框架

pull/5700/head^2
xinwen 4 years ago committed by Jiangjie.Bai
parent ccb0509d85
commit c4eacbabc6

@ -1,13 +1,24 @@
import json import time
from django.core.cache import cache
from redis import Redis
from common.utils.lock import DistributedLock from common.utils.lock import DistributedLock
from common.utils import lazyproperty from common.utils import lazyproperty
from common.utils import get_logger from common.utils import get_logger
from jumpserver.const import CONFIG
logger = get_logger(__file__) logger = get_logger(__file__)
class ComputeLock(DistributedLock):
"""
需要重建缓存的时候加上该锁避免重复计算
"""
def __init__(self, key):
name = f'compute:{key}'
super().__init__(name=name)
class CacheFieldBase: class CacheFieldBase:
field_type = str field_type = str
@ -25,7 +36,7 @@ class IntegerField(CacheFieldBase):
field_type = int field_type = int
class CacheBase(type): class CacheType(type):
def __new__(cls, name, bases, attrs: dict): def __new__(cls, name, bases, attrs: dict):
to_update = {} to_update = {}
field_desc_mapper = {} field_desc_mapper = {}
@ -41,12 +52,31 @@ class CacheBase(type):
return type.__new__(cls, name, bases, attrs) return type.__new__(cls, name, bases, attrs)
class Cache(metaclass=CacheBase): class Cache(metaclass=CacheType):
field_desc_mapper: dict field_desc_mapper: dict
timeout = None timeout = None
def __init__(self): def __init__(self):
self._data = None self._data = None
self.redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
def __getitem__(self, item):
return self.field_desc_mapper[item]
def __contains__(self, item):
return item in self.field_desc_mapper
def get_field(self, name):
return self.field_desc_mapper[name]
@property
def fields(self):
return self.field_desc_mapper.values()
@property
def field_names(self):
names = self.field_desc_mapper.keys()
return names
@lazyproperty @lazyproperty
def key_suffix(self): def key_suffix(self):
@ -64,91 +94,75 @@ class Cache(metaclass=CacheBase):
@property @property
def data(self): def data(self):
if self._data is None: if self._data is None:
data = self.get_data() data = self.load_data_from_db()
if data is None: if not data:
# 缓存中没有数据时,去数据库获取 with ComputeLock(self.key):
self.compute_and_set_all_data() data = self.load_data_from_db()
if not data:
# 缓存中没有数据时,去数据库获取
self.init_all_values()
return self._data return self._data
def get_data(self) -> dict: def to_internal_value(self, data: dict):
data = cache.get(self.key) internal_data = {}
for k, v in data.items():
field = k.decode()
if field in self:
value = self[field].to_internal_value(v.decode())
internal_data[field] = value
else:
logger.warn(f'Cache got invalid field: '
f'key={self.key} '
f'invalid_field={field} '
f'valid_fields={self.field_names}')
return internal_data
def load_data_from_db(self) -> dict:
data = self.redis.hgetall(self.key)
logger.debug(f'Get data from cache: key={self.key} data={data}') logger.debug(f'Get data from cache: key={self.key} data={data}')
if data is not None: if data:
data = json.loads(data) data = self.to_internal_value(data)
self._data = data self._data = data
return data return data
def set_data(self, data): def save_data_to_db(self, data):
self._data = data logger.info(f'Set data to cache: key={self.key} data={data}')
to_json = json.dumps(data) self.redis.hset(self.key, mapping=data)
logger.info(f'Set data to cache: key={self.key} data={to_json} timeout={self.timeout}') self.load_data_from_db()
cache.set(self.key, to_json, timeout=self.timeout)
def compute_values(self, *fields):
field_objs = []
for field in fields:
field_objs.append(self[field])
def compute_data(self, *fields):
field_descs = []
if not fields:
field_descs = self.field_desc_mapper.values()
else:
for field in fields:
assert field in self.field_desc_mapper, f'{field} is not a valid field'
field_descs.append(self.field_desc_mapper[field])
data = { data = {
field_desc.field_name: field_desc.compute_value(self) field_obj.field_name: field_obj.compute_value(self)
for field_desc in field_descs for field_obj in field_objs
} }
return data return data
def compute_and_set_all_data(self, computed_data: dict = None): def init_all_values(self):
""" t_start = time.time()
TODO 怎样防止并发更新全部数据浪费数据库资源 logger.info(f'Start init cache: key={self.key}')
""" data = self.compute_values(*self.field_names)
uncomputed_keys = () self.save_data_to_db(data)
if computed_data: logger.info(f'End init cache: cost={time.time()-t_start} key={self.key}')
computed_keys = computed_data.keys()
all_keys = self.field_desc_mapper.keys()
uncomputed_keys = all_keys - computed_keys
else:
computed_data = {}
data = self.compute_data(*uncomputed_keys)
data.update(computed_data)
self.set_data(data)
return data return data
def refresh_part_data_with_lock(self, refresh_data):
with DistributedLock(name=f'{self.key}.refresh'):
data = self.get_data()
if data is not None:
data.update(refresh_data)
self.set_data(data)
return data
def expire_fields_with_lock(self, *fields):
with DistributedLock(name=f'{self.key}.refresh'):
data = self.get_data()
if data is not None:
logger.info(f'Expire cached fields: key={self.key} fields={fields}')
for f in fields:
data.pop(f, None)
self.set_data(data)
return data
def refresh(self, *fields): def refresh(self, *fields):
if not fields: if not fields:
# 没有指定 field 要刷新所有的值 # 没有指定 field 要刷新所有的值
self.compute_and_set_all_data() self.init_all_values()
return return
data = self.get_data() data = self.load_data_from_db()
if data is None: if not data:
# 缓存中没有数据,设置所有的值 # 缓存中没有数据,设置所有的值
self.compute_and_set_all_data() self.init_all_values()
return return
refresh_data = self.compute_data(*fields) refresh_values = self.compute_values(*fields)
if not self.refresh_part_data_with_lock(refresh_data): self.save_data_to_db(refresh_values)
# 刷新部分失败,缓存中没有数据,更新所有的值
self.compute_and_set_all_data(refresh_data)
return
def get_key_suffix(self): def get_key_suffix(self):
raise NotImplementedError raise NotImplementedError
@ -157,12 +171,13 @@ class Cache(metaclass=CacheBase):
self._data = None self._data = None
def expire(self, *fields): def expire(self, *fields):
self._data = None
if not fields: if not fields:
self._data = None
logger.info(f'Delete cached key: key={self.key}') logger.info(f'Delete cached key: key={self.key}')
cache.delete(self.key) self.redis.delete(self.key)
else: else:
self.expire_fields_with_lock(*fields) self.redis.hdel(self.key, *fields)
logger.info(f'Expire cached fields: key={self.key} fields={fields}')
class CacheValueDesc: class CacheValueDesc:
@ -185,6 +200,8 @@ class CacheValueDesc:
return value return value
def compute_value(self, instance: Cache): def compute_value(self, instance: Cache):
t_start = time.time()
logger.info(f'Start compute cache field: field={self.field_name} key={instance.key}')
if self.field_type.queryset is not None: if self.field_type.queryset is not None:
new_value = self.field_type.queryset.count() new_value = self.field_type.queryset.count()
else: else:
@ -197,5 +214,8 @@ class CacheValueDesc:
new_value = compute_func() new_value = compute_func()
new_value = self.field_type.field_type(new_value) new_value = self.field_type.field_type(new_value)
logger.info(f'Compute cache field value: key={instance.key} field={self.field_name} value={new_value}') logger.info(f'End compute cache field: cost={time.time()-t_start} field={self.field_name} value={new_value} key={instance.key}')
return new_value return new_value
def to_internal_value(self, value):
return self.field_type.field_type(value)

@ -1,39 +0,0 @@
from django.db.transaction import on_commit
from common.cache import *
from .utils import current_org, tmp_to_org
from .tasks import refresh_org_cache_task
from orgs.models import Organization
class OrgRelatedCache(Cache):
def __init__(self):
super().__init__()
self.current_org = Organization.get_instance(current_org.id)
def get_current_org(self):
"""
暴露给子类控制组织的回调
1. 在交互式环境下能控制组织
2. celery 任务下能控制组织
"""
return self.current_org
def compute_data(self, *fields):
with tmp_to_org(self.get_current_org()):
return super().compute_data(*fields)
def refresh_async(self, *fields):
"""
在事务提交之后再发送信号防止因事务的隔离性导致未获得最新的数据
"""
def func():
logger.info(f'CACHE: Send refresh task {self}.{fields}')
refresh_org_cache_task.delay(self, *fields)
on_commit(func)
def expire(self, *fields):
def func():
super(OrgRelatedCache, self).expire(*fields)
on_commit(func)

@ -1,4 +1,10 @@
from .cache import OrgRelatedCache, IntegerField from django.db.transaction import on_commit
from orgs.models import Organization
from orgs.tasks import refresh_org_cache_task
from orgs.utils import current_org, tmp_to_org
from common.cache import Cache, IntegerField
from common.utils import get_logger
from users.models import UserGroup, User from users.models import UserGroup, User
from assets.models import Node, AdminUser, SystemUser, Domain, Gateway from assets.models import Node, AdminUser, SystemUser, Domain, Gateway
from applications.models import Application from applications.models import Application
@ -6,6 +12,42 @@ from perms.models import AssetPermission, ApplicationPermission
from .models import OrganizationMember from .models import OrganizationMember
logger = get_logger(__file__)
class OrgRelatedCache(Cache):
def __init__(self):
super().__init__()
self.current_org = Organization.get_instance(current_org.id)
def get_current_org(self):
"""
暴露给子类控制组织的回调
1. 在交互式环境下能控制组织
2. celery 任务下能控制组织
"""
return self.current_org
def compute_values(self, *fields):
with tmp_to_org(self.get_current_org()):
return super().compute_values(*fields)
def refresh_async(self, *fields):
"""
在事务提交之后再发送信号防止因事务的隔离性导致未获得最新的数据
"""
def func():
logger.info(f'CACHE: Send refresh task {self}.{fields}')
refresh_org_cache_task.delay(self, *fields)
on_commit(func)
def expire(self, *fields):
def func():
super(OrgRelatedCache, self).expire(*fields)
on_commit(func)
class OrgResourceStatisticsCache(OrgRelatedCache): class OrgResourceStatisticsCache(OrgRelatedCache):
users_amount = IntegerField() users_amount = IntegerField()
groups_amount = IntegerField(queryset=UserGroup.objects) groups_amount = IntegerField(queryset=UserGroup.objects)

Loading…
Cancel
Save