# -*- coding: utf-8 -*-
#
import inspect
from functools import partial
import time
from typing import Callable

from django.utils.http import urlencode
from django.core.cache import cache
from django.conf import settings
from django.contrib import auth
from django.utils.translation import ugettext as _
from rest_framework.request import Request
from django.contrib.auth import (
    BACKEND_SESSION_KEY, load_backend,
    PermissionDenied, user_login_failed, _clean_credentials,
)
from django.core.exceptions import ImproperlyConfigured
from django.shortcuts import reverse, redirect, get_object_or_404

from common.utils import get_request_ip, get_logger, bulk_get, FlashMessageUtil
from acls.models import LoginACL
from users.models import User
from users.utils import LoginBlockUtil, MFABlockUtils, LoginIpBlockUtil
from . import errors
from .signals import post_auth_success, post_auth_failed

logger = get_logger(__name__)


def _get_backends(return_tuples=False):
    backends = []
    for backend_path in settings.AUTHENTICATION_BACKENDS:
        backend = load_backend(backend_path)
        # 检查 backend 是否启用
        if not backend.is_enabled():
            continue
        backends.append((backend, backend_path) if return_tuples else backend)
    if not backends:
        raise ImproperlyConfigured(
            'No authentication backends have been defined. Does '
            'AUTHENTICATION_BACKENDS contain anything?'
        )
    return backends


auth._get_backends = _get_backends


def authenticate(request=None, **credentials):
    """
    If the given credentials are valid, return a User object.
    之所以 hack 这个 auticate
    """
    username = credentials.get('username')

    for backend, backend_path in _get_backends(return_tuples=True):
        # 检查用户名是否允许认证 (预先检查,不浪费认证时间)
        logger.info('Try using auth backend: {}'.format(str(backend)))
        if not backend.username_allow_authenticate(username):
            continue

        # 原生
        backend_signature = inspect.signature(backend.authenticate)
        try:
            backend_signature.bind(request, **credentials)
        except TypeError:
            # This backend doesn't accept these credentials as arguments. Try the next one.
            continue
        try:
            user = backend.authenticate(request, **credentials)
        except PermissionDenied:
            # This backend says to stop in our tracks - this user should not be allowed in at all.
            break
        if user is None:
            continue

        # 检查用户是否允许认证
        if not backend.user_allow_authenticate(user):
            continue

        # Annotate the user object with the path of the backend.
        user.backend = backend_path
        return user

    # The credentials supplied are invalid to all backends, fire signal
    user_login_failed.send(sender=__name__, credentials=_clean_credentials(credentials), request=request)


auth.authenticate = authenticate


class CommonMixin:
    request: Request

    def get_request_ip(self):
        ip = ''
        if hasattr(self.request, 'data'):
            ip = self.request.data.get('remote_addr', '')
        ip = ip or get_request_ip(self.request)
        return ip

    def raise_credential_error(self, error):
        raise self.partial_credential_error(error=error)

    def _set_partial_credential_error(self, username, ip, request):
        self.partial_credential_error = partial(
            errors.CredentialError, username=username,
            ip=ip, request=request
        )

    def get_user_from_session(self):
        if self.request.session.is_empty():
            raise errors.SessionEmptyError()

        if all([
            self.request.user,
            not self.request.user.is_anonymous,
            BACKEND_SESSION_KEY in self.request.session
        ]):
            user = self.request.user
            user.backend = self.request.session[BACKEND_SESSION_KEY]
            return user

        user_id = self.request.session.get('user_id')
        auth_password = self.request.session.get('auth_password')
        auth_expired_at = self.request.session.get('auth_password_expired_at')
        auth_expired = auth_expired_at < time.time() if auth_expired_at else False

        if not user_id or not auth_password or auth_expired:
            raise errors.SessionEmptyError()

        user = get_object_or_404(User, pk=user_id)
        user.backend = self.request.session.get("auth_backend")
        return user

    def get_auth_data(self, data):
        request = self.request

        items = ['username', 'password', 'challenge', 'public_key', 'auto_login']
        username, password, challenge, public_key, auto_login = bulk_get(data, items, default='')
        ip = self.get_request_ip()
        self._set_partial_credential_error(username=username, ip=ip, request=request)
        password = password + challenge.strip()
        return username, password, public_key, ip, auto_login


