mirror of https://github.com/jumpserver/jumpserver
parent
92d369aaca
commit
815973fb63
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,310 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
import uuid
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import AbstractUser
|
||||
from django.db import models
|
||||
from django.shortcuts import reverse
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
|
||||
from common.db import fields, models as jms_models
|
||||
from common.utils import (
|
||||
date_expired_default,
|
||||
get_logger,
|
||||
)
|
||||
from labels.mixins import LabeledMixin
|
||||
from orgs.utils import current_org
|
||||
from ._auth import AuthMixin, MFAMixin
|
||||
from ._json import JSONFilterMixin
|
||||
from ._role import RoleMixin
|
||||
from ._source import SourceMixin, Source
|
||||
from ._token import TokenMixin
|
||||
|
||||
logger = get_logger(__file__)
|
||||
__all__ = [
|
||||
"User",
|
||||
"UserPasswordHistory",
|
||||
"MFAMixin"
|
||||
]
|
||||
|
||||
|
||||
class User(
|
||||
AuthMixin,
|
||||
SourceMixin,
|
||||
TokenMixin,
|
||||
RoleMixin,
|
||||
MFAMixin,
|
||||
LabeledMixin,
|
||||
JSONFilterMixin,
|
||||
AbstractUser,
|
||||
):
|
||||
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
|
||||
username = models.CharField(max_length=128, unique=True, verbose_name=_("Username"))
|
||||
name = models.CharField(max_length=128, verbose_name=_("Name"))
|
||||
email = models.EmailField(max_length=128, unique=True, verbose_name=_("Email"))
|
||||
groups = models.ManyToManyField(
|
||||
"users.UserGroup",
|
||||
related_name="users",
|
||||
blank=True,
|
||||
verbose_name=_("User group"),
|
||||
)
|
||||
role = models.CharField(
|
||||
default="User", max_length=10, blank=True, verbose_name=_("Role")
|
||||
)
|
||||
is_service_account = models.BooleanField(
|
||||
default=False, verbose_name=_("Is service account")
|
||||
)
|
||||
avatar = models.ImageField(upload_to="avatar", null=True, verbose_name=_("Avatar"))
|
||||
wechat = fields.EncryptCharField(
|
||||
max_length=128, blank=True, verbose_name=_("Wechat")
|
||||
)
|
||||
phone = fields.EncryptCharField(
|
||||
max_length=128, blank=True, null=True, verbose_name=_("Phone")
|
||||
)
|
||||
mfa_level = models.SmallIntegerField(
|
||||
default=0, choices=MFAMixin.MFA_LEVEL_CHOICES, verbose_name=_("MFA")
|
||||
)
|
||||
otp_secret_key = fields.EncryptCharField(
|
||||
max_length=128, blank=True, null=True, verbose_name=_("OTP secret key")
|
||||
)
|
||||
# Todo: Auto generate key, let user download
|
||||
private_key = fields.EncryptTextField(
|
||||
blank=True, null=True, verbose_name=_("Private key")
|
||||
)
|
||||
public_key = fields.EncryptTextField(
|
||||
blank=True, null=True, verbose_name=_("Public key")
|
||||
)
|
||||
comment = models.TextField(blank=True, null=True, verbose_name=_("Comment"))
|
||||
is_first_login = models.BooleanField(default=True, verbose_name=_("Is first login"))
|
||||
date_expired = models.DateTimeField(
|
||||
default=date_expired_default,
|
||||
blank=True,
|
||||
null=True,
|
||||
db_index=True,
|
||||
verbose_name=_("Date expired"),
|
||||
)
|
||||
created_by = models.CharField(
|
||||
max_length=30, default="", blank=True, verbose_name=_("Created by")
|
||||
)
|
||||
updated_by = models.CharField(
|
||||
max_length=30, default="", blank=True, verbose_name=_("Updated by")
|
||||
)
|
||||
date_password_last_updated = models.DateTimeField(
|
||||
auto_now_add=True,
|
||||
blank=True,
|
||||
null=True,
|
||||
verbose_name=_("Date password last updated"),
|
||||
)
|
||||
need_update_password = models.BooleanField(
|
||||
default=False, verbose_name=_("Need update password")
|
||||
)
|
||||
source = models.CharField(
|
||||
max_length=30,
|
||||
default=Source.local,
|
||||
choices=Source.choices,
|
||||
verbose_name=_("Source"),
|
||||
)
|
||||
wecom_id = models.CharField(
|
||||
null=True, default=None, max_length=128, verbose_name=_("WeCom")
|
||||
)
|
||||
dingtalk_id = models.CharField(
|
||||
null=True, default=None, max_length=128, verbose_name=_("DingTalk")
|
||||
)
|
||||
feishu_id = models.CharField(
|
||||
null=True, default=None, max_length=128, verbose_name=_("FeiShu")
|
||||
)
|
||||
lark_id = models.CharField(
|
||||
null=True, default=None, max_length=128, verbose_name="Lark"
|
||||
)
|
||||
slack_id = models.CharField(
|
||||
null=True, default=None, max_length=128, verbose_name=_("Slack")
|
||||
)
|
||||
date_api_key_last_used = models.DateTimeField(
|
||||
null=True, blank=True, verbose_name=_("Date api key used")
|
||||
)
|
||||
date_updated = models.DateTimeField(auto_now=True, verbose_name=_("Date updated"))
|
||||
DATE_EXPIRED_WARNING_DAYS = 5
|
||||
|
||||
def __str__(self):
|
||||
return "{0.name}({0.username})".format(self)
|
||||
|
||||
@classmethod
|
||||
def get_queryset(cls):
|
||||
queryset = cls.objects.all()
|
||||
if not current_org.is_root():
|
||||
queryset = current_org.get_members()
|
||||
queryset = queryset.exclude(is_service_account=True)
|
||||
return queryset
|
||||
|
||||
@property
|
||||
def secret_key(self):
|
||||
instance = self.preferences.filter(name="secret_key").first()
|
||||
if not instance:
|
||||
return
|
||||
return instance.decrypt_value
|
||||
|
||||
@property
|
||||
def receive_backends(self):
|
||||
try:
|
||||
return self.user_msg_subscription.receive_backends
|
||||
except:
|
||||
return []
|
||||
|
||||
@property
|
||||
def is_otp_secret_key_bound(self):
|
||||
return bool(self.otp_secret_key)
|
||||
|
||||
def get_absolute_url(self):
|
||||
return reverse("users:user-detail", args=(self.id,))
|
||||
|
||||
@property
|
||||
def is_expired(self):
|
||||
if self.date_expired and self.date_expired < timezone.now():
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_password_authenticate(self):
|
||||
cas = self.Source.cas
|
||||
saml2 = self.Source.saml2
|
||||
oauth2 = self.Source.oauth2
|
||||
return self.source not in [cas, saml2, oauth2]
|
||||
|
||||
@property
|
||||
def expired_remain_days(self):
|
||||
date_remain = self.date_expired - timezone.now()
|
||||
return date_remain.days
|
||||
|
||||
@property
|
||||
def will_expired(self):
|
||||
if 0 <= self.expired_remain_days <= self.DATE_EXPIRED_WARNING_DAYS:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def lang(self):
|
||||
return self.preference.get_value("lang")
|
||||
|
||||
@lang.setter
|
||||
def lang(self, value):
|
||||
self.preference.set_value('lang', value)
|
||||
|
||||
@property
|
||||
def preference(self):
|
||||
from users.models.preference import PreferenceManager
|
||||
return PreferenceManager(self)
|
||||
|
||||
@property
|
||||
def is_valid(self):
|
||||
if self.is_active and not self.is_expired:
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_required_attr_if_need(self):
|
||||
if not self.name:
|
||||
self.name = self.username
|
||||
if not self.email or "@" not in self.email:
|
||||
email = "{}@{}".format(self.username, settings.EMAIL_SUFFIX)
|
||||
if "@" in self.username:
|
||||
email = self.username
|
||||
self.email = email
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
self.set_required_attr_if_need()
|
||||
if self.username == "admin":
|
||||
self.role = "Admin"
|
||||
self.is_active = True
|
||||
return super().save(*args, **kwargs)
|
||||
|
||||
def is_member_of(self, user_group):
|
||||
if user_group in self.groups.all():
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_avatar(self, f):
|
||||
self.avatar.save(self.username, f)
|
||||
|
||||
@classmethod
|
||||
def get_avatar_url(cls, username):
|
||||
user_default = settings.STATIC_URL + "img/avatar/user.png"
|
||||
return user_default
|
||||
|
||||
def avatar_url(self):
|
||||
admin_default = settings.STATIC_URL + "img/avatar/admin.png"
|
||||
user_default = settings.STATIC_URL + "img/avatar/user.png"
|
||||
if self.avatar:
|
||||
return self.avatar.url
|
||||
if self.is_superuser:
|
||||
return admin_default
|
||||
else:
|
||||
return user_default
|
||||
|
||||
def unblock_login(self):
|
||||
from users.utils import LoginBlockUtil, MFABlockUtils
|
||||
|
||||
LoginBlockUtil.unblock_user(self.username)
|
||||
MFABlockUtils.unblock_user(self.username)
|
||||
|
||||
@property
|
||||
def login_blocked(self):
|
||||
from users.utils import LoginBlockUtil, MFABlockUtils
|
||||
|
||||
if LoginBlockUtil.is_user_block(self.username):
|
||||
return True
|
||||
if MFABlockUtils.is_user_block(self.username):
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete(self, using=None, keep_parents=False):
|
||||
if self.pk == 1 or self.username == "admin":
|
||||
raise PermissionDenied(_("Can not delete admin user"))
|
||||
return super(User, self).delete(using=using, keep_parents=keep_parents)
|
||||
|
||||
class Meta:
|
||||
ordering = ["username"]
|
||||
verbose_name = _("User")
|
||||
unique_together = (
|
||||
("dingtalk_id",),
|
||||
("wecom_id",),
|
||||
("feishu_id",),
|
||||
("lark_id",),
|
||||
("slack_id",),
|
||||
)
|
||||
permissions = [
|
||||
("invite_user", _("Can invite user")),
|
||||
("remove_user", _("Can remove user")),
|
||||
("match_user", _("Can match user")),
|
||||
]
|
||||
|
||||
def can_send_created_mail(self):
|
||||
if self.email and self.source == self.Source.local.value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class UserPasswordHistory(models.Model):
|
||||
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
|
||||
password = models.CharField(max_length=128)
|
||||
user = models.ForeignKey(
|
||||
"users.User",
|
||||
related_name="history_passwords",
|
||||
on_delete=jms_models.CASCADE_SIGNAL_SKIP,
|
||||
verbose_name=_("User"),
|
||||
)
|
||||
date_created = models.DateTimeField(
|
||||
auto_now_add=True, verbose_name=_("Date created")
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.user} set at {self.date_created}"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("User password history")
|
@ -0,0 +1,273 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
import datetime
|
||||
from typing import Callable
|
||||
|
||||
import sshpubkeys
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.hashers import check_password
|
||||
from django.core.cache import cache
|
||||
from django.db import models
|
||||
from django.utils import timezone
|
||||
from django.utils.module_loading import import_string
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common.utils import (
|
||||
get_logger,
|
||||
lazyproperty,
|
||||
)
|
||||
from users.signals import post_user_change_password
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
__all__ = ['MFAMixin', 'AuthMixin']
|
||||
|
||||
|
||||
class MFAMixin:
|
||||
mfa_level = 0
|
||||
otp_secret_key = ""
|
||||
MFA_LEVEL_CHOICES = (
|
||||
(0, _("Disabled")),
|
||||
(1, _("Enabled")),
|
||||
(2, _("Force enabled")),
|
||||
)
|
||||
is_org_admin: bool
|
||||
username: str
|
||||
phone: str
|
||||
|
||||
@property
|
||||
def mfa_enabled(self):
|
||||
if self.mfa_force_enabled:
|
||||
return True
|
||||
return self.mfa_level > 0
|
||||
|
||||
@property
|
||||
def mfa_force_enabled(self):
|
||||
force_level = settings.SECURITY_MFA_AUTH
|
||||
# 1 All users
|
||||
if force_level in [True, 1]:
|
||||
return True
|
||||
# 2 仅管理员强制开启
|
||||
if force_level == 2 and self.is_org_admin:
|
||||
return True
|
||||
# 3 仅用户开启
|
||||
return self.mfa_level == 2
|
||||
|
||||
def enable_mfa(self):
|
||||
if not self.mfa_level == 2:
|
||||
self.mfa_level = 1
|
||||
|
||||
def force_enable_mfa(self):
|
||||
self.mfa_level = 2
|
||||
|
||||
def disable_mfa(self):
|
||||
self.mfa_level = 0
|
||||
|
||||
def no_active_mfa(self):
|
||||
return len(self.active_mfa_backends) == 0
|
||||
|
||||
@lazyproperty
|
||||
def active_mfa_backends(self):
|
||||
backends = self.get_user_mfa_backends(self)
|
||||
active_backends = [b for b in backends if b.is_active()]
|
||||
return active_backends
|
||||
|
||||
@property
|
||||
def active_mfa_backends_mapper(self):
|
||||
return {b.name: b for b in self.active_mfa_backends}
|
||||
|
||||
@staticmethod
|
||||
def get_user_mfa_backends(user):
|
||||
backends = []
|
||||
for cls in settings.MFA_BACKENDS:
|
||||
cls = import_string(cls)
|
||||
if cls.global_enabled():
|
||||
backends.append(cls(user))
|
||||
return backends
|
||||
|
||||
def get_active_mfa_backend_by_type(self, mfa_type):
|
||||
backend = self.get_mfa_backend_by_type(mfa_type)
|
||||
if not backend or not backend.is_active():
|
||||
return None
|
||||
return backend
|
||||
|
||||
def get_mfa_backend_by_type(self, mfa_type):
|
||||
mfa_mapper = {b.name: b for b in self.get_user_mfa_backends(self)}
|
||||
backend = mfa_mapper.get(mfa_type)
|
||||
if not backend:
|
||||
return None
|
||||
return backend
|
||||
|
||||
|
||||
class AuthMixin:
|
||||
date_password_last_updated: datetime.datetime
|
||||
history_passwords: models.Manager
|
||||
need_update_password: bool
|
||||
public_key: str
|
||||
username: str
|
||||
is_local: bool
|
||||
set_password: Callable
|
||||
save: Callable
|
||||
history_passwords: models.Manager
|
||||
sect_cache_tpl = "user_sect_{}"
|
||||
id: str
|
||||
|
||||
@property
|
||||
def password_raw(self):
|
||||
raise AttributeError("Password raw is not a readable attribute")
|
||||
|
||||
#: Use this attr to set user object password, example
|
||||
#: user = User(username='example', password_raw='password', ...)
|
||||
#: It's equal:
|
||||
#: user = User(username='example', ...)
|
||||
#: user.set_password('password')
|
||||
@password_raw.setter
|
||||
def password_raw(self, password_raw_):
|
||||
self.set_password(password_raw_)
|
||||
|
||||
def set_password(self, raw_password):
|
||||
if self.can_update_password():
|
||||
if self.username:
|
||||
self.date_password_last_updated = timezone.now()
|
||||
post_user_change_password.send(self.__class__, user=self)
|
||||
super().set_password(raw_password) # noqa
|
||||
|
||||
def set_public_key(self, public_key):
|
||||
if self.can_update_ssh_key():
|
||||
self.public_key = public_key
|
||||
self.save()
|
||||
post_user_change_password.send(self.__class__, user=self)
|
||||
|
||||
def can_update_password(self):
|
||||
return self.is_local
|
||||
|
||||
def can_update_ssh_key(self):
|
||||
return self.can_use_ssh_key_login()
|
||||
|
||||
@staticmethod
|
||||
def can_use_ssh_key_login():
|
||||
return settings.TERMINAL_PUBLIC_KEY_AUTH
|
||||
|
||||
def is_history_password(self, password):
|
||||
allow_history_password_count = settings.OLD_PASSWORD_HISTORY_LIMIT_COUNT
|
||||
history_passwords = self.history_passwords.all().order_by("-date_created")[
|
||||
: int(allow_history_password_count)
|
||||
]
|
||||
|
||||
for history_password in history_passwords:
|
||||
if check_password(password, history_password.password):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def is_public_key_valid(self):
|
||||
"""
|
||||
Check if the user's ssh public key is valid.
|
||||
This function is used in base.html.
|
||||
"""
|
||||
if self.public_key:
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def public_key_obj(self):
|
||||
class PubKey(object):
|
||||
def __getattr__(self, item):
|
||||
return ""
|
||||
|
||||
if self.public_key:
|
||||
try:
|
||||
return sshpubkeys.SSHKey(self.public_key)
|
||||
except (TabError, TypeError):
|
||||
pass
|
||||
return PubKey()
|
||||
|
||||
def get_public_key_comment(self):
|
||||
return self.public_key_obj.comment
|
||||
|
||||
def get_public_key_hash_md5(self):
|
||||
if not callable(self.public_key_obj.hash_md5):
|
||||
return ""
|
||||
try:
|
||||
return self.public_key_obj.hash_md5()
|
||||
except:
|
||||
return ""
|
||||
|
||||
def reset_password(self, new_password):
|
||||
self.set_password(new_password)
|
||||
self.need_update_password = False
|
||||
self.save()
|
||||
|
||||
@property
|
||||
def date_password_expired(self):
|
||||
interval = settings.SECURITY_PASSWORD_EXPIRATION_TIME
|
||||
date_expired = self.date_password_last_updated + timezone.timedelta(
|
||||
days=int(interval)
|
||||
)
|
||||
return date_expired
|
||||
|
||||
@property
|
||||
def password_expired_remain_days(self):
|
||||
date_remain = self.date_password_expired - timezone.now()
|
||||
return date_remain.days
|
||||
|
||||
@property
|
||||
def password_has_expired(self):
|
||||
if self.is_local and self.password_expired_remain_days < 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def password_will_expired(self):
|
||||
if self.is_local and 0 <= self.password_expired_remain_days < 5:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_public_key_md5(key):
|
||||
try:
|
||||
key_obj = sshpubkeys.SSHKey(key)
|
||||
return key_obj.hash_md5()
|
||||
except Exception as e:
|
||||
return ""
|
||||
|
||||
def check_public_key(self, key):
|
||||
if not self.public_key:
|
||||
return False
|
||||
key_md5 = self.get_public_key_md5(key)
|
||||
if not key_md5:
|
||||
return False
|
||||
self_key_md5 = self.get_public_key_md5(self.public_key)
|
||||
return key_md5 == self_key_md5
|
||||
|
||||
def cache_login_password_if_need(self, password):
|
||||
from common.utils import signer
|
||||
|
||||
if not settings.CACHE_LOGIN_PASSWORD_ENABLED:
|
||||
return
|
||||
backend = getattr(self, "backend", "")
|
||||
if backend.lower().find("ldap") < 0:
|
||||
return
|
||||
if not password:
|
||||
return
|
||||
key = self.sect_cache_tpl.format(self.id)
|
||||
ttl = settings.CACHE_LOGIN_PASSWORD_TTL
|
||||
if not isinstance(ttl, int) or ttl <= 0:
|
||||
return
|
||||
secret = signer.sign(password)
|
||||
cache.set(key, secret, ttl)
|
||||
|
||||
def get_cached_password_if_has(self):
|
||||
from common.utils import signer
|
||||
|
||||
if not settings.CACHE_LOGIN_PASSWORD_ENABLED:
|
||||
return ""
|
||||
key = self.sect_cache_tpl.format(self.id)
|
||||
secret = cache.get(key)
|
||||
if not secret:
|
||||
return ""
|
||||
password = signer.unsign(secret)
|
||||
return password
|
||||
|
@ -0,0 +1,32 @@
|
||||
from django.db import models
|
||||
from django.db.models import Count
|
||||
|
||||
|
||||
class JSONFilterMixin:
|
||||
@staticmethod
|
||||
def get_json_filter_attr_q(name, value, match):
|
||||
from rbac.models import RoleBinding
|
||||
from orgs.utils import current_org
|
||||
|
||||
kwargs = {}
|
||||
if name == "system_roles":
|
||||
kwargs["scope"] = "system"
|
||||
elif name == "org_roles":
|
||||
kwargs["scope"] = "org"
|
||||
if not current_org.is_root():
|
||||
kwargs["org_id"] = current_org.id
|
||||
else:
|
||||
return None
|
||||
|
||||
bindings = RoleBinding.objects.filter(**kwargs, role__in=value)
|
||||
if match == "m2m_all":
|
||||
user_id = (
|
||||
bindings.values("user_id")
|
||||
.annotate(count=Count("user_id")) # 这里不能有 distinct 会导致 count 不准确, acls 中过滤用户时会出现问题
|
||||
.filter(count=len(value))
|
||||
.values_list("user_id", flat=True)
|
||||
)
|
||||
else:
|
||||
user_id = bindings.values_list("user_id", flat=True)
|
||||
return models.Q(id__in=user_id)
|
||||
|
@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class Source(models.TextChoices):
|
||||
local = "local", _("Local")
|
||||
ldap = "ldap", "LDAP/AD"
|
||||
openid = "openid", "OpenID"
|
||||
radius = "radius", "Radius"
|
||||
cas = "cas", "CAS"
|
||||
saml2 = "saml2", "SAML2"
|
||||
oauth2 = "oauth2", "OAuth2"
|
||||
wecom = "wecom", _("WeCom")
|
||||
dingtalk = "dingtalk", _("DingTalk")
|
||||
feishu = "feishu", _("FeiShu")
|
||||
lark = "lark", _("Lark")
|
||||
slack = "slack", _("Slack")
|
||||
custom = "custom", "Custom"
|
||||
|
||||
|
||||
class SourceMixin:
|
||||
source: str
|
||||
_source_choices = []
|
||||
Source = Source
|
||||
|
||||
SOURCE_BACKEND_MAPPING = {
|
||||
Source.local: [
|
||||
settings.AUTH_BACKEND_MODEL,
|
||||
settings.AUTH_BACKEND_PUBKEY,
|
||||
],
|
||||
Source.ldap: [settings.AUTH_BACKEND_LDAP],
|
||||
Source.openid: [
|
||||
settings.AUTH_BACKEND_OIDC_PASSWORD,
|
||||
settings.AUTH_BACKEND_OIDC_CODE,
|
||||
],
|
||||
Source.radius: [settings.AUTH_BACKEND_RADIUS],
|
||||
Source.cas: [settings.AUTH_BACKEND_CAS],
|
||||
Source.saml2: [settings.AUTH_BACKEND_SAML2],
|
||||
Source.oauth2: [settings.AUTH_BACKEND_OAUTH2],
|
||||
Source.wecom: [settings.AUTH_BACKEND_WECOM],
|
||||
Source.feishu: [settings.AUTH_BACKEND_FEISHU],
|
||||
Source.lark: [settings.AUTH_BACKEND_LARK],
|
||||
Source.slack: [settings.AUTH_BACKEND_SLACK],
|
||||
Source.dingtalk: [settings.AUTH_BACKEND_DINGTALK],
|
||||
Source.custom: [settings.AUTH_BACKEND_CUSTOM],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_sources_enabled(cls):
|
||||
mapper = {
|
||||
cls.Source.local: True,
|
||||
cls.Source.ldap: settings.AUTH_LDAP,
|
||||
cls.Source.openid: settings.AUTH_OPENID,
|
||||
cls.Source.radius: settings.AUTH_RADIUS,
|
||||
cls.Source.cas: settings.AUTH_CAS,
|
||||
cls.Source.saml2: settings.AUTH_SAML2,
|
||||
cls.Source.oauth2: settings.AUTH_OAUTH2,
|
||||
cls.Source.wecom: settings.AUTH_WECOM,
|
||||
cls.Source.feishu: settings.AUTH_FEISHU,
|
||||
cls.Source.slack: settings.AUTH_SLACK,
|
||||
cls.Source.dingtalk: settings.AUTH_DINGTALK,
|
||||
cls.Source.custom: settings.AUTH_CUSTOM,
|
||||
}
|
||||
return [str(k) for k, v in mapper.items() if v]
|
||||
|
||||
@property
|
||||
def source_display(self):
|
||||
return self.get_source_display()
|
||||
|
||||
@property
|
||||
def is_local(self):
|
||||
return self.source == self.Source.local.value
|
||||
|
||||
@classmethod
|
||||
def get_source_choices(cls):
|
||||
if cls._source_choices:
|
||||
return cls._source_choices
|
||||
used = (
|
||||
cls.objects.values_list("source", flat=True).order_by("source").distinct()
|
||||
)
|
||||
enabled_sources = cls.get_sources_enabled()
|
||||
_choices = []
|
||||
for k, v in cls.Source.choices:
|
||||
if k in enabled_sources or k in used:
|
||||
_choices.append((k, v))
|
||||
cls._source_choices = _choices
|
||||
return cls._source_choices
|
||||
|
||||
@classmethod
|
||||
def get_user_allowed_auth_backend_paths(cls, username):
|
||||
if not settings.ONLY_ALLOW_AUTH_FROM_SOURCE or not username:
|
||||
return None
|
||||
user = cls.objects.filter(username=username).first()
|
||||
if not user:
|
||||
return None
|
||||
return user.get_allowed_auth_backend_paths()
|
||||
|
||||
def get_allowed_auth_backend_paths(self):
|
||||
if not settings.ONLY_ALLOW_AUTH_FROM_SOURCE:
|
||||
return None
|
||||
return self.SOURCE_BACKEND_MAPPING.get(self.source, [])
|
@ -0,0 +1,95 @@
|
||||
import base64
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
from django.utils import timezone
|
||||
|
||||
from common.utils import (
|
||||
get_logger,
|
||||
random_string,
|
||||
)
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class TokenMixin:
|
||||
CACHE_KEY_USER_RESET_PASSWORD_PREFIX = "_KEY_USER_RESET_PASSWORD_{}"
|
||||
email = ""
|
||||
id = None
|
||||
|
||||
@property
|
||||
def private_token(self):
|
||||
return self.create_private_token()
|
||||
|
||||
def create_private_token(self):
|
||||
from authentication.models import PrivateToken
|
||||
|
||||
token, created = PrivateToken.objects.get_or_create(user=self)
|
||||
return token
|
||||
|
||||
def delete_private_token(self):
|
||||
from authentication.models import PrivateToken
|
||||
|
||||
PrivateToken.objects.filter(user=self).delete()
|
||||
|
||||
def refresh_private_token(self):
|
||||
self.delete_private_token()
|
||||
return self.create_private_token()
|
||||
|
||||
def create_bearer_token(self, request=None):
|
||||
expiration = settings.TOKEN_EXPIRATION or 3600
|
||||
if request:
|
||||
remote_addr = request.META.get("REMOTE_ADDR", "")
|
||||
else:
|
||||
remote_addr = "0.0.0.0"
|
||||
if not isinstance(remote_addr, bytes):
|
||||
remote_addr = remote_addr.encode("utf-8")
|
||||
remote_addr = base64.b16encode(remote_addr) # .replace(b'=', '')
|
||||
cache_key = "%s_%s" % (self.id, remote_addr)
|
||||
token = cache.get(cache_key)
|
||||
if not token:
|
||||
token = random_string(36)
|
||||
cache.set(token, self.id, expiration)
|
||||
cache.set("%s_%s" % (self.id, remote_addr), token, expiration)
|
||||
date_expired = timezone.now() + timezone.timedelta(seconds=expiration)
|
||||
return token, date_expired
|
||||
|
||||
def refresh_bearer_token(self, token):
|
||||
pass
|
||||
|
||||
def create_access_key(self):
|
||||
access_key = self.access_keys.create()
|
||||
return access_key
|
||||
|
||||
@property
|
||||
def access_key(self):
|
||||
return self.access_keys.first()
|
||||
|
||||
def generate_reset_token(self):
|
||||
token = random_string(50)
|
||||
key = self.CACHE_KEY_USER_RESET_PASSWORD_PREFIX.format(token)
|
||||
cache.set(key, {"id": self.id, "email": self.email}, 3600)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def validate_reset_password_token(cls, token):
|
||||
if not token:
|
||||
return None
|
||||
key = cls.CACHE_KEY_USER_RESET_PASSWORD_PREFIX.format(token)
|
||||
value = cache.get(key)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
user_id = value.get("id", "")
|
||||
email = value.get("email", "")
|
||||
user = cls.objects.get(id=user_id, email=email)
|
||||
return user
|
||||
except (AttributeError, cls.DoesNotExist) as e:
|
||||
logger.error(e, exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def expired_reset_password_token(cls, token):
|
||||
key = cls.CACHE_KEY_USER_RESET_PASSWORD_PREFIX.format(token)
|
||||
cache.delete(key)
|
||||
|
Loading…
Reference in new issue