import time

from channels_redis.core import RedisChannelLayer as _RedisChannelLayer

from common.utils.lock import DistributedLock
from common.utils.connection import get_redis_client
from common.utils import lazyproperty
from common.utils import get_logger

logger = get_logger(__file__)


class ComputeLock(DistributedLock):
    """
    需要重建缓存的时候加上该锁,避免重复计算
    """
    def __init__(self, key):
        name = f'compute:{key}'
        super().__init__(name=name)


class CacheFieldBase:
    field_type = str

    def __init__(self, queryset=None, compute_func_name=None):
        assert None in (queryset, compute_func_name), f'queryset and compute_func_name can only have one'
        self.compute_func_name = compute_func_name
        self.queryset = queryset


class CharField(CacheFieldBase):
    field_type = str


class IntegerField(CacheFieldBase):
    field_type = int


class CacheType(type):
    def __new__(cls, name, bases, attrs: dict):
        to_update = {}
        field_desc_mapper = {}

        for k, v in attrs.items():
            if isinstance(v, CacheFieldBase):
                desc = CacheValueDesc(k, v)
                to_update[k] = desc
                field_desc_mapper[k] = desc

        attrs.update(to_update)
        attrs['field_desc_mapper'] = field_desc_mapper
        return type.__new__(cls, name, bases, attrs)


class Cache(metaclass=CacheType):
    field_desc_mapper: dict
    timeout = None

    def __init__(self):
        self._data = None
        self.redis = get_redis_client()

    def __getitem__(self, item):
        return self.field_desc_mapper[item]

    def __contains__(self, item):
        return item in self.field_desc_mapper

    def get_field(self, name):
        return self.field_desc_mapper[name]

    @property
    def fields(self):
        return self.field_desc_mapper.values()

    @property
    def field_names(self):
        names = self.field_desc_mapper.keys()
        return names

    @lazyproperty
    def key_suffix(self):
        return self.get_key_suffix()

    @property
    def key_prefix(self):
        clz = self.__class__
        return f'cache.{clz.__module__}.{clz.__name__}'

    @property
    def key(self):
        return f'{self.key_prefix}.{self.key_suffix}'

    @property
    def data(self):
        if self._data is None:
            data = self.load_data_from_db()
            if not data:
                with ComputeLock(self.key):
                    data = self.load_data_from_db()
                    if not data:
                        # 缓存中没有数据时,去数据库获取
                        self.init_all_values()
        return self._data

    def to_internal_value(self, data: dict):
        internal_data = {}
        for k, v in data.items():
            field = k.decode()
            if field in self:
                value = self[field].to_internal_value(v.decode())
                internal_data[field] = value
            else:
                logger.warn(f'Cache got invalid field: '
                            f'key={self.key} '
                            f'invalid_field={field} '
                            f'valid_fields={self.field_names}')
        return internal_data

    def load_data_from_db(self) -> dict:
        data = self.redis.hgetall(self.key)
        logger.debug(f'Get data from cache: key={self.key} data={data}')
        if data:
            data = self.to_internal_value(data)
            self._data = data
        return data

    def save_data_to_db(self, data):
        logger.debug(f'Set data to cache: key={self.key} data={data}')
        self.redis.hset(self.key, mapping=data)
        self.load_data_from_db()

    def compute_values(self, *fields):
        field_objs = []
        for field in fields:
            field_objs.append(self[field])

        data = {
            field_obj.field_name: field_obj.compute_value(self)
            for field_obj in field_objs
        }
        return data

    def init_all_values(self):
        t_start = time.time()
        logger.debug(f'Start init cache: key={self.key}')
        data = self.compute_values(*self.field_names)
        self.save_data_to_db(data)
        logger.debug(f'End init cache: cost={time.time()-t_start} key={self.key}')
        return data

    def refresh(self, *fields):
        if not fields:
            # 没有指定 field 要刷新所有的值
            self.init_all_values()
            return

        data = self.load_data_from_db()
        if not data:
            # 缓存中没有数据,设置所有的值
            self.init_all_values()
            return

        refresh_values = self.compute_values(*fields)
        self.save_data_to_db(refresh_values)

    def get_key_suffix(self):
        raise NotImplementedError

    def reload(self):
        self._data = None

    def expire(self, *fields):
        self._data = None
        if not fields:
            self.redis.delete(self.key)
        else:
            self.redis.hdel(self.key, *fields)
            logger.debug(f'Expire cached fields: key={self.key} fields={fields}')


class CacheValueDesc:
    def __init__(self, field_name, field_type: CacheFieldBase):
        self.field_name = field_name
        self.field_type = field_type
        self._data = None

    def __repr__(self):
        clz = self.__class__
        return f'<{clz.__name__} {self.field_name} {self.field_type}>'

    def __get__(self, instance: Cache, owner):
        if instance is None:
            return self
        if self.field_name not in instance.data:
            instance.refresh(self.field_name)
        # 防止边界情况没有值,报错
        value = instance.data.get(self.field_name)
        return value

    def compute_value(self, instance: Cache):
        t_start = time.time()
        logger.debug(f'Start compute cache field: field={self.field_name} key={instance.key}')
        if self.field_type.queryset is not None:
            new_value = self.field_type.queryset.count()
        else:
            compute_func_name = self.field_type.compute_func_name
            if not compute_func_name:
                compute_func_name = f'compute_{self.field_name}'
            compute_func = getattr(instance, compute_func_name, None)
            assert compute_func is not None, \
                f'Define `{compute_func_name}` method in {instance.__class__}'
            new_value = compute_func()

        new_value = self.field_type.field_type(new_value)
        logger.debug(f'End compute cache field: cost={time.time()-t_start} field={self.field_name} value={new_value} key={instance.key}')
        return new_value

    def to_internal_value(self, value):
        return self.field_type.field_type(value)


class RedisChannelLayer(_RedisChannelLayer):
    async def _brpop_with_clean(self, index, channel, timeout):
        cleanup_script = """
            local backed_up = redis.call('ZRANGE', ARGV[2], 0, -1, 'WITHSCORES')
            for i = #backed_up, 1, -2 do
                redis.call('ZADD', ARGV[1], backed_up[i], backed_up[i - 1])
            end
            redis.call('DEL', ARGV[2])
        """
        backup_queue = self._backup_channel_name(channel)
        async with self.connection(index) as connection:
            # 部分云厂商的 Redis 此操作会报错(不支持,比如阿里云有限制)
            try:
                await connection.eval(cleanup_script, keys=[], args=[channel, backup_queue])
            except:
                pass
            result = await connection.bzpopmin(channel, timeout=timeout)

            if result is not None:
                _, member, timestamp = result
                await connection.zadd(backup_queue, float(timestamp), member)
            else:
                member = None
            return member