Merge pull request #5728 from jumpserver/dev

v2.8 发版
pull/5813/head
Jiangjie.Bai 2021-03-11 21:17:39 +08:00 committed by GitHub
commit 174cc16980
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
234 changed files with 7528 additions and 4481 deletions

View File

@ -23,6 +23,7 @@ RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list \
&& sed -i 's/security.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list \ && sed -i 's/security.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list \
&& apt update \ && apt update \
&& grep -v '^#' ./requirements/deb_buster_requirements.txt | xargs apt -y install \ && grep -v '^#' ./requirements/deb_buster_requirements.txt | xargs apt -y install \
&& rm -rf /var/lib/apt/lists/* \
&& localedef -c -f UTF-8 -i zh_CN zh_CN.UTF-8 \ && localedef -c -f UTF-8 -i zh_CN zh_CN.UTF-8 \
&& cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime

3
apps/acls/admin.py Normal file
View File

@ -0,0 +1,3 @@
from django.contrib import admin
# Register your models here.

View File

@ -0,0 +1,3 @@
from .login_acl import *
from .login_asset_acl import *
from .login_asset_check import *

View File

@ -0,0 +1,19 @@
from common.permissions import IsOrgAdmin, HasQueryParamsUserAndIsCurrentOrgMember
from common.drf.api import JMSBulkModelViewSet
from ..models import LoginACL
from .. import serializers
__all__ = ['LoginACLViewSet', ]
class LoginACLViewSet(JMSBulkModelViewSet):
queryset = LoginACL.objects.all()
filterset_fields = ('name', 'user', )
search_fields = filterset_fields
permission_classes = (IsOrgAdmin, )
serializer_class = serializers.LoginACLSerializer
def get_permissions(self):
if self.action in ["retrieve", "list"]:
self.permission_classes = (IsOrgAdmin, HasQueryParamsUserAndIsCurrentOrgMember)
return super().get_permissions()

View File

@ -0,0 +1,15 @@
from orgs.mixins.api import OrgBulkModelViewSet
from common.permissions import IsOrgAdmin
from .. import models, serializers
__all__ = ['LoginAssetACLViewSet']
class LoginAssetACLViewSet(OrgBulkModelViewSet):
model = models.LoginAssetACL
filterset_fields = ('name', )
search_fields = filterset_fields
permission_classes = (IsOrgAdmin, )
serializer_class = serializers.LoginAssetACLSerializer

View File

@ -0,0 +1,105 @@
from django.shortcuts import get_object_or_404
from rest_framework.response import Response
from rest_framework.generics import CreateAPIView, RetrieveDestroyAPIView
from common.permissions import IsAppUser
from common.utils import reverse, lazyproperty
from orgs.utils import tmp_to_org, tmp_to_root_org
from tickets.models import Ticket
from ..models import LoginAssetACL
from .. import serializers
__all__ = ['LoginAssetCheckAPI', 'LoginAssetConfirmStatusAPI']
class LoginAssetCheckAPI(CreateAPIView):
permission_classes = (IsAppUser, )
serializer_class = serializers.LoginAssetCheckSerializer
def create(self, request, *args, **kwargs):
is_need_confirm, response_data = self.check_if_need_confirm()
return Response(data=response_data, status=200)
def check_if_need_confirm(self):
queries = {
'user': self.serializer.user, 'asset': self.serializer.asset,
'system_user': self.serializer.system_user,
'action': LoginAssetACL.ActionChoices.login_confirm
}
with tmp_to_org(self.serializer.org):
acl = LoginAssetACL.filter(**queries).valid().first()
if not acl:
is_need_confirm = False
response_data = {}
else:
is_need_confirm = True
response_data = self._get_response_data_of_need_confirm(acl)
response_data['need_confirm'] = is_need_confirm
return is_need_confirm, response_data
def _get_response_data_of_need_confirm(self, acl):
ticket = LoginAssetACL.create_login_asset_confirm_ticket(
user=self.serializer.user,
asset=self.serializer.asset,
system_user=self.serializer.system_user,
assignees=acl.reviewers.all(),
org_id=self.serializer.org.id
)
confirm_status_url = reverse(
view_name='acls:login-asset-confirm-status',
kwargs={'pk': str(ticket.id)}
)
ticket_detail_url = reverse(
view_name='api-tickets:ticket-detail',
kwargs={'pk': str(ticket.id)},
external=True, api_to_ui=True
)
ticket_detail_url = '{url}?type={type}'.format(url=ticket_detail_url, type=ticket.type)
data = {
'check_confirm_status': {'method': 'GET', 'url': confirm_status_url},
'close_confirm': {'method': 'DELETE', 'url': confirm_status_url},
'ticket_detail_url': ticket_detail_url,
'reviewers': [str(user) for user in ticket.assignees.all()],
}
return data
@lazyproperty
def serializer(self):
serializer = self.get_serializer(data=self.request.data)
serializer.is_valid(raise_exception=True)
return serializer
class LoginAssetConfirmStatusAPI(RetrieveDestroyAPIView):
permission_classes = (IsAppUser, )
def retrieve(self, request, *args, **kwargs):
if self.ticket.action_open:
status = 'await'
elif self.ticket.action_approve:
status = 'approve'
else:
status = 'reject'
data = {
'status': status,
'action': self.ticket.action,
'processor': self.ticket.processor_display
}
return Response(data=data, status=200)
def destroy(self, request, *args, **kwargs):
if self.ticket.status_open:
self.ticket.close(processor=self.ticket.applicant)
data = {
'action': self.ticket.action,
'status': self.ticket.status,
'processor': self.ticket.processor_display
}
return Response(data=data, status=200)
@lazyproperty
def ticket(self):
with tmp_to_root_org():
return get_object_or_404(Ticket, pk=self.kwargs['pk'])

5
apps/acls/apps.py Normal file
View File

@ -0,0 +1,5 @@
from django.apps import AppConfig
class AclsConfig(AppConfig):
name = 'acls'

9
apps/acls/const.py Normal file
View File

@ -0,0 +1,9 @@
from django.utils.translation import ugettext as _
common_help_text = _('Format for comma-delimited string, with * indicating a match all. ')
ip_group_help_text = common_help_text + _(
'Such as: '
'192.168.10.1, 192.168.1.0/24, 10.1.1.1-10.1.1.20, 2001:db8:2de::e13, 2001:db8:1a:1110::/64 '
)

View File

@ -0,0 +1,61 @@
# Generated by Django 3.1 on 2021-03-11 09:53
from django.conf import settings
import django.core.validators
from django.db import migrations, models
import django.db.models.deletion
import uuid
class Migration(migrations.Migration):
initial = True
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name='LoginACL',
fields=[
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('created_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Created by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('name', models.CharField(max_length=128, verbose_name='Name')),
('priority', models.IntegerField(default=50, help_text='1-100, the lower the value will be match first', validators=[django.core.validators.MinValueValidator(1), django.core.validators.MaxValueValidator(100)], verbose_name='Priority')),
('is_active', models.BooleanField(default=True, verbose_name='Active')),
('comment', models.TextField(blank=True, default='', verbose_name='Comment')),
('ip_group', models.JSONField(default=list, verbose_name='Login IP')),
('action', models.CharField(choices=[('reject', 'Reject'), ('allow', 'Allow')], default='reject', max_length=64, verbose_name='Action')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='login_acls', to=settings.AUTH_USER_MODEL, verbose_name='User')),
],
options={
'ordering': ('priority', '-date_updated', 'name'),
},
),
migrations.CreateModel(
name='LoginAssetACL',
fields=[
('org_id', models.CharField(blank=True, db_index=True, default='', max_length=36, verbose_name='Organization')),
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('created_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Created by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('name', models.CharField(max_length=128, verbose_name='Name')),
('priority', models.IntegerField(default=50, help_text='1-100, the lower the value will be match first', validators=[django.core.validators.MinValueValidator(1), django.core.validators.MaxValueValidator(100)], verbose_name='Priority')),
('is_active', models.BooleanField(default=True, verbose_name='Active')),
('comment', models.TextField(blank=True, default='', verbose_name='Comment')),
('users', models.JSONField(verbose_name='User')),
('system_users', models.JSONField(verbose_name='System User')),
('assets', models.JSONField(verbose_name='Asset')),
('action', models.CharField(choices=[('login_confirm', 'Login confirm')], default='login_confirm', max_length=64, verbose_name='Action')),
('reviewers', models.ManyToManyField(blank=True, related_name='review_login_asset_acls', to=settings.AUTH_USER_MODEL, verbose_name='Reviewers')),
],
options={
'ordering': ('priority', '-date_updated', 'name'),
'unique_together': {('name', 'org_id')},
},
),
]

View File

View File

@ -0,0 +1,2 @@
from .login_acl import *
from .login_asset_acl import *

35
apps/acls/models/base.py Normal file
View File

@ -0,0 +1,35 @@
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.core.validators import MinValueValidator, MaxValueValidator
from common.mixins import CommonModelMixin
__all__ = ['BaseACL', 'BaseACLQuerySet']
class BaseACLQuerySet(models.QuerySet):
def active(self):
return self.filter(is_active=True)
def inactive(self):
return self.filter(is_active=False)
def valid(self):
return self.active()
def invalid(self):
return self.inactive()
class BaseACL(CommonModelMixin):
name = models.CharField(max_length=128, verbose_name=_('Name'))
priority = models.IntegerField(
default=50, verbose_name=_("Priority"),
help_text=_("1-100, the lower the value will be match first"),
validators=[MinValueValidator(1), MaxValueValidator(100)]
)
is_active = models.BooleanField(default=True, verbose_name=_("Active"))
comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
class Meta:
abstract = True

View File

@ -0,0 +1,54 @@
from django.db import models
from django.utils.translation import ugettext_lazy as _
from .base import BaseACL, BaseACLQuerySet
from ..utils import contains_ip
class ACLManager(models.Manager):
def valid(self):
return self.get_queryset().valid()
class LoginACL(BaseACL):
class ActionChoices(models.TextChoices):
reject = 'reject', _('Reject')
allow = 'allow', _('Allow')
# 条件
ip_group = models.JSONField(default=list, verbose_name=_('Login IP'))
# 动作
action = models.CharField(
max_length=64, choices=ActionChoices.choices, default=ActionChoices.reject,
verbose_name=_('Action')
)
# 关联
user = models.ForeignKey(
'users.User', on_delete=models.CASCADE, related_name='login_acls', verbose_name=_('User')
)
objects = ACLManager.from_queryset(BaseACLQuerySet)()
class Meta:
ordering = ('priority', '-date_updated', 'name')
@property
def action_reject(self):
return self.action == self.ActionChoices.reject
@property
def action_allow(self):
return self.action == self.ActionChoices.allow
@staticmethod
def allow_user_to_login(user, ip):
acl = user.login_acls.valid().first()
if not acl:
return True
is_contained = contains_ip(ip, acl.ip_group)
if acl.action_allow and is_contained:
return True
if acl.action_reject and not is_contained:
return True
return False

View File

@ -0,0 +1,99 @@
from django.db import models
from django.db.models import Q
from django.utils.translation import ugettext_lazy as _
from orgs.mixins.models import OrgModelMixin, OrgManager
from .base import BaseACL, BaseACLQuerySet
from ..utils import contains_ip
class ACLManager(OrgManager):
def valid(self):
return self.get_queryset().valid()
class LoginAssetACL(BaseACL, OrgModelMixin):
class ActionChoices(models.TextChoices):
login_confirm = 'login_confirm', _('Login confirm')
# 条件
users = models.JSONField(verbose_name=_('User'))
system_users = models.JSONField(verbose_name=_('System User'))
assets = models.JSONField(verbose_name=_('Asset'))
# 动作
action = models.CharField(
max_length=64, choices=ActionChoices.choices, default=ActionChoices.login_confirm,
verbose_name=_('Action')
)
# 动作: 附加字段
# - login_confirm
reviewers = models.ManyToManyField(
'users.User', related_name='review_login_asset_acls', blank=True,
verbose_name=_("Reviewers")
)
objects = ACLManager.from_queryset(BaseACLQuerySet)()
class Meta:
unique_together = ('name', 'org_id')
ordering = ('priority', '-date_updated', 'name')
@classmethod
def filter(cls, user, asset, system_user, action):
queryset = cls.objects.filter(action=action)
queryset = cls.filter_user(user, queryset)
queryset = cls.filter_asset(asset, queryset)
queryset = cls.filter_system_user(system_user, queryset)
return queryset
@classmethod
def filter_user(cls, user, queryset):
queryset = queryset.filter(
Q(users__username_group__contains=user.username) |
Q(users__username_group__contains='*')
)
return queryset
@classmethod
def filter_asset(cls, asset, queryset):
queryset = queryset.filter(
Q(assets__hostname_group__contains=asset.hostname) |
Q(assets__hostname_group__contains='*')
)
ids = [q.id for q in queryset if contains_ip(asset.ip, q.assets.get('ip_group', []))]
queryset = cls.objects.filter(id__in=ids)
return queryset
@classmethod
def filter_system_user(cls, system_user, queryset):
queryset = queryset.filter(
Q(system_users__name_group__contains=system_user.name) |
Q(system_users__name_group__contains='*')
).filter(
Q(system_users__username_group__contains=system_user.username) |
Q(system_users__username_group__contains='*')
).filter(
Q(system_users__protocol_group__contains=system_user.protocol) |
Q(system_users__protocol_group__contains='*')
)
return queryset
@classmethod
def create_login_asset_confirm_ticket(cls, user, asset, system_user, assignees, org_id):
from tickets.const import TicketTypeChoices
from tickets.models import Ticket
data = {
'title': _('Login asset confirm') + ' ({})'.format(user),
'type': TicketTypeChoices.login_asset_confirm,
'meta': {
'apply_login_user': str(user),
'apply_login_asset': str(asset),
'apply_login_system_user': str(system_user),
},
'org_id': org_id,
}
ticket = Ticket.objects.create(**data)
ticket.assignees.set(assignees)
ticket.open(applicant=user)
return ticket

View File

@ -0,0 +1,3 @@
from .login_acl import *
from .login_asset_acl import *
from .login_asset_check import *

View File

@ -0,0 +1,49 @@
from django.utils.translation import ugettext as _
from rest_framework import serializers
from common.drf.serializers import BulkModelSerializer
from orgs.utils import current_org
from ..models import LoginACL
from ..utils import is_ip_address, is_ip_network, is_ip_segment
from .. import const
__all__ = ['LoginACLSerializer', ]
def ip_group_child_validator(ip_group_child):
is_valid = ip_group_child == '*' \
or is_ip_address(ip_group_child) \
or is_ip_network(ip_group_child) \
or is_ip_segment(ip_group_child)
if not is_valid:
error = _('IP address invalid: `{}`').format(ip_group_child)
raise serializers.ValidationError(error)
class LoginACLSerializer(BulkModelSerializer):
ip_group = serializers.ListField(
default=['*'], label=_('IP'), help_text=const.ip_group_help_text,
child=serializers.CharField(max_length=1024, validators=[ip_group_child_validator])
)
user_display = serializers.ReadOnlyField(source='user.name', label=_('User'))
action_display = serializers.ReadOnlyField(source='get_action_display', label=_('Action'))
class Meta:
model = LoginACL
fields = [
'id', 'name', 'priority', 'ip_group', 'user', 'user_display', 'action',
'action_display', 'is_active', 'comment', 'created_by', 'date_created', 'date_updated'
]
extra_kwargs = {
'priority': {'default': 50},
'is_active': {'default': True},
}
@staticmethod
def validate_user(user):
if user not in current_org.get_members():
error = _('The user `{}` is not in the current organization: `{}`').format(
user, current_org
)
raise serializers.ValidationError(error)
return user

View File

@ -0,0 +1,87 @@
from rest_framework import serializers
from django.utils.translation import ugettext as _
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from assets.models import SystemUser
from acls import models
from orgs.models import Organization
from .. import const
__all__ = ['LoginAssetACLSerializer']
class LoginAssetACLUsersSerializer(serializers.Serializer):
username_group = serializers.ListField(
default=['*'], child=serializers.CharField(max_length=128), label=_('Username'),
help_text=const.common_help_text
)
class LoginAssetACLAssestsSerializer(serializers.Serializer):
ip_group = serializers.ListField(
default=['*'], child=serializers.CharField(max_length=1024), label=_('IP'),
help_text=const.ip_group_help_text + _('(Domain name support)')
)
hostname_group = serializers.ListField(
default=['*'], child=serializers.CharField(max_length=128), label=_('Hostname'),
help_text=const.common_help_text
)
class LoginAssetACLSystemUsersSerializer(serializers.Serializer):
name_group = serializers.ListField(
default=['*'], child=serializers.CharField(max_length=128), label=_('Name'),
help_text=const.common_help_text
)
username_group = serializers.ListField(
default=['*'], child=serializers.CharField(max_length=128), label=_('Username'),
help_text=const.common_help_text
)
protocol_group = serializers.ListField(
default=['*'], child=serializers.CharField(max_length=16), label=_('Protocol'),
help_text=const.common_help_text + _('Protocol options: {}').format(
', '.join(SystemUser.ASSET_CATEGORY_PROTOCOLS)
)
)
@staticmethod
def validate_protocol_group(protocol_group):
unsupported_protocols = set(protocol_group) - set(SystemUser.ASSET_CATEGORY_PROTOCOLS + ['*'])
if unsupported_protocols:
error = _('Unsupported protocols: {}').format(unsupported_protocols)
raise serializers.ValidationError(error)
return protocol_group
class LoginAssetACLSerializer(BulkOrgResourceModelSerializer):
users = LoginAssetACLUsersSerializer()
assets = LoginAssetACLAssestsSerializer()
system_users = LoginAssetACLSystemUsersSerializer()
reviewers_amount = serializers.IntegerField(read_only=True, source='reviewers.count')
action_display = serializers.ReadOnlyField(source='get_action_display', label=_('Action'))
class Meta:
model = models.LoginAssetACL
fields = [
'id', 'name', 'priority', 'users', 'system_users', 'assets', 'action', 'action_display',
'is_active', 'comment', 'reviewers', 'reviewers_amount', 'created_by', 'date_created',
'date_updated', 'org_id'
]
extra_kwargs = {
"reviewers": {'allow_null': False, 'required': True},
'priority': {'default': 50},
'is_active': {'default': True},
}
def validate_reviewers(self, reviewers):
org_id = self.fields['org_id'].default()
org = Organization.get_instance(org_id)
if not org:
error = _('The organization `{}` does not exist'.format(org_id))
raise serializers.ValidationError(error)
users = org.get_members()
valid_reviewers = list(set(reviewers) & set(users))
if not valid_reviewers:
error = _('None of the reviewers belong to Organization `{}`'.format(org.name))
raise serializers.ValidationError(error)
return valid_reviewers

View File

@ -0,0 +1,71 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from orgs.utils import tmp_to_root_org
from common.utils import get_object_or_none, lazyproperty
from users.models import User
from assets.models import Asset, SystemUser
__all__ = ['LoginAssetCheckSerializer']
class LoginAssetCheckSerializer(serializers.Serializer):
user_id = serializers.UUIDField(required=True, allow_null=False)
asset_id = serializers.UUIDField(required=True, allow_null=False)
system_user_id = serializers.UUIDField(required=True, allow_null=False)
system_user_username = serializers.CharField(max_length=128, default='')
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.user = None
self.asset = None
self._system_user = None
self._system_user_username = None
def validate_user_id(self, user_id):
self.user = self.validate_object_exist(User, user_id)
return user_id
def validate_asset_id(self, asset_id):
self.asset = self.validate_object_exist(Asset, asset_id)
return asset_id
def validate_system_user_id(self, system_user_id):
self._system_user = self.validate_object_exist(SystemUser, system_user_id)
return system_user_id
def validate_system_user_username(self, system_user_username):
system_user_id = self.initial_data.get('system_user_id')
system_user = self.validate_object_exist(SystemUser, system_user_id)
if self._system_user.login_mode == SystemUser.LOGIN_MANUAL \
and not system_user.username \
and not system_user.username_same_with_user \
and not system_user_username:
error = 'Missing parameter: system_user_username'
raise serializers.ValidationError(error)
self._system_user_username = system_user_username
return system_user_username
@staticmethod
def validate_object_exist(model, field_id):
with tmp_to_root_org():
obj = get_object_or_none(model, pk=field_id)
if not obj:
error = '{} Model object does not exist'.format(model.__name__)
raise serializers.ValidationError(error)
return obj
@lazyproperty
def system_user(self):
if self._system_user.username_same_with_user:
username = self.user.username
elif self._system_user.login_mode == SystemUser.LOGIN_MANUAL:
username = self._system_user_username
else:
username = self._system_user.username
self._system_user.username = username
return self._system_user
@lazyproperty
def org(self):
return self.asset.org

3
apps/acls/tests.py Normal file
View File

@ -0,0 +1,3 @@
from django.test import TestCase
# Create your tests here.

View File

@ -0,0 +1 @@
from .api_urls import *

View File

@ -0,0 +1,18 @@
from django.urls import path
from rest_framework_bulk.routes import BulkRouter
from .. import api
app_name = 'acls'
router = BulkRouter()
router.register(r'login-acls', api.LoginACLViewSet, 'login-acl')
router.register(r'login-asset-acls', api.LoginAssetACLViewSet, 'login-asset-acl')
urlpatterns = [
path('login-asset/check/', api.LoginAssetCheckAPI.as_view(), name='login-asset-check'),
path('login-asset-confirm/<uuid:pk>/status/', api.LoginAssetConfirmStatusAPI.as_view(), name='login-asset-confirm-status')
]
urlpatterns += router.urls

68
apps/acls/utils.py Normal file
View File

@ -0,0 +1,68 @@
from ipaddress import ip_network, ip_address
def is_ip_address(address):
""" 192.168.10.1 """
try:
ip_address(address)
except ValueError:
return False
else:
return True
def is_ip_network(ip):
""" 192.168.1.0/24 """
try:
ip_network(ip)
except ValueError:
return False
else:
return True
def is_ip_segment(ip):
""" 10.1.1.1-10.1.1.20 """
if '-' not in ip:
return False
ip_address1, ip_address2 = ip.split('-')
return is_ip_address(ip_address1) and is_ip_address(ip_address2)
def in_ip_segment(ip, ip_segment):
ip1, ip2 = ip_segment.split('-')
ip1 = int(ip_address(ip1))
ip2 = int(ip_address(ip2))
ip = int(ip_address(ip))
return min(ip1, ip2) <= ip <= max(ip1, ip2)
def contains_ip(ip, ip_group):
"""
ip_group:
[192.168.10.1, 192.168.1.0/24, 10.1.1.1-10.1.1.20, 2001:db8:2de::e13, 2001:db8:1a:1110::/64.]
"""
if '*' in ip_group:
return True
for _ip in ip_group:
if is_ip_address(_ip):
# 192.168.10.1
if ip == _ip:
return True
elif is_ip_network(_ip) and is_ip_address(ip):
# 192.168.1.0/24
if ip_address(ip) in ip_network(_ip):
return True
elif is_ip_segment(_ip) and is_ip_address(ip):
# 10.1.1.1-10.1.1.20
if in_ip_segment(ip, _ip):
return True
else:
# is domain name
if ip == _ip:
return True
return False

View File

@ -77,8 +77,8 @@ class SerializeApplicationToTreeNodeMixin:
@staticmethod @staticmethod
def filter_organizations(applications): def filter_organizations(applications):
organizations_id = set(applications.values_list('org_id', flat=True)) organization_ids = set(applications.values_list('org_id', flat=True))
organizations = [Organization.get_instance(org_id) for org_id in organizations_id] organizations = [Organization.get_instance(org_id) for org_id in organization_ids]
return organizations return organizations
def serialize_applications_with_org(self, applications): def serialize_applications_with_org(self, applications):

View File

@ -3,6 +3,7 @@ from django.utils.translation import ugettext_lazy as _
from orgs.mixins.models import OrgModelMixin from orgs.mixins.models import OrgModelMixin
from common.mixins import CommonModelMixin from common.mixins import CommonModelMixin
from assets.models import Asset
from .. import const from .. import const
@ -35,3 +36,35 @@ class Application(CommonModelMixin, OrgModelMixin):
@property @property
def category_remote_app(self): def category_remote_app(self):
return self.category == const.ApplicationCategoryChoices.remote_app.value return self.category == const.ApplicationCategoryChoices.remote_app.value
def get_rdp_remote_app_setting(self):
from applications.serializers.attrs import get_serializer_class_by_application_type
if not self.category_remote_app:
raise ValueError(f"Not a remote app application: {self.name}")
serializer_class = get_serializer_class_by_application_type(self.type)
fields = serializer_class().get_fields()
parameters = [self.type]
for field_name in list(fields.keys()):
if field_name in ['asset']:
continue
value = self.attrs.get(field_name)
if not value:
continue
if field_name == 'path':
value = '\"%s\"' % value
parameters.append(str(value))
parameters = ' '.join(parameters)
return {
'program': '||jmservisor',
'working_directory': '',
'parameters': parameters
}
def get_remote_app_asset(self):
asset_id = self.attrs.get('asset')
if not asset_id:
raise ValueError("Remote App not has asset attr")
asset = Asset.objects.filter(id=asset_id).first()
return asset

View File

@ -27,31 +27,5 @@ class RemoteAppConnectionInfoSerializer(serializers.ModelSerializer):
return obj.attrs.get('asset') return obj.attrs.get('asset')
@staticmethod @staticmethod
def get_parameters(obj): def get_parameter_remote_app(obj):
""" return obj.get_rdp_remote_app_setting()
返回Guacamole需要的RemoteApp配置参数信息中的parameters参数
"""
from .attrs import get_serializer_class_by_application_type
serializer_class = get_serializer_class_by_application_type(obj.type)
fields = serializer_class().get_fields()
parameters = [obj.type]
for field_name in list(fields.keys()):
if field_name in ['asset']:
continue
value = obj.attrs.get(field_name)
if not value:
continue
if field_name == 'path':
value = '\"%s\"' % value
parameters.append(str(value))
parameters = ' '.join(parameters)
return parameters
def get_parameter_remote_app(self, obj):
return {
'program': '||jmservisor',
'working_directory': '',
'parameters': self.get_parameters(obj)
}

View File

@ -33,6 +33,10 @@ class AdminUserViewSet(OrgBulkModelViewSet):
search_fields = filterset_fields search_fields = filterset_fields
serializer_class = serializers.AdminUserSerializer serializer_class = serializers.AdminUserSerializer
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_classes = {
'default': serializers.AdminUserSerializer,
'retrieve': serializers.AdminUserDetailSerializer,
}
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() queryset = super().get_queryset()

View File

@ -3,8 +3,6 @@
from assets.api import FilterAssetByNodeMixin from assets.api import FilterAssetByNodeMixin
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from rest_framework.generics import RetrieveAPIView from rest_framework.generics import RetrieveAPIView
from rest_framework.response import Response
from rest_framework import status
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from common.utils import get_logger, get_object_or_none from common.utils import get_logger, get_object_or_none

View File

@ -1,14 +1,15 @@
from typing import List from typing import List
from common.utils.common import timeit
from assets.models import Node, Asset from assets.models import Node, Asset
from assets.pagination import AssetLimitOffsetPagination from assets.pagination import NodeAssetTreePagination
from common.utils import lazyproperty, dict_get_any, is_uuid, get_object_or_none from common.utils import lazyproperty
from assets.utils import get_node, is_query_node_all_assets from assets.utils import get_node, is_query_node_all_assets
class SerializeToTreeNodeMixin: class SerializeToTreeNodeMixin:
permission_classes = ()
@timeit
def serialize_nodes(self, nodes: List[Node], with_asset_amount=False): def serialize_nodes(self, nodes: List[Node], with_asset_amount=False):
if with_asset_amount: if with_asset_amount:
def _name(node: Node): def _name(node: Node):
@ -45,6 +46,7 @@ class SerializeToTreeNodeMixin:
return platform return platform
return default return default
@timeit
def serialize_assets(self, assets, node_key=None): def serialize_assets(self, assets, node_key=None):
if node_key is None: if node_key is None:
get_pid = lambda asset: getattr(asset, 'parent_key', '') get_pid = lambda asset: getattr(asset, 'parent_key', '')
@ -79,7 +81,7 @@ class SerializeToTreeNodeMixin:
class FilterAssetByNodeMixin: class FilterAssetByNodeMixin:
pagination_class = AssetLimitOffsetPagination pagination_class = NodeAssetTreePagination
@lazyproperty @lazyproperty
def is_query_node_all_assets(self): def is_query_node_all_assets(self):

View File

@ -8,7 +8,6 @@ from rest_framework.response import Response
from rest_framework.decorators import action from rest_framework.decorators import action
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.shortcuts import get_object_or_404, Http404 from django.shortcuts import get_object_or_404, Http404
from django.utils.decorators import method_decorator
from django.db.models.signals import m2m_changed from django.db.models.signals import m2m_changed
from common.const.http import POST from common.const.http import POST
@ -17,17 +16,15 @@ from common.const.signals import PRE_REMOVE, POST_REMOVE
from assets.models import Asset from assets.models import Asset
from common.utils import get_logger, get_object_or_none from common.utils import get_logger, get_object_or_none
from common.tree import TreeNodeSerializer from common.tree import TreeNodeSerializer
from common.const.distributed_lock_key import UPDATE_NODE_TREE_LOCK_KEY
from orgs.mixins.api import OrgModelViewSet from orgs.mixins.api import OrgModelViewSet
from orgs.mixins import generics from orgs.mixins import generics
from orgs.lock import org_level_transaction_lock
from orgs.utils import current_org from orgs.utils import current_org
from assets.tasks import check_node_assets_amount_task
from ..hands import IsOrgAdmin from ..hands import IsOrgAdmin
from ..models import Node from ..models import Node
from ..tasks import ( from ..tasks import (
update_node_assets_hardware_info_manual, update_node_assets_hardware_info_manual,
test_node_assets_connectivity_manual, test_node_assets_connectivity_manual,
check_node_assets_amount_task
) )
from .. import serializers from .. import serializers
from .mixin import SerializeToTreeNodeMixin from .mixin import SerializeToTreeNodeMixin
@ -50,17 +47,17 @@ class NodeViewSet(OrgModelViewSet):
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = serializers.NodeSerializer serializer_class = serializers.NodeSerializer
@action(methods=[POST], detail=False, url_name='launch-check-assets-amount-task')
def launch_check_assets_amount_task(self, request):
task = check_node_assets_amount_task.delay(current_org.id)
return Response(data={'task': task.id})
# 仅支持根节点指直接创建子节点下的节点需要通过children接口创建 # 仅支持根节点指直接创建子节点下的节点需要通过children接口创建
def perform_create(self, serializer): def perform_create(self, serializer):
child_key = Node.org_root().get_next_child_key() child_key = Node.org_root().get_next_child_key()
serializer.validated_data["key"] = child_key serializer.validated_data["key"] = child_key
serializer.save() serializer.save()
@action(methods=[POST], detail=False, url_path='check_assets_amount_task')
def check_assets_amount_task(self, request):
task = check_node_assets_amount_task.delay(current_org.id)
return Response(data={'task': task.id})
def perform_update(self, serializer): def perform_update(self, serializer):
node = self.get_object() node = self.get_object()
if node.is_org_root() and node.value != serializer.validated_data['value']: if node.is_org_root() and node.value != serializer.validated_data['value']:
@ -130,9 +127,13 @@ class NodeChildrenApi(generics.ListCreateAPIView):
def get_object(self): def get_object(self):
pk = self.kwargs.get('pk') or self.request.query_params.get('id') pk = self.kwargs.get('pk') or self.request.query_params.get('id')
key = self.request.query_params.get("key") key = self.request.query_params.get("key")
if not pk and not key: if not pk and not key:
node = Node.org_root()
self.is_initial = True self.is_initial = True
if current_org.is_root():
node = None
else:
node = Node.org_root()
return node return node
if pk: if pk:
node = get_object_or_404(Node, pk=pk) node = get_object_or_404(Node, pk=pk)
@ -140,16 +141,26 @@ class NodeChildrenApi(generics.ListCreateAPIView):
node = get_object_or_404(Node, key=key) node = get_object_or_404(Node, key=key)
return node return node
def get_org_root_queryset(self, query_all):
if query_all:
return Node.objects.all()
else:
return Node.org_root_nodes()
def get_queryset(self): def get_queryset(self):
query_all = self.request.query_params.get("all", "0") == "all" query_all = self.request.query_params.get("all", "0") == "all"
if not self.instance:
return Node.objects.none() if self.is_initial and current_org.is_root():
return self.get_org_root_queryset(query_all)
if self.is_initial: if self.is_initial:
with_self = True with_self = True
else: else:
with_self = False with_self = False
if not self.instance:
return Node.objects.none()
if query_all: if query_all:
queryset = self.instance.get_all_children(with_self=with_self) queryset = self.instance.get_all_children(with_self=with_self)
else: else:
@ -181,12 +192,12 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
def get_assets(self): def get_assets(self):
include_assets = self.request.query_params.get('assets', '0') == '1' include_assets = self.request.query_params.get('assets', '0') == '1'
if not include_assets: if not self.instance or not include_assets:
return [] return []
assets = self.instance.get_assets().only( assets = self.instance.get_assets().only(
"id", "hostname", "ip", "os", "id", "hostname", "ip", "os", "platform_id",
"org_id", "protocols", "is_active" "org_id", "protocols", "is_active",
) ).prefetch_related('platform')
return self.serialize_assets(assets, self.instance.key) return self.serialize_assets(assets, self.instance.key)
@ -212,15 +223,13 @@ class NodeAddChildrenApi(generics.UpdateAPIView):
def put(self, request, *args, **kwargs): def put(self, request, *args, **kwargs):
instance = self.get_object() instance = self.get_object()
nodes_id = request.data.get("nodes") node_ids = request.data.get("nodes")
children = Node.objects.filter(id__in=nodes_id) children = Node.objects.filter(id__in=node_ids)
for node in children: for node in children:
node.parent = instance node.parent = instance
return Response("OK") return Response("OK")
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class NodeAddAssetsApi(generics.UpdateAPIView): class NodeAddAssetsApi(generics.UpdateAPIView):
model = Node model = Node
serializer_class = serializers.NodeAssetsSerializer serializer_class = serializers.NodeAssetsSerializer
@ -233,8 +242,6 @@ class NodeAddAssetsApi(generics.UpdateAPIView):
instance.assets.add(*tuple(assets)) instance.assets.add(*tuple(assets))
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class NodeRemoveAssetsApi(generics.UpdateAPIView): class NodeRemoveAssetsApi(generics.UpdateAPIView):
model = Node model = Node
serializer_class = serializers.NodeAssetsSerializer serializer_class = serializers.NodeAssetsSerializer
@ -247,12 +254,13 @@ class NodeRemoveAssetsApi(generics.UpdateAPIView):
node.assets.remove(*assets) node.assets.remove(*assets)
# 把孤儿资产添加到 root 节点 # 把孤儿资产添加到 root 节点
orphan_assets = Asset.objects.filter(id__in=[a.id for a in assets], nodes__isnull=True).distinct() orphan_assets = Asset.objects.filter(
id__in=[a.id for a in assets],
nodes__isnull=True
).distinct()
Node.org_root().assets.add(*orphan_assets) Node.org_root().assets.add(*orphan_assets)
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class MoveAssetsToNodeApi(generics.UpdateAPIView): class MoveAssetsToNodeApi(generics.UpdateAPIView):
model = Node model = Node
serializer_class = serializers.NodeAssetsSerializer serializer_class = serializers.NodeAssetsSerializer

View File

@ -87,13 +87,13 @@ class SystemUserTaskApi(generics.CreateAPIView):
permission_classes = (IsOrgAdmin,) permission_classes = (IsOrgAdmin,)
serializer_class = serializers.SystemUserTaskSerializer serializer_class = serializers.SystemUserTaskSerializer
def do_push(self, system_user, assets_id=None): def do_push(self, system_user, asset_ids=None):
if assets_id is None: if asset_ids is None:
task = push_system_user_to_assets_manual.delay(system_user) task = push_system_user_to_assets_manual.delay(system_user)
else: else:
username = self.request.query_params.get('username') username = self.request.query_params.get('username')
task = push_system_user_to_assets.delay( task = push_system_user_to_assets.delay(
system_user.id, assets_id, username=username system_user.id, asset_ids, username=username
) )
return task return task
@ -114,9 +114,9 @@ class SystemUserTaskApi(generics.CreateAPIView):
system_user = self.get_object() system_user = self.get_object()
if action == 'push': if action == 'push':
assets = [asset] if asset else assets assets = [asset] if asset else assets
assets_id = [asset.id for asset in assets] asset_ids = [asset.id for asset in assets]
assets_id = assets_id if assets_id else None asset_ids = asset_ids if asset_ids else None
task = self.do_push(system_user, assets_id) task = self.do_push(system_user, asset_ids)
else: else:
task = self.do_test(system_user) task = self.do_test(system_user)
data = getattr(serializer, '_data', {}) data = getattr(serializer, '_data', {})

View File

@ -40,7 +40,7 @@ class BaseBackend:
return values return values
@staticmethod @staticmethod
def make_assets_as_id(assets): def make_assets_as_ids(assets):
if not assets: if not assets:
return [] return []
if isinstance(assets[0], Asset): if isinstance(assets[0], Asset):

View File

@ -69,9 +69,9 @@ class DBBackend(BaseBackend):
self.queryset = self.queryset.filter(union_id=union_id) self.queryset = self.queryset.filter(union_id=union_id)
def _filter_assets(self, assets): def _filter_assets(self, assets):
assets_id = self.make_assets_as_id(assets) asset_ids = self.make_assets_as_ids(assets)
if assets_id: if asset_ids:
self.queryset = self.queryset.filter(asset_id__in=assets_id) self.queryset = self.queryset.filter(asset_id__in=asset_ids)
def _filter_node(self, node): def _filter_node(self, node):
pass pass

20
apps/assets/locks.py Normal file
View File

@ -0,0 +1,20 @@
from orgs.utils import current_org
from common.utils.lock import DistributedLock
class NodeTreeUpdateLock(DistributedLock):
name_template = 'assets.node.tree.update.<org_id:{org_id}>'
def get_name(self):
if current_org:
org_id = current_org.id
else:
org_id = 'current_org_is_null'
name = self.name_template.format(
org_id=org_id
)
return name
def __init__(self):
name = self.get_name()
super().__init__(name=name, release_on_transaction_commit=True, reentrant=True)

View File

@ -0,0 +1,17 @@
# Generated by Django 3.1 on 2021-02-08 10:02
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('assets', '0065_auto_20210121_1549'),
]
operations = [
migrations.AlterModelOptions(
name='asset',
options={'ordering': ['hostname'], 'verbose_name': 'Asset'},
),
]

View File

@ -0,0 +1,48 @@
# Generated by Django 3.1 on 2021-03-11 03:13
import django.core.validators
from django.db import migrations, models
def migrate_cmd_filter_priority(apps, schema_editor):
cmd_filter_rule_model = apps.get_model('assets', 'CommandFilterRule')
cmd_filter_rules = cmd_filter_rule_model.objects.all()
for cmd_filter_rule in cmd_filter_rules:
cmd_filter_rule.priority = 100 - cmd_filter_rule.priority + 1
cmd_filter_rule_model.objects.bulk_update(cmd_filter_rules, fields=['priority'])
def migrate_system_user_priority(apps, schema_editor):
system_user_model = apps.get_model('assets', 'SystemUser')
system_users = system_user_model.objects.all()
for system_user in system_users:
system_user.priority = 100 - system_user.priority + 1
system_user_model.objects.bulk_update(system_users, fields=['priority'])
class Migration(migrations.Migration):
dependencies = [
('assets', '0066_auto_20210208_1802'),
]
operations = [
migrations.RunPython(migrate_cmd_filter_priority),
migrations.RunPython(migrate_system_user_priority),
migrations.AlterModelOptions(
name='commandfilterrule',
options={'ordering': ('priority', 'action'), 'verbose_name': 'Command filter rule'},
),
migrations.AlterField(
model_name='commandfilterrule',
name='priority',
field=models.IntegerField(default=50, help_text='1-100, the lower the value will be match first', validators=[django.core.validators.MinValueValidator(1), django.core.validators.MaxValueValidator(100)], verbose_name='Priority'),
),
migrations.AlterField(
model_name='systemuser',
name='priority',
field=models.IntegerField(default=20, help_text='1-100, the lower the value will be match first', validators=[django.core.validators.MinValueValidator(1), django.core.validators.MaxValueValidator(100)], verbose_name='Priority'),
),
]

View File

@ -17,7 +17,7 @@ from orgs.mixins.models import OrgModelMixin, OrgManager
from .base import ConnectivityMixin from .base import ConnectivityMixin
from .utils import Connectivity from .utils import Connectivity
__all__ = ['Asset', 'ProtocolsMixin', 'Platform'] __all__ = ['Asset', 'ProtocolsMixin', 'Platform', 'AssetQuerySet']
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,13 +41,6 @@ def default_node():
class AssetManager(OrgManager): class AssetManager(OrgManager):
def get_queryset(self):
return super().get_queryset().annotate(
platform_base=models.F('platform__base')
)
class AssetOrgManager(OrgManager):
pass pass
@ -230,7 +223,6 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
comment = models.TextField(default='', blank=True, verbose_name=_('Comment')) comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
objects = AssetManager.from_queryset(AssetQuerySet)() objects = AssetManager.from_queryset(AssetQuerySet)()
org_objects = AssetOrgManager.from_queryset(AssetQuerySet)()
_connectivity = None _connectivity = None
def __str__(self): def __str__(self):
@ -361,4 +353,4 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
class Meta: class Meta:
unique_together = [('org_id', 'hostname')] unique_together = [('org_id', 'hostname')]
verbose_name = _("Asset") verbose_name = _("Asset")
ordering = ["hostname", "ip"] ordering = ["hostname", ]

View File

@ -11,9 +11,12 @@ from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.conf import settings from django.conf import settings
from common.db.models import ChoiceSet
from common.utils import random_string
from common.utils import ( from common.utils import (
ssh_key_string_to_obj, ssh_key_gen, get_logger, lazyproperty ssh_key_string_to_obj, ssh_key_gen, get_logger, lazyproperty
) )
from common.utils.encode import ssh_pubkey_gen
from common.validators import alphanumeric from common.validators import alphanumeric
from common import fields from common import fields
from orgs.mixins.models import OrgModelMixin from orgs.mixins.models import OrgModelMixin
@ -105,6 +108,19 @@ class AuthMixin:
username = '' username = ''
_prefer = 'system_user' _prefer = 'system_user'
@property
def ssh_key_fingerprint(self):
if self.public_key:
public_key = self.public_key
elif self.private_key:
public_key = ssh_pubkey_gen(self.private_key, self.password)
else:
return ''
public_key_obj = sshpubkeys.SSHKey(public_key)
fingerprint = public_key_obj.hash_md5()
return fingerprint
@property @property
def private_key_obj(self): def private_key_obj(self):
if self.private_key: if self.private_key:
@ -204,8 +220,8 @@ class AuthMixin:
self.save() self.save()
@staticmethod @staticmethod
def gen_password(): def gen_password(length=36):
return str(uuid.uuid4()) return random_string(length, special_char=True)
@staticmethod @staticmethod
def gen_key(username): def gen_key(username):

View File

@ -50,7 +50,7 @@ class CommandFilterRule(OrgModelMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True) id = models.UUIDField(default=uuid.uuid4, primary_key=True)
filter = models.ForeignKey('CommandFilter', on_delete=models.CASCADE, verbose_name=_("Filter"), related_name='rules') filter = models.ForeignKey('CommandFilter', on_delete=models.CASCADE, verbose_name=_("Filter"), related_name='rules')
type = models.CharField(max_length=16, default=TYPE_COMMAND, choices=TYPE_CHOICES, verbose_name=_("Type")) type = models.CharField(max_length=16, default=TYPE_COMMAND, choices=TYPE_CHOICES, verbose_name=_("Type"))
priority = models.IntegerField(default=50, verbose_name=_("Priority"), help_text=_("1-100, the higher will be match first"), priority = models.IntegerField(default=50, verbose_name=_("Priority"), help_text=_("1-100, the lower the value will be match first"),
validators=[MinValueValidator(1), MaxValueValidator(100)]) validators=[MinValueValidator(1), MaxValueValidator(100)])
content = models.TextField(verbose_name=_("Content"), help_text=_("One line one command")) content = models.TextField(verbose_name=_("Content"), help_text=_("One line one command"))
action = models.IntegerField(default=ACTION_DENY, choices=ACTION_CHOICES, verbose_name=_("Action")) action = models.IntegerField(default=ACTION_DENY, choices=ACTION_CHOICES, verbose_name=_("Action"))
@ -60,7 +60,7 @@ class CommandFilterRule(OrgModelMixin):
created_by = models.CharField(max_length=128, blank=True, default='', verbose_name=_('Created by')) created_by = models.CharField(max_length=128, blank=True, default='', verbose_name=_('Created by'))
class Meta: class Meta:
ordering = ('-priority', 'action') ordering = ('priority', 'action')
verbose_name = _("Command filter rule") verbose_name = _("Command filter rule")
@lazyproperty @lazyproperty

View File

@ -16,17 +16,5 @@ class FavoriteAsset(CommonModelMixin):
unique_together = ('user', 'asset') unique_together = ('user', 'asset')
@classmethod @classmethod
def get_user_favorite_assets_id(cls, user): def get_user_favorite_asset_ids(cls, user):
return cls.objects.filter(user=user).values_list('asset', flat=True) return cls.objects.filter(user=user).values_list('asset', flat=True)
@classmethod
def get_user_favorite_assets(cls, user, asset_perms_id=None):
from assets.models import Asset
from perms.utils.asset.user_permission import get_user_granted_all_assets
asset_ids = get_user_granted_all_assets(
user,
via_mapping_node=False,
asset_perms_id=asset_perms_id
).values_list('id', flat=True)
query_name = cls.asset.field.related_query_name()
return Asset.org_objects.filter(**{f'{query_name}__user_id': user.id}, id__in=asset_ids).distinct()

View File

@ -1,23 +1,32 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import uuid
import re import re
import time
import uuid
import threading
import os
import time
import uuid
from collections import defaultdict
from django.db import models, transaction from django.db import models, transaction
from django.db.models import Q from django.db.models import Q, Manager
from django.db.utils import IntegrityError from django.db.utils import IntegrityError
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ugettext from django.utils.translation import ugettext
from django.db.transaction import atomic from django.db.transaction import atomic
from django.core.cache import cache
from common.utils.lock import DistributedLock
from common.utils.common import timeit
from common.db.models import output_as_string
from common.utils import get_logger from common.utils import get_logger
from common.utils.common import lazyproperty
from orgs.mixins.models import OrgModelMixin, OrgManager from orgs.mixins.models import OrgModelMixin, OrgManager
from orgs.utils import get_current_org, tmp_to_org from orgs.utils import get_current_org, tmp_to_org
from orgs.models import Organization from orgs.models import Organization
__all__ = ['Node', 'FamilyMixin', 'compute_parent_key'] __all__ = ['Node', 'FamilyMixin', 'compute_parent_key', 'NodeQuerySet']
logger = get_logger(__name__) logger = get_logger(__name__)
@ -247,9 +256,147 @@ class FamilyMixin:
return [*tuple(ancestors), self, *tuple(children)] return [*tuple(ancestors), self, *tuple(children)]
class NodeAssetsMixin: class NodeAllAssetsMappingMixin:
# Use a new plan
# { org_id: { node_key: [ asset1_id, asset2_id ] } }
orgid_nodekey_assetsid_mapping = defaultdict(dict)
locks_for_get_mapping_from_cache = defaultdict(threading.Lock)
@classmethod
def get_lock(cls, org_id):
lock = cls.locks_for_get_mapping_from_cache[str(org_id)]
return lock
@classmethod
def get_node_all_asset_ids_mapping(cls, org_id):
_mapping = cls.get_node_all_asset_ids_mapping_from_memory(org_id)
if _mapping:
return _mapping
logger.debug(f'Get node asset mapping from memory failed, acquire thread lock: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
with cls.get_lock(org_id):
logger.debug(f'Acquired thread lock ok. check if mapping is in memory now: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
_mapping = cls.get_node_all_asset_ids_mapping_from_memory(org_id)
if _mapping:
logger.debug(f'Mapping is already in memory now: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
return _mapping
_mapping = cls.get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(org_id)
cls.set_node_all_asset_ids_mapping_to_memory(org_id, mapping=_mapping)
return _mapping
# from memory
@classmethod
def get_node_all_asset_ids_mapping_from_memory(cls, org_id):
mapping = cls.orgid_nodekey_assetsid_mapping.get(org_id, {})
return mapping
@classmethod
def set_node_all_asset_ids_mapping_to_memory(cls, org_id, mapping):
cls.orgid_nodekey_assetsid_mapping[org_id] = mapping
@classmethod
def expire_node_all_asset_ids_mapping_from_memory(cls, org_id):
org_id = str(org_id)
cls.orgid_nodekey_assetsid_mapping.pop(org_id, None)
# get order: from memory -> (from cache -> to generate)
@classmethod
def get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(cls, org_id):
mapping = cls.get_node_all_asset_ids_mapping_from_cache(org_id)
if mapping:
return mapping
lock_key = f'KEY_LOCK_GENERATE_ORG_{org_id}_NODE_ALL_ASSET_ids_MAPPING'
with DistributedLock(lock_key):
# 这里使用无限期锁,原因是如果这里卡住了,就卡在数据库了,说明
# 数据库繁忙,所以不应该再有线程执行这个操作,使数据库忙上加忙
_mapping = cls.get_node_all_asset_ids_mapping_from_cache(org_id)
if _mapping:
return _mapping
_mapping = cls.generate_node_all_asset_ids_mapping(org_id)
cls.set_node_all_asset_ids_mapping_to_cache(org_id=org_id, mapping=_mapping)
return _mapping
@classmethod
def get_node_all_asset_ids_mapping_from_cache(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
mapping = cache.get(cache_key)
logger.info(f'Get node asset mapping from cache {bool(mapping)}: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
return mapping
@classmethod
def set_node_all_asset_ids_mapping_to_cache(cls, org_id, mapping):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.set(cache_key, mapping, timeout=None)
@classmethod
def expire_node_all_asset_ids_mapping_from_cache(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.delete(cache_key)
@staticmethod
def _get_cache_key_for_node_all_asset_ids_mapping(org_id):
return 'ASSETS_ORG_NODE_ALL_ASSET_ids_MAPPING_{}'.format(org_id)
@classmethod
def generate_node_all_asset_ids_mapping(cls, org_id):
from .asset import Asset
logger.info(f'Generate node asset mapping: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
t1 = time.time()
with tmp_to_org(org_id):
node_ids_key = Node.objects.annotate(
char_id=output_as_string('id')
).values_list('char_id', 'key')
# * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢)
nodes_asset_ids = Asset.nodes.through.objects.all() \
.annotate(char_node_id=output_as_string('node_id')) \
.annotate(char_asset_id=output_as_string('asset_id')) \
.values_list('char_node_id', 'char_asset_id')
node_id_ancestor_keys_mapping = {
node_id: cls.get_node_ancestor_keys(node_key, with_self=True)
for node_id, node_key in node_ids_key
}
nodeid_assetsid_mapping = defaultdict(set)
for node_id, asset_id in nodes_asset_ids:
nodeid_assetsid_mapping[node_id].add(asset_id)
t2 = time.time()
mapping = defaultdict(set)
for node_id, node_key in node_ids_key:
asset_ids = nodeid_assetsid_mapping[node_id]
node_ancestor_keys = node_id_ancestor_keys_mapping[node_id]
for ancestor_key in node_ancestor_keys:
mapping[ancestor_key].update(asset_ids)
t3 = time.time()
logger.info('t1-t2(DB Query): {} s, t3-t2(Generate mapping): {} s'.format(t2-t1, t3-t2))
return mapping
class NodeAssetsMixin(NodeAllAssetsMappingMixin):
org_id: str
key = '' key = ''
id = None id = None
objects: Manager
def get_all_assets(self): def get_all_assets(self):
from .asset import Asset from .asset import Asset
@ -263,8 +410,7 @@ class NodeAssetsMixin:
# 可是 startswith 会导致表关联时 Asset 索引失效 # 可是 startswith 会导致表关联时 Asset 索引失效
from .asset import Asset from .asset import Asset
node_ids = cls.objects.filter( node_ids = cls.objects.filter(
Q(key__startswith=f'{key}:') | Q(key__startswith=f'{key}:') | Q(key=key)
Q(key=key)
).values_list('id', flat=True).distinct() ).values_list('id', flat=True).distinct()
assets = Asset.objects.filter( assets = Asset.objects.filter(
nodes__id__in=list(node_ids) nodes__id__in=list(node_ids)
@ -283,29 +429,34 @@ class NodeAssetsMixin:
return self.get_all_assets().valid() return self.get_all_assets().valid()
@classmethod @classmethod
def get_nodes_all_assets_ids(cls, nodes_keys): def get_nodes_all_asset_ids_by_keys(cls, nodes_keys):
assets_ids = cls.get_nodes_all_assets(nodes_keys).values_list('id', flat=True) nodes = Node.objects.filter(key__in=nodes_keys)
return assets_ids asset_ids = cls.get_nodes_all_assets(*nodes).values_list('id', flat=True)
return asset_ids
@classmethod @classmethod
def get_nodes_all_assets(cls, nodes_keys, extra_assets_ids=None): def get_nodes_all_assets(cls, *nodes):
from .asset import Asset from .asset import Asset
nodes_keys = cls.clean_children_keys(nodes_keys) node_ids = set()
q = Q() descendant_node_query = Q()
node_ids = () for n in nodes:
for key in nodes_keys: node_ids.add(n.id)
q |= Q(key__startswith=f'{key}:') descendant_node_query |= Q(key__istartswith=f'{n.key}:')
q |= Q(key=key) if descendant_node_query:
if q: _ids = Node.objects.order_by().filter(descendant_node_query).values_list('id', flat=True)
node_ids = Node.objects.filter(q).distinct().values_list('id', flat=True) node_ids.update(_ids)
return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct()
q = Q(nodes__id__in=list(node_ids)) def get_all_asset_ids(self):
if extra_assets_ids: asset_ids = self.get_all_asset_ids_by_node_key(org_id=self.org_id, node_key=self.key)
q |= Q(id__in=extra_assets_ids) return set(asset_ids)
if q:
return Asset.org_objects.filter(q).distinct() @classmethod
else: def get_all_asset_ids_by_node_key(cls, org_id, node_key):
return Asset.objects.none() org_id = str(org_id)
nodekey_assetsid_mapping = cls.get_node_all_asset_ids_mapping(org_id)
asset_ids = nodekey_assetsid_mapping.get(node_key, [])
return set(asset_ids)
class SomeNodesMixin: class SomeNodesMixin:
@ -317,8 +468,9 @@ class SomeNodesMixin:
@classmethod @classmethod
def default_node(cls): def default_node(cls):
with tmp_to_org(Organization.default()): default_org = Organization.default()
defaults = {'value': cls.default_value} with tmp_to_org(default_org):
defaults = {'value': default_org.name}
try: try:
obj, created = cls.objects.get_or_create( obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key, defaults=defaults, key=cls.default_key,
@ -353,25 +505,40 @@ class SomeNodesMixin:
@classmethod @classmethod
def create_org_root_node(cls): def create_org_root_node(cls):
# 如果使用current_org 在set_current_org时会死循环
ori_org = get_current_org() ori_org = get_current_org()
with transaction.atomic(): with transaction.atomic():
if not ori_org.is_real():
return cls.default_node()
key = cls.get_next_org_root_node_key() key = cls.get_next_org_root_node_key()
root = cls.objects.create(key=key, value=ori_org.name) root = cls.objects.create(key=key, value=ori_org.name)
return root return root
@classmethod @classmethod
def org_root(cls): def org_root_nodes(cls):
root = cls.objects.filter(parent_key='')\ nodes = cls.objects.filter(parent_key='') \
.filter(key__regex=r'^[0-9]+$') \ .filter(key__regex=r'^[0-9]+$') \
.exclude(key__startswith='-') \ .exclude(key__startswith='-') \
.order_by('key') .order_by('key')
if root: return nodes
return root[0]
@classmethod
def org_root(cls):
# 如果使用current_org 在set_current_org时会死循环
ori_org = get_current_org()
if ori_org and ori_org.is_default():
return cls.default_node()
if ori_org and ori_org.is_root():
return None
org_roots = cls.org_root_nodes()
org_roots_length = len(org_roots)
if org_roots_length == 1:
return org_roots[0]
elif org_roots_length == 0:
root = cls.create_org_root_node()
return root
else: else:
return cls.create_org_root_node() raise ValueError('Current org root node not 1, get {}'.format(org_roots_length))
@classmethod @classmethod
def initial_some_nodes(cls): def initial_some_nodes(cls):
@ -390,8 +557,9 @@ class SomeNodesMixin:
if not node_key1: if not node_key1:
logger.info("Not found node that `key` = 1") logger.info("Not found node that `key` = 1")
return return
if not node_key1.org.is_real(): if node_key1.org_id == '':
logger.info("Org is not real for node that `key` = 1") node_key1.org_id = str(Organization.default().id)
node_key1.save()
return return
with transaction.atomic(): with transaction.atomic():

View File

@ -116,7 +116,7 @@ class SystemUser(BaseUser):
assets = models.ManyToManyField('assets.Asset', blank=True, verbose_name=_("Assets")) assets = models.ManyToManyField('assets.Asset', blank=True, verbose_name=_("Assets"))
users = models.ManyToManyField('users.User', blank=True, verbose_name=_("Users")) users = models.ManyToManyField('users.User', blank=True, verbose_name=_("Users"))
groups = models.ManyToManyField('users.UserGroup', blank=True, verbose_name=_("User groups")) groups = models.ManyToManyField('users.UserGroup', blank=True, verbose_name=_("User groups"))
priority = models.IntegerField(default=20, verbose_name=_("Priority"), validators=[MinValueValidator(1), MaxValueValidator(100)]) priority = models.IntegerField(default=20, verbose_name=_("Priority"), help_text=_("1-100, the lower the value will be match first"), validators=[MinValueValidator(1), MaxValueValidator(100)])
protocol = models.CharField(max_length=16, choices=PROTOCOL_CHOICES, default='ssh', verbose_name=_('Protocol')) protocol = models.CharField(max_length=16, choices=PROTOCOL_CHOICES, default='ssh', verbose_name=_('Protocol'))
auto_push = models.BooleanField(default=True, verbose_name=_('Auto push')) auto_push = models.BooleanField(default=True, verbose_name=_('Auto push'))
sudo = models.TextField(default='/bin/whoami', verbose_name=_('Sudo')) sudo = models.TextField(default='/bin/whoami', verbose_name=_('Sudo'))
@ -198,10 +198,10 @@ class SystemUser(BaseUser):
def get_all_assets(self): def get_all_assets(self):
from assets.models import Node from assets.models import Node
nodes_keys = self.nodes.all().values_list('key', flat=True) nodes_keys = self.nodes.all().values_list('key', flat=True)
assets_ids = set(self.assets.all().values_list('id', flat=True)) asset_ids = set(self.assets.all().values_list('id', flat=True))
nodes_assets_ids = Node.get_nodes_all_assets_ids(nodes_keys) nodes_asset_ids = Node.get_nodes_all_asset_ids_by_keys(nodes_keys)
assets_ids.update(nodes_assets_ids) asset_ids.update(nodes_asset_ids)
assets = Asset.objects.filter(id__in=assets_ids) assets = Asset.objects.filter(id__in=asset_ids)
return assets return assets
@classmethod @classmethod

View File

@ -1,39 +1,52 @@
from rest_framework.pagination import LimitOffsetPagination from rest_framework.pagination import LimitOffsetPagination
from rest_framework.request import Request from rest_framework.request import Request
from common.utils import get_logger
from assets.models import Node from assets.models import Node
logger = get_logger(__name__)
class AssetPaginationBase(LimitOffsetPagination):
def init_attrs(self, queryset, request: Request, view=None):
self._request = request
self._view = view
self._user = request.user
def paginate_queryset(self, queryset, request: Request, view=None):
self.init_attrs(queryset, request, view)
return super().paginate_queryset(queryset, request, view=None)
class AssetLimitOffsetPagination(LimitOffsetPagination):
"""
需要与 `assets.api.mixin.FilterAssetByNodeMixin` 配合使用
"""
def get_count(self, queryset): def get_count(self, queryset):
"""
1. 如果查询节点下的所有资产 count 使用 Node.assets_amount
2. 如果有其他过滤条件使用 super
3. 如果只查询该节点下的资产使用 super
"""
exclude_query_params = { exclude_query_params = {
self.limit_query_param, self.limit_query_param,
self.offset_query_param, self.offset_query_param,
'node', 'all', 'show_current_asset', 'key', 'all', 'show_current_asset',
'node_id', 'display', 'draw', 'fields_size', 'cache_policy', 'display', 'draw',
'order', 'node', 'node_id', 'fields_size',
} }
for k, v in self._request.query_params.items(): for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None: if k not in exclude_query_params and v is not None:
logger.warn(f'Not hit node.assets_amount because find a unknow query_param `{k}` -> {self._request.get_full_path()}')
return super().get_count(queryset) return super().get_count(queryset)
node_assets_count = self.get_count_from_nodes(queryset)
if node_assets_count is None:
return super().get_count(queryset)
return node_assets_count
def get_count_from_nodes(self, queryset):
raise NotImplementedError
class NodeAssetTreePagination(AssetPaginationBase):
def get_count_from_nodes(self, queryset):
is_query_all = self._view.is_query_node_all_assets is_query_all = self._view.is_query_node_all_assets
if is_query_all: if is_query_all:
node = self._view.node node = self._view.node
if not node: if not node:
node = Node.org_root() node = Node.org_root()
if node:
logger.debug(f'Hit node.assets_amount[{node.assets_amount}] -> {self._request.get_full_path()}')
return node.assets_amount return node.assets_amount
return super().get_count(queryset) return None
def paginate_queryset(self, queryset, request: Request, view=None):
self._request = request
self._view = view
return super().paginate_queryset(queryset, request, view=None)

View File

@ -3,8 +3,6 @@
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from common.drf.serializers import AdaptedBulkListSerializer
from ..models import Node, AdminUser from ..models import Node, AdminUser
from orgs.mixins.serializers import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
@ -17,7 +15,6 @@ class AdminUserSerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer):
""" """
class Meta: class Meta:
list_serializer_class = AdaptedBulkListSerializer
model = AdminUser model = AdminUser
fields = [ fields = [
'id', 'name', 'username', 'password', 'private_key', 'public_key', 'id', 'name', 'username', 'password', 'private_key', 'public_key',
@ -33,6 +30,11 @@ class AdminUserSerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer):
} }
class AdminUserDetailSerializer(AdminUserSerializer):
class Meta(AdminUserSerializer.Meta):
fields = AdminUserSerializer.Meta.fields + ['ssh_key_fingerprint']
class AdminUserAuthSerializer(AuthSerializer): class AdminUserAuthSerializer(AuthSerializer):
class Meta: class Meta:

View File

@ -111,7 +111,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
@classmethod @classmethod
def setup_eager_loading(cls, queryset): def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """ """ Perform necessary eager loading of data. """
queryset = queryset.select_related('admin_user', 'domain', 'platform') queryset = queryset.prefetch_related('admin_user', 'domain', 'platform')
queryset = queryset.prefetch_related('nodes', 'labels') queryset = queryset.prefetch_related('nodes', 'labels')
return queryset return queryset
@ -166,16 +166,9 @@ class AssetDisplaySerializer(AssetSerializer):
'connectivity', 'connectivity',
] ]
@classmethod
def setup_eager_loading(cls, queryset):
queryset = super().setup_eager_loading(queryset)
queryset = queryset\
.annotate(admin_user_username=F('admin_user__username'))
return queryset
class PlatformSerializer(serializers.ModelSerializer): class PlatformSerializer(serializers.ModelSerializer):
meta = serializers.DictField(required=False, allow_null=True) meta = serializers.DictField(required=False, allow_null=True, label=_('Meta'))
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View File

@ -41,10 +41,6 @@ class AuthSerializerMixin:
def validate_private_key(self, private_key): def validate_private_key(self, private_key):
if not private_key: if not private_key:
return return
if 'OPENSSH' in private_key:
msg = _("Not support openssh format key, using "
"ssh-keygen -t rsa -m pem to generate")
raise serializers.ValidationError(msg)
password = self.initial_data.get("password") password = self.initial_data.get("password")
valid = validate_ssh_private_key(private_key, password) valid = validate_ssh_private_key(private_key, password)
if not valid: if not valid:

View File

@ -77,8 +77,6 @@ class GatewayWithAuthSerializer(GatewaySerializer):
return fields return fields
class DomainWithGatewaySerializer(BulkOrgResourceModelSerializer): class DomainWithGatewaySerializer(BulkOrgResourceModelSerializer):
gateways = GatewayWithAuthSerializer(many=True, read_only=True) gateways = GatewayWithAuthSerializer(many=True, read_only=True)

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from rest_framework import serializers from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
from common.drf.serializers import AdaptedBulkListSerializer from common.drf.serializers import AdaptedBulkListSerializer
from orgs.mixins.serializers import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
@ -9,16 +10,17 @@ from ..models import Label
class LabelSerializer(BulkOrgResourceModelSerializer): class LabelSerializer(BulkOrgResourceModelSerializer):
asset_count = serializers.SerializerMethodField() asset_count = serializers.SerializerMethodField(label=_("Assets amount"))
category_display = serializers.ReadOnlyField(source='get_category_display', label=_('Category display'))
class Meta: class Meta:
model = Label model = Label
fields = [ fields = [
'id', 'name', 'value', 'category', 'is_active', 'comment', 'id', 'name', 'value', 'category', 'is_active', 'comment',
'date_created', 'asset_count', 'assets', 'get_category_display' 'date_created', 'asset_count', 'assets', 'category_display'
] ]
read_only_fields = ( read_only_fields = (
'category', 'date_created', 'asset_count', 'get_category_display' 'category', 'date_created', 'asset_count',
) )
extra_kwargs = { extra_kwargs = {
'assets': {'required': False} 'assets': {'required': False}

View File

@ -33,7 +33,7 @@ class SystemUserSerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer):
'priority', 'username_same_with_user', 'priority', 'username_same_with_user',
'auto_push', 'cmd_filters', 'sudo', 'shell', 'comment', 'auto_push', 'cmd_filters', 'sudo', 'shell', 'comment',
'auto_generate_key', 'sftp_root', 'token', 'auto_generate_key', 'sftp_root', 'token',
'assets_amount', 'date_created', 'created_by', 'assets_amount', 'date_created', 'date_updated', 'created_by',
'home', 'system_groups', 'ad_domain' 'home', 'system_groups', 'ad_domain'
] ]
extra_kwargs = { extra_kwargs = {
@ -155,7 +155,8 @@ class SystemUserListSerializer(SystemUserSerializer):
'auto_push', 'sudo', 'shell', 'comment', 'auto_push', 'sudo', 'shell', 'comment',
"assets_amount", 'home', 'system_groups', "assets_amount", 'home', 'system_groups',
'auto_generate_key', 'ad_domain', 'auto_generate_key', 'ad_domain',
'sftp_root', 'sftp_root', 'created_by', 'date_created',
'date_updated',
] ]
extra_kwargs = { extra_kwargs = {
'password': {"write_only": True}, 'password': {"write_only": True},

View File

@ -0,0 +1,3 @@
from .common import *
from .node_assets_amount import *
from .node_assets_mapping import *

View File

@ -1,21 +1,17 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from operator import add, sub
from assets.utils import is_asset_exists_in_node
from django.db.models.signals import ( from django.db.models.signals import (
post_save, m2m_changed, pre_delete, post_delete, pre_save post_save, m2m_changed, pre_delete, post_delete, pre_save
) )
from django.db.models import Q, F
from django.dispatch import receiver from django.dispatch import receiver
from common.exceptions import M2MReverseNotAllowed from common.exceptions import M2MReverseNotAllowed
from common.const.signals import PRE_ADD, POST_ADD, POST_REMOVE, PRE_CLEAR, PRE_REMOVE from common.const.signals import POST_ADD, POST_REMOVE, PRE_REMOVE
from common.utils import get_logger from common.utils import get_logger
from common.decorator import on_transaction_commit from common.decorator import on_transaction_commit
from .models import Asset, SystemUser, Node, compute_parent_key from assets.models import Asset, SystemUser, Node
from users.models import User from users.models import User
from .tasks import ( from assets.tasks import (
update_assets_hardware_info_util, update_assets_hardware_info_util,
test_asset_connectivity_util, test_asset_connectivity_util,
push_system_user_to_assets_manual, push_system_user_to_assets_manual,
@ -23,7 +19,6 @@ from .tasks import (
add_nodes_assets_to_system_users add_nodes_assets_to_system_users
) )
logger = get_logger(__file__) logger = get_logger(__file__)
@ -87,13 +82,13 @@ def on_system_user_assets_change(instance, action, model, pk_set, **kwargs):
return return
logger.debug("System user assets change signal recv: {}".format(instance)) logger.debug("System user assets change signal recv: {}".format(instance))
if model == Asset: if model == Asset:
system_users_id = [instance.id] system_user_ids = [instance.id]
assets_id = pk_set asset_ids = pk_set
else: else:
system_users_id = pk_set system_user_ids = pk_set
assets_id = [instance.id] asset_ids = [instance.id]
for system_user_id in system_users_id: for system_user_id in system_user_ids:
push_system_user_to_assets.delay(system_user_id, assets_id) push_system_user_to_assets.delay(system_user_id, asset_ids)
@receiver(m2m_changed, sender=SystemUser.users.through) @receiver(m2m_changed, sender=SystemUser.users.through)
@ -202,134 +197,6 @@ def on_asset_nodes_add(instance, action, reverse, pk_set, **kwargs):
m2m_model.objects.bulk_create(to_create) m2m_model.objects.bulk_create(to_create)
def _update_node_assets_amount(node: Node, asset_pk_set: set, operator=add):
"""
一个节点与多个资产关系变化时更新计数
:param node: 节点实例
:param asset_pk_set: 资产的`id`集合, 内部不会修改该值
:param operator: 操作
* -> Node
# -> Asset
* [3]
/ \
* * [2]
/ \
* * [1]
/ / \
* [a] # # [b]
"""
# 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key`
ancestor_keys = node.get_ancestor_keys(with_self=True)
ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key')
to_update = []
for ancestor in ancestors:
# 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3]
# 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的
# 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量
asset_pk_set -= set(Asset.objects.filter(
id__in=asset_pk_set
).filter(
Q(nodes__key__istartswith=f'{ancestor.key}:') |
Q(nodes__key=ancestor.key)
).distinct().values_list('id', flat=True))
if not asset_pk_set:
# 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量
# 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用
# 处理了
break
ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set))
to_update.append(ancestor)
Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key'))
def _remove_ancestor_keys(ancestor_key, tree_set):
# 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while ancestor_key and ancestor_key in tree_set:
tree_set.remove(ancestor_key)
ancestor_key = compute_parent_key(ancestor_key)
def _update_nodes_asset_amount(node_keys, asset_pk, operator):
"""
一个资产与多个节点关系变化时更新计数
:param node_keys: 节点 id 的集合
:param asset_pk: 资产 id
:param operator: 操作
"""
# 所有相关节点的祖先节点,组成一棵局部树
ancestor_keys = set()
for key in node_keys:
ancestor_keys.update(Node.get_node_ancestor_keys(key))
# 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉
node_keys -= ancestor_keys
to_update_keys = []
for key in node_keys:
# 遍历相关节点,处理它及其祖先节点
# 查询该节点是否包含待处理资产
exists = is_asset_exists_in_node(asset_pk, key)
parent_key = compute_parent_key(key)
if exists:
# 如果资产在该节点,那么他及其祖先节点都不用处理
_remove_ancestor_keys(parent_key, ancestor_keys)
continue
else:
# 不存在,要更新本节点
to_update_keys.append(key)
# 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while parent_key and parent_key in ancestor_keys:
exists = is_asset_exists_in_node(asset_pk, parent_key)
if exists:
_remove_ancestor_keys(parent_key, ancestor_keys)
break
else:
to_update_keys.append(parent_key)
ancestor_keys.remove(parent_key)
parent_key = compute_parent_key(parent_key)
Node.objects.filter(key__in=to_update_keys).update(
assets_amount=operator(F('assets_amount'), 1)
)
@receiver(m2m_changed, sender=Asset.nodes.through)
def update_nodes_assets_amount(action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
refused = (PRE_CLEAR,)
if action in refused:
raise ValueError
mapper = {
PRE_ADD: add,
POST_REMOVE: sub
}
if action not in mapper:
return
operator = mapper[action]
if reverse:
node: Node = instance
asset_pk_set = set(pk_set)
_update_node_assets_amount(node, asset_pk_set, operator)
else:
asset_pk = instance.id
# 与资产直接关联的节点
node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True))
_update_nodes_asset_amount(node_keys, asset_pk, operator)
RELATED_NODE_IDS = '_related_node_ids' RELATED_NODE_IDS = '_related_node_ids'

View File

@ -0,0 +1,159 @@
# -*- coding: utf-8 -*-
#
from operator import add, sub
from django.db.models import Q, F
from django.dispatch import receiver
from django.db.models.signals import (
m2m_changed
)
from orgs.utils import ensure_in_real_or_default_org
from common.const.signals import PRE_ADD, POST_REMOVE, PRE_CLEAR
from common.utils import get_logger
from assets.models import Asset, Node, compute_parent_key
from assets.locks import NodeTreeUpdateLock
logger = get_logger(__file__)
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
refused = (PRE_CLEAR,)
if action in refused:
raise ValueError
mapper = {
PRE_ADD: add,
POST_REMOVE: sub
}
if action not in mapper:
return
operator = mapper[action]
if reverse:
node: Node = instance
asset_pk_set = set(pk_set)
NodeAssetsAmountUtils.update_node_assets_amount(node, asset_pk_set, operator)
else:
asset_pk = instance.id
# 与资产直接关联的节点
node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True))
NodeAssetsAmountUtils.update_nodes_asset_amount(node_keys, asset_pk, operator)
class NodeAssetsAmountUtils:
@classmethod
def _remove_ancestor_keys(cls, ancestor_key, tree_set):
# 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while ancestor_key and ancestor_key in tree_set:
tree_set.remove(ancestor_key)
ancestor_key = compute_parent_key(ancestor_key)
@classmethod
def _is_asset_exists_in_node(cls, asset_pk, node_key):
exists = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key)
).filter(id=asset_pk).exists()
return exists
@classmethod
@ensure_in_real_or_default_org
@NodeTreeUpdateLock()
def update_nodes_asset_amount(cls, node_keys, asset_pk, operator):
"""
一个资产与多个节点关系变化时更新计数
:param node_keys: 节点 id 的集合
:param asset_pk: 资产 id
:param operator: 操作
"""
# 所有相关节点的祖先节点,组成一棵局部树
ancestor_keys = set()
for key in node_keys:
ancestor_keys.update(Node.get_node_ancestor_keys(key))
# 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉
node_keys -= ancestor_keys
to_update_keys = []
for key in node_keys:
# 遍历相关节点,处理它及其祖先节点
# 查询该节点是否包含待处理资产
exists = cls._is_asset_exists_in_node(asset_pk, key)
parent_key = compute_parent_key(key)
if exists:
# 如果资产在该节点,那么他及其祖先节点都不用处理
cls._remove_ancestor_keys(parent_key, ancestor_keys)
continue
else:
# 不存在,要更新本节点
to_update_keys.append(key)
# 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while parent_key and parent_key in ancestor_keys:
exists = cls._is_asset_exists_in_node(asset_pk, parent_key)
if exists:
cls._remove_ancestor_keys(parent_key, ancestor_keys)
break
else:
to_update_keys.append(parent_key)
ancestor_keys.remove(parent_key)
parent_key = compute_parent_key(parent_key)
Node.objects.filter(key__in=to_update_keys).update(
assets_amount=operator(F('assets_amount'), 1)
)
@classmethod
@ensure_in_real_or_default_org
@NodeTreeUpdateLock()
def update_node_assets_amount(cls, node: Node, asset_pk_set: set, operator=add):
"""
一个节点与多个资产关系变化时更新计数
:param node: 节点实例
:param asset_pk_set: 资产的`id`集合, 内部不会修改该值
:param operator: 操作
* -> Node
# -> Asset
* [3]
/ \
* * [2]
/ \
* * [1]
/ / \
* [a] # # [b]
"""
# 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key`
ancestor_keys = node.get_ancestor_keys(with_self=True)
ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key')
to_update = []
for ancestor in ancestors:
# 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3]
# 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的
# 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量
asset_pk_set -= set(Asset.objects.filter(
id__in=asset_pk_set
).filter(
Q(nodes__key__istartswith=f'{ancestor.key}:') |
Q(nodes__key=ancestor.key)
).distinct().values_list('id', flat=True))
if not asset_pk_set:
# 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量
# 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用
# 处理了
break
ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set))
to_update.append(ancestor)
Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key'))

View File

@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
#
import os
import threading
from django.db.models.signals import (
m2m_changed, post_save, post_delete
)
from django.dispatch import receiver
from django.utils.functional import LazyObject
from common.signals import django_ready
from common.utils.connection import RedisPubSub
from common.utils import get_logger
from assets.models import Asset, Node
logger = get_logger(__file__)
# clear node assets mapping for memory
# ------------------------------------
def get_node_assets_mapping_for_memory_pub_sub():
return RedisPubSub('fm.node_all_asset_ids_memory_mapping')
class NodeAssetsMappingForMemoryPubSub(LazyObject):
def _setup(self):
self._wrapped = get_node_assets_mapping_for_memory_pub_sub()
node_assets_mapping_for_memory_pub_sub = NodeAssetsMappingForMemoryPubSub()
def expire_node_assets_mapping_for_memory(org_id):
# 所有进程清除(自己的 memory 数据)
org_id = str(org_id)
node_assets_mapping_for_memory_pub_sub.publish(org_id)
# 当前进程清除(cache 数据)
logger.debug(
"Expire node assets id mapping from cache of org={}, pid={}"
"".format(org_id, os.getpid())
)
Node.expire_node_all_asset_ids_mapping_from_cache(org_id)
@receiver(post_save, sender=Node)
def on_node_post_create(sender, instance, created, update_fields, **kwargs):
if created:
need_expire = True
elif update_fields and 'key' in update_fields:
need_expire = True
else:
need_expire = False
if need_expire:
expire_node_assets_mapping_for_memory(instance.org_id)
@receiver(post_delete, sender=Node)
def on_node_post_delete(sender, instance, **kwargs):
expire_node_assets_mapping_for_memory(instance.org_id)
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, instance, **kwargs):
expire_node_assets_mapping_for_memory(instance.org_id)
@receiver(django_ready)
def subscribe_node_assets_mapping_expire(sender, **kwargs):
logger.debug("Start subscribe for expire node assets id mapping from memory")
def keep_subscribe():
subscribe = node_assets_mapping_for_memory_pub_sub.subscribe()
for message in subscribe.listen():
if message["type"] != "message":
continue
org_id = message['data'].decode()
Node.expire_node_all_asset_ids_mapping_from_memory(org_id)
logger.debug(
"Expire node assets id mapping from memory of org={}, pid={}"
"".format(str(org_id), os.getpid())
)
t = threading.Thread(target=keep_subscribe)
t.daemon = True
t.start()

View File

@ -12,6 +12,7 @@ __all__ = ['add_nodes_assets_to_system_users']
@tmp_to_root_org() @tmp_to_root_org()
def add_nodes_assets_to_system_users(nodes_keys, system_users): def add_nodes_assets_to_system_users(nodes_keys, system_users):
from ..models import Node from ..models import Node
assets = Node.get_nodes_all_assets(nodes_keys).values_list('id', flat=True) nodes = Node.objects.filter(key__in=nodes_keys)
assets = Node.get_nodes_all_assets(*nodes)
for system_user in system_users: for system_user in system_users:
system_user.assets.add(*tuple(assets)) system_user.assets.add(*tuple(assets))

View File

@ -141,7 +141,8 @@ def gather_asset_users(assets, task_name=None):
@shared_task(queue="ansible") @shared_task(queue="ansible")
def gather_nodes_asset_users(nodes_key): def gather_nodes_asset_users(nodes_key):
assets = Node.get_nodes_all_assets(nodes_key) nodes = Node.objects.filter(key__in=nodes_key)
assets = Node.get_nodes_all_assets(*nodes)
assets_groups_by_100 = [assets[i:i+100] for i in range(0, len(assets), 100)] assets_groups_by_100 = [assets[i:i+100] for i in range(0, len(assets), 100)]
for _assets in assets_groups_by_100: for _assets in assets_groups_by_100:
gather_asset_users(_assets) gather_asset_users(_assets)

View File

@ -12,16 +12,24 @@ from common.utils import get_logger
logger = get_logger(__file__) logger = get_logger(__file__)
@shared_task(queue='celery_heavy_tasks') @shared_task
def check_node_assets_amount_task(org_id=Organization.ROOT_ID): def check_node_assets_amount_task(org_id=None):
if org_id is None:
orgs = Organization.objects.all()
else:
orgs = [Organization.get_instance(org_id)]
for org in orgs:
try: try:
with tmp_to_org(Organization.get_instance(org_id)): with tmp_to_org(org):
check_node_assets_amount() check_node_assets_amount()
except AcquireFailed: except AcquireFailed:
logger.error(_('The task of self-checking is already running and cannot be started repeatedly')) error = _('The task of self-checking is already running '
'and cannot be started repeatedly')
logger.error(error)
@register_as_period_task(crontab='0 2 * * *') @register_as_period_task(crontab='0 2 * * *')
@shared_task(queue='celery_heavy_tasks') @shared_task
def check_node_assets_amount_period_task(): def check_node_assets_amount_period_task():
check_node_assets_amount_task() check_node_assets_amount_task()

View File

@ -32,11 +32,19 @@ def _dump_args(args: dict):
def get_push_unixlike_system_user_tasks(system_user, username=None): def get_push_unixlike_system_user_tasks(system_user, username=None):
comment = system_user.name
if username is None: if username is None:
username = system_user.username username = system_user.username
if system_user.username_same_with_user:
from users.models import User
user = User.objects.filter(username=username).only('name', 'username').first()
if user:
comment = f'{system_user.name}[{str(user)}]'
password = system_user.password password = system_user.password
public_key = system_user.public_key public_key = system_user.public_key
comment = system_user.name
groups = _split_by_comma(system_user.system_groups) groups = _split_by_comma(system_user.system_groups)
@ -225,18 +233,18 @@ def push_system_user_util(system_user, assets, task_name, username=None):
print(_("Hosts count: {}").format(len(_assets))) print(_("Hosts count: {}").format(len(_assets)))
id_asset_map = {_asset.id: _asset for _asset in _assets} id_asset_map = {_asset.id: _asset for _asset in _assets}
assets_id = id_asset_map.keys() asset_ids = id_asset_map.keys()
no_special_auth = [] no_special_auth = []
special_auth_set = set() special_auth_set = set()
auth_books = AuthBook.objects.filter(username__in=usernames, asset_id__in=assets_id) auth_books = AuthBook.objects.filter(username__in=usernames, asset_id__in=asset_ids)
for auth_book in auth_books: for auth_book in auth_books:
special_auth_set.add((auth_book.username, auth_book.asset_id)) special_auth_set.add((auth_book.username, auth_book.asset_id))
for _username in usernames: for _username in usernames:
no_special_assets = [] no_special_assets = []
for asset_id in assets_id: for asset_id in asset_ids:
if (_username, asset_id) not in special_auth_set: if (_username, asset_id) not in special_auth_set:
no_special_assets.append(id_asset_map[asset_id]) no_special_assets.append(id_asset_map[asset_id])
if no_special_assets: if no_special_assets:
@ -281,12 +289,12 @@ def push_system_user_a_asset_manual(system_user, asset, username=None):
@shared_task(queue="ansible") @shared_task(queue="ansible")
@tmp_to_root_org() @tmp_to_root_org()
def push_system_user_to_assets(system_user_id, assets_id, username=None): def push_system_user_to_assets(system_user_id, asset_ids, username=None):
""" """
推送系统用户到指定的若干资产上 推送系统用户到指定的若干资产上
""" """
system_user = SystemUser.objects.get(id=system_user_id) system_user = SystemUser.objects.get(id=system_user_id)
assets = get_objects(Asset, assets_id) assets = get_objects(Asset, asset_ids)
task_name = _("Push system users to assets: {}").format(system_user.name) task_name = _("Push system users to assets: {}").format(system_user.name)
return push_system_user_util(system_user, assets, task_name, username=username) return push_system_user_util(system_user, assets, task_name, username=username)

33
apps/assets/tests/tree.py Normal file
View File

@ -0,0 +1,33 @@
from assets.tree import Tree
def test():
from orgs.models import Organization
from assets.models import Node, Asset
import time
Organization.objects.get(id='1863cf22-f666-474e-94aa-935fe175203c').change_to()
t1 = time.time()
nodes = list(Node.objects.exclude(key__startswith='-').only('id', 'key', 'parent_key'))
node_asset_id_pairs = Asset.nodes.through.objects.all().values_list('node_id', 'asset_id')
t2 = time.time()
node_asset_id_pairs = list(node_asset_id_pairs)
tree = Tree(nodes, node_asset_id_pairs)
tree.build_tree()
tree.nodes = None
tree.node_asset_id_pairs = None
import pickle
d = pickle.dumps(tree)
print('------------', len(d))
return tree
tree.compute_tree_node_assets_amount()
print(f'校对算法准确性 ......')
for node in nodes:
tree_node = tree.key_tree_node_mapper[node.key]
if tree_node.assets_amount != node.assets_amount:
print(f'ERROR: {tree_node.assets_amount} {node.assets_amount}')
# print(f'OK {tree_node.asset_amount} {node.assets_amount}')
print(f'数据库时间: {t2 - t1}')
return tree

View File

@ -2,7 +2,6 @@
from django.urls import path, re_path from django.urls import path, re_path
from rest_framework_nested import routers from rest_framework_nested import routers
from rest_framework_bulk.routes import BulkRouter from rest_framework_bulk.routes import BulkRouter
from django.db.transaction import non_atomic_requests
from common import api as capi from common import api as capi
@ -57,9 +56,9 @@ urlpatterns = [
path('nodes/children/', api.NodeChildrenApi.as_view(), name='node-children-2'), path('nodes/children/', api.NodeChildrenApi.as_view(), name='node-children-2'),
path('nodes/<uuid:pk>/children/add/', api.NodeAddChildrenApi.as_view(), name='node-add-children'), path('nodes/<uuid:pk>/children/add/', api.NodeAddChildrenApi.as_view(), name='node-add-children'),
path('nodes/<uuid:pk>/assets/', api.NodeAssetsApi.as_view(), name='node-assets'), path('nodes/<uuid:pk>/assets/', api.NodeAssetsApi.as_view(), name='node-assets'),
path('nodes/<uuid:pk>/assets/add/', non_atomic_requests(api.NodeAddAssetsApi.as_view()), name='node-add-assets'), path('nodes/<uuid:pk>/assets/add/', api.NodeAddAssetsApi.as_view(), name='node-add-assets'),
path('nodes/<uuid:pk>/assets/replace/', non_atomic_requests(api.MoveAssetsToNodeApi.as_view()), name='node-replace-assets'), path('nodes/<uuid:pk>/assets/replace/', api.MoveAssetsToNodeApi.as_view(), name='node-replace-assets'),
path('nodes/<uuid:pk>/assets/remove/', non_atomic_requests(api.NodeRemoveAssetsApi.as_view()), name='node-remove-assets'), path('nodes/<uuid:pk>/assets/remove/', api.NodeRemoveAssetsApi.as_view(), name='node-remove-assets'),
path('nodes/<uuid:pk>/tasks/', api.NodeTaskCreateApi.as_view(), name='node-task-create'), path('nodes/<uuid:pk>/tasks/', api.NodeTaskCreateApi.as_view(), name='node-task-create'),
path('gateways/<uuid:pk>/test-connective/', api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'), path('gateways/<uuid:pk>/test-connective/', api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'),

View File

@ -1,41 +1,47 @@
# ~*~ coding: utf-8 ~*~ # ~*~ coding: utf-8 ~*~
# #
import time from collections import defaultdict
from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none, timeit
from django.db.models import Q
from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none
from common.utils.lock import DistributedLock
from common.http import is_true from common.http import is_true
from .models import Asset, Node from common.struct import Stack
from common.db.models import output_as_string
from orgs.utils import ensure_in_real_or_default_org, current_org
from .locks import NodeTreeUpdateLock
from .models import Node, Asset
logger = get_logger(__file__) logger = get_logger(__file__)
@DistributedLock(name="assets.node.check_node_assets_amount", blocking=False) @NodeTreeUpdateLock()
@ensure_in_real_or_default_org
def check_node_assets_amount(): def check_node_assets_amount():
for node in Node.objects.all(): logger.info(f'Check node assets amount {current_org}')
logger.info(f'Check node assets amount: {node}') nodes = list(Node.objects.all().only('id', 'key', 'assets_amount'))
assets_amount = Asset.objects.filter( nodeid_assetid_pairs = list(Asset.nodes.through.objects.all().values_list('node_id', 'asset_id'))
Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes=node)
).distinct().count()
nodekey_assetids_mapper = defaultdict(set)
nodeid_nodekey_mapper = {}
for node in nodes:
nodeid_nodekey_mapper[node.id] = node.key
for nodeid, assetid in nodeid_assetid_pairs:
if nodeid not in nodeid_nodekey_mapper:
continue
nodekey = nodeid_nodekey_mapper[nodeid]
nodekey_assetids_mapper[nodekey].add(assetid)
util = NodeAssetsUtil(nodes, nodekey_assetids_mapper)
util.generate()
to_updates = []
for node in nodes:
assets_amount = util.get_assets_amount(node.key)
if node.assets_amount != assets_amount: if node.assets_amount != assets_amount:
logger.warn(f'Node wrong assets amount <Node:{node.key}> ' logger.error(f'Node[{node.key}] assets amount error {node.assets_amount} != {assets_amount}')
f'{node.assets_amount} right is {assets_amount}')
node.assets_amount = assets_amount node.assets_amount = assets_amount
node.save() to_updates.append(node)
# 防止自检程序给数据库的压力太大 Node.objects.bulk_update(to_updates, fields=('assets_amount',))
time.sleep(0.1)
def is_asset_exists_in_node(asset_pk, node_key):
return Asset.objects.filter(
id=asset_pk
).filter(
Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key)
).exists()
def is_query_node_all_assets(request): def is_query_node_all_assets(request):
@ -57,3 +63,77 @@ def get_node(request):
else: else:
node = get_object_or_none(Node, key=node_id) node = get_object_or_none(Node, key=node_id)
return node return node
class NodeAssetsInfo:
__slots__ = ('key', 'assets_amount', 'assets')
def __init__(self, key, assets_amount, assets):
self.key = key
self.assets_amount = assets_amount
self.assets = assets
def __str__(self):
return self.key
class NodeAssetsUtil:
def __init__(self, nodes, nodekey_assetsid_mapper):
"""
:param nodes: 节点
:param nodekey_assetsid_mapper: 节点直接资产id的映射 {"key1": set(), "key2": set()}
"""
self.nodes = nodes
# node_id --> set(asset_id1, asset_id2)
self.nodekey_assetsid_mapper = nodekey_assetsid_mapper
self.nodekey_assetsinfo_mapper = {}
@timeit
def generate(self):
# 准备排序好的资产信息数据
infos = []
for node in self.nodes:
assets = self.nodekey_assetsid_mapper.get(node.key, set())
info = NodeAssetsInfo(key=node.key, assets_amount=0, assets=assets)
infos.append(info)
infos = sorted(infos, key=lambda i: [int(i) for i in i.key.split(':')])
# 这个守卫需要添加一下,避免最后一个无法出栈
guarder = NodeAssetsInfo(key='', assets_amount=0, assets=set())
infos.append(guarder)
stack = Stack()
for info in infos:
# 如果栈顶的不是这个节点的父祖节点,那么可以出栈了,可以计算资产数量了
while stack.top and not info.key.startswith(f'{stack.top.key}:'):
pop_info = stack.pop()
pop_info.assets_amount = len(pop_info.assets)
self.nodekey_assetsinfo_mapper[pop_info.key] = pop_info
if not stack.top:
continue
stack.top.assets.update(pop_info.assets)
stack.push(info)
def get_assets_by_key(self, key):
info = self.nodekey_assetsinfo_mapper[key]
return info['assets']
def get_assets_amount(self, key):
info = self.nodekey_assetsinfo_mapper[key]
return info.assets_amount
@classmethod
def test_it(cls):
from assets.models import Node, Asset
nodes = list(Node.objects.all())
nodes_assets = Asset.nodes.through.objects.all()\
.annotate(aid=output_as_string('asset_id'))\
.values_list('node__key', 'aid')
mapping = defaultdict(set)
for key, asset_id in nodes_assets:
mapping[key].add(asset_id)
util = cls(nodes, mapping)
util.generate()
return util

View File

@ -1,10 +1,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.db.models.signals import post_save, post_delete from django.db.models.signals import post_save, post_delete
from django.dispatch import receiver from django.dispatch import receiver
from django.conf import settings
from django.db import transaction from django.db import transaction
from django.utils import timezone from django.utils import timezone
from django.utils.functional import LazyObject
from django.contrib.auth import BACKEND_SESSION_KEY from django.contrib.auth import BACKEND_SESSION_KEY
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.renderers import JSONRenderer from rest_framework.renderers import JSONRenderer
@ -34,17 +35,22 @@ MODELS_NEED_RECORD = (
) )
LOGIN_BACKEND = { class AuthBackendLabelMapping(LazyObject):
'PublicKeyAuthBackend': _('SSH Key'), @staticmethod
'RadiusBackend': User.Source.radius.label, def get_login_backends():
'RadiusRealmBackend': User.Source.radius.label, backend_label_mapping = {}
'LDAPAuthorizationBackend': User.Source.ldap.label, for source, backends in User.SOURCE_BACKEND_MAPPING.items():
'ModelBackend': _('Password'), for backend in backends:
'SSOAuthentication': _('SSO'), backend_label_mapping[backend] = source.label
'CASBackend': User.Source.cas.label, backend_label_mapping[settings.AUTH_BACKEND_PUBKEY] = _('SSH Key')
'OIDCAuthCodeBackend': User.Source.openid.label, backend_label_mapping[settings.AUTH_BACKEND_MODEL] = _('Password')
'OIDCAuthPasswordBackend': User.Source.openid.label, return backend_label_mapping
}
def _setup(self):
self._wrapped = self.get_login_backends()
AUTH_BACKEND_LABEL_MAPPING = AuthBackendLabelMapping()
def create_operate_log(action, sender, resource): def create_operate_log(action, sender, resource):
@ -70,6 +76,7 @@ def create_operate_log(action, sender, resource):
@receiver(post_save) @receiver(post_save)
def on_object_created_or_update(sender, instance=None, created=False, update_fields=None, **kwargs): def on_object_created_or_update(sender, instance=None, created=False, update_fields=None, **kwargs):
# last_login 改变是最后登录日期, 每次登录都会改变
if instance._meta.object_name == 'User' and \ if instance._meta.object_name == 'User' and \
update_fields and 'last_login' in update_fields: update_fields and 'last_login' in update_fields:
return return
@ -125,14 +132,13 @@ def on_audits_log_create(sender, instance=None, **kwargs):
def get_login_backend(request): def get_login_backend(request):
backend = request.session.get('auth_backend', '') or request.session.get(BACKEND_SESSION_KEY, '') backend = request.session.get('auth_backend', '') or \
request.session.get(BACKEND_SESSION_KEY, '')
backend = backend.rsplit('.', maxsplit=1)[-1] backend_label = AUTH_BACKEND_LABEL_MAPPING.get(backend, None)
if backend in LOGIN_BACKEND: if backend_label is None:
return LOGIN_BACKEND[backend] backend_label = ''
else: return backend_label
logger.warn(f'LOGIN_BACKEND_NOT_FOUND: {backend}')
return ''
def generate_data(username, request): def generate_data(username, request):

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from .auth import * from .connection_token import *
from .token import * from .token import *
from .mfa import * from .mfa import *
from .access_key import * from .access_key import *

View File

@ -1,55 +0,0 @@
# -*- coding: utf-8 -*-
#
import uuid
from django.core.cache import cache
from django.shortcuts import get_object_or_404
from rest_framework.response import Response
from rest_framework.views import APIView
from common.utils import get_logger
from common.permissions import IsOrgAdminOrAppUser
from orgs.mixins.api import RootOrgViewMixin
from users.models import User
from assets.models import Asset, SystemUser
logger = get_logger(__name__)
__all__ = [
'UserConnectionTokenApi',
]
class UserConnectionTokenApi(RootOrgViewMixin, APIView):
permission_classes = (IsOrgAdminOrAppUser,)
def post(self, request):
user_id = request.data.get('user', '')
asset_id = request.data.get('asset', '')
system_user_id = request.data.get('system_user', '')
token = str(uuid.uuid4())
user = get_object_or_404(User, id=user_id)
asset = get_object_or_404(Asset, id=asset_id)
system_user = get_object_or_404(SystemUser, id=system_user_id)
value = {
'user': user_id,
'username': user.username,
'asset': asset_id,
'hostname': asset.hostname,
'system_user': system_user_id,
'system_user_name': system_user.name
}
cache.set(token, value, timeout=20)
return Response({"token": token}, status=201)
def get(self, request):
token = request.query_params.get('token')
user_only = request.query_params.get('user-only', None)
value = cache.get(token, None)
if not value:
return Response('', status=404)
if not user_only:
return Response(value)
else:
return Response({'user': value['user']})

View File

@ -0,0 +1,242 @@
# -*- coding: utf-8 -*-
#
from django.conf import settings
from django.core.cache import cache
from django.shortcuts import get_object_or_404
from django.http import HttpResponse
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet
from rest_framework.decorators import action
from rest_framework.exceptions import PermissionDenied
from common.utils import get_logger, random_string
from common.drf.api import SerializerMixin2
from common.permissions import IsSuperUserOrAppUser, IsValidUser, IsSuperUser
from orgs.mixins.api import RootOrgViewMixin
from ..serializers import (
ConnectionTokenSerializer, ConnectionTokenSecretSerializer,
RDPFileSerializer
)
logger = get_logger(__name__)
__all__ = ['UserConnectionTokenViewSet']
class UserConnectionTokenViewSet(RootOrgViewMixin, SerializerMixin2, GenericViewSet):
permission_classes = (IsSuperUserOrAppUser,)
serializer_classes = {
'default': ConnectionTokenSerializer,
'get_secret_detail': ConnectionTokenSecretSerializer,
'get_rdp_file': RDPFileSerializer
}
CACHE_KEY_PREFIX = 'CONNECTION_TOKEN_{}'
@staticmethod
def check_resource_permission(user, asset, application, system_user):
from perms.utils.asset import has_asset_system_permission
from perms.utils.application import has_application_system_permission
if asset and not has_asset_system_permission(user, asset, system_user):
error = f'User not has this asset and system user permission: ' \
f'user={user.id} system_user={system_user.id} asset={asset.id}'
raise PermissionDenied(error)
if application and not has_application_system_permission(user, application, system_user):
error = f'User not has this application and system user permission: ' \
f'user={user.id} system_user={system_user.id} application={application.id}'
raise PermissionDenied(error)
return True
def create_token(self, user, asset, application, system_user):
if not settings.CONNECTION_TOKEN_ENABLED:
raise PermissionDenied('Connection token disabled')
if not user:
user = self.request.user
if not self.request.user.is_superuser and user != self.request.user:
raise PermissionDenied('Only super user can create user token')
self.check_resource_permission(user, asset, application, system_user)
token = random_string(36)
value = {
'user': str(user.id),
'username': user.username,
'system_user': str(system_user.id),
'system_user_name': system_user.name
}
if asset:
value.update({
'type': 'asset',
'asset': str(asset.id),
'hostname': asset.hostname,
})
elif application:
value.update({
'type': 'application',
'application': application.id,
'application_name': str(application)
})
key = self.CACHE_KEY_PREFIX.format(token)
cache.set(key, value, timeout=20)
return token
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
asset = serializer.validated_data.get('asset')
application = serializer.validated_data.get('application')
system_user = serializer.validated_data['system_user']
user = serializer.validated_data.get('user')
token = self.create_token(user, asset, application, system_user)
return Response({"token": token}, status=201)
@action(methods=['POST', 'GET'], detail=False, url_path='rdp/file')
def get_rdp_file(self, request, *args, **kwargs):
options = {
'full address:s': '',
'username:s': '',
'screen mode id:i': '0',
'desktopwidth:i': '1280',
'desktopheight:i': '800',
'use multimon:i': '1',
'session bpp:i': '24',
'audiomode:i': '0',
'disable wallpaper:i': '0',
'disable full window drag:i': '0',
'disable menu anims:i': '0',
'disable themes:i': '0',
'alternate shell:s': '',
'shell working directory:s': '',
'authentication level:i': '2',
'connect to console:i': '0',
'disable cursor setting:i': '0',
'allow font smoothing:i': '1',
'allow desktop composition:i': '1',
'redirectprinters:i': '0',
'prompt for credentials on client:i': '0',
'autoreconnection enabled:i': '1',
'bookmarktype:i': '3',
'use redirection server name:i': '0',
# 'alternate shell:s:': '||MySQLWorkbench',
# 'remoteapplicationname:s': 'Firefox',
# 'remoteapplicationcmdline:s': '',
}
if self.request.method == 'GET':
data = self.request.query_params
else:
data = request.data
serializer = self.get_serializer(data=data)
serializer.is_valid(raise_exception=True)
asset = serializer.validated_data.get('asset')
application = serializer.validated_data.get('application')
system_user = serializer.validated_data['system_user']
user = serializer.validated_data.get('user')
height = serializer.validated_data.get('height')
width = serializer.validated_data.get('width')
token = self.create_token(user, asset, application, system_user)
# Todo: 上线后地址是 JumpServerAddr:3389
address = self.request.query_params.get('address') or '1.1.1.1'
options['full address:s'] = address
options['username:s'] = '{}@{}'.format(user.username, token)
options['desktopwidth:i'] = width
options['desktopheight:i'] = height
data = ''
for k, v in options.items():
data += f'{k}:{v}\n'
response = HttpResponse(data, content_type='text/plain')
filename = "{}-{}-jumpserver.rdp".format(user.username, asset.hostname)
response['Content-Disposition'] = 'attachment; filename={}'.format(filename)
return response
@staticmethod
def _get_application_secret_detail(value):
from applications.models import Application
from perms.models import Action
application = get_object_or_404(Application, id=value.get('application'))
gateway = None
if not application.category_remote_app:
actions = Action.NONE
remote_app = {}
asset = None
domain = application.domain
else:
remote_app = application.get_rdp_remote_app_setting()
actions = Action.CONNECT
asset = application.get_remote_app_asset()
domain = asset.domain
if domain and domain.has_gateway():
gateway = domain.random_gateway()
return {
'asset': asset,
'application': application,
'gateway': gateway,
'remote_app': remote_app,
'actions': actions
}
@staticmethod
def _get_asset_secret_detail(value, user, system_user):
from assets.models import Asset
from perms.utils.asset import get_asset_system_user_ids_with_actions_by_user
asset = get_object_or_404(Asset, id=value.get('asset'))
systemuserid_actions_mapper = get_asset_system_user_ids_with_actions_by_user(user, asset)
actions = systemuserid_actions_mapper.get(system_user.id, [])
gateway = None
if asset and asset.domain and asset.domain.has_gateway():
gateway = asset.domain.random_gateway()
return {
'asset': asset,
'application': None,
'gateway': gateway,
'remote_app': None,
'actions': actions,
}
@action(methods=['POST'], detail=False, permission_classes=[IsSuperUserOrAppUser], url_path='secret-info/detail')
def get_secret_detail(self, request, *args, **kwargs):
from users.models import User
from assets.models import SystemUser
token = request.data.get('token', '')
key = self.CACHE_KEY_PREFIX.format(token)
value = cache.get(key, None)
if not value:
return Response(status=404)
user = get_object_or_404(User, id=value.get('user'))
system_user = get_object_or_404(SystemUser, id=value.get('system_user'))
data = dict(user=user, system_user=system_user)
if value.get('type') == 'asset':
asset_detail = self._get_asset_secret_detail(value, user=user, system_user=system_user)
data['type'] = 'asset'
data.update(asset_detail)
else:
app_detail = self._get_application_secret_detail(value)
data['type'] = 'application'
data.update(app_detail)
serializer = self.get_serializer(data)
return Response(data=serializer.data, status=200)
def get_permissions(self):
if self.action in ["create", "get_rdp_file"]:
if self.request.data.get('user', None):
self.permission_classes = (IsSuperUser,)
else:
self.permission_classes = (IsValidUser,)
return super().get_permissions()
def get(self, request):
token = request.query_params.get('token')
key = self.CACHE_KEY_PREFIX.format(token)
value = cache.get(key, None)
if not value:
return Response('', status=404)
return Response(value)

View File

@ -3,10 +3,10 @@
from rest_framework.generics import UpdateAPIView from rest_framework.generics import UpdateAPIView
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.permissions import AllowAny
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.utils.translation import ugettext as _
from common.utils import get_logger, get_object_or_none from common.utils import get_logger
from common.permissions import IsOrgAdmin from common.permissions import IsOrgAdmin
from ..models import LoginConfirmSetting from ..models import LoginConfirmSetting
from ..serializers import LoginConfirmSettingSerializer from ..serializers import LoginConfirmSettingSerializer
@ -32,7 +32,7 @@ class LoginConfirmSettingUpdateApi(UpdateAPIView):
class TicketStatusApi(mixins.AuthMixin, APIView): class TicketStatusApi(mixins.AuthMixin, APIView):
permission_classes = () permission_classes = (AllowAny,)
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
try: try:

View File

@ -7,6 +7,7 @@ from django.http.response import HttpResponseRedirect
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.permissions import AllowAny
from common.utils.timezone import utcnow from common.utils.timezone import utcnow
from common.const.http import POST, GET from common.const.http import POST, GET
@ -31,6 +32,7 @@ class SSOViewSet(AuthMixin, JmsGenericViewSet):
'login_url': SSOTokenSerializer, 'login_url': SSOTokenSerializer,
'login': EmptySerializer 'login': EmptySerializer
} }
permission_classes = (IsSuperUser,)
@action(methods=[POST], detail=False, permission_classes=[IsSuperUser], url_path='login-url') @action(methods=[POST], detail=False, permission_classes=[IsSuperUser], url_path='login-url')
def login_url(self, request, *args, **kwargs): def login_url(self, request, *args, **kwargs):
@ -54,7 +56,7 @@ class SSOViewSet(AuthMixin, JmsGenericViewSet):
login_url = '%s?%s' % (reverse('api-auth:sso-login', external=True), urlencode(query)) login_url = '%s?%s' % (reverse('api-auth:sso-login', external=True), urlencode(query))
return Response(data={'login_url': login_url}) return Response(data={'login_url': login_url})
@action(methods=[GET], detail=False, filter_backends=[AuthKeyQueryDeclaration], permission_classes=[]) @action(methods=[GET], detail=False, filter_backends=[AuthKeyQueryDeclaration], permission_classes=[AllowAny])
def login(self, request: Request, *args, **kwargs): def login(self, request: Request, *args, **kwargs):
""" """
此接口违反了 `Restful` 的规范 此接口违反了 `Restful` 的规范

View File

@ -2,7 +2,7 @@
# #
import traceback import traceback
from django.contrib.auth import get_user_model, authenticate from django.contrib.auth import get_user_model
from radiusauth.backends import RADIUSBackend, RADIUSRealmBackend from radiusauth.backends import RADIUSBackend, RADIUSRealmBackend
from django.conf import settings from django.conf import settings

View File

@ -18,6 +18,8 @@ reason_user_not_exist = 'user_not_exist'
reason_password_expired = 'password_expired' reason_password_expired = 'password_expired'
reason_user_invalid = 'user_invalid' reason_user_invalid = 'user_invalid'
reason_user_inactive = 'user_inactive' reason_user_inactive = 'user_inactive'
reason_backend_not_match = 'backend_not_match'
reason_acl_not_allow = 'acl_not_allow'
reason_choices = { reason_choices = {
reason_password_failed: _('Username/password check failed'), reason_password_failed: _('Username/password check failed'),
@ -27,7 +29,9 @@ reason_choices = {
reason_user_not_exist: _("Username does not exist"), reason_user_not_exist: _("Username does not exist"),
reason_password_expired: _("Password expired"), reason_password_expired: _("Password expired"),
reason_user_invalid: _('Disabled or expired'), reason_user_invalid: _('Disabled or expired'),
reason_user_inactive: _("This account is inactive.") reason_user_inactive: _("This account is inactive."),
reason_backend_not_match: _("Auth backend not match"),
reason_acl_not_allow: _("ACL is not allowed")
} }
old_reason_choices = { old_reason_choices = {
'0': '-', '0': '-',

View File

@ -8,11 +8,26 @@ from captcha.fields import CaptchaField, CaptchaTextInput
class UserLoginForm(forms.Form): class UserLoginForm(forms.Form):
username = forms.CharField(label=_('Username'), max_length=100) days_auto_login = int(settings.SESSION_COOKIE_AGE / 3600 / 24)
disable_days_auto_login = settings.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE or days_auto_login < 1
username = forms.CharField(
label=_('Username'), max_length=100,
widget=forms.TextInput(attrs={
'placeholder': _("Username"),
'autofocus': 'autofocus'
})
)
password = forms.CharField( password = forms.CharField(
label=_('Password'), widget=forms.PasswordInput, label=_('Password'), widget=forms.PasswordInput,
max_length=1024, strip=False max_length=1024, strip=False
) )
auto_login = forms.BooleanField(
label=_("{} days auto login").format(days_auto_login or 1),
required=False, initial=False, widget=forms.CheckboxInput(
attrs={'disabled': disable_days_auto_login}
)
)
def confirm_login_allowed(self, user): def confirm_login_allowed(self, user):
if not user.is_staff: if not user.is_staff:
@ -35,8 +50,13 @@ class CaptchaMixin(forms.Form):
class ChallengeMixin(forms.Form): class ChallengeMixin(forms.Form):
challenge = forms.CharField(label=_('MFA code'), max_length=6, challenge = forms.CharField(
required=False) label=_('MFA code'), max_length=6, required=False,
widget=forms.TextInput(attrs={
'placeholder': _("MFA code"),
'style': 'width: 50%'
})
)
def get_user_login_form_cls(*, captcha=False): def get_user_login_form_cls(*, captcha=False):

View File

@ -9,7 +9,7 @@ from django.contrib.auth import authenticate
from django.shortcuts import reverse from django.shortcuts import reverse
from django.contrib.auth import BACKEND_SESSION_KEY from django.contrib.auth import BACKEND_SESSION_KEY
from common.utils import get_object_or_none, get_request_ip, get_logger from common.utils import get_object_or_none, get_request_ip, get_logger, bulk_get
from users.models import User from users.models import User
from users.utils import ( from users.utils import (
is_block_login, clean_failed_count is_block_login, clean_failed_count
@ -24,6 +24,7 @@ logger = get_logger(__name__)
class AuthMixin: class AuthMixin:
request = None request = None
partial_credential_error = None
def get_user_from_session(self): def get_user_from_session(self):
if self.request.session.is_empty(): if self.request.session.is_empty():
@ -75,49 +76,84 @@ class AuthMixin:
return rsa_decrypt(raw_passwd, rsa_private_key) return rsa_decrypt(raw_passwd, rsa_private_key)
except Exception as e: except Exception as e:
logger.error(e, exc_info=True) logger.error(e, exc_info=True)
logger.error(f'Decrypt password faild: password[{raw_passwd}] rsa_private_key[{rsa_private_key}]') logger.error(f'Decrypt password failed: password[{raw_passwd}] '
f'rsa_private_key[{rsa_private_key}]')
return None return None
return raw_passwd return raw_passwd
def check_user_auth(self, decrypt_passwd=False): def raise_credential_error(self, error):
self.check_is_block() raise self.partial_credential_error(error=error)
def get_auth_data(self, decrypt_passwd=False):
request = self.request request = self.request
if hasattr(request, 'data'): if hasattr(request, 'data'):
data = request.data data = request.data
else: else:
data = request.POST data = request.POST
username = data.get('username', '')
password = data.get('password', '')
challenge = data.get('challenge', '')
public_key = data.get('public_key', '')
ip = self.get_request_ip()
CredentialError = partial(errors.CredentialError, username=username, ip=ip, request=request) items = ['username', 'password', 'challenge', 'public_key', 'auto_login']
username, password, challenge, public_key, auto_login = bulk_get(data, *items, default='')
password = password + challenge.strip()
ip = self.get_request_ip()
self.partial_credential_error = partial(errors.CredentialError, username=username, ip=ip, request=request)
if decrypt_passwd: if decrypt_passwd:
password = self.decrypt_passwd(password) password = self.decrypt_passwd(password)
if not password: if not password:
raise CredentialError(error=errors.reason_password_decrypt_failed) self.raise_credential_error(errors.reason_password_decrypt_failed)
return username, password, public_key, ip, auto_login
user = authenticate(request, def _check_only_allow_exists_user_auth(self, username):
username=username, # 仅允许预先存在的用户认证
password=password + challenge.strip(), if settings.ONLY_ALLOW_EXIST_USER_AUTH:
public_key=public_key) exist = User.objects.filter(username=username).exists()
if not exist:
logger.error(f"Only allow exist user auth, login failed: {username}")
self.raise_credential_error(errors.reason_user_not_exist)
def _check_auth_user_is_valid(self, username, password, public_key):
user = authenticate(self.request, username=username, password=password, public_key=public_key)
if not user: if not user:
raise CredentialError(error=errors.reason_password_failed) self.raise_credential_error(errors.reason_password_failed)
elif user.is_expired: elif user.is_expired:
raise CredentialError(error=errors.reason_user_inactive) self.raise_credential_error(errors.reason_user_inactive)
elif not user.is_active: elif not user.is_active:
raise CredentialError(error=errors.reason_user_inactive) self.raise_credential_error(errors.reason_user_inactive)
return user
def _check_auth_source_is_valid(self, user, auth_backend):
# 限制只能从认证来源登录
if settings.ONLY_ALLOW_AUTH_FROM_SOURCE:
auth_backends_allowed = user.SOURCE_BACKEND_MAPPING.get(user.source)
if auth_backend not in auth_backends_allowed:
self.raise_credential_error(error=errors.reason_backend_not_match)
def _check_login_acl(self, user, ip):
# ACL 限制用户登录
from acls.models import LoginACL
is_allowed = LoginACL.allow_user_to_login(user, ip)
if not is_allowed:
raise self.raise_credential_error(error=errors.reason_acl_not_allow)
def check_user_auth(self, decrypt_passwd=False):
self.check_is_block()
request = self.request
username, password, public_key, ip, auto_login = self.get_auth_data(decrypt_passwd=decrypt_passwd)
self._check_only_allow_exists_user_auth(username)
user = self._check_auth_user_is_valid(username, password, public_key)
# 校验login-acl规则
self._check_login_acl(user, ip)
# 限制只能从认证来源登录
auth_backend = getattr(user, 'backend', 'django.contrib.auth.backends.ModelBackend')
self._check_auth_source_is_valid(user, auth_backend)
self._check_password_require_reset_or_not(user) self._check_password_require_reset_or_not(user)
self._check_passwd_is_too_simple(user, password) self._check_passwd_is_too_simple(user, password)
clean_failed_count(username, ip) clean_failed_count(username, ip)
request.session['auth_password'] = 1 request.session['auth_password'] = 1
request.session['user_id'] = str(user.id) request.session['user_id'] = str(user.id)
auth_backend = getattr(user, 'backend', 'django.contrib.auth.backends.ModelBackend') request.session['auto_login'] = auto_login
request.session['auth_backend'] = auth_backend request.session['auth_backend'] = auth_backend
return user return user

View File

@ -4,13 +4,17 @@ from rest_framework import serializers
from common.utils import get_object_or_none from common.utils import get_object_or_none
from users.models import User from users.models import User
from assets.models import Asset, SystemUser, Gateway
from applications.models import Application
from users.serializers import UserProfileSerializer from users.serializers import UserProfileSerializer
from perms.serializers.asset.permission import ActionsField
from .models import AccessKey, LoginConfirmSetting, SSOToken from .models import AccessKey, LoginConfirmSetting, SSOToken
__all__ = [ __all__ = [
'AccessKeySerializer', 'OtpVerifySerializer', 'BearerTokenSerializer', 'AccessKeySerializer', 'OtpVerifySerializer', 'BearerTokenSerializer',
'MFAChallengeSerializer', 'LoginConfirmSettingSerializer', 'SSOTokenSerializer', 'MFAChallengeSerializer', 'LoginConfirmSettingSerializer', 'SSOTokenSerializer',
'ConnectionTokenSerializer', 'ConnectionTokenSecretSerializer', 'RDPFileSerializer'
] ]
@ -82,3 +86,103 @@ class SSOTokenSerializer(serializers.Serializer):
username = serializers.CharField(write_only=True) username = serializers.CharField(write_only=True)
login_url = serializers.CharField(read_only=True) login_url = serializers.CharField(read_only=True)
next = serializers.CharField(write_only=True, allow_blank=True, required=False, allow_null=True) next = serializers.CharField(write_only=True, allow_blank=True, required=False, allow_null=True)
class ConnectionTokenSerializer(serializers.Serializer):
user = serializers.CharField(max_length=128, required=False, allow_blank=True)
system_user = serializers.CharField(max_length=128, required=True)
asset = serializers.CharField(max_length=128, required=False)
application = serializers.CharField(max_length=128, required=False)
@staticmethod
def validate_user(user_id):
from users.models import User
user = User.objects.filter(id=user_id).first()
if user is None:
raise serializers.ValidationError('user id not exist')
return user
@staticmethod
def validate_system_user(system_user_id):
from assets.models import SystemUser
system_user = SystemUser.objects.filter(id=system_user_id).first()
if system_user is None:
raise serializers.ValidationError('system_user id not exist')
return system_user
@staticmethod
def validate_asset(asset_id):
from assets.models import Asset
asset = Asset.objects.filter(id=asset_id).first()
if asset is None:
raise serializers.ValidationError('asset id not exist')
return asset
@staticmethod
def validate_application(app_id):
from applications.models import Application
app = Application.objects.filter(id=app_id).first()
if app is None:
raise serializers.ValidationError('app id not exist')
return app
def validate(self, attrs):
asset = attrs.get('asset')
application = attrs.get('application')
if not asset and not application:
raise serializers.ValidationError('asset or application required')
if asset and application:
raise serializers.ValidationError('asset and application should only one')
return super().validate(attrs)
class ConnectionTokenUserSerializer(serializers.ModelSerializer):
class Meta:
model = User
fields = ['id', 'name', 'username', 'email']
class ConnectionTokenAssetSerializer(serializers.ModelSerializer):
class Meta:
model = Asset
fields = ['id', 'hostname', 'ip', 'port', 'org_id']
class ConnectionTokenSystemUserSerializer(serializers.ModelSerializer):
class Meta:
model = SystemUser
fields = ['id', 'name', 'username', 'password', 'private_key']
class ConnectionTokenGatewaySerializer(serializers.ModelSerializer):
class Meta:
model = Gateway
fields = ['id', 'ip', 'port', 'username', 'password', 'private_key']
class ConnectionTokenRemoteAppSerializer(serializers.Serializer):
program = serializers.CharField()
working_directory = serializers.CharField()
parameters = serializers.CharField()
class ConnectionTokenApplicationSerializer(serializers.ModelSerializer):
class Meta:
model = Application
fields = ['id', 'name', 'category', 'type']
class ConnectionTokenSecretSerializer(serializers.Serializer):
type = serializers.ChoiceField(choices=[('application', 'Application'), ('asset', 'Asset')])
user = ConnectionTokenUserSerializer(read_only=True)
asset = ConnectionTokenAssetSerializer(read_only=True)
remote_app = ConnectionTokenRemoteAppSerializer(read_only=True)
application = ConnectionTokenApplicationSerializer(read_only=True)
system_user = ConnectionTokenSystemUserSerializer(read_only=True)
gateway = ConnectionTokenGatewaySerializer(read_only=True)
actions = ActionsField()
class RDPFileSerializer(ConnectionTokenSerializer):
width = serializers.IntegerField(default=1280)
height = serializers.IntegerField(default=800)

View File

@ -24,7 +24,7 @@ def on_user_auth_login_success(sender, user, request, **kwargs):
@receiver(openid_user_login_success) @receiver(openid_user_login_success)
def on_oidc_user_login_success(sender, request, user, **kwargs): def on_oidc_user_login_success(sender, request, user, create=False, **kwargs):
request.session[BACKEND_SESSION_KEY] = 'OIDCAuthCodeBackend' request.session[BACKEND_SESSION_KEY] = 'OIDCAuthCodeBackend'
post_auth_success.send(sender, user=user, request=request) post_auth_success.send(sender, user=user, request=request)

View File

@ -1,9 +1,8 @@
{% load static %}
{% load i18n %} {% load i18n %}
{% load bootstrap3 %}
{% load static %}
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<!--/*@thymesVar id="LoginConstants" type="com.fit2cloud.support.common.constants.LoginConstants"*/-->
<!--/*@thymesVar id="message" type="java.lang.String"*/-->
<head> <head>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8"> <meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
<link rel="shortcut icon" href="{{ FAVICON_URL }}" type="image/x-icon"> <link rel="shortcut icon" href="{{ FAVICON_URL }}" type="image/x-icon">
@ -16,6 +15,8 @@
<link href="{% static 'css/font-awesome.min.css' %}" rel="stylesheet"> <link href="{% static 'css/font-awesome.min.css' %}" rel="stylesheet">
<link href="{% static 'css/bootstrap-style.css' %}" rel="stylesheet"> <link href="{% static 'css/bootstrap-style.css' %}" rel="stylesheet">
<link href="{% static 'css/login-style.css' %}" rel="stylesheet"> <link href="{% static 'css/login-style.css' %}" rel="stylesheet">
<link href="{% static 'css/style.css' %}" rel="stylesheet">
<link href="{% static 'css/jumpserver.css' %}" rel="stylesheet">
<!-- scripts --> <!-- scripts -->
<script src="{% static 'js/jquery-3.1.1.min.js' %}"></script> <script src="{% static 'js/jquery-3.1.1.min.js' %}"></script>
@ -24,26 +25,54 @@
<script src="{% static 'js/plugins/datatables/datatables.min.js' %}"></script> <script src="{% static 'js/plugins/datatables/datatables.min.js' %}"></script>
<style> <style>
.login-content {
box-shadow: 0 5px 5px -3px rgb(0 0 0 / 20%), 0 8px 10px 1px rgb(0 0 0 / 14%), 0 3px 14px 2px rgb(0 0 0 / 12%);
}
.box-1{ .help-block {
margin: 0;
text-align: left;
}
form label {
color: #737373;
font-size: 13px;
font-weight: normal;
}
.hr-line-dashed {
border-top: 1px dashed #e7eaec;
color: #ffffff;
background-color: #ffffff;
height: 1px;
margin: 20px 0;
}
.login-content {
height: 472px; height: 472px;
width: 984px; width: 984px;
margin-right: auto; margin-right: auto;
margin-left: auto; margin-left: auto;
margin-top: calc((100vh - 470px)/2); margin-top: calc((100vh - 470px) / 3);
} }
.box-2{ body {
background-color: #f2f2f2;
height: calc(100vh - (100vh - 470px) / 3);
}
.right-image-box {
height: 100%; height: 100%;
width: 50%; width: 50%;
float: right; float: right;
} }
.box-3{
.left-form-box {
text-align: center; text-align: center;
background-color: white; background-color: white;
height: 100%; height: 100%;
width: 50%; width: 50%;
} }
.captcha { .captcha {
float: right; float: right;
} }
@ -56,82 +85,99 @@
text-align: left; text-align: left;
} }
.form-group.has-error {
margin-bottom: 0;
}
.captch-field .has-error .help-block {
margin-top: -8px !important;
}
.no-captcha-challenge .form-group {
margin-bottom: 20px;
}
.jms-title {
padding: 40px 10px 10px;
}
.no-captcha-challenge .jms-title {
padding: 60px 10px 10px;
}
.no-captcha-challenge .welcome-message {
padding-top: 10px;
}
.radio, .checkbox {
margin: 0;
}
#github_star {
float: right;
margin: 10px 10px 0 0;
}
</style> </style>
</head> </head>
<body style="height: 100%;font-size: 13px"> <body>
<div> <div class="login-content ">
<div class="box-1"> <div class="right-image-box">
<div class="box-2"> <a href="{% if not XPACK_ENABLED %}https://github.com/jumpserver/jumpserver{% endif %}">
<img src="{{ LOGIN_IMAGE_URL }}" style="height: 100%; width: 100%"/> <img src="{{ LOGIN_IMAGE_URL }}" style="height: 100%; width: 100%"/>
</a>
</div> </div>
<div class="box-3"> <div class="left-form-box {% if not form.challenge and not form.captcha %} no-captcha-challenge {% endif %}">
<div style="background-color: white"> <div style="background-color: white">
{% if form.challenge %} <div class="jms-title">
<div style="margin-top: 20px;padding-top: 30px;padding-left: 20px;padding-right: 20px;height: 60px">
{% else %}
<div style="margin-top: 20px;padding-top: 40px;padding-left: 20px;padding-right: 20px;height: 80px">
{% endif %}
<span style="font-size: 21px;font-weight:400;color: #151515;letter-spacing: 0;">{{ JMS_TITLE }}</span> <span style="font-size: 21px;font-weight:400;color: #151515;letter-spacing: 0;">{{ JMS_TITLE }}</span>
</div> </div>
<div style="font-size: 12px;color: #999999;letter-spacing: 0;line-height: 18px;margin-top: 18px"> <div class="contact-form col-md-10 col-md-offset-1">
{% trans 'Welcome back, please enter username and password to login' %} <form id="login-form" action="" method="post" role="form" novalidate="novalidate">
</div>
<div style="margin-bottom: 0px">
<div>
<div class="col-md-1"></div>
<div class="contact-form col-md-10" style="margin-top: 0px;height: 35px">
<form id="contact-form" action="" method="post" role="form" novalidate="novalidate">
{% csrf_token %} {% csrf_token %}
<div style="line-height: 17px;margin-bottom: 20px;color: #999999;">
{% if form.errors %}
<p class="red-fonts" style="color: red">
{% if form.non_field_errors %} {% if form.non_field_errors %}
{% if form.challenge %} {{ form.non_field_errors.as_text }}
<div style="height: 50px;color: red;line-height: 17px;"> {% endif %}
{% else %} </p>
<div style="height: 70px;color: red;line-height: 17px;"> {% else %}
<p class="welcome-message">
{% trans 'Welcome back, please enter username and password to login' %}
</p>
{% endif %} {% endif %}
<p class="red-fonts">{{ form.non_field_errors.as_text }}</p>
</div> </div>
{% elif form.errors.captcha %}
<p class="red-fonts">{% trans 'Captcha invalid' %}</p>
{% else %}
<div style="height: 50px"></div>
{% endif %}
<div class="form-group"> {% bootstrap_field form.username show_label=False %}
<input type="text" class="form-control" name="{{ form.username.html_name }}" placeholder="{% trans 'Username' %}" required="" value="{% if form.username.value %}{{ form.username.value }}{% endif %}" style="height: 35px">
{% if form.errors.username %}
<div class="help-block field-error">
<p class="red-fonts">{{ form.errors.username.as_text }}</p>
</div>
{% endif %}
</div>
<div class="form-group"> <div class="form-group">
<input type="password" class="form-control" id="password" placeholder="{% trans 'Password' %}" required=""> <input type="password" class="form-control" id="password" placeholder="{% trans 'Password' %}" required="">
<input id="password-hidden" type="text" style="display:none" name="{{ form.password.html_name }}"> <input id="password-hidden" type="text" style="display:none" name="{{ form.password.html_name }}">
{% if form.errors.password %}
<div class="help-block field-error">
<p class="red-fonts">{{ form.errors.password.as_text }}</p>
</div>
{% endif %}
</div> </div>
{% if form.challenge %} {% if form.challenge %}
<div class="form-group"> {% bootstrap_field form.challenge show_label=False %}
<input type="challenge" class="form-control" id="challenge" name="{{ form.challenge.html_name }}" placeholder="{% trans 'MFA code' %}" > {% elif form.captcha %}
{% if form.errors.challenge %} <div class="captch-field">
<div class="help-block field-error"> {% bootstrap_field form.captcha show_label=False %}
<p class="red-fonts">{{ form.errors.challenge.as_text }}</p>
</div> </div>
{% endif %} {% endif %}
</div> <div class="form-group" style="padding-top: 5px; margin-bottom: 10px">
<div class="row">
<div class="col-md-6" style="text-align: left">
{% if form.auto_login %}
{% bootstrap_field form.auto_login form_group_class='' %}
{% endif %} {% endif %}
{% if form.captcha %}
<div class="form-group" style="height: 50px;margin-bottom: 0;font-size: 13px">
{{ form.captcha }}
</div> </div>
{% else %} <div class="col-md-6">
<div class="form-group" style="height: 25px;margin-bottom: 0;font-size: 13px"></div> <a id="forgot_password" href="{{ forgot_password_url }}" style="float: right">
{% endif %} <small>{% trans 'Forgot password' %}?</small>
<div class="form-group" style="margin-top: 10px"> </a>
</div>
</div>
</div>
<div class="form-group" style="">
<button type="submit" class="btn btn-transparent" onclick="doLogin();return false;">{% trans 'Login' %}</button> <button type="submit" class="btn btn-transparent" onclick="doLogin();return false;">{% trans 'Login' %}</button>
</div> </div>
@ -141,36 +187,26 @@
<div style="display: inline-block; float: left"> <div style="display: inline-block; float: left">
<b class="text-muted text-left" style="margin-right: 10px">{% trans "More login options" %}</b> <b class="text-muted text-left" style="margin-right: 10px">{% trans "More login options" %}</b>
{% if AUTH_OPENID %} {% if AUTH_OPENID %}
<a href="{% url 'authentication:openid:login' %}"> <a href="{% url 'authentication:openid:login' %}" class="more-login-item">
<i class="fa fa-openid"></i> {% trans 'OpenID' %} <i class="fa fa-openid"></i> {% trans 'OpenID' %}
</a> </a>
{% endif %} {% endif %}
{% if AUTH_CAS %} {% if AUTH_CAS %}
<a href="{% url 'authentication:cas:cas-login' %}"> <a href="{% url 'authentication:cas:cas-login' %}" class="more-login-item">
<i class="fa"><img src="{{ LOGIN_CAS_LOGO_URL }}" height="13" width="13"></i> {% trans 'CAS' %} <i class="fa"><img src="{{ LOGIN_CAS_LOGO_URL }}" height="13" width="13"></i> {% trans 'CAS' %}
</a> </a>
{% endif %} {% endif %}
</div> </div>
<div class="text-center" style="display: inline-block; float: right">
{% else %} {% else %}
<div class="text-center" style="display: inline-block;"> <div class="text-center" style="display: inline-block;">
{% endif %} {% endif %}
<a id="forgot_password" href="{% url 'authentication:forgot-password' %}">
<small>{% trans 'Forgot password' %}?</small>
</a>
</div> </div>
</div> </div>
</form> </form>
</div> </div>
<div class="col-md-1"></div>
</div> </div>
</div> </div>
</div> </div>
</div>
</div>
</div>
</div>
</body> </body>
<script type="text/javascript" src="/static/js/plugins/jsencrypt/jsencrypt.min.js"></script> <script type="text/javascript" src="/static/js/plugins/jsencrypt/jsencrypt.min.js"></script>
<script> <script>
@ -179,13 +215,14 @@
jsencrypt.setPublicKey(rsaPublicKey); // 设置密钥 jsencrypt.setPublicKey(rsaPublicKey); // 设置密钥
return jsencrypt.encrypt(password); //加密 return jsencrypt.encrypt(password); //加密
} }
function doLogin() { function doLogin() {
//公钥加密 //公钥加密
var rsaPublicKey = "{{ rsa_public_key }}" var rsaPublicKey = "{{ rsa_public_key }}"
var password = $('#password').val(); //明文密码 var password = $('#password').val(); //明文密码
var passwordEncrypted = encryptLoginPassword(password, rsaPublicKey) var passwordEncrypted = encryptLoginPassword(password, rsaPublicKey)
$('#password-hidden').val(passwordEncrypted); //返回给密码输入input $('#password-hidden').val(passwordEncrypted); //返回给密码输入input
$('#contact-form').submit();//post提交 $('#login-form').submit();//post提交
} }
$(document).ready(function () { $(document).ready(function () {

View File

@ -9,6 +9,7 @@ app_name = 'authentication'
router = DefaultRouter() router = DefaultRouter()
router.register('access-keys', api.AccessKeyViewSet, 'access-key') router.register('access-keys', api.AccessKeyViewSet, 'access-key')
router.register('sso', api.SSOViewSet, 'sso') router.register('sso', api.SSOViewSet, 'sso')
router.register('connection-token', api.UserConnectionTokenViewSet, 'connection-token')
urlpatterns = [ urlpatterns = [
@ -16,8 +17,6 @@ urlpatterns = [
path('auth/', api.TokenCreateApi.as_view(), name='user-auth'), path('auth/', api.TokenCreateApi.as_view(), name='user-auth'),
path('tokens/', api.TokenCreateApi.as_view(), name='auth-token'), path('tokens/', api.TokenCreateApi.as_view(), name='auth-token'),
path('mfa/challenge/', api.MFAChallengeApi.as_view(), name='mfa-challenge'), path('mfa/challenge/', api.MFAChallengeApi.as_view(), name='mfa-challenge'),
path('connection-token/',
api.UserConnectionTokenApi.as_view(), name='connection-token'),
path('otp/verify/', api.UserOtpVerifyApi.as_view(), name='user-otp-verify'), path('otp/verify/', api.UserOtpVerifyApi.as_view(), name='user-otp-verify'),
path('login-confirm-ticket/status/', api.TicketStatusApi.as_view(), name='login-confirm-ticket-status'), path('login-confirm-ticket/status/', api.TicketStatusApi.as_view(), name='login-confirm-ticket-status'),
path('login-confirm-settings/<uuid:user_id>/', api.LoginConfirmSettingUpdateApi.as_view(), name='login-confirm-setting-update') path('login-confirm-settings/<uuid:user_id>/', api.LoginConfirmSettingUpdateApi.as_view(), name='login-confirm-setting-update')

View File

@ -1,9 +1,9 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import base64 import base64
from Crypto.PublicKey import RSA from Cryptodome.PublicKey import RSA
from Crypto.Cipher import PKCS1_v1_5 from Cryptodome.Cipher import PKCS1_v1_5
from Crypto import Random from Cryptodome import Random
from common.utils import get_logger from common.utils import get_logger

View File

@ -45,9 +45,10 @@ class UserLoginView(mixins.AuthMixin, FormView):
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
if request.user.is_staff: if request.user.is_staff:
return redirect(redirect_user_first_login_or_index( first_login_url = redirect_user_first_login_or_index(
request, self.redirect_field_name) request, self.redirect_field_name
) )
return redirect(first_login_url)
request.session.set_test_cookie() request.session.set_test_cookie()
return super().get(request, *args, **kwargs) return super().get(request, *args, **kwargs)
@ -99,11 +100,17 @@ class UserLoginView(mixins.AuthMixin, FormView):
self.request.session[RSA_PRIVATE_KEY] = rsa_private_key self.request.session[RSA_PRIVATE_KEY] = rsa_private_key
self.request.session[RSA_PUBLIC_KEY] = rsa_public_key self.request.session[RSA_PUBLIC_KEY] = rsa_public_key
forgot_password_url = reverse('authentication:forgot-password')
has_other_auth_backend = settings.AUTHENTICATION_BACKENDS[0] != settings.AUTH_BACKEND_MODEL
if has_other_auth_backend and settings.FORGOT_PASSWORD_URL:
forgot_password_url = settings.FORGOT_PASSWORD_URL
context = { context = {
'demo_mode': os.environ.get("DEMO_MODE"), 'demo_mode': os.environ.get("DEMO_MODE"),
'AUTH_OPENID': settings.AUTH_OPENID, 'AUTH_OPENID': settings.AUTH_OPENID,
'AUTH_CAS': settings.AUTH_CAS, 'AUTH_CAS': settings.AUTH_CAS,
'rsa_public_key': rsa_public_key, 'rsa_public_key': rsa_public_key,
'forgot_password_url': forgot_password_url
} }
kwargs.update(context) kwargs.update(context)
return super().get_context_data(**kwargs) return super().get_context_data(**kwargs)
@ -121,6 +128,13 @@ class UserLoginGuardView(mixins.AuthMixin, RedirectView):
url = "%s?%s" % (url, args) url = "%s?%s" % (url, args)
return url return url
def login_it(self, user):
auth_login(self.request, user)
# 如果设置了自动登录,那需要设置 session_id cookie 的有效期
if self.request.session.get('auto_login'):
age = self.request.session.get_expiry_age()
self.request.session.set_expiry(age)
def get_redirect_url(self, *args, **kwargs): def get_redirect_url(self, *args, **kwargs):
try: try:
user = self.check_user_auth_if_need() user = self.check_user_auth_if_need()
@ -137,7 +151,7 @@ class UserLoginGuardView(mixins.AuthMixin, RedirectView):
except errors.PasswdTooSimple as e: except errors.PasswdTooSimple as e:
return e.url return e.url
else: else:
auth_login(self.request, user) self.login_it(user)
self.send_auth_signal(success=True, user=user) self.send_auth_signal(success=True, user=user)
self.clear_auth_mark() self.clear_auth_mark()
url = redirect_user_first_login_or_index( url = redirect_user_first_login_or_index(

View File

@ -13,7 +13,7 @@ from rest_framework.viewsets import GenericViewSet
from common.permissions import IsValidUser from common.permissions import IsValidUser
from .http import HttpResponseTemporaryRedirect from .http import HttpResponseTemporaryRedirect
from .const import KEY_CACHE_RESOURCES_ID from .const import KEY_CACHE_RESOURCE_IDS
from .utils import get_logger from .utils import get_logger
from .mixins import CommonApiMixin from .mixins import CommonApiMixin
@ -93,7 +93,7 @@ class ResourcesIDCacheApi(APIView):
spm = str(uuid.uuid4()) spm = str(uuid.uuid4())
resources = request.data.get('resources') resources = request.data.get('resources')
if resources is not None: if resources is not None:
cache_key = KEY_CACHE_RESOURCES_ID.format(spm) cache_key = KEY_CACHE_RESOURCE_IDS.format(spm)
cache.set(cache_key, resources, 300) cache.set(cache_key, resources, 300)
return Response({'spm': spm}) return Response({'spm': spm})

View File

@ -1,13 +1,24 @@
import json import time
from django.core.cache import cache
from redis import Redis
from common.utils.lock import DistributedLock from common.utils.lock import DistributedLock
from common.utils import lazyproperty from common.utils import lazyproperty
from common.utils import get_logger from common.utils import get_logger
from jumpserver.const import CONFIG
logger = get_logger(__file__) logger = get_logger(__file__)
class ComputeLock(DistributedLock):
"""
需要重建缓存的时候加上该锁避免重复计算
"""
def __init__(self, key):
name = f'compute:{key}'
super().__init__(name=name)
class CacheFieldBase: class CacheFieldBase:
field_type = str field_type = str
@ -25,7 +36,7 @@ class IntegerField(CacheFieldBase):
field_type = int field_type = int
class CacheBase(type): class CacheType(type):
def __new__(cls, name, bases, attrs: dict): def __new__(cls, name, bases, attrs: dict):
to_update = {} to_update = {}
field_desc_mapper = {} field_desc_mapper = {}
@ -41,12 +52,31 @@ class CacheBase(type):
return type.__new__(cls, name, bases, attrs) return type.__new__(cls, name, bases, attrs)
class Cache(metaclass=CacheBase): class Cache(metaclass=CacheType):
field_desc_mapper: dict field_desc_mapper: dict
timeout = None timeout = None
def __init__(self): def __init__(self):
self._data = None self._data = None
self.redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
def __getitem__(self, item):
return self.field_desc_mapper[item]
def __contains__(self, item):
return item in self.field_desc_mapper
def get_field(self, name):
return self.field_desc_mapper[name]
@property
def fields(self):
return self.field_desc_mapper.values()
@property
def field_names(self):
names = self.field_desc_mapper.keys()
return names
@lazyproperty @lazyproperty
def key_suffix(self): def key_suffix(self):
@ -64,81 +94,75 @@ class Cache(metaclass=CacheBase):
@property @property
def data(self): def data(self):
if self._data is None: if self._data is None:
data = self.get_data() data = self.load_data_from_db()
if data is None: if not data:
with ComputeLock(self.key):
data = self.load_data_from_db()
if not data:
# 缓存中没有数据时,去数据库获取 # 缓存中没有数据时,去数据库获取
self.compute_and_set_all_data() self.init_all_values()
return self._data return self._data
def get_data(self) -> dict: def to_internal_value(self, data: dict):
data = cache.get(self.key) internal_data = {}
logger.debug(f'CACHE: get {self.key} = {data}') for k, v in data.items():
if data is not None: field = k.decode()
data = json.loads(data) if field in self:
value = self[field].to_internal_value(v.decode())
internal_data[field] = value
else:
logger.warn(f'Cache got invalid field: '
f'key={self.key} '
f'invalid_field={field} '
f'valid_fields={self.field_names}')
return internal_data
def load_data_from_db(self) -> dict:
data = self.redis.hgetall(self.key)
logger.debug(f'Get data from cache: key={self.key} data={data}')
if data:
data = self.to_internal_value(data)
self._data = data self._data = data
return data return data
def set_data(self, data): def save_data_to_db(self, data):
self._data = data logger.info(f'Set data to cache: key={self.key} data={data}')
to_json = json.dumps(data) self.redis.hset(self.key, mapping=data)
logger.info(f'CACHE: set {self.key} = {to_json}, timeout={self.timeout}') self.load_data_from_db()
cache.set(self.key, to_json, timeout=self.timeout)
def compute_data(self, *fields): def compute_values(self, *fields):
field_descs = [] field_objs = []
if not fields:
field_descs = self.field_desc_mapper.values()
else:
for field in fields: for field in fields:
assert field in self.field_desc_mapper, f'{field} is not a valid field' field_objs.append(self[field])
field_descs.append(self.field_desc_mapper[field])
data = { data = {
field_desc.field_name: field_desc.compute_value(self) field_obj.field_name: field_obj.compute_value(self)
for field_desc in field_descs for field_obj in field_objs
} }
return data return data
def compute_and_set_all_data(self, computed_data: dict = None): def init_all_values(self):
""" t_start = time.time()
TODO 怎样防止并发更新全部数据浪费数据库资源 logger.info(f'Start init cache: key={self.key}')
""" data = self.compute_values(*self.field_names)
uncomputed_keys = () self.save_data_to_db(data)
if computed_data: logger.info(f'End init cache: cost={time.time()-t_start} key={self.key}')
computed_keys = computed_data.keys()
all_keys = self.field_desc_mapper.keys()
uncomputed_keys = all_keys - computed_keys
else:
computed_data = {}
data = self.compute_data(*uncomputed_keys)
data.update(computed_data)
self.set_data(data)
return data
def refresh_part_data_with_lock(self, refresh_data):
with DistributedLock(name=f'{self.key}.refresh'):
data = self.get_data()
if data is not None:
data.update(refresh_data)
self.set_data(data)
return data return data
def refresh(self, *fields): def refresh(self, *fields):
if not fields: if not fields:
# 没有指定 field 要刷新所有的值 # 没有指定 field 要刷新所有的值
self.compute_and_set_all_data() self.init_all_values()
return return
data = self.get_data() data = self.load_data_from_db()
if data is None: if not data:
# 缓存中没有数据,设置所有的值 # 缓存中没有数据,设置所有的值
self.compute_and_set_all_data() self.init_all_values()
return return
refresh_data = self.compute_data(*fields) refresh_values = self.compute_values(*fields)
if not self.refresh_part_data_with_lock(refresh_data): self.save_data_to_db(refresh_values)
# 刷新部分失败,缓存中没有数据,更新所有的值
self.compute_and_set_all_data(refresh_data)
return
def get_key_suffix(self): def get_key_suffix(self):
raise NotImplementedError raise NotImplementedError
@ -146,10 +170,14 @@ class Cache(metaclass=CacheBase):
def reload(self): def reload(self):
self._data = None self._data = None
def delete(self): def expire(self, *fields):
self._data = None self._data = None
logger.info(f'CACHE: delete {self.key}') if not fields:
cache.delete(self.key) logger.info(f'Delete cached key: key={self.key}')
self.redis.delete(self.key)
else:
self.redis.hdel(self.key, *fields)
logger.info(f'Expire cached fields: key={self.key} fields={fields}')
class CacheValueDesc: class CacheValueDesc:
@ -167,10 +195,13 @@ class CacheValueDesc:
return self return self
if self.field_name not in instance.data: if self.field_name not in instance.data:
instance.refresh(self.field_name) instance.refresh(self.field_name)
value = instance.data[self.field_name] # 防止边界情况没有值,报错
value = instance.data.get(self.field_name)
return value return value
def compute_value(self, instance: Cache): def compute_value(self, instance: Cache):
t_start = time.time()
logger.info(f'Start compute cache field: field={self.field_name} key={instance.key}')
if self.field_type.queryset is not None: if self.field_type.queryset is not None:
new_value = self.field_type.queryset.count() new_value = self.field_type.queryset.count()
else: else:
@ -183,5 +214,8 @@ class CacheValueDesc:
new_value = compute_func() new_value = compute_func()
new_value = self.field_type.field_type(new_value) new_value = self.field_type.field_type(new_value)
logger.info(f'CACHE: compute {instance.key}.{self.field_name} = {new_value}') logger.info(f'End compute cache field: cost={time.time()-t_start} field={self.field_name} value={new_value} key={instance.key}')
return new_value return new_value
def to_internal_value(self, value):
return self.field_type.field_type(value)

View File

@ -7,7 +7,7 @@ create_success_msg = _("%(name)s was created successfully")
update_success_msg = _("%(name)s was updated successfully") update_success_msg = _("%(name)s was updated successfully")
FILE_END_GUARD = ">>> Content End <<<" FILE_END_GUARD = ">>> Content End <<<"
celery_task_pre_key = "CELERY_" celery_task_pre_key = "CELERY_"
KEY_CACHE_RESOURCES_ID = "RESOURCES_ID_{}" KEY_CACHE_RESOURCE_IDS = "RESOURCE_IDS_{}"
# AD User AccountDisable # AD User AccountDisable
# https://blog.csdn.net/bytxl/article/details/17763975 # https://blog.csdn.net/bytxl/article/details/17763975

View File

@ -1,2 +0,0 @@
UPDATE_NODE_TREE_LOCK_KEY = 'org_level_transaction_lock_{org_id}_assets_update_node_tree'
UPDATE_MAPPING_NODE_TASK_LOCK_KEY = 'org_level_transaction_lock_{user_id}_update_mapping_node_task'

View File

@ -10,8 +10,11 @@
""" """
import uuid import uuid
from functools import reduce, partial
import inspect
from django.db.models import * from django.db.models import *
from django.db.models import QuerySet
from django.db.models.functions import Concat from django.db.models.functions import Concat
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -82,3 +85,88 @@ class JMSModel(JMSBaseModel):
def concated_display(name1, name2): def concated_display(name1, name2):
return Concat(F(name1), Value('('), F(name2), Value(')')) return Concat(F(name1), Value('('), F(name2), Value(')'))
def output_as_string(field_name):
return ExpressionWrapper(F(field_name), output_field=CharField())
class UnionQuerySet(QuerySet):
after_union = ['order_by']
not_return_qs = [
'query', 'get', 'create', 'get_or_create',
'update_or_create', 'bulk_create', 'count',
'latest', 'earliest', 'first', 'last', 'aggregate',
'exists', 'update', 'delete', 'as_manager', 'explain',
]
def __init__(self, *queryset_list):
self.queryset_list = queryset_list
self.after_union_items = []
self.before_union_items = []
def __execute(self):
queryset_list = []
for qs in self.queryset_list:
for attr, args, kwargs in self.before_union_items:
qs = getattr(qs, attr)(*args, **kwargs)
queryset_list.append(qs)
union_qs = reduce(lambda x, y: x.union(y), queryset_list)
for attr, args, kwargs in self.after_union_items:
union_qs = getattr(union_qs, attr)(*args, **kwargs)
return union_qs
def __before_union_perform(self, item, *args, **kwargs):
self.before_union_items.append((item, args, kwargs))
return self.__clone(*self.queryset_list)
def __after_union_perform(self, item, *args, **kwargs):
self.after_union_items.append((item, args, kwargs))
return self.__clone(*self.queryset_list)
def __clone(self, *queryset_list):
uqs = UnionQuerySet(*queryset_list)
uqs.after_union_items = self.after_union_items
uqs.before_union_items = self.before_union_items
return uqs
def __getattribute__(self, item):
if item.startswith('__') or item in UnionQuerySet.__dict__ or item in [
'queryset_list', 'after_union_items', 'before_union_items'
]:
return object.__getattribute__(self, item)
if item in UnionQuerySet.not_return_qs:
return getattr(self.__execute(), item)
origin_item = object.__getattribute__(self, 'queryset_list')[0]
origin_attr = getattr(origin_item, item, None)
if not inspect.ismethod(origin_attr):
return getattr(self.__execute(), item)
if item in UnionQuerySet.after_union:
attr = partial(self.__after_union_perform, item)
else:
attr = partial(self.__before_union_perform, item)
return attr
def __getitem__(self, item):
return self.__execute()[item]
def __iter__(self):
return iter(self.__execute())
def __str__(self):
return str(self.__execute())
def __repr__(self):
return repr(self.__execute())
@classmethod
def test_it(cls):
from assets.models import Asset
assets1 = Asset.objects.filter(hostname__startswith='a')
assets2 = Asset.objects.filter(hostname__startswith='b')
qs = cls(assets1, assets2)
return qs

View File

@ -3,39 +3,35 @@ from rest_framework_bulk import BulkModelViewSet
from ..mixins.api import ( from ..mixins.api import (
SerializerMixin2, QuerySetMixin, ExtraFilterFieldsMixin, PaginatedResponseMixin, SerializerMixin2, QuerySetMixin, ExtraFilterFieldsMixin, PaginatedResponseMixin,
RelationMixin, AllowBulkDestoryMixin RelationMixin, AllowBulkDestoryMixin, RenderToJsonMixin,
) )
class JmsGenericViewSet(SerializerMixin2, class CommonMixin(SerializerMixin2,
QuerySetMixin, QuerySetMixin,
ExtraFilterFieldsMixin, ExtraFilterFieldsMixin,
PaginatedResponseMixin, PaginatedResponseMixin,
RenderToJsonMixin):
pass
class JmsGenericViewSet(CommonMixin,
GenericViewSet): GenericViewSet):
pass pass
class JMSModelViewSet(SerializerMixin2, class JMSModelViewSet(CommonMixin,
QuerySetMixin,
ExtraFilterFieldsMixin,
PaginatedResponseMixin,
ModelViewSet): ModelViewSet):
pass pass
class JMSBulkModelViewSet(SerializerMixin2, class JMSBulkModelViewSet(CommonMixin,
QuerySetMixin,
ExtraFilterFieldsMixin,
PaginatedResponseMixin,
AllowBulkDestoryMixin, AllowBulkDestoryMixin,
BulkModelViewSet): BulkModelViewSet):
pass pass
class JMSBulkRelationModelViewSet(SerializerMixin2, class JMSBulkRelationModelViewSet(CommonMixin,
QuerySetMixin,
ExtraFilterFieldsMixin,
PaginatedResponseMixin,
RelationMixin, RelationMixin,
AllowBulkDestoryMixin, AllowBulkDestoryMixin,
BulkModelViewSet): BulkModelViewSet):

View File

@ -19,7 +19,10 @@ def extract_object_name(exc, index=0):
`No User matches the given query.` `No User matches the given query.`
提取 `User``index=1` 提取 `User``index=1`
""" """
(msg, *_) = exc.args if exc.args:
(msg, *others) = exc.args
else:
return gettext('Object')
return gettext(msg.split(sep=' ', maxsplit=index + 1)[index]) return gettext(msg.split(sep=' ', maxsplit=index + 1)[index])

View File

@ -6,11 +6,25 @@ from rest_framework.serializers import ValidationError
from rest_framework.compat import coreapi, coreschema from rest_framework.compat import coreapi, coreschema
from django.core.cache import cache from django.core.cache import cache
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django_filters import rest_framework as drf_filters
import logging import logging
from common import const from common import const
__all__ = ["DatetimeRangeFilter", "IDSpmFilter", 'IDInFilter', "CustomFilter"] __all__ = [
"DatetimeRangeFilter", "IDSpmFilter", 'IDInFilter', "CustomFilter",
"BaseFilterSet"
]
class BaseFilterSet(drf_filters.FilterSet):
def do_nothing(self, queryset, name, value):
return queryset
def get_query_param(self, k, default=None):
if k in self.form.data:
return self.form.cleaned_data[k]
return default
class DatetimeRangeFilter(filters.BaseFilterBackend): class DatetimeRangeFilter(filters.BaseFilterBackend):
@ -94,11 +108,11 @@ class IDSpmFilter(filters.BaseFilterBackend):
spm = request.query_params.get('spm') spm = request.query_params.get('spm')
if not spm: if not spm:
return queryset return queryset
cache_key = const.KEY_CACHE_RESOURCES_ID.format(spm) cache_key = const.KEY_CACHE_RESOURCE_IDS.format(spm)
resources_id = cache.get(cache_key) resource_ids = cache.get(cache_key)
if resources_id is None or not isinstance(resources_id, list): if resource_ids is None or not isinstance(resource_ids, list):
return queryset return queryset
queryset = queryset.filter(id__in=resources_id) queryset = queryset.filter(id__in=resource_ids)
return queryset return queryset

View File

@ -3,6 +3,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from collections import OrderedDict from collections import OrderedDict
import datetime
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.http import Http404 from django.http import Http404
@ -21,7 +22,7 @@ class SimpleMetadataWithFilters(SimpleMetadata):
attrs = [ attrs = [
'read_only', 'label', 'help_text', 'read_only', 'label', 'help_text',
'min_length', 'max_length', 'min_length', 'max_length',
'min_value', 'max_value', "write_only" 'min_value', 'max_value', "write_only",
] ]
def determine_actions(self, request, view): def determine_actions(self, request, view):
@ -59,8 +60,9 @@ class SimpleMetadataWithFilters(SimpleMetadata):
field_info['type'] = self.label_lookup[field] field_info['type'] = self.label_lookup[field]
field_info['required'] = getattr(field, 'required', False) field_info['required'] = getattr(field, 'required', False)
default = getattr(field, 'default', False) default = getattr(field, 'default', None)
if default and isinstance(default, (str, int)): if default is not None and default != empty:
if isinstance(default, (str, int, bool, datetime.datetime, list)):
field_info['default'] = default field_info['default'] = default
for attr in self.attrs: for attr in self.attrs:
@ -95,6 +97,8 @@ class SimpleMetadataWithFilters(SimpleMetadata):
fields = view.filterset_fields fields = view.filterset_fields
elif hasattr(view, 'get_filterset_fields'): elif hasattr(view, 'get_filterset_fields'):
fields = view.get_filterset_fields(request) fields = view.get_filterset_fields(request)
elif hasattr(view, 'filterset_class'):
fields = view.filterset_class.Meta.fields
if isinstance(fields, dict): if isinstance(fields, dict):
fields = list(fields.keys()) fields = list(fields.keys())

View File

@ -22,6 +22,7 @@ class BaseFileParser(BaseParser):
FILE_CONTENT_MAX_LENGTH = 1024 * 1024 * 10 FILE_CONTENT_MAX_LENGTH = 1024 * 1024 * 10
serializer_cls = None serializer_cls = None
serializer_fields = None
def check_content_length(self, meta): def check_content_length(self, meta):
content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0))) content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)))
@ -45,7 +46,7 @@ class BaseFileParser(BaseParser):
def convert_to_field_names(self, column_titles): def convert_to_field_names(self, column_titles):
fields_map = {} fields_map = {}
fields = self.serializer_cls().fields fields = self.serializer_fields
fields_map.update({v.label: k for k, v in fields.items()}) fields_map.update({v.label: k for k, v in fields.items()})
fields_map.update({k: k for k, _ in fields.items()}) fields_map.update({k: k for k, _ in fields.items()})
field_names = [ field_names = [
@ -89,7 +90,7 @@ class BaseFileParser(BaseParser):
构建json数据后的行数据处理 构建json数据后的行数据处理
""" """
new_row_data = {} new_row_data = {}
serializer_fields = self.serializer_cls().fields serializer_fields = self.serializer_fields
for k, v in row_data.items(): for k, v in row_data.items():
if isinstance(v, list) or isinstance(v, dict) or isinstance(v, str) and k.strip() and v.strip(): if isinstance(v, list) or isinstance(v, dict) or isinstance(v, str) and k.strip() and v.strip():
# 解决类似disk_info为字符串的'{}'的问题 # 解决类似disk_info为字符串的'{}'的问题
@ -111,12 +112,15 @@ class BaseFileParser(BaseParser):
return data return data
def parse(self, stream, media_type=None, parser_context=None): def parse(self, stream, media_type=None, parser_context=None):
parser_context = parser_context or {} assert parser_context is not None, '`parser_context` should not be `None`'
view = parser_context['view']
request = view.request
try: try:
view = parser_context['view'] meta = request.META
meta = view.request.META
self.serializer_cls = view.get_serializer_class() self.serializer_cls = view.get_serializer_class()
self.serializer_fields = self.serializer_cls().fields
except Exception as e: except Exception as e:
logger.debug(e, exc_info=True) logger.debug(e, exc_info=True)
raise ParseError('The resource does not support imports!') raise ParseError('The resource does not support imports!')
@ -128,6 +132,13 @@ class BaseFileParser(BaseParser):
rows = self.generate_rows(stream_data) rows = self.generate_rows(stream_data)
column_titles = self.get_column_titles(rows) column_titles = self.get_column_titles(rows)
field_names = self.convert_to_field_names(column_titles) field_names = self.convert_to_field_names(column_titles)
# 给 `common.mixins.api.RenderToJsonMixin` 提供,暂时只能耦合
column_title_field_pairs = list(zip(column_titles, field_names))
if not hasattr(request, 'jms_context'):
request.jms_context = {}
request.jms_context['column_title_field_pairs'] = column_title_field_pairs
data = self.generate_data(field_names, rows) data = self.generate_data(field_names, rows)
return data return data
except Exception as e: except Exception as e:

View File

@ -1,40 +1,7 @@
# -*- coding: utf-8 -*- from werkzeug.local import Local
#
from jumpserver.const import DYNAMIC
from werkzeug.local import Local, LocalProxy
thread_local = Local() thread_local = Local()
def _find(attr): def _find(attr):
return getattr(thread_local, attr, None) return getattr(thread_local, attr, None)
class _Settings:
pass
def get_dynamic_cfg_from_thread_local():
KEY = 'dynamic_config'
try:
cfg = getattr(thread_local, KEY)
except AttributeError:
cfg = _Settings()
setattr(thread_local, KEY, cfg)
return cfg
class DynamicDefaultLocalProxy(LocalProxy):
def __getattr__(self, item):
try:
value = super().__getattr__(item)
except AttributeError:
value = getattr(DYNAMIC, item)()
setattr(self, item, value)
return value
LOCAL_DYNAMIC_SETTINGS = DynamicDefaultLocalProxy(get_dynamic_cfg_from_thread_local)

View File

@ -11,13 +11,16 @@ from django.core.cache import cache
from django.http import JsonResponse from django.http import JsonResponse
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.decorators import action
from rest_framework.request import Request
from common.const.http import POST
from common.drf.filters import IDSpmFilter, CustomFilter, IDInFilter from common.drf.filters import IDSpmFilter, CustomFilter, IDInFilter
from ..utils import lazyproperty from ..utils import lazyproperty
__all__ = [ __all__ = [
'JSONResponseMixin', 'CommonApiMixin', 'AsyncApiMixin', 'RelationMixin', 'JSONResponseMixin', 'CommonApiMixin', 'AsyncApiMixin', 'RelationMixin',
'SerializerMixin2', 'QuerySetMixin', 'ExtraFilterFieldsMixin' 'SerializerMixin2', 'QuerySetMixin', 'ExtraFilterFieldsMixin', 'RenderToJsonMixin',
] ]
@ -32,6 +35,21 @@ class JSONResponseMixin(object):
# ---------------------- # ----------------------
class RenderToJsonMixin:
@action(methods=[POST], detail=False, url_path='render-to-json')
def render_to_json(self, request: Request):
data = {
'title': (),
'data': request.data,
}
jms_context = getattr(request, 'jms_context', {})
column_title_field_pairs = jms_context.get('column_title_field_pairs', ())
data['title'] = column_title_field_pairs
return Response(data=data)
class SerializerMixin: class SerializerMixin:
""" 根据用户请求动作的不同,获取不同的 `serializer_class `""" """ 根据用户请求动作的不同,获取不同的 `serializer_class `"""
@ -98,7 +116,7 @@ class PaginatedResponseMixin:
return Response(serializer.data) return Response(serializer.data)
class CommonApiMixin(SerializerMixin, ExtraFilterFieldsMixin): class CommonApiMixin(SerializerMixin, ExtraFilterFieldsMixin, RenderToJsonMixin):
pass pass

View File

@ -2,7 +2,7 @@
# #
from collections import Iterable from collections import Iterable
from django.db.models import Prefetch, F from django.db.models import Prefetch, F, NOT_PROVIDED
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from rest_framework.utils import html from rest_framework.utils import html
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -71,7 +71,7 @@ class BulkListSerializerMixin(object):
""" """
List of dicts of native values <- List of dicts of primitive datatypes. List of dicts of native values <- List of dicts of primitive datatypes.
""" """
if not self.instance: if self.instance is None:
return super().to_internal_value(data) return super().to_internal_value(data)
if html.is_html_input(data): if html.is_html_input(data):
@ -106,7 +106,7 @@ class BulkListSerializerMixin(object):
pk = item["pk"] pk = item["pk"]
else: else:
raise ValidationError("id or pk not in data") raise ValidationError("id or pk not in data")
child = self.instance.get(id=pk) if self.instance else None child = self.instance.get(id=pk)
self.child.instance = child self.child.instance = child
self.child.initial_data = item self.child.initial_data = item
# raw # raw
@ -228,7 +228,43 @@ class SizedModelFieldsMixin(BaseDynamicFieldsPlugin):
return fields_to_drop return fields_to_drop
class DefaultValueFieldsMixin:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.set_fields_default_value()
def set_fields_default_value(self):
if not hasattr(self, 'Meta'):
return
if not hasattr(self.Meta, 'model'):
return
model = self.Meta.model
for name, serializer_field in self.fields.items():
if serializer_field.default != empty or serializer_field.required:
continue
model_field = getattr(model, name, None)
if model_field is None:
continue
if not hasattr(model_field, 'field') \
or not hasattr(model_field.field, 'default') \
or model_field.field.default == NOT_PROVIDED:
continue
if name == 'id':
continue
default = model_field.field.default
if callable(default):
default = default()
if default == '':
continue
# print(f"Set default value: {name}: {default}")
serializer_field.default = default
class DynamicFieldsMixin: class DynamicFieldsMixin:
"""
可以控制显示不同的字段mini 最少small 不包含关系
"""
dynamic_fields_plugins = [QueryFieldsMixin, SizedModelFieldsMixin] dynamic_fields_plugins = [QueryFieldsMixin, SizedModelFieldsMixin]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -256,7 +292,7 @@ class EagerLoadQuerySetFields:
return queryset return queryset
class CommonSerializerMixin(DynamicFieldsMixin): class CommonSerializerMixin(DynamicFieldsMixin, DefaultValueFieldsMixin):
pass pass

View File

@ -1,7 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import time import time
from rest_framework import permissions from rest_framework import permissions
from django.contrib.auth.mixins import UserPassesTestMixin from django.contrib.auth.mixins import UserPassesTestMixin
from django.conf import settings from django.conf import settings
@ -97,7 +96,7 @@ class WithBootstrapToken(permissions.BasePermission):
class PermissionsMixin(UserPassesTestMixin): class PermissionsMixin(UserPassesTestMixin):
permission_classes = [] permission_classes = [permissions.IsAuthenticated]
def get_permissions(self): def get_permissions(self):
return self.permission_classes return self.permission_classes
@ -110,12 +109,17 @@ class PermissionsMixin(UserPassesTestMixin):
return True return True
class UserCanUpdatePassword: class UserCanUseCurrentOrg(permissions.BasePermission):
def has_permission(self, request, view):
return current_org.can_use_by(request.user)
class UserCanUpdatePassword(permissions.BasePermission):
def has_permission(self, request, view): def has_permission(self, request, view):
return request.user.can_update_password() return request.user.can_update_password()
class UserCanUpdateSSHKey: class UserCanUpdateSSHKey(permissions.BasePermission):
def has_permission(self, request, view): def has_permission(self, request, view):
return request.user.can_update_ssh_key() return request.user.can_update_ssh_key()
@ -188,3 +192,12 @@ class IsObjectOwner(IsValidUser):
def has_object_permission(self, request, view, obj): def has_object_permission(self, request, view, obj):
return (super().has_object_permission(request, view, obj) and return (super().has_object_permission(request, view, obj) and
request.user == getattr(obj, 'user', None)) request.user == getattr(obj, 'user', None))
class HasQueryParamsUserAndIsCurrentOrgMember(permissions.BasePermission):
def has_permission(self, request, view):
query_user_id = request.query_params.get('user')
if not query_user_id:
return False
query_user = current_org.get_members().filter(id=query_user_id).first()
return bool(query_user)

View File

@ -5,16 +5,12 @@ import os
import logging import logging
from collections import defaultdict from collections import defaultdict
from django.conf import settings from django.conf import settings
from django.dispatch import receiver
from django.core.signals import request_finished from django.core.signals import request_finished
from django.db import connection from django.db import connection
from django.conf import LazySettings
from django.db.utils import ProgrammingError, OperationalError
from jumpserver.utils import get_current_request from jumpserver.utils import get_current_request
from .local import thread_local from .local import thread_local
from .signals import django_ready
pattern = re.compile(r'FROM `(\w+)`') pattern = re.compile(r'FROM `(\w+)`')
logger = logging.getLogger("jumpserver.common") logger = logging.getLogger("jumpserver.common")
@ -74,17 +70,3 @@ if settings.DEBUG and DEBUG_DB:
request_finished.connect(on_request_finished_logging_db_query) request_finished.connect(on_request_finished_logging_db_query)
else: else:
request_finished.connect(on_request_finished_release_local) request_finished.connect(on_request_finished_release_local)
@receiver(django_ready)
def monkey_patch_settings(sender, **kwargs):
def monkey_patch_getattr(self, name):
val = getattr(self._wrapped, name)
if callable(val):
val = val()
return val
try:
LazySettings.__getattr__ = monkey_patch_getattr
except (ProgrammingError, OperationalError):
pass

View File

@ -7,3 +7,4 @@ from .encode import *
from .http import * from .http import *
from .ipip import * from .ipip import *
from .crypto import * from .crypto import *
from .random import *

View File

@ -1,18 +1,17 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import re import re
import data_tree
from collections import OrderedDict from collections import OrderedDict
from itertools import chain from itertools import chain
import logging import logging
import datetime import datetime
import uuid import uuid
from functools import wraps from functools import wraps
import string
import random
import time import time
import ipaddress import ipaddress
import psutil import psutil
from django.utils.translation import ugettext_lazy as _
from ..exceptions import JMSException
UUID_PATTERN = re.compile(r'\w{8}(-\w{4}){3}-\w{12}') UUID_PATTERN = re.compile(r'\w{8}(-\w{4}){3}-\w{12}')
@ -143,7 +142,7 @@ def is_uuid(seq):
elif isinstance(seq, str) and UUID_PATTERN.match(seq): elif isinstance(seq, str) and UUID_PATTERN.match(seq):
return True return True
elif isinstance(seq, (list, tuple)): elif isinstance(seq, (list, tuple)):
all([is_uuid(x) for x in seq]) return all([is_uuid(x) for x in seq])
return False return False
@ -194,23 +193,17 @@ def with_cache(func):
return wrapper return wrapper
def random_string(length):
import string
import random
charset = string.ascii_letters + string.digits
s = [random.choice(charset) for i in range(length)]
return ''.join(s)
logger = get_logger(__name__) logger = get_logger(__name__)
def timeit(func): def timeit(func):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if hasattr(func, '__name__'):
name = func.__name__
else:
name = func name = func
for attr in ('__qualname__', '__name__'):
if hasattr(func, attr):
name = getattr(func, attr)
break
logger.debug("Start call: {}".format(name)) logger.debug("Start call: {}".format(name))
now = time.time() now = time.time()
result = func(*args, **kwargs) result = func(*args, **kwargs)
@ -254,3 +247,29 @@ def get_disk_usage():
mount_points = [p.mountpoint for p in partitions] mount_points = [p.mountpoint for p in partitions]
usages = {p: psutil.disk_usage(p) for p in mount_points} usages = {p: psutil.disk_usage(p) for p in mount_points}
return usages return usages
class Time:
def __init__(self):
self._timestamps = []
self._msgs = []
def begin(self):
self._timestamps.append(time.time())
def time(self, msg):
self._timestamps.append(time.time())
self._msgs.append(msg)
def print(self):
last, *timestamps = self._timestamps
for timestamp, msg in zip(timestamps, self._msgs):
logger.debug(f'TIME_IT: {msg} {timestamp-last}')
last = timestamp
def bulk_get(d, *keys, default=None):
values = []
for key in keys:
values.append(d.get(key, default))
return values

View File

@ -0,0 +1,27 @@
import redis
from django.conf import settings
def get_redis_client(db):
rc = redis.StrictRedis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
password=settings.REDIS_PASSWORD,
db=db
)
return rc
class RedisPubSub:
def __init__(self, ch, db=10):
self.ch = ch
self.redis = get_redis_client(db)
def subscribe(self):
ps = self.redis.pubsub()
ps.subscribe(self.ch)
return ps
def publish(self, data):
self.redis.publish(self.ch, data)
return True

View File

@ -1,7 +1,7 @@
import base64 import base64
from Crypto.Cipher import AES from Cryptodome.Cipher import AES
from Crypto.Util.Padding import pad from Cryptodome.Util.Padding import pad
from Crypto.Random import get_random_bytes from Cryptodome.Random import get_random_bytes
from gmssl.sm4 import CryptSM4, SM4_ENCRYPT, SM4_DECRYPT from gmssl.sm4 import CryptSM4, SM4_ENCRYPT, SM4_DECRYPT
from django.conf import settings from django.conf import settings

View File

@ -1,12 +1,14 @@
from functools import wraps from functools import wraps
import threading import threading
from redis_lock import Lock as RedisLock from redis_lock import Lock as RedisLock, NotAcquired
from redis import Redis from redis import Redis
from django.db import transaction
from common.utils import get_logger from common.utils import get_logger
from common.utils.inspect import copy_function_args from common.utils.inspect import copy_function_args
from apps.jumpserver.const import CONFIG from jumpserver.const import CONFIG
from common.local import thread_local
logger = get_logger(__file__) logger = get_logger(__file__)
@ -15,37 +17,49 @@ class AcquireFailed(RuntimeError):
pass pass
class LockHasTimeOut(RuntimeError):
pass
class DistributedLock(RedisLock): class DistributedLock(RedisLock):
def __init__(self, name, blocking=True, expire=60*2, auto_renewal=True): def __init__(self, name, *, expire=None, release_on_transaction_commit=False,
reentrant=False, release_raise_exc=False, auto_renewal_seconds=60):
""" """
使用 redis 构造的分布式锁 使用 redis 构造的分布式锁
:param name: :param name:
锁的名字要全局唯一 锁的名字要全局唯一
:param blocking:
该参数只在锁作为装饰器或者 `with` 时有效
:param expire: :param expire:
锁的过期时间注意不一定是锁到这个时间就释放了分两种情况 锁的过期时间
`auto_renewal=False` 锁会释放 :param release_on_transaction_commit:
`auto_renewal=True` 如果过期之前程序还没释放锁我们会延长锁的存活时间 是否在当前事务结束后再释放锁
这里的作用是防止程序意外终止没有释放锁导致死锁 :param release_raise_exc:
释放锁时如果没有持有锁是否抛异常或静默
:param auto_renewal_seconds:
当持有一个无限期锁的时候刷新锁的时间具体参考 `redis_lock.Lock#auto_renewal`
:param reentrant:
是否可重入
""" """
self.kwargs_copy = copy_function_args(self.__init__, locals()) self.kwargs_copy = copy_function_args(self.__init__, locals())
redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD) redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
if expire is None:
expire = auto_renewal_seconds
auto_renewal = True
else:
auto_renewal = False
super().__init__(redis_client=redis, name=name, expire=expire, auto_renewal=auto_renewal) super().__init__(redis_client=redis, name=name, expire=expire, auto_renewal=auto_renewal)
self._blocking = blocking self._release_on_transaction_commit = release_on_transaction_commit
self._release_raise_exc = release_raise_exc
self._reentrant = reentrant
self._acquired_reentrant_lock = False
self._thread_id = threading.current_thread().ident
def __enter__(self): def __enter__(self):
thread_id = threading.current_thread().ident acquired = self.acquire(blocking=True)
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> attempt to acquire <lock:{self._name}> ...')
acquired = self.acquire(blocking=self._blocking)
if self._blocking and not acquired:
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> was not acquired <lock:{self._name}>, but blocking=True')
raise EnvironmentError("Lock wasn't acquired, but blocking=True")
if not acquired: if not acquired:
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> acquire <lock:{self._name}> failed')
raise AcquireFailed raise AcquireFailed
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> acquire <lock:{self._name}> ok')
return self return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None): def __exit__(self, exc_type=None, exc_value=None, traceback=None):
@ -57,5 +71,114 @@ class DistributedLock(RedisLock):
# 要创建一个新的锁对象 # 要创建一个新的锁对象
with self.__class__(**self.kwargs_copy): with self.__class__(**self.kwargs_copy):
return func(*args, **kwds) return func(*args, **kwds)
return inner return inner
def locked_by_me(self):
if self.locked():
if self.get_owner_id() == self.id:
return True
return False
def locked_by_current_thread(self):
if self.locked():
owner_id = self.get_owner_id()
local_owner_id = getattr(thread_local, self.name, None)
if local_owner_id and owner_id == local_owner_id:
return True
return False
def acquire(self, blocking=True, timeout=None):
if self._reentrant:
if self.locked_by_current_thread():
self._acquired_reentrant_lock = True
logger.debug(
f'Reentry lock ok: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name} thread={self._thread_id}')
return True
logger.debug(f'Attempt acquire reentrant-lock: lock_id={self.id} lock={self.name} thread={self._thread_id}')
acquired = super().acquire(blocking=blocking, timeout=timeout)
if acquired:
logger.debug(f'Acquired reentrant-lock ok: lock_id={self.id} lock={self.name} thread={self._thread_id}')
setattr(thread_local, self.name, self.id)
else:
logger.debug(f'Acquired reentrant-lock failed: lock_id={self.id} lock={self.name} thread={self._thread_id}')
return acquired
else:
logger.debug(f'Attempt acquire lock: lock_id={self.id} lock={self.name} thread={self._thread_id}')
acquired = super().acquire(blocking=blocking, timeout=timeout)
logger.debug(f'Acquired lock: ok={acquired} lock_id={self.id} lock={self.name} thread={self._thread_id}')
return acquired
@property
def name(self):
return self._name
def _raise_exc_with_log(self, msg, *, exc_cls=NotAcquired):
e = exc_cls(msg)
logger.error(msg)
self._raise_exc(e)
def _raise_exc(self, e):
if self._release_raise_exc:
raise e
def _release_on_reentrant_locked_by_brother(self):
if self._acquired_reentrant_lock:
self._acquired_reentrant_lock = False
logger.debug(f'Released reentrant-lock: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name} thread={self._thread_id}')
return
else:
self._raise_exc_with_log(f'Reentrant-lock is not acquired: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name} thread={self._thread_id}')
def _release_on_reentrant_locked_by_me(self):
logger.debug(f'Release reentrant-lock locked by me: lock_id={self.id} lock={self.name} thread={self._thread_id}')
id = getattr(thread_local, self.name, None)
if id != self.id:
raise PermissionError(f'Reentrant-lock is not locked by me: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name} thread={self._thread_id}')
try:
# 这里要保证先删除 thread_local 的标记,
delattr(thread_local, self.name)
except AttributeError:
pass
finally:
try:
# 这里处理的是边界情况,
# 判断锁是我的 -> 锁超时 -> 释放锁报错
# 此时的报错应该被静默
self._release_redis_lock()
except NotAcquired:
pass
def _release_redis_lock(self):
# 最底层 api
super().release()
def _release(self):
try:
self._release_redis_lock()
logger.debug(f'Released lock: lock_id={self.id} lock={self.name} thread={self._thread_id}')
except NotAcquired as e:
logger.error(f'Release lock failed: lock_id={self.id} lock={self.name} thread={self._thread_id} error: {e}')
self._raise_exc(e)
def release(self):
_release = self._release
# 处理可重入锁
if self._reentrant:
if self.locked_by_current_thread():
if self.locked_by_me():
_release = self._release_on_reentrant_locked_by_me
else:
_release = self._release_on_reentrant_locked_by_brother
else:
self._raise_exc_with_log(f'Reentrant-lock is not acquired: lock_id={self.id} lock={self.name} thread={self._thread_id}')
# 处理是否在事务提交时才释放锁
if self._release_on_transaction_commit:
logger.debug(f'Release lock on transaction commit ... :lock_id={self.id} lock={self.name} thread={self._thread_id}')
transaction.on_commit(_release)
else:
_release()

View File

@ -1,8 +1,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import socket
import struct import struct
import random import random
import socket
import string
import secrets
string_punctuation = '!#$%&()*+,-.:;<=>?@[]^_{}~'
def random_datetime(date_start, date_end): def random_datetime(date_start, date_end):
@ -14,6 +19,29 @@ def random_ip():
return socket.inet_ntoa(struct.pack('>I', random.randint(1, 0xffffffff))) return socket.inet_ntoa(struct.pack('>I', random.randint(1, 0xffffffff)))
def random_string(length, lower=True, upper=True, digit=True, special_char=False):
chars = string.ascii_letters
if digit:
chars += string.digits
while True:
password = list(random.choice(chars) for i in range(length))
if upper and not any(c.upper() for c in password):
continue
if lower and not any(c.lower() for c in password):
continue
if digit and not any(c.isdigit() for c in password):
continue
break
if special_char:
spc = random.choice(string_punctuation)
i = random.choice(range(len(password)))
password[i] = spc
password = ''.join(password)
return password
# def strTimeProp(start, end, prop, fmt): # def strTimeProp(start, end, prop, fmt):
# time_start = time.mktime(time.strptime(start, fmt)) # time_start = time.mktime(time.strptime(start, fmt))

View File

@ -4,6 +4,7 @@ from django.utils.timesince import timesince
from django.db.models import Count, Max from django.db.models import Count, Max
from django.http.response import JsonResponse, HttpResponse from django.http.response import JsonResponse, HttpResponse
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.permissions import AllowAny
from collections import Counter from collections import Counter
from users.models import User from users.models import User
@ -307,7 +308,7 @@ class IndexApi(TotalCountMixin, DatesLoginMetricMixin, APIView):
class PrometheusMetricsApi(APIView): class PrometheusMetricsApi(APIView):
permission_classes = () permission_classes = (AllowAny,)
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
util = ComponentsPrometheusMetricsUtil() util = ComponentsPrometheusMetricsUtil()

View File

@ -280,7 +280,14 @@ class Config(dict):
'SESSION_COOKIE_SECURE': False, 'SESSION_COOKIE_SECURE': False,
'CSRF_COOKIE_SECURE': False, 'CSRF_COOKIE_SECURE': False,
'REFERER_CHECK_ENABLED': False, 'REFERER_CHECK_ENABLED': False,
'SERVER_REPLAY_STORAGE': {} 'SERVER_REPLAY_STORAGE': {},
'CONNECTION_TOKEN_ENABLED': False,
'ONLY_ALLOW_EXIST_USER_AUTH': False,
'ONLY_ALLOW_AUTH_FROM_SOURCE': True,
'DISK_CHECK_ENABLED': True,
'SESSION_SAVE_EVERY_REQUEST': True,
'SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE': False,
'FORGOT_PASSWORD_URL': '',
} }
def compatible_auth_openid_of_key(self): def compatible_auth_openid_of_key(self):
@ -426,98 +433,6 @@ class Config(dict):
return self.get(item) return self.get(item)
class DynamicConfig:
def __init__(self, static_config):
self.static_config = static_config
self.db_setting = None
def __getitem__(self, item):
return self.dynamic(item)
def __getattr__(self, item):
return self.dynamic(item)
def dynamic(self, item):
return lambda: self.get(item)
def LOGIN_URL(self):
return self.get('LOGIN_URL')
def AUTHENTICATION_BACKENDS(self):
backends = [
'authentication.backends.pubkey.PublicKeyAuthBackend',
'django.contrib.auth.backends.ModelBackend',
]
if self.get('AUTH_LDAP'):
backends.insert(0, 'authentication.backends.ldap.LDAPAuthorizationBackend')
if self.static_config.get('AUTH_CAS'):
backends.insert(0, 'authentication.backends.cas.CASBackend')
if self.static_config.get('AUTH_OPENID'):
backends.insert(0, 'jms_oidc_rp.backends.OIDCAuthPasswordBackend')
backends.insert(0, 'jms_oidc_rp.backends.OIDCAuthCodeBackend')
if self.static_config.get('AUTH_RADIUS'):
backends.insert(0, 'authentication.backends.radius.RadiusBackend')
if self.static_config.get('AUTH_SSO'):
backends.insert(0, 'authentication.backends.api.SSOAuthentication')
return backends
def XPACK_LICENSE_IS_VALID(self):
if not HAS_XPACK:
return False
try:
from xpack.plugins.license.models import License
return License.has_valid_license()
except:
return False
def XPACK_INTERFACE_LOGIN_TITLE(self):
default_title = _('Welcome to the JumpServer open source fortress')
if not HAS_XPACK:
return default_title
try:
from xpack.plugins.interface.models import Interface
return Interface.get_login_title()
except:
return default_title
def LOGO_URLS(self):
logo_urls = {'logo_logout': static('img/logo.png'),
'logo_index': static('img/logo_text.png'),
'login_image': static('img/login_image.png'),
'favicon': static('img/facio.ico')}
if not HAS_XPACK:
return logo_urls
try:
from xpack.plugins.interface.models import Interface
obj = Interface.interface()
if obj:
if obj.logo_logout:
logo_urls.update({'logo_logout': obj.logo_logout.url})
if obj.logo_index:
logo_urls.update({'logo_index': obj.logo_index.url})
if obj.login_image:
logo_urls.update({'login_image': obj.login_image.url})
if obj.favicon:
logo_urls.update({'favicon': obj.favicon.url})
except:
pass
return logo_urls
def get_from_db(self, item):
if self.db_setting is not None:
value = self.db_setting.get(item)
if value is not None:
return value
return None
def get(self, item):
# 先从数据库中获取
value = self.get_from_db(item)
if value is not None:
return value
return self.static_config.get(item)
class ConfigManager: class ConfigManager:
config_class = Config config_class = Config
@ -694,7 +609,3 @@ class ConfigManager:
# 对config进行兼容处理 # 对config进行兼容处理
config.compatible() config.compatible()
return config return config
@classmethod
def get_dynamic_config(cls, config):
return DynamicConfig(config)

View File

@ -4,12 +4,11 @@ import os
from .conf import ConfigManager from .conf import ConfigManager
__all__ = ['BASE_DIR', 'PROJECT_DIR', 'VERSION', 'CONFIG', 'DYNAMIC'] __all__ = ['BASE_DIR', 'PROJECT_DIR', 'VERSION', 'CONFIG']
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
PROJECT_DIR = os.path.dirname(BASE_DIR) PROJECT_DIR = os.path.dirname(BASE_DIR)
VERSION = '2.0.0' VERSION = '2.0.0'
CONFIG = ConfigManager.load_user_config() CONFIG = ConfigManager.load_user_config()
DYNAMIC = ConfigManager.get_dynamic_config(CONFIG)

Some files were not shown because too many files have changed in this diff Show More