class AuthPreCheckMixin:
    request: Request
    get_request_ip: Callable
    raise_credential_error: Callable

    def _check_is_block(self, username, raise_exception=True):
        ip = self.get_request_ip()

        if LoginIpBlockUtil(ip).is_block():
            raise errors.BlockGlobalIpLoginError(username=username, ip=ip)

        is_block = LoginBlockUtil(username, ip).is_block()
        if not is_block:
            return
        logger.warn('Ip was blocked' + ': ' + username + ':' + ip)
        exception = errors.BlockLoginError(username=username, ip=ip)
        if raise_exception:
            raise errors.BlockLoginError(username=username, ip=ip)
        else:
            return exception

    def check_is_block(self, raise_exception=True):
        if hasattr(self.request, 'data'):
            username = self.request.data.get("username")
        else:
            username = self.request.POST.get("username")

        self._check_is_block(username, raise_exception)

    def _check_only_allow_exists_user_auth(self, username):
        # 仅允许预先存在的用户认证
        if not settings.ONLY_ALLOW_EXIST_USER_AUTH:
            return

        exist = User.objects.filter(username=username).exists()
        if not exist:
            logger.error(f"Only allow exist user auth, login failed: {username}")
            self.raise_credential_error(errors.reason_user_not_exist)


class MFAMixin:
    request: Request
    get_user_from_session: Callable
    get_request_ip: Callable

    def _check_if_no_active_mfa(self, user):
        active_mfa_mapper = user.active_mfa_backends_mapper
        if not active_mfa_mapper:
            url = reverse('authentication:user-otp-enable-start')
            raise errors.MFAUnsetError(user, self.request, url)

    def _check_login_page_mfa_if_need(self, user):
        if not settings.SECURITY_MFA_IN_LOGIN_PAGE:
            return
        if not user.active_mfa_backends:
            return

        request = self.request
        data = request.data if hasattr(request, 'data') else request.POST
        code = data.get('code')
        mfa_type = data.get('mfa_type', 'otp')

        if not code:
            return
        self._do_check_user_mfa(code, mfa_type, user=user)

    def check_user_mfa_if_need(self, user):
        if self.request.session.get('auth_mfa'):
            return
        if not user.mfa_enabled:
            return

        active_mfa_names = user.active_mfa_backends_mapper.keys()
        raise errors.MFARequiredError(mfa_types=tuple(active_mfa_names))

    def mark_mfa_ok(self, mfa_type):
        self.request.session['auth_mfa'] = 1
        self.request.session['auth_mfa_time'] = time.time()
        self.request.session['auth_mfa_required'] = 0
        self.request.session['auth_mfa_type'] = mfa_type

    def clean_mfa_mark(self):
        keys = ['auth_mfa', 'auth_mfa_time', 'auth_mfa_required', 'auth_mfa_type']
        for k in keys:
            self.request.session.pop(k, '')

    def check_mfa_is_block(self, username, ip, raise_exception=True):
        blocked = MFABlockUtils(username, ip).is_block()
        if not blocked:
            return
        logger.warn('Ip was blocked' + ': ' + username + ':' + ip)
        exception = errors.BlockMFAError(username=username, request=self.request, ip=ip)
        if raise_exception:
            raise exception
        else:
            return exception

    def _do_check_user_mfa(self, code, mfa_type, user=None):
        user = user if user else self.get_user_from_session()
        if not user.mfa_enabled:
            return

        # 监测 MFA 是不是屏蔽了
        ip = self.get_request_ip()
        self.check_mfa_is_block(user.username, ip)

        ok = False
        mfa_backend = user.get_mfa_backend_by_type(mfa_type)
        backend_error = _('The MFA type ({}) is not enabled')
        if not mfa_backend:
            msg = backend_error.format(mfa_type)
        elif not mfa_backend.is_active():
            msg = backend_error.format(mfa_backend.display_name)
        else:
            ok, msg = mfa_backend.check_code(code)

        if ok:
            self.mark_mfa_ok(mfa_type)
            return

        raise errors.MFAFailedError(
            username=user.username,
            request=self.request,
            ip=ip, mfa_type=mfa_type,
            error=msg
        )

    @staticmethod
    def get_user_mfa_context(user=None):
        mfa_backends = User.get_user_mfa_backends(user)
        return {'mfa_backends': mfa_backends}

    @staticmethod
    def incr_mfa_failed_time(username, ip):
        util = MFABlockUtils(username, ip)
        util.incr_failed_count()


class AuthPostCheckMixin:
    @classmethod
    def generate_reset_password_url_with_flash_msg(cls, user, message):
        reset_passwd_url = reverse('authentication:reset-password')
        query_str = urlencode({
            'token': user.generate_reset_token()
        })
        reset_passwd_url = f'{reset_passwd_url}?{query_str}'

        message_data = {
            'title': _('Please change your password'),
            'message': message,
            'interval': 3,
            'redirect_url': reset_passwd_url,
        }
        return FlashMessageUtil.gen_message_url(message_data)

    @classmethod
    def _check_passwd_is_too_simple(cls, user: User, password):
        if user.is_superuser and password == 'admin':
            message = _('Your password is too simple, please change it for security')
            url = cls.generate_reset_password_url_with_flash_msg(user, message=message)
            raise errors.PasswordTooSimple(url)

    @classmethod
    def _check_passwd_need_update(cls, user: User):
        if user.need_update_password:
            message = _('You should to change your password before login')
            url = cls.generate_reset_password_url_with_flash_msg(user, message)
            raise errors.PasswordNeedUpdate(url)

    @classmethod
    def _check_password_require_reset_or_not(cls, user: User):
        if user.password_has_expired:
            message = _('Your password has expired, please reset before logging in')
            url = cls.generate_reset_password_url_with_flash_msg(user, message)
            raise errors.PasswordRequireResetError(url)


class AuthACLMixin:
    request: Request
    get_request_ip: Callable

    def _check_login_acl(self, user, ip):
        # ACL 限制用户登录
        is_allowed, limit_type = LoginACL.allow_user_to_login(user, ip)
        if is_allowed:
            return
        if limit_type == 'ip':
            raise errors.LoginIPNotAllowed(username=user.username, request=self.request)
        elif limit_type == 'time':
            raise errors.TimePeriodNotAllowed(username=user.username, request=self.request)

    def get_ticket(self):
        from tickets.models import ApplyLoginTicket
        ticket_id = self.request.session.get("auth_ticket_id")
        logger.debug('Login confirm ticket id: {}'.format(ticket_id))
        if not ticket_id:
            ticket = None
        else:
            ticket = ApplyLoginTicket.all().filter(id=ticket_id).first()
        return ticket

    def get_ticket_or_create(self, confirm_setting):
        ticket = self.get_ticket()
        if not ticket or ticket.is_status(ticket.Status.closed):
            ticket = confirm_setting.create_confirm_ticket(self.request)
            self.request.session['auth_ticket_id'] = str(ticket.id)
        return ticket

    def check_user_login_confirm(self):
        ticket = self.get_ticket()
        if not ticket:
            raise errors.LoginConfirmOtherError('', "Not found")

        if ticket.is_status(ticket.Status.open):
            raise errors.LoginConfirmWaitError(ticket.id)
        elif ticket.is_state(ticket.State.approved):
            self.request.session["auth_confirm"] = "1"
            return
        elif ticket.is_state(ticket.State.rejected):
            raise errors.LoginConfirmOtherError(
                ticket.id, ticket.get_state_display()
            )
        elif ticket.is_state(ticket.State.closed):
            raise errors.LoginConfirmOtherError(
                ticket.id, ticket.get_state_display()
            )
        else:
            raise errors.LoginConfirmOtherError(
                ticket.id, ticket.get_status_display()
            )

    def check_user_login_confirm_if_need(self, user):
        ip = self.get_request_ip()
        is_allowed, confirm_setting = LoginACL.allow_user_confirm_if_need(user, ip)
        if self.request.session.get('auth_confirm') or not is_allowed:
            return
        self.get_ticket_or_create(confirm_setting)
        self.check_user_login_confirm()


class AuthMixin(CommonMixin, AuthPreCheckMixin, AuthACLMixin, MFAMixin, AuthPostCheckMixin):
    request = None
    partial_credential_error = None

    key_prefix_captcha = "_LOGIN_INVALID_{}"

    def _check_auth_user_is_valid(self, username, password, public_key):
        user = authenticate(
            self.request, username=username,
            password=password, public_key=public_key
        )
        if not user:
            self.raise_credential_error(errors.reason_password_failed)

        self.request.session['auth_backend'] = getattr(user, 'backend', settings.AUTH_BACKEND_MODEL)

        if user.is_expired:
            self.raise_credential_error(errors.reason_user_expired)
        elif not user.is_active:
            self.raise_credential_error(errors.reason_user_inactive)
        return user

    def set_login_failed_mark(self):
        ip = self.get_request_ip()
        cache.set(self.key_prefix_captcha.format(ip), 1, 3600)

    def check_is_need_captcha(self):
        # 最近有登录失败时需要填写验证码
        ip = get_request_ip(self.request)
        need = cache.get(self.key_prefix_captcha.format(ip))
        return need

    def check_user_auth(self, valid_data=None):
        # pre check
        self.check_is_block()
        username, password, public_key, ip, auto_login = self.get_auth_data(valid_data)
        self._check_only_allow_exists_user_auth(username)

        # check auth
        user = self._check_auth_user_is_valid(username, password, public_key)

        # 校验login-acl规则
        self._check_login_acl(user, ip)

        # post check
        self._check_password_require_reset_or_not(user)
        self._check_passwd_is_too_simple(user, password)
        self._check_passwd_need_update(user)

        # 校验login-mfa, 如果登录页面上显示 mfa 的话
        self._check_login_page_mfa_if_need(user)

        # 标记密码验证成功
        self.mark_password_ok(user=user, auto_login=auto_login)
        LoginBlockUtil(user.username, ip).clean_failed_count()
        LoginIpBlockUtil(ip).clean_block_if_need()
        return user

    def mark_password_ok(self, user, auto_login=False):
        request = self.request
        request.session['auth_password'] = 1
        request.session['auth_password_expired_at'] = time.time() + settings.AUTH_EXPIRED_SECONDS
        request.session['user_id'] = str(user.id)
        request.session['auto_login'] = auto_login
        request.session['auth_backend'] = getattr(user, 'backend', settings.AUTH_BACKEND_MODEL)

    def check_oauth2_auth(self, user: User, auth_backend):
        ip = self.get_request_ip()
        request = self.request

        self._set_partial_credential_error(user.username, ip, request)

        if user.is_expired:
            self.raise_credential_error(errors.reason_user_expired)
        elif not user.is_active:
            self.raise_credential_error(errors.reason_user_inactive)

        self._check_is_block(user.username)
        self._check_login_acl(user, ip)

        LoginBlockUtil(user.username, ip).clean_failed_count()
        LoginIpBlockUtil(ip).clean_block_if_need()
        MFABlockUtils(user.username, ip).clean_failed_count()

        self.mark_password_ok(user, False)
        return user

    def get_user_or_auth(self, valid_data):
        request = self.request
        if request.session.get('auth_password'):
            return self.get_user_from_session()
        else:
            return self.check_user_auth(valid_data)

    def clear_auth_mark(self):
        keys = ['auth_password', 'user_id', 'auth_confirm', 'auth_ticket_id']
        for k in keys:
            self.request.session.pop(k, '')

    def send_auth_signal(self, success=True, user=None, username='', reason=''):
        if success:
            post_auth_success.send(
                sender=self.__class__, user=user, request=self.request
            )
        else:
            post_auth_failed.send(
                sender=self.__class__, username=username,
                request=self.request, reason=reason
            )

    def redirect_to_guard_view(self):
        guard_url = reverse('authentication:login-guard')
        args = self.request.META.get('QUERY_STRING', '')
        if args:
            guard_url = "%s?%s" % (guard_url, args)
        return redirect(guard_url)