From c4eacbabc6fc31dbfa9d2d8cccd35228537fc671 Mon Sep 17 00:00:00 2001 From: xinwen Date: Tue, 9 Mar 2021 13:57:58 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E7=BC=93?= =?UTF-8?q?=E5=AD=98=E6=A1=86=E6=9E=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/cache.py | 164 ++++++++++++++++++++++++------------------- apps/orgs/cache.py | 39 ---------- apps/orgs/caches.py | 44 +++++++++++- 3 files changed, 135 insertions(+), 112 deletions(-) delete mode 100644 apps/orgs/cache.py diff --git a/apps/common/cache.py b/apps/common/cache.py index 0bed4fa30..b16d1e7dd 100644 --- a/apps/common/cache.py +++ b/apps/common/cache.py @@ -1,13 +1,24 @@ -import json -from django.core.cache import cache +import time + +from redis import Redis from common.utils.lock import DistributedLock from common.utils import lazyproperty from common.utils import get_logger +from jumpserver.const import CONFIG logger = get_logger(__file__) +class ComputeLock(DistributedLock): + """ + 需要重建缓存的时候加上该锁,避免重复计算 + """ + def __init__(self, key): + name = f'compute:{key}' + super().__init__(name=name) + + class CacheFieldBase: field_type = str @@ -25,7 +36,7 @@ class IntegerField(CacheFieldBase): field_type = int -class CacheBase(type): +class CacheType(type): def __new__(cls, name, bases, attrs: dict): to_update = {} field_desc_mapper = {} @@ -41,12 +52,31 @@ class CacheBase(type): return type.__new__(cls, name, bases, attrs) -class Cache(metaclass=CacheBase): +class Cache(metaclass=CacheType): field_desc_mapper: dict timeout = None def __init__(self): self._data = None + self.redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD) + + def __getitem__(self, item): + return self.field_desc_mapper[item] + + def __contains__(self, item): + return item in self.field_desc_mapper + + def get_field(self, name): + return self.field_desc_mapper[name] + + @property + def fields(self): + return self.field_desc_mapper.values() + + @property + def field_names(self): + names = self.field_desc_mapper.keys() + return names @lazyproperty def key_suffix(self): @@ -64,91 +94,75 @@ class Cache(metaclass=CacheBase): @property def data(self): if self._data is None: - data = self.get_data() - if data is None: - # 缓存中没有数据时,去数据库获取 - self.compute_and_set_all_data() + data = self.load_data_from_db() + if not data: + with ComputeLock(self.key): + data = self.load_data_from_db() + if not data: + # 缓存中没有数据时,去数据库获取 + self.init_all_values() return self._data - def get_data(self) -> dict: - data = cache.get(self.key) + def to_internal_value(self, data: dict): + internal_data = {} + for k, v in data.items(): + field = k.decode() + if field in self: + value = self[field].to_internal_value(v.decode()) + internal_data[field] = value + else: + logger.warn(f'Cache got invalid field: ' + f'key={self.key} ' + f'invalid_field={field} ' + f'valid_fields={self.field_names}') + return internal_data + + def load_data_from_db(self) -> dict: + data = self.redis.hgetall(self.key) logger.debug(f'Get data from cache: key={self.key} data={data}') - if data is not None: - data = json.loads(data) + if data: + data = self.to_internal_value(data) self._data = data return data - def set_data(self, data): - self._data = data - to_json = json.dumps(data) - logger.info(f'Set data to cache: key={self.key} data={to_json} timeout={self.timeout}') - cache.set(self.key, to_json, timeout=self.timeout) + def save_data_to_db(self, data): + logger.info(f'Set data to cache: key={self.key} data={data}') + self.redis.hset(self.key, mapping=data) + self.load_data_from_db() + + def compute_values(self, *fields): + field_objs = [] + for field in fields: + field_objs.append(self[field]) - def compute_data(self, *fields): - field_descs = [] - if not fields: - field_descs = self.field_desc_mapper.values() - else: - for field in fields: - assert field in self.field_desc_mapper, f'{field} is not a valid field' - field_descs.append(self.field_desc_mapper[field]) data = { - field_desc.field_name: field_desc.compute_value(self) - for field_desc in field_descs + field_obj.field_name: field_obj.compute_value(self) + for field_obj in field_objs } return data - def compute_and_set_all_data(self, computed_data: dict = None): - """ - TODO 怎样防止并发更新全部数据,浪费数据库资源 - """ - uncomputed_keys = () - if computed_data: - computed_keys = computed_data.keys() - all_keys = self.field_desc_mapper.keys() - uncomputed_keys = all_keys - computed_keys - else: - computed_data = {} - data = self.compute_data(*uncomputed_keys) - data.update(computed_data) - self.set_data(data) + def init_all_values(self): + t_start = time.time() + logger.info(f'Start init cache: key={self.key}') + data = self.compute_values(*self.field_names) + self.save_data_to_db(data) + logger.info(f'End init cache: cost={time.time()-t_start} key={self.key}') return data - def refresh_part_data_with_lock(self, refresh_data): - with DistributedLock(name=f'{self.key}.refresh'): - data = self.get_data() - if data is not None: - data.update(refresh_data) - self.set_data(data) - return data - - def expire_fields_with_lock(self, *fields): - with DistributedLock(name=f'{self.key}.refresh'): - data = self.get_data() - if data is not None: - logger.info(f'Expire cached fields: key={self.key} fields={fields}') - for f in fields: - data.pop(f, None) - self.set_data(data) - return data - def refresh(self, *fields): if not fields: # 没有指定 field 要刷新所有的值 - self.compute_and_set_all_data() + self.init_all_values() return - data = self.get_data() - if data is None: + data = self.load_data_from_db() + if not data: # 缓存中没有数据,设置所有的值 - self.compute_and_set_all_data() + self.init_all_values() return - refresh_data = self.compute_data(*fields) - if not self.refresh_part_data_with_lock(refresh_data): - # 刷新部分失败,缓存中没有数据,更新所有的值 - self.compute_and_set_all_data(refresh_data) - return + refresh_values = self.compute_values(*fields) + self.save_data_to_db(refresh_values) def get_key_suffix(self): raise NotImplementedError @@ -157,12 +171,13 @@ class Cache(metaclass=CacheBase): self._data = None def expire(self, *fields): + self._data = None if not fields: - self._data = None logger.info(f'Delete cached key: key={self.key}') - cache.delete(self.key) + self.redis.delete(self.key) else: - self.expire_fields_with_lock(*fields) + self.redis.hdel(self.key, *fields) + logger.info(f'Expire cached fields: key={self.key} fields={fields}') class CacheValueDesc: @@ -185,6 +200,8 @@ class CacheValueDesc: return value def compute_value(self, instance: Cache): + t_start = time.time() + logger.info(f'Start compute cache field: field={self.field_name} key={instance.key}') if self.field_type.queryset is not None: new_value = self.field_type.queryset.count() else: @@ -197,5 +214,8 @@ class CacheValueDesc: new_value = compute_func() new_value = self.field_type.field_type(new_value) - logger.info(f'Compute cache field value: key={instance.key} field={self.field_name} value={new_value}') + logger.info(f'End compute cache field: cost={time.time()-t_start} field={self.field_name} value={new_value} key={instance.key}') return new_value + + def to_internal_value(self, value): + return self.field_type.field_type(value) diff --git a/apps/orgs/cache.py b/apps/orgs/cache.py deleted file mode 100644 index f0a9cb83e..000000000 --- a/apps/orgs/cache.py +++ /dev/null @@ -1,39 +0,0 @@ -from django.db.transaction import on_commit - -from common.cache import * -from .utils import current_org, tmp_to_org -from .tasks import refresh_org_cache_task -from orgs.models import Organization - - -class OrgRelatedCache(Cache): - - def __init__(self): - super().__init__() - self.current_org = Organization.get_instance(current_org.id) - - def get_current_org(self): - """ - 暴露给子类控制组织的回调 - 1. 在交互式环境下能控制组织 - 2. 在 celery 任务下能控制组织 - """ - return self.current_org - - def compute_data(self, *fields): - with tmp_to_org(self.get_current_org()): - return super().compute_data(*fields) - - def refresh_async(self, *fields): - """ - 在事务提交之后再发送信号,防止因事务的隔离性导致未获得最新的数据 - """ - def func(): - logger.info(f'CACHE: Send refresh task {self}.{fields}') - refresh_org_cache_task.delay(self, *fields) - on_commit(func) - - def expire(self, *fields): - def func(): - super(OrgRelatedCache, self).expire(*fields) - on_commit(func) diff --git a/apps/orgs/caches.py b/apps/orgs/caches.py index b7e086d6a..9c29659e4 100644 --- a/apps/orgs/caches.py +++ b/apps/orgs/caches.py @@ -1,4 +1,10 @@ -from .cache import OrgRelatedCache, IntegerField +from django.db.transaction import on_commit +from orgs.models import Organization +from orgs.tasks import refresh_org_cache_task +from orgs.utils import current_org, tmp_to_org + +from common.cache import Cache, IntegerField +from common.utils import get_logger from users.models import UserGroup, User from assets.models import Node, AdminUser, SystemUser, Domain, Gateway from applications.models import Application @@ -6,6 +12,42 @@ from perms.models import AssetPermission, ApplicationPermission from .models import OrganizationMember +logger = get_logger(__file__) + + +class OrgRelatedCache(Cache): + + def __init__(self): + super().__init__() + self.current_org = Organization.get_instance(current_org.id) + + def get_current_org(self): + """ + 暴露给子类控制组织的回调 + 1. 在交互式环境下能控制组织 + 2. 在 celery 任务下能控制组织 + """ + return self.current_org + + def compute_values(self, *fields): + with tmp_to_org(self.get_current_org()): + return super().compute_values(*fields) + + def refresh_async(self, *fields): + """ + 在事务提交之后再发送信号,防止因事务的隔离性导致未获得最新的数据 + """ + def func(): + logger.info(f'CACHE: Send refresh task {self}.{fields}') + refresh_org_cache_task.delay(self, *fields) + on_commit(func) + + def expire(self, *fields): + def func(): + super(OrgRelatedCache, self).expire(*fields) + on_commit(func) + + class OrgResourceStatisticsCache(OrgRelatedCache): users_amount = IntegerField() groups_amount = IntegerField(queryset=UserGroup.objects)