from django.core.validators import MinValueValidator, MaxValueValidator
from django.db import models
from django.utils.translation import gettext_lazy as _

from assets.models import Asset
from common.db.fields import PortField
from common.db.models import JMSBaseModel
from common.utils.ip import contains_ip


class Endpoint(JMSBaseModel):
    name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True)
    host = models.CharField(max_length=256, blank=True, verbose_name=_('Host'))
    # value=0 表示 disabled
    https_port = PortField(default=443, verbose_name=_('HTTPS port'))
    http_port = PortField(default=80, verbose_name=_('HTTP port'))
    ssh_port = PortField(default=2222, verbose_name=_('SSH port'))
    rdp_port = PortField(default=3389, verbose_name=_('RDP port'))
    mysql_port = PortField(default=33061, verbose_name=_('MySQL port'))
    mariadb_port = PortField(default=33062, verbose_name=_('MariaDB port'))
    postgresql_port = PortField(default=54320, verbose_name=_('PostgreSQL port'))
    redis_port = PortField(default=63790, verbose_name=_('Redis port'))
    sqlserver_port = PortField(default=14330, verbose_name=_('SQLServer port'))

    comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))

    default_id = '00000000-0000-0000-0000-000000000001'

    class Meta:
        verbose_name = _('Endpoint')
        ordering = ('name',)

    def __str__(self):
        return self.name

    def get_port(self, target_instance, protocol):
        from terminal.utils import db_port_manager
        from assets.const import DatabaseTypes, Protocol

        if isinstance(target_instance, Asset) and \
                target_instance.is_type(DatabaseTypes.ORACLE) and \
                protocol == Protocol.oracle:
            port = db_port_manager.get_port_by_db(target_instance)
        else:
            if protocol in [Protocol.sftp, Protocol.telnet]:
                protocol = Protocol.ssh
            port = getattr(self, f'{protocol}_port', 0)
        return port

    def is_default(self):
        return str(self.id) == self.default_id

    def delete(self, using=None, keep_parents=False):
        if self.is_default():
            return
        return super().delete(using, keep_parents)

    def is_valid_for(self, target_instance, protocol):
        if self.is_default():
            return True
        if self.get_port(target_instance, protocol) != 0:
            return True
        return False

    @classmethod
    def get_or_create_default(cls, request=None):
        data = {
            'id': cls.default_id,
            'name': 'Default',
            'host': '',
            'https_port': 0,
            'http_port': 0,
        }
        endpoint, created = cls.objects.get_or_create(id=cls.default_id, defaults=data)
        return endpoint

    @classmethod
    def handle_endpoint_host(cls, endpoint, request=None):
        if not endpoint.host and request:
            # 动态添加 current request host
            host_port = request.get_host()
            # IPv6
            if host_port.startswith('['):
                host = host_port.split(']:')[0].rstrip(']') + ']'
            else:
                host = host_port.split(':')[0]
            endpoint.host = host
        return endpoint

    @classmethod
    def match_by_instance_label(cls, instance, protocol, request=None):
        from assets.models import Asset
        from terminal.models import Session
        if isinstance(instance, Session):
            instance = instance.get_asset()
        if not isinstance(instance, Asset):
            return None
        values = instance.labels.filter(label__name='endpoint').values_list('label__value', flat=True)
        if not values:
            return None
        endpoints = cls.objects.filter(name__in=list(values)).order_by('-date_updated')
        for endpoint in endpoints:
            if endpoint.is_valid_for(instance, protocol):
                endpoint = cls.handle_endpoint_host(endpoint, request)
                return endpoint


class EndpointRule(JMSBaseModel):
    name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True)
    ip_group = models.JSONField(default=list, verbose_name=_('IP group'))
    priority = models.IntegerField(
        verbose_name=_("Priority"), validators=[MinValueValidator(1), MaxValueValidator(100)],
        unique=True, help_text=_("1-100, the lower the value will be match first"),
    )
    endpoint = models.ForeignKey(
        'terminal.Endpoint', null=True, blank=True, related_name='rules',
        on_delete=models.SET_NULL, verbose_name=_("Endpoint"),
    )
    comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
    is_active = models.BooleanField(default=True, verbose_name=_('Is active'))

    class Meta:
        verbose_name = _('Endpoint rule')
        ordering = ('priority', 'is_active', 'name')

    def __str__(self):
        return f'{self.name}({self.priority})'

    @classmethod
    def match(cls, target_instance, target_ip, protocol):
        for endpoint_rule in cls.objects.prefetch_related('endpoint').filter(is_active=True):
            if not contains_ip(target_ip, endpoint_rule.ip_group):
                continue
            if not endpoint_rule.endpoint:
                continue
            if not endpoint_rule.endpoint.is_valid_for(target_instance, protocol):
                continue
            return endpoint_rule

    @classmethod
    def match_endpoint(cls, target_instance, target_ip, protocol, request=None):
        endpoint_rule = cls.match(target_instance, target_ip, protocol)
        if endpoint_rule:
            endpoint = endpoint_rule.endpoint
        else:
            endpoint = Endpoint.get_or_create_default(request)
        endpoint = Endpoint.handle_endpoint_host(endpoint, request)
        return endpoint