Merge pull request #5728 from jumpserver/dev

v2.8 发版
pull/5813/head
Jiangjie.Bai 2021-03-11 21:17:39 +08:00 committed by GitHub
commit 174cc16980
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
234 changed files with 7528 additions and 4481 deletions

View File

@ -23,6 +23,7 @@ RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list \
&& sed -i 's/security.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list \
&& apt update \
&& grep -v '^#' ./requirements/deb_buster_requirements.txt | xargs apt -y install \
&& rm -rf /var/lib/apt/lists/* \
&& localedef -c -f UTF-8 -i zh_CN zh_CN.UTF-8 \
&& cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime

3
apps/acls/admin.py Normal file
View File

@ -0,0 +1,3 @@
from django.contrib import admin
# Register your models here.

View File

@ -0,0 +1,3 @@
from .login_acl import *
from .login_asset_acl import *
from .login_asset_check import *

View File

@ -0,0 +1,19 @@
from common.permissions import IsOrgAdmin, HasQueryParamsUserAndIsCurrentOrgMember
from common.drf.api import JMSBulkModelViewSet
from ..models import LoginACL
from .. import serializers
__all__ = ['LoginACLViewSet', ]
class LoginACLViewSet(JMSBulkModelViewSet):
queryset = LoginACL.objects.all()
filterset_fields = ('name', 'user', )
search_fields = filterset_fields
permission_classes = (IsOrgAdmin, )
serializer_class = serializers.LoginACLSerializer
def get_permissions(self):
if self.action in ["retrieve", "list"]:
self.permission_classes = (IsOrgAdmin, HasQueryParamsUserAndIsCurrentOrgMember)
return super().get_permissions()

View File

@ -0,0 +1,15 @@
from orgs.mixins.api import OrgBulkModelViewSet
from common.permissions import IsOrgAdmin
from .. import models, serializers
__all__ = ['LoginAssetACLViewSet']
class LoginAssetACLViewSet(OrgBulkModelViewSet):
model = models.LoginAssetACL
filterset_fields = ('name', )
search_fields = filterset_fields
permission_classes = (IsOrgAdmin, )
serializer_class = serializers.LoginAssetACLSerializer

View File

@ -0,0 +1,105 @@
from django.shortcuts import get_object_or_404
from rest_framework.response import Response
from rest_framework.generics import CreateAPIView, RetrieveDestroyAPIView
from common.permissions import IsAppUser
from common.utils import reverse, lazyproperty
from orgs.utils import tmp_to_org, tmp_to_root_org
from tickets.models import Ticket
from ..models import LoginAssetACL
from .. import serializers
__all__ = ['LoginAssetCheckAPI', 'LoginAssetConfirmStatusAPI']
class LoginAssetCheckAPI(CreateAPIView):
permission_classes = (IsAppUser, )
serializer_class = serializers.LoginAssetCheckSerializer
def create(self, request, *args, **kwargs):
is_need_confirm, response_data = self.check_if_need_confirm()
return Response(data=response_data, status=200)
def check_if_need_confirm(self):
queries = {
'user': self.serializer.user, 'asset': self.serializer.asset,
'system_user': self.serializer.system_user,
'action': LoginAssetACL.ActionChoices.login_confirm
}
with tmp_to_org(self.serializer.org):
acl = LoginAssetACL.filter(**queries).valid().first()
if not acl:
is_need_confirm = False
response_data = {}
else:
is_need_confirm = True
response_data = self._get_response_data_of_need_confirm(acl)
response_data['need_confirm'] = is_need_confirm
return is_need_confirm, response_data
def _get_response_data_of_need_confirm(self, acl):
ticket = LoginAssetACL.create_login_asset_confirm_ticket(
user=self.serializer.user,
asset=self.serializer.asset,
system_user=self.serializer.system_user,
assignees=acl.reviewers.all(),
org_id=self.serializer.org.id
)
confirm_status_url = reverse(
view_name='acls:login-asset-confirm-status',
kwargs={'pk': str(ticket.id)}
)
ticket_detail_url = reverse(
view_name='api-tickets:ticket-detail',
kwargs={'pk': str(ticket.id)},
external=True, api_to_ui=True
)
ticket_detail_url = '{url}?type={type}'.format(url=ticket_detail_url, type=ticket.type)
data = {
'check_confirm_status': {'method': 'GET', 'url': confirm_status_url},
'close_confirm': {'method': 'DELETE', 'url': confirm_status_url},
'ticket_detail_url': ticket_detail_url,
'reviewers': [str(user) for user in ticket.assignees.all()],
}
return data
@lazyproperty
def serializer(self):
serializer = self.get_serializer(data=self.request.data)
serializer.is_valid(raise_exception=True)
return serializer
class LoginAssetConfirmStatusAPI(RetrieveDestroyAPIView):
permission_classes = (IsAppUser, )
def retrieve(self, request, *args, **kwargs):
if self.ticket.action_open:
status = 'await'
elif self.ticket.action_approve:
status = 'approve'
else:
status = 'reject'
data = {
'status': status,
'action': self.ticket.action,
'processor': self.ticket.processor_display
}
return Response(data=data, status=200)
def destroy(self, request, *args, **kwargs):
if self.ticket.status_open:
self.ticket.close(processor=self.ticket.applicant)
data = {
'action': self.ticket.action,
'status': self.ticket.status,
'processor': self.ticket.processor_display
}
return Response(data=data, status=200)
@lazyproperty
def ticket(self):
with tmp_to_root_org():
return get_object_or_404(Ticket, pk=self.kwargs['pk'])

5
apps/acls/apps.py Normal file
View File

@ -0,0 +1,5 @@
from django.apps import AppConfig
class AclsConfig(AppConfig):
name = 'acls'

9
apps/acls/const.py Normal file
View File

@ -0,0 +1,9 @@
from django.utils.translation import ugettext as _
common_help_text = _('Format for comma-delimited string, with * indicating a match all. ')
ip_group_help_text = common_help_text + _(
'Such as: '
'192.168.10.1, 192.168.1.0/24, 10.1.1.1-10.1.1.20, 2001:db8:2de::e13, 2001:db8:1a:1110::/64 '
)

View File

@ -0,0 +1,61 @@
# Generated by Django 3.1 on 2021-03-11 09:53
from django.conf import settings
import django.core.validators
from django.db import migrations, models
import django.db.models.deletion
import uuid
class Migration(migrations.Migration):
initial = True
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name='LoginACL',
fields=[
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('created_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Created by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('name', models.CharField(max_length=128, verbose_name='Name')),
('priority', models.IntegerField(default=50, help_text='1-100, the lower the value will be match first', validators=[django.core.validators.MinValueValidator(1), django.core.validators.MaxValueValidator(100)], verbose_name='Priority')),
('is_active', models.BooleanField(default=True, verbose_name='Active')),
('comment', models.TextField(blank=True, default='', verbose_name='Comment')),
('ip_group', models.JSONField(default=list, verbose_name='Login IP')),
('action', models.CharField(choices=[('reject', 'Reject'), ('allow', 'Allow')], default='reject', max_length=64, verbose_name='Action')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='login_acls', to=settings.AUTH_USER_MODEL, verbose_name='User')),
],
options={
'ordering': ('priority', '-date_updated', 'name'),
},
),
migrations.CreateModel(
name='LoginAssetACL',
fields=[
('org_id', models.CharField(blank=True, db_index=True, default='', max_length=36, verbose_name='Organization')),
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('created_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Created by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('name', models.CharField(max_length=128, verbose_name='Name')),
('priority', models.IntegerField(default=50, help_text='1-100, the lower the value will be match first', validators=[django.core.validators.MinValueValidator(1), django.core.validators.MaxValueValidator(100)], verbose_name='Priority')),
('is_active', models.BooleanField(default=True, verbose_name='Active')),
('comment', models.TextField(blank=True, default='', verbose_name='Comment')),
('users', models.JSONField(verbose_name='User')),
('system_users', models.JSONField(verbose_name='System User')),
('assets', models.JSONField(verbose_name='Asset')),
('action', models.CharField(choices=[('login_confirm', 'Login confirm')], default='login_confirm', max_length=64, verbose_name='Action')),
('reviewers', models.ManyToManyField(blank=True, related_name='review_login_asset_acls', to=settings.AUTH_USER_MODEL, verbose_name='Reviewers')),
],
options={
'ordering': ('priority', '-date_updated', 'name'),
'unique_together': {('name', 'org_id')},
},
),
]

View File

View File

@ -0,0 +1,2 @@
from .login_acl import *
from .login_asset_acl import *

35
apps/acls/models/base.py Normal file
View File

@ -0,0 +1,35 @@
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.core.validators import MinValueValidator, MaxValueValidator
from common.mixins import CommonModelMixin
__all__ = ['BaseACL', 'BaseACLQuerySet']
class BaseACLQuerySet(models.QuerySet):
def active(self):
return self.filter(is_active=True)
def inactive(self):
return self.filter(is_active=False)
def valid(self):
return self.active()
def invalid(self):
return self.inactive()
class BaseACL(CommonModelMixin):
name = models.CharField(max_length=128, verbose_name=_('Name'))
priority = models.IntegerField(
default=50, verbose_name=_("Priority"),
help_text=_("1-100, the lower the value will be match first"),
validators=[MinValueValidator(1), MaxValueValidator(100)]
)
is_active = models.BooleanField(default=True, verbose_name=_("Active"))
comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
class Meta:
abstract = True

View File

@ -0,0 +1,54 @@
from django.db import models
from django.utils.translation import ugettext_lazy as _
from .base import BaseACL, BaseACLQuerySet
from ..utils import contains_ip
class ACLManager(models.Manager):
def valid(self):
return self.get_queryset().valid()
class LoginACL(BaseACL):
class ActionChoices(models.TextChoices):
reject = 'reject', _('Reject')
allow = 'allow', _('Allow')
# 条件
ip_group = models.JSONField(default=list, verbose_name=_('Login IP'))
# 动作
action = models.CharField(
max_length=64, choices=ActionChoices.choices, default=ActionChoices.reject,
verbose_name=_('Action')
)
# 关联
user = models.ForeignKey(
'users.User', on_delete=models.CASCADE, related_name='login_acls', verbose_name=_('User')
)
objects = ACLManager.from_queryset(BaseACLQuerySet)()
class Meta:
ordering = ('priority', '-date_updated', 'name')
@property
def action_reject(self):
return self.action == self.ActionChoices.reject
@property
def action_allow(self):
return self.action == self.ActionChoices.allow
@staticmethod
def allow_user_to_login(user, ip):
acl = user.login_acls.valid().first()
if not acl:
return True
is_contained = contains_ip(ip, acl.ip_group)
if acl.action_allow and is_contained:
return True
if acl.action_reject and not is_contained:
return True
return False

View File

@ -0,0 +1,99 @@
from django.db import models
from django.db.models import Q
from django.utils.translation import ugettext_lazy as _
from orgs.mixins.models import OrgModelMixin, OrgManager
from .base import BaseACL, BaseACLQuerySet
from ..utils import contains_ip
class ACLManager(OrgManager):
def valid(self):
return self.get_queryset().valid()
class LoginAssetACL(BaseACL, OrgModelMixin):
class ActionChoices(models.TextChoices):
login_confirm = 'login_confirm', _('Login confirm')
# 条件
users = models.JSONField(verbose_name=_('User'))
system_users = models.JSONField(verbose_name=_('System User'))
assets = models.JSONField(verbose_name=_('Asset'))
# 动作
action = models.CharField(
max_length=64, choices=ActionChoices.choices, default=ActionChoices.login_confirm,
verbose_name=_('Action')
)
# 动作: 附加字段
# - login_confirm
reviewers = models.ManyToManyField(
'users.User', related_name='review_login_asset_acls', blank=True,
verbose_name=_("Reviewers")
)
objects = ACLManager.from_queryset(BaseACLQuerySet)()
class Meta:
unique_together = ('name', 'org_id')
ordering = ('priority', '-date_updated', 'name')
@classmethod
def filter(cls, user, asset, system_user, action):
queryset = cls.objects.filter(action=action)
queryset = cls.filter_user(user, queryset)
queryset = cls.filter_asset(asset, queryset)
queryset = cls.filter_system_user(system_user, queryset)
return queryset
@classmethod
def filter_user(cls, user, queryset):
queryset = queryset.filter(
Q(users__username_group__contains=user.username) |
Q(users__username_group__contains='*')
)
return queryset
@classmethod
def filter_asset(cls, asset, queryset):
queryset = queryset.filter(
Q(assets__hostname_group__contains=asset.hostname) |
Q(assets__hostname_group__contains='*')
)
ids = [q.id for q in queryset if contains_ip(asset.ip, q.assets.get('ip_group', []))]
queryset = cls.objects.filter(id__in=ids)
return queryset
@classmethod
def filter_system_user(cls, system_user, queryset):
queryset = queryset.filter(
Q(system_users__name_group__contains=system_user.name) |
Q(system_users__name_group__contains='*')
).filter(
Q(system_users__username_group__contains=system_user.username) |
Q(system_users__username_group__contains='*')
).filter(
Q(system_users__protocol_group__contains=system_user.protocol) |
Q(system_users__protocol_group__contains='*')
)
return queryset
@classmethod
def create_login_asset_confirm_ticket(cls, user, asset, system_user, assignees, org_id):
from tickets.const import TicketTypeChoices
from tickets.models import Ticket
data = {
'title': _('Login asset confirm') + ' ({})'.format(user),
'type': TicketTypeChoices.login_asset_confirm,
'meta': {
'apply_login_user': str(user),
'apply_login_asset': str(asset),
'apply_login_system_user': str(system_user),
},
'org_id': org_id,
}
ticket = Ticket.objects.create(**data)
ticket.assignees.set(assignees)
ticket.open(applicant=user)
return ticket

View File

@ -0,0 +1,3 @@
from .login_acl import *
from .login_asset_acl import *
from .login_asset_check import *

View File

@ -0,0 +1,49 @@
from django.utils.translation import ugettext as _
from rest_framework import serializers
from common.drf.serializers import BulkModelSerializer
from orgs.utils import current_org
from ..models import LoginACL
from ..utils import is_ip_address, is_ip_network, is_ip_segment
from .. import const
__all__ = ['LoginACLSerializer', ]
def ip_group_child_validator(ip_group_child):
is_valid = ip_group_child == '*' \
or is_ip_address(ip_group_child) \
or is_ip_network(ip_group_child) \
or is_ip_segment(ip_group_child)
if not is_valid:
error = _('IP address invalid: `{}`').format(ip_group_child)
raise serializers.ValidationError(error)
class LoginACLSerializer(BulkModelSerializer):
ip_group = serializers.ListField(
default=['*'], label=_('IP'), help_text=const.ip_group_help_text,
child=serializers.CharField(max_length=1024, validators=[ip_group_child_validator])
)
user_display = serializers.ReadOnlyField(source='user.name', label=_('User'))
action_display = serializers.ReadOnlyField(source='get_action_display', label=_('Action'))
class Meta:
model = LoginACL
fields = [
'id', 'name', 'priority', 'ip_group', 'user', 'user_display', 'action',
'action_display', 'is_active', 'comment', 'created_by', 'date_created', 'date_updated'
]
extra_kwargs = {
'priority': {'default': 50},
'is_active': {'default': True},
}
@staticmethod
def validate_user(user):
if user not in current_org.get_members():
error = _('The user `{}` is not in the current organization: `{}`').format(
user, current_org
)
raise serializers.ValidationError(error)
return user

View File

@ -0,0 +1,87 @@
from rest_framework import serializers
from django.utils.translation import ugettext as _
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from assets.models import SystemUser
from acls import models
from orgs.models import Organization
from .. import const
__all__ = ['LoginAssetACLSerializer']
class LoginAssetACLUsersSerializer(serializers.Serializer):
username_group = serializers.ListField(
default=['*'], child=serializers.CharField(max_length=128), label=_('Username'),
help_text=const.common_help_text
)
class LoginAssetACLAssestsSerializer(serializers.Serializer):
ip_group = serializers.ListField(
default=['*'], child=serializers.CharField(max_length=1024), label=_('IP'),
help_text=const.ip_group_help_text + _('(Domain name support)')
)
hostname_group = serializers.ListField(
default=['*'], child=serializers.CharField(max_length=128), label=_('Hostname'),
help_text=const.common_help_text
)
class LoginAssetACLSystemUsersSerializer(serializers.Serializer):
name_group = serializers.ListField(
default=['*'], child=serializers.CharField(max_length=128), label=_('Name'),
help_text=const.common_help_text
)
username_group = serializers.ListField(
default=['*'], child=serializers.CharField(max_length=128), label=_('Username'),
help_text=const.common_help_text
)
protocol_group = serializers.ListField(
default=['*'], child=serializers.CharField(max_length=16), label=_('Protocol'),
help_text=const.common_help_text + _('Protocol options: {}').format(
', '.join(SystemUser.ASSET_CATEGORY_PROTOCOLS)
)
)
@staticmethod
def validate_protocol_group(protocol_group):
unsupported_protocols = set(protocol_group) - set(SystemUser.ASSET_CATEGORY_PROTOCOLS + ['*'])
if unsupported_protocols:
error = _('Unsupported protocols: {}').format(unsupported_protocols)
raise serializers.ValidationError(error)
return protocol_group
class LoginAssetACLSerializer(BulkOrgResourceModelSerializer):
users = LoginAssetACLUsersSerializer()
assets = LoginAssetACLAssestsSerializer()
system_users = LoginAssetACLSystemUsersSerializer()
reviewers_amount = serializers.IntegerField(read_only=True, source='reviewers.count')
action_display = serializers.ReadOnlyField(source='get_action_display', label=_('Action'))
class Meta:
model = models.LoginAssetACL
fields = [
'id', 'name', 'priority', 'users', 'system_users', 'assets', 'action', 'action_display',
'is_active', 'comment', 'reviewers', 'reviewers_amount', 'created_by', 'date_created',
'date_updated', 'org_id'
]
extra_kwargs = {
"reviewers": {'allow_null': False, 'required': True},
'priority': {'default': 50},
'is_active': {'default': True},
}
def validate_reviewers(self, reviewers):
org_id = self.fields['org_id'].default()
org = Organization.get_instance(org_id)
if not org:
error = _('The organization `{}` does not exist'.format(org_id))
raise serializers.ValidationError(error)
users = org.get_members()
valid_reviewers = list(set(reviewers) & set(users))
if not valid_reviewers:
error = _('None of the reviewers belong to Organization `{}`'.format(org.name))
raise serializers.ValidationError(error)
return valid_reviewers

View File

@ -0,0 +1,71 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from orgs.utils import tmp_to_root_org
from common.utils import get_object_or_none, lazyproperty
from users.models import User
from assets.models import Asset, SystemUser
__all__ = ['LoginAssetCheckSerializer']
class LoginAssetCheckSerializer(serializers.Serializer):
user_id = serializers.UUIDField(required=True, allow_null=False)
asset_id = serializers.UUIDField(required=True, allow_null=False)
system_user_id = serializers.UUIDField(required=True, allow_null=False)
system_user_username = serializers.CharField(max_length=128, default='')
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.user = None
self.asset = None
self._system_user = None
self._system_user_username = None
def validate_user_id(self, user_id):
self.user = self.validate_object_exist(User, user_id)
return user_id
def validate_asset_id(self, asset_id):
self.asset = self.validate_object_exist(Asset, asset_id)
return asset_id
def validate_system_user_id(self, system_user_id):
self._system_user = self.validate_object_exist(SystemUser, system_user_id)
return system_user_id
def validate_system_user_username(self, system_user_username):
system_user_id = self.initial_data.get('system_user_id')
system_user = self.validate_object_exist(SystemUser, system_user_id)
if self._system_user.login_mode == SystemUser.LOGIN_MANUAL \
and not system_user.username \
and not system_user.username_same_with_user \
and not system_user_username:
error = 'Missing parameter: system_user_username'
raise serializers.ValidationError(error)
self._system_user_username = system_user_username
return system_user_username
@staticmethod
def validate_object_exist(model, field_id):
with tmp_to_root_org():
obj = get_object_or_none(model, pk=field_id)
if not obj:
error = '{} Model object does not exist'.format(model.__name__)
raise serializers.ValidationError(error)
return obj
@lazyproperty
def system_user(self):
if self._system_user.username_same_with_user:
username = self.user.username
elif self._system_user.login_mode == SystemUser.LOGIN_MANUAL:
username = self._system_user_username
else:
username = self._system_user.username
self._system_user.username = username
return self._system_user
@lazyproperty
def org(self):
return self.asset.org

3
apps/acls/tests.py Normal file
View File

@ -0,0 +1,3 @@
from django.test import TestCase
# Create your tests here.

View File

@ -0,0 +1 @@
from .api_urls import *

View File

@ -0,0 +1,18 @@
from django.urls import path
from rest_framework_bulk.routes import BulkRouter
from .. import api
app_name = 'acls'
router = BulkRouter()
router.register(r'login-acls', api.LoginACLViewSet, 'login-acl')
router.register(r'login-asset-acls', api.LoginAssetACLViewSet, 'login-asset-acl')
urlpatterns = [
path('login-asset/check/', api.LoginAssetCheckAPI.as_view(), name='login-asset-check'),
path('login-asset-confirm/<uuid:pk>/status/', api.LoginAssetConfirmStatusAPI.as_view(), name='login-asset-confirm-status')
]
urlpatterns += router.urls

68
apps/acls/utils.py Normal file
View File

@ -0,0 +1,68 @@
from ipaddress import ip_network, ip_address
def is_ip_address(address):
""" 192.168.10.1 """
try:
ip_address(address)
except ValueError:
return False
else:
return True
def is_ip_network(ip):
""" 192.168.1.0/24 """
try:
ip_network(ip)
except ValueError:
return False
else:
return True
def is_ip_segment(ip):
""" 10.1.1.1-10.1.1.20 """
if '-' not in ip:
return False
ip_address1, ip_address2 = ip.split('-')
return is_ip_address(ip_address1) and is_ip_address(ip_address2)
def in_ip_segment(ip, ip_segment):
ip1, ip2 = ip_segment.split('-')
ip1 = int(ip_address(ip1))
ip2 = int(ip_address(ip2))
ip = int(ip_address(ip))
return min(ip1, ip2) <= ip <= max(ip1, ip2)
def contains_ip(ip, ip_group):
"""
ip_group:
[192.168.10.1, 192.168.1.0/24, 10.1.1.1-10.1.1.20, 2001:db8:2de::e13, 2001:db8:1a:1110::/64.]
"""
if '*' in ip_group:
return True
for _ip in ip_group:
if is_ip_address(_ip):
# 192.168.10.1
if ip == _ip:
return True
elif is_ip_network(_ip) and is_ip_address(ip):
# 192.168.1.0/24
if ip_address(ip) in ip_network(_ip):
return True
elif is_ip_segment(_ip) and is_ip_address(ip):
# 10.1.1.1-10.1.1.20
if in_ip_segment(ip, _ip):
return True
else:
# is domain name
if ip == _ip:
return True
return False

View File

@ -77,8 +77,8 @@ class SerializeApplicationToTreeNodeMixin:
@staticmethod
def filter_organizations(applications):
organizations_id = set(applications.values_list('org_id', flat=True))
organizations = [Organization.get_instance(org_id) for org_id in organizations_id]
organization_ids = set(applications.values_list('org_id', flat=True))
organizations = [Organization.get_instance(org_id) for org_id in organization_ids]
return organizations
def serialize_applications_with_org(self, applications):

View File

@ -3,6 +3,7 @@ from django.utils.translation import ugettext_lazy as _
from orgs.mixins.models import OrgModelMixin
from common.mixins import CommonModelMixin
from assets.models import Asset
from .. import const
@ -35,3 +36,35 @@ class Application(CommonModelMixin, OrgModelMixin):
@property
def category_remote_app(self):
return self.category == const.ApplicationCategoryChoices.remote_app.value
def get_rdp_remote_app_setting(self):
from applications.serializers.attrs import get_serializer_class_by_application_type
if not self.category_remote_app:
raise ValueError(f"Not a remote app application: {self.name}")
serializer_class = get_serializer_class_by_application_type(self.type)
fields = serializer_class().get_fields()
parameters = [self.type]
for field_name in list(fields.keys()):
if field_name in ['asset']:
continue
value = self.attrs.get(field_name)
if not value:
continue
if field_name == 'path':
value = '\"%s\"' % value
parameters.append(str(value))
parameters = ' '.join(parameters)
return {
'program': '||jmservisor',
'working_directory': '',
'parameters': parameters
}
def get_remote_app_asset(self):
asset_id = self.attrs.get('asset')
if not asset_id:
raise ValueError("Remote App not has asset attr")
asset = Asset.objects.filter(id=asset_id).first()
return asset

View File

@ -27,31 +27,5 @@ class RemoteAppConnectionInfoSerializer(serializers.ModelSerializer):
return obj.attrs.get('asset')
@staticmethod
def get_parameters(obj):
"""
返回Guacamole需要的RemoteApp配置参数信息中的parameters参数
"""
from .attrs import get_serializer_class_by_application_type
serializer_class = get_serializer_class_by_application_type(obj.type)
fields = serializer_class().get_fields()
parameters = [obj.type]
for field_name in list(fields.keys()):
if field_name in ['asset']:
continue
value = obj.attrs.get(field_name)
if not value:
continue
if field_name == 'path':
value = '\"%s\"' % value
parameters.append(str(value))
parameters = ' '.join(parameters)
return parameters
def get_parameter_remote_app(self, obj):
return {
'program': '||jmservisor',
'working_directory': '',
'parameters': self.get_parameters(obj)
}
def get_parameter_remote_app(obj):
return obj.get_rdp_remote_app_setting()

View File

@ -33,6 +33,10 @@ class AdminUserViewSet(OrgBulkModelViewSet):
search_fields = filterset_fields
serializer_class = serializers.AdminUserSerializer
permission_classes = (IsOrgAdmin,)
serializer_classes = {
'default': serializers.AdminUserSerializer,
'retrieve': serializers.AdminUserDetailSerializer,
}
def get_queryset(self):
queryset = super().get_queryset()

View File

@ -3,8 +3,6 @@
from assets.api import FilterAssetByNodeMixin
from rest_framework.viewsets import ModelViewSet
from rest_framework.generics import RetrieveAPIView
from rest_framework.response import Response
from rest_framework import status
from django.shortcuts import get_object_or_404
from common.utils import get_logger, get_object_or_none

View File

@ -1,14 +1,15 @@
from typing import List
from common.utils.common import timeit
from assets.models import Node, Asset
from assets.pagination import AssetLimitOffsetPagination
from common.utils import lazyproperty, dict_get_any, is_uuid, get_object_or_none
from assets.pagination import NodeAssetTreePagination
from common.utils import lazyproperty
from assets.utils import get_node, is_query_node_all_assets
class SerializeToTreeNodeMixin:
permission_classes = ()
@timeit
def serialize_nodes(self, nodes: List[Node], with_asset_amount=False):
if with_asset_amount:
def _name(node: Node):
@ -45,6 +46,7 @@ class SerializeToTreeNodeMixin:
return platform
return default
@timeit
def serialize_assets(self, assets, node_key=None):
if node_key is None:
get_pid = lambda asset: getattr(asset, 'parent_key', '')
@ -79,7 +81,7 @@ class SerializeToTreeNodeMixin:
class FilterAssetByNodeMixin:
pagination_class = AssetLimitOffsetPagination
pagination_class = NodeAssetTreePagination
@lazyproperty
def is_query_node_all_assets(self):

View File

@ -8,7 +8,6 @@ from rest_framework.response import Response
from rest_framework.decorators import action
from django.utils.translation import ugettext_lazy as _
from django.shortcuts import get_object_or_404, Http404
from django.utils.decorators import method_decorator
from django.db.models.signals import m2m_changed
from common.const.http import POST
@ -17,17 +16,15 @@ from common.const.signals import PRE_REMOVE, POST_REMOVE
from assets.models import Asset
from common.utils import get_logger, get_object_or_none
from common.tree import TreeNodeSerializer
from common.const.distributed_lock_key import UPDATE_NODE_TREE_LOCK_KEY
from orgs.mixins.api import OrgModelViewSet
from orgs.mixins import generics
from orgs.lock import org_level_transaction_lock
from orgs.utils import current_org
from assets.tasks import check_node_assets_amount_task
from ..hands import IsOrgAdmin
from ..models import Node
from ..tasks import (
update_node_assets_hardware_info_manual,
test_node_assets_connectivity_manual,
check_node_assets_amount_task
)
from .. import serializers
from .mixin import SerializeToTreeNodeMixin
@ -50,17 +47,17 @@ class NodeViewSet(OrgModelViewSet):
permission_classes = (IsOrgAdmin,)
serializer_class = serializers.NodeSerializer
@action(methods=[POST], detail=False, url_name='launch-check-assets-amount-task')
def launch_check_assets_amount_task(self, request):
task = check_node_assets_amount_task.delay(current_org.id)
return Response(data={'task': task.id})
# 仅支持根节点指直接创建子节点下的节点需要通过children接口创建
def perform_create(self, serializer):
child_key = Node.org_root().get_next_child_key()
serializer.validated_data["key"] = child_key
serializer.save()
@action(methods=[POST], detail=False, url_path='check_assets_amount_task')
def check_assets_amount_task(self, request):
task = check_node_assets_amount_task.delay(current_org.id)
return Response(data={'task': task.id})
def perform_update(self, serializer):
node = self.get_object()
if node.is_org_root() and node.value != serializer.validated_data['value']:
@ -130,9 +127,13 @@ class NodeChildrenApi(generics.ListCreateAPIView):
def get_object(self):
pk = self.kwargs.get('pk') or self.request.query_params.get('id')
key = self.request.query_params.get("key")
if not pk and not key:
node = Node.org_root()
self.is_initial = True
if current_org.is_root():
node = None
else:
node = Node.org_root()
return node
if pk:
node = get_object_or_404(Node, pk=pk)
@ -140,16 +141,26 @@ class NodeChildrenApi(generics.ListCreateAPIView):
node = get_object_or_404(Node, key=key)
return node
def get_org_root_queryset(self, query_all):
if query_all:
return Node.objects.all()
else:
return Node.org_root_nodes()
def get_queryset(self):
query_all = self.request.query_params.get("all", "0") == "all"
if not self.instance:
return Node.objects.none()
if self.is_initial and current_org.is_root():
return self.get_org_root_queryset(query_all)
if self.is_initial:
with_self = True
else:
with_self = False
if not self.instance:
return Node.objects.none()
if query_all:
queryset = self.instance.get_all_children(with_self=with_self)
else:
@ -181,12 +192,12 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
def get_assets(self):
include_assets = self.request.query_params.get('assets', '0') == '1'
if not include_assets:
if not self.instance or not include_assets:
return []
assets = self.instance.get_assets().only(
"id", "hostname", "ip", "os",
"org_id", "protocols", "is_active"
)
"id", "hostname", "ip", "os", "platform_id",
"org_id", "protocols", "is_active",
).prefetch_related('platform')
return self.serialize_assets(assets, self.instance.key)
@ -212,15 +223,13 @@ class NodeAddChildrenApi(generics.UpdateAPIView):
def put(self, request, *args, **kwargs):
instance = self.get_object()
nodes_id = request.data.get("nodes")
children = Node.objects.filter(id__in=nodes_id)
node_ids = request.data.get("nodes")
children = Node.objects.filter(id__in=node_ids)
for node in children:
node.parent = instance
return Response("OK")
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class NodeAddAssetsApi(generics.UpdateAPIView):
model = Node
serializer_class = serializers.NodeAssetsSerializer
@ -233,8 +242,6 @@ class NodeAddAssetsApi(generics.UpdateAPIView):
instance.assets.add(*tuple(assets))
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class NodeRemoveAssetsApi(generics.UpdateAPIView):
model = Node
serializer_class = serializers.NodeAssetsSerializer
@ -247,12 +254,13 @@ class NodeRemoveAssetsApi(generics.UpdateAPIView):
node.assets.remove(*assets)
# 把孤儿资产添加到 root 节点
orphan_assets = Asset.objects.filter(id__in=[a.id for a in assets], nodes__isnull=True).distinct()
orphan_assets = Asset.objects.filter(
id__in=[a.id for a in assets],
nodes__isnull=True
).distinct()
Node.org_root().assets.add(*orphan_assets)
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='patch')
@method_decorator(org_level_transaction_lock(UPDATE_NODE_TREE_LOCK_KEY), name='put')
class MoveAssetsToNodeApi(generics.UpdateAPIView):
model = Node
serializer_class = serializers.NodeAssetsSerializer

View File

@ -87,13 +87,13 @@ class SystemUserTaskApi(generics.CreateAPIView):
permission_classes = (IsOrgAdmin,)
serializer_class = serializers.SystemUserTaskSerializer
def do_push(self, system_user, assets_id=None):
if assets_id is None:
def do_push(self, system_user, asset_ids=None):
if asset_ids is None:
task = push_system_user_to_assets_manual.delay(system_user)
else:
username = self.request.query_params.get('username')
task = push_system_user_to_assets.delay(
system_user.id, assets_id, username=username
system_user.id, asset_ids, username=username
)
return task
@ -114,9 +114,9 @@ class SystemUserTaskApi(generics.CreateAPIView):
system_user = self.get_object()
if action == 'push':
assets = [asset] if asset else assets
assets_id = [asset.id for asset in assets]
assets_id = assets_id if assets_id else None
task = self.do_push(system_user, assets_id)
asset_ids = [asset.id for asset in assets]
asset_ids = asset_ids if asset_ids else None
task = self.do_push(system_user, asset_ids)
else:
task = self.do_test(system_user)
data = getattr(serializer, '_data', {})

View File

@ -40,7 +40,7 @@ class BaseBackend:
return values
@staticmethod
def make_assets_as_id(assets):
def make_assets_as_ids(assets):
if not assets:
return []
if isinstance(assets[0], Asset):

View File

@ -69,9 +69,9 @@ class DBBackend(BaseBackend):
self.queryset = self.queryset.filter(union_id=union_id)
def _filter_assets(self, assets):
assets_id = self.make_assets_as_id(assets)
if assets_id:
self.queryset = self.queryset.filter(asset_id__in=assets_id)
asset_ids = self.make_assets_as_ids(assets)
if asset_ids:
self.queryset = self.queryset.filter(asset_id__in=asset_ids)
def _filter_node(self, node):
pass

20
apps/assets/locks.py Normal file
View File

@ -0,0 +1,20 @@
from orgs.utils import current_org
from common.utils.lock import DistributedLock
class NodeTreeUpdateLock(DistributedLock):
name_template = 'assets.node.tree.update.<org_id:{org_id}>'
def get_name(self):
if current_org:
org_id = current_org.id
else:
org_id = 'current_org_is_null'
name = self.name_template.format(
org_id=org_id
)
return name
def __init__(self):
name = self.get_name()
super().__init__(name=name, release_on_transaction_commit=True, reentrant=True)

View File

@ -0,0 +1,17 @@
# Generated by Django 3.1 on 2021-02-08 10:02
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('assets', '0065_auto_20210121_1549'),
]
operations = [
migrations.AlterModelOptions(
name='asset',
options={'ordering': ['hostname'], 'verbose_name': 'Asset'},
),
]

View File

@ -0,0 +1,48 @@
# Generated by Django 3.1 on 2021-03-11 03:13
import django.core.validators
from django.db import migrations, models
def migrate_cmd_filter_priority(apps, schema_editor):
cmd_filter_rule_model = apps.get_model('assets', 'CommandFilterRule')
cmd_filter_rules = cmd_filter_rule_model.objects.all()
for cmd_filter_rule in cmd_filter_rules:
cmd_filter_rule.priority = 100 - cmd_filter_rule.priority + 1
cmd_filter_rule_model.objects.bulk_update(cmd_filter_rules, fields=['priority'])
def migrate_system_user_priority(apps, schema_editor):
system_user_model = apps.get_model('assets', 'SystemUser')
system_users = system_user_model.objects.all()
for system_user in system_users:
system_user.priority = 100 - system_user.priority + 1
system_user_model.objects.bulk_update(system_users, fields=['priority'])
class Migration(migrations.Migration):
dependencies = [
('assets', '0066_auto_20210208_1802'),
]
operations = [
migrations.RunPython(migrate_cmd_filter_priority),
migrations.RunPython(migrate_system_user_priority),
migrations.AlterModelOptions(
name='commandfilterrule',
options={'ordering': ('priority', 'action'), 'verbose_name': 'Command filter rule'},
),
migrations.AlterField(
model_name='commandfilterrule',
name='priority',
field=models.IntegerField(default=50, help_text='1-100, the lower the value will be match first', validators=[django.core.validators.MinValueValidator(1), django.core.validators.MaxValueValidator(100)], verbose_name='Priority'),
),
migrations.AlterField(
model_name='systemuser',
name='priority',
field=models.IntegerField(default=20, help_text='1-100, the lower the value will be match first', validators=[django.core.validators.MinValueValidator(1), django.core.validators.MaxValueValidator(100)], verbose_name='Priority'),
),
]

View File

@ -17,7 +17,7 @@ from orgs.mixins.models import OrgModelMixin, OrgManager
from .base import ConnectivityMixin
from .utils import Connectivity
__all__ = ['Asset', 'ProtocolsMixin', 'Platform']
__all__ = ['Asset', 'ProtocolsMixin', 'Platform', 'AssetQuerySet']
logger = logging.getLogger(__name__)
@ -41,13 +41,6 @@ def default_node():
class AssetManager(OrgManager):
def get_queryset(self):
return super().get_queryset().annotate(
platform_base=models.F('platform__base')
)
class AssetOrgManager(OrgManager):
pass
@ -230,7 +223,6 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
comment = models.TextField(default='', blank=True, verbose_name=_('Comment'))
objects = AssetManager.from_queryset(AssetQuerySet)()
org_objects = AssetOrgManager.from_queryset(AssetQuerySet)()
_connectivity = None
def __str__(self):
@ -361,4 +353,4 @@ class Asset(ProtocolsMixin, NodesRelationMixin, OrgModelMixin):
class Meta:
unique_together = [('org_id', 'hostname')]
verbose_name = _("Asset")
ordering = ["hostname", "ip"]
ordering = ["hostname", ]

View File

@ -11,9 +11,12 @@ from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.conf import settings
from common.db.models import ChoiceSet
from common.utils import random_string
from common.utils import (
ssh_key_string_to_obj, ssh_key_gen, get_logger, lazyproperty
)
from common.utils.encode import ssh_pubkey_gen
from common.validators import alphanumeric
from common import fields
from orgs.mixins.models import OrgModelMixin
@ -105,6 +108,19 @@ class AuthMixin:
username = ''
_prefer = 'system_user'
@property
def ssh_key_fingerprint(self):
if self.public_key:
public_key = self.public_key
elif self.private_key:
public_key = ssh_pubkey_gen(self.private_key, self.password)
else:
return ''
public_key_obj = sshpubkeys.SSHKey(public_key)
fingerprint = public_key_obj.hash_md5()
return fingerprint
@property
def private_key_obj(self):
if self.private_key:
@ -204,8 +220,8 @@ class AuthMixin:
self.save()
@staticmethod
def gen_password():
return str(uuid.uuid4())
def gen_password(length=36):
return random_string(length, special_char=True)
@staticmethod
def gen_key(username):

View File

@ -50,7 +50,7 @@ class CommandFilterRule(OrgModelMixin):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
filter = models.ForeignKey('CommandFilter', on_delete=models.CASCADE, verbose_name=_("Filter"), related_name='rules')
type = models.CharField(max_length=16, default=TYPE_COMMAND, choices=TYPE_CHOICES, verbose_name=_("Type"))
priority = models.IntegerField(default=50, verbose_name=_("Priority"), help_text=_("1-100, the higher will be match first"),
priority = models.IntegerField(default=50, verbose_name=_("Priority"), help_text=_("1-100, the lower the value will be match first"),
validators=[MinValueValidator(1), MaxValueValidator(100)])
content = models.TextField(verbose_name=_("Content"), help_text=_("One line one command"))
action = models.IntegerField(default=ACTION_DENY, choices=ACTION_CHOICES, verbose_name=_("Action"))
@ -60,7 +60,7 @@ class CommandFilterRule(OrgModelMixin):
created_by = models.CharField(max_length=128, blank=True, default='', verbose_name=_('Created by'))
class Meta:
ordering = ('-priority', 'action')
ordering = ('priority', 'action')
verbose_name = _("Command filter rule")
@lazyproperty

View File

@ -16,17 +16,5 @@ class FavoriteAsset(CommonModelMixin):
unique_together = ('user', 'asset')
@classmethod
def get_user_favorite_assets_id(cls, user):
def get_user_favorite_asset_ids(cls, user):
return cls.objects.filter(user=user).values_list('asset', flat=True)
@classmethod
def get_user_favorite_assets(cls, user, asset_perms_id=None):
from assets.models import Asset
from perms.utils.asset.user_permission import get_user_granted_all_assets
asset_ids = get_user_granted_all_assets(
user,
via_mapping_node=False,
asset_perms_id=asset_perms_id
).values_list('id', flat=True)
query_name = cls.asset.field.related_query_name()
return Asset.org_objects.filter(**{f'{query_name}__user_id': user.id}, id__in=asset_ids).distinct()

View File

@ -1,23 +1,32 @@
# -*- coding: utf-8 -*-
#
import uuid
import re
import time
import uuid
import threading
import os
import time
import uuid
from collections import defaultdict
from django.db import models, transaction
from django.db.models import Q
from django.db.models import Q, Manager
from django.db.utils import IntegrityError
from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ugettext
from django.db.transaction import atomic
from django.core.cache import cache
from common.utils.lock import DistributedLock
from common.utils.common import timeit
from common.db.models import output_as_string
from common.utils import get_logger
from common.utils.common import lazyproperty
from orgs.mixins.models import OrgModelMixin, OrgManager
from orgs.utils import get_current_org, tmp_to_org
from orgs.models import Organization
__all__ = ['Node', 'FamilyMixin', 'compute_parent_key']
__all__ = ['Node', 'FamilyMixin', 'compute_parent_key', 'NodeQuerySet']
logger = get_logger(__name__)
@ -247,9 +256,147 @@ class FamilyMixin:
return [*tuple(ancestors), self, *tuple(children)]
class NodeAssetsMixin:
class NodeAllAssetsMappingMixin:
# Use a new plan
# { org_id: { node_key: [ asset1_id, asset2_id ] } }
orgid_nodekey_assetsid_mapping = defaultdict(dict)
locks_for_get_mapping_from_cache = defaultdict(threading.Lock)
@classmethod
def get_lock(cls, org_id):
lock = cls.locks_for_get_mapping_from_cache[str(org_id)]
return lock
@classmethod
def get_node_all_asset_ids_mapping(cls, org_id):
_mapping = cls.get_node_all_asset_ids_mapping_from_memory(org_id)
if _mapping:
return _mapping
logger.debug(f'Get node asset mapping from memory failed, acquire thread lock: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
with cls.get_lock(org_id):
logger.debug(f'Acquired thread lock ok. check if mapping is in memory now: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
_mapping = cls.get_node_all_asset_ids_mapping_from_memory(org_id)
if _mapping:
logger.debug(f'Mapping is already in memory now: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
return _mapping
_mapping = cls.get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(org_id)
cls.set_node_all_asset_ids_mapping_to_memory(org_id, mapping=_mapping)
return _mapping
# from memory
@classmethod
def get_node_all_asset_ids_mapping_from_memory(cls, org_id):
mapping = cls.orgid_nodekey_assetsid_mapping.get(org_id, {})
return mapping
@classmethod
def set_node_all_asset_ids_mapping_to_memory(cls, org_id, mapping):
cls.orgid_nodekey_assetsid_mapping[org_id] = mapping
@classmethod
def expire_node_all_asset_ids_mapping_from_memory(cls, org_id):
org_id = str(org_id)
cls.orgid_nodekey_assetsid_mapping.pop(org_id, None)
# get order: from memory -> (from cache -> to generate)
@classmethod
def get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(cls, org_id):
mapping = cls.get_node_all_asset_ids_mapping_from_cache(org_id)
if mapping:
return mapping
lock_key = f'KEY_LOCK_GENERATE_ORG_{org_id}_NODE_ALL_ASSET_ids_MAPPING'
with DistributedLock(lock_key):
# 这里使用无限期锁,原因是如果这里卡住了,就卡在数据库了,说明
# 数据库繁忙,所以不应该再有线程执行这个操作,使数据库忙上加忙
_mapping = cls.get_node_all_asset_ids_mapping_from_cache(org_id)
if _mapping:
return _mapping
_mapping = cls.generate_node_all_asset_ids_mapping(org_id)
cls.set_node_all_asset_ids_mapping_to_cache(org_id=org_id, mapping=_mapping)
return _mapping
@classmethod
def get_node_all_asset_ids_mapping_from_cache(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
mapping = cache.get(cache_key)
logger.info(f'Get node asset mapping from cache {bool(mapping)}: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
return mapping
@classmethod
def set_node_all_asset_ids_mapping_to_cache(cls, org_id, mapping):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.set(cache_key, mapping, timeout=None)
@classmethod
def expire_node_all_asset_ids_mapping_from_cache(cls, org_id):
cache_key = cls._get_cache_key_for_node_all_asset_ids_mapping(org_id)
cache.delete(cache_key)
@staticmethod
def _get_cache_key_for_node_all_asset_ids_mapping(org_id):
return 'ASSETS_ORG_NODE_ALL_ASSET_ids_MAPPING_{}'.format(org_id)
@classmethod
def generate_node_all_asset_ids_mapping(cls, org_id):
from .asset import Asset
logger.info(f'Generate node asset mapping: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
t1 = time.time()
with tmp_to_org(org_id):
node_ids_key = Node.objects.annotate(
char_id=output_as_string('id')
).values_list('char_id', 'key')
# * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢)
nodes_asset_ids = Asset.nodes.through.objects.all() \
.annotate(char_node_id=output_as_string('node_id')) \
.annotate(char_asset_id=output_as_string('asset_id')) \
.values_list('char_node_id', 'char_asset_id')
node_id_ancestor_keys_mapping = {
node_id: cls.get_node_ancestor_keys(node_key, with_self=True)
for node_id, node_key in node_ids_key
}
nodeid_assetsid_mapping = defaultdict(set)
for node_id, asset_id in nodes_asset_ids:
nodeid_assetsid_mapping[node_id].add(asset_id)
t2 = time.time()
mapping = defaultdict(set)
for node_id, node_key in node_ids_key:
asset_ids = nodeid_assetsid_mapping[node_id]
node_ancestor_keys = node_id_ancestor_keys_mapping[node_id]
for ancestor_key in node_ancestor_keys:
mapping[ancestor_key].update(asset_ids)
t3 = time.time()
logger.info('t1-t2(DB Query): {} s, t3-t2(Generate mapping): {} s'.format(t2-t1, t3-t2))
return mapping
class NodeAssetsMixin(NodeAllAssetsMappingMixin):
org_id: str
key = ''
id = None
objects: Manager
def get_all_assets(self):
from .asset import Asset
@ -263,8 +410,7 @@ class NodeAssetsMixin:
# 可是 startswith 会导致表关联时 Asset 索引失效
from .asset import Asset
node_ids = cls.objects.filter(
Q(key__startswith=f'{key}:') |
Q(key=key)
Q(key__startswith=f'{key}:') | Q(key=key)
).values_list('id', flat=True).distinct()
assets = Asset.objects.filter(
nodes__id__in=list(node_ids)
@ -283,29 +429,34 @@ class NodeAssetsMixin:
return self.get_all_assets().valid()
@classmethod
def get_nodes_all_assets_ids(cls, nodes_keys):
assets_ids = cls.get_nodes_all_assets(nodes_keys).values_list('id', flat=True)
return assets_ids
def get_nodes_all_asset_ids_by_keys(cls, nodes_keys):
nodes = Node.objects.filter(key__in=nodes_keys)
asset_ids = cls.get_nodes_all_assets(*nodes).values_list('id', flat=True)
return asset_ids
@classmethod
def get_nodes_all_assets(cls, nodes_keys, extra_assets_ids=None):
def get_nodes_all_assets(cls, *nodes):
from .asset import Asset
nodes_keys = cls.clean_children_keys(nodes_keys)
q = Q()
node_ids = ()
for key in nodes_keys:
q |= Q(key__startswith=f'{key}:')
q |= Q(key=key)
if q:
node_ids = Node.objects.filter(q).distinct().values_list('id', flat=True)
node_ids = set()
descendant_node_query = Q()
for n in nodes:
node_ids.add(n.id)
descendant_node_query |= Q(key__istartswith=f'{n.key}:')
if descendant_node_query:
_ids = Node.objects.order_by().filter(descendant_node_query).values_list('id', flat=True)
node_ids.update(_ids)
return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct()
q = Q(nodes__id__in=list(node_ids))
if extra_assets_ids:
q |= Q(id__in=extra_assets_ids)
if q:
return Asset.org_objects.filter(q).distinct()
else:
return Asset.objects.none()
def get_all_asset_ids(self):
asset_ids = self.get_all_asset_ids_by_node_key(org_id=self.org_id, node_key=self.key)
return set(asset_ids)
@classmethod
def get_all_asset_ids_by_node_key(cls, org_id, node_key):
org_id = str(org_id)
nodekey_assetsid_mapping = cls.get_node_all_asset_ids_mapping(org_id)
asset_ids = nodekey_assetsid_mapping.get(node_key, [])
return set(asset_ids)
class SomeNodesMixin:
@ -317,8 +468,9 @@ class SomeNodesMixin:
@classmethod
def default_node(cls):
with tmp_to_org(Organization.default()):
defaults = {'value': cls.default_value}
default_org = Organization.default()
with tmp_to_org(default_org):
defaults = {'value': default_org.name}
try:
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key,
@ -353,25 +505,40 @@ class SomeNodesMixin:
@classmethod
def create_org_root_node(cls):
# 如果使用current_org 在set_current_org时会死循环
ori_org = get_current_org()
with transaction.atomic():
if not ori_org.is_real():
return cls.default_node()
key = cls.get_next_org_root_node_key()
root = cls.objects.create(key=key, value=ori_org.name)
return root
@classmethod
def org_root(cls):
root = cls.objects.filter(parent_key='')\
.filter(key__regex=r'^[0-9]+$')\
.exclude(key__startswith='-')\
def org_root_nodes(cls):
nodes = cls.objects.filter(parent_key='') \
.filter(key__regex=r'^[0-9]+$') \
.exclude(key__startswith='-') \
.order_by('key')
if root:
return root[0]
return nodes
@classmethod
def org_root(cls):
# 如果使用current_org 在set_current_org时会死循环
ori_org = get_current_org()
if ori_org and ori_org.is_default():
return cls.default_node()
if ori_org and ori_org.is_root():
return None
org_roots = cls.org_root_nodes()
org_roots_length = len(org_roots)
if org_roots_length == 1:
return org_roots[0]
elif org_roots_length == 0:
root = cls.create_org_root_node()
return root
else:
return cls.create_org_root_node()
raise ValueError('Current org root node not 1, get {}'.format(org_roots_length))
@classmethod
def initial_some_nodes(cls):
@ -390,8 +557,9 @@ class SomeNodesMixin:
if not node_key1:
logger.info("Not found node that `key` = 1")
return
if not node_key1.org.is_real():
logger.info("Org is not real for node that `key` = 1")
if node_key1.org_id == '':
node_key1.org_id = str(Organization.default().id)
node_key1.save()
return
with transaction.atomic():

View File

@ -116,7 +116,7 @@ class SystemUser(BaseUser):
assets = models.ManyToManyField('assets.Asset', blank=True, verbose_name=_("Assets"))
users = models.ManyToManyField('users.User', blank=True, verbose_name=_("Users"))
groups = models.ManyToManyField('users.UserGroup', blank=True, verbose_name=_("User groups"))
priority = models.IntegerField(default=20, verbose_name=_("Priority"), validators=[MinValueValidator(1), MaxValueValidator(100)])
priority = models.IntegerField(default=20, verbose_name=_("Priority"), help_text=_("1-100, the lower the value will be match first"), validators=[MinValueValidator(1), MaxValueValidator(100)])
protocol = models.CharField(max_length=16, choices=PROTOCOL_CHOICES, default='ssh', verbose_name=_('Protocol'))
auto_push = models.BooleanField(default=True, verbose_name=_('Auto push'))
sudo = models.TextField(default='/bin/whoami', verbose_name=_('Sudo'))
@ -198,10 +198,10 @@ class SystemUser(BaseUser):
def get_all_assets(self):
from assets.models import Node
nodes_keys = self.nodes.all().values_list('key', flat=True)
assets_ids = set(self.assets.all().values_list('id', flat=True))
nodes_assets_ids = Node.get_nodes_all_assets_ids(nodes_keys)
assets_ids.update(nodes_assets_ids)
assets = Asset.objects.filter(id__in=assets_ids)
asset_ids = set(self.assets.all().values_list('id', flat=True))
nodes_asset_ids = Node.get_nodes_all_asset_ids_by_keys(nodes_keys)
asset_ids.update(nodes_asset_ids)
assets = Asset.objects.filter(id__in=asset_ids)
return assets
@classmethod

View File

@ -1,39 +1,52 @@
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.request import Request
from common.utils import get_logger
from assets.models import Node
logger = get_logger(__name__)
class AssetPaginationBase(LimitOffsetPagination):
def init_attrs(self, queryset, request: Request, view=None):
self._request = request
self._view = view
self._user = request.user
def paginate_queryset(self, queryset, request: Request, view=None):
self.init_attrs(queryset, request, view)
return super().paginate_queryset(queryset, request, view=None)
class AssetLimitOffsetPagination(LimitOffsetPagination):
"""
需要与 `assets.api.mixin.FilterAssetByNodeMixin` 配合使用
"""
def get_count(self, queryset):
"""
1. 如果查询节点下的所有资产 count 使用 Node.assets_amount
2. 如果有其他过滤条件使用 super
3. 如果只查询该节点下的资产使用 super
"""
exclude_query_params = {
self.limit_query_param,
self.offset_query_param,
'node', 'all', 'show_current_asset',
'node_id', 'display', 'draw', 'fields_size',
'key', 'all', 'show_current_asset',
'cache_policy', 'display', 'draw',
'order', 'node', 'node_id', 'fields_size',
}
for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None:
logger.warn(f'Not hit node.assets_amount because find a unknow query_param `{k}` -> {self._request.get_full_path()}')
return super().get_count(queryset)
node_assets_count = self.get_count_from_nodes(queryset)
if node_assets_count is None:
return super().get_count(queryset)
return node_assets_count
def get_count_from_nodes(self, queryset):
raise NotImplementedError
class NodeAssetTreePagination(AssetPaginationBase):
def get_count_from_nodes(self, queryset):
is_query_all = self._view.is_query_node_all_assets
if is_query_all:
node = self._view.node
if not node:
node = Node.org_root()
return node.assets_amount
return super().get_count(queryset)
def paginate_queryset(self, queryset, request: Request, view=None):
self._request = request
self._view = view
return super().paginate_queryset(queryset, request, view=None)
if node:
logger.debug(f'Hit node.assets_amount[{node.assets_amount}] -> {self._request.get_full_path()}')
return node.assets_amount
return None

View File

@ -3,8 +3,6 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from common.drf.serializers import AdaptedBulkListSerializer
from ..models import Node, AdminUser
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
@ -17,7 +15,6 @@ class AdminUserSerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer):
"""
class Meta:
list_serializer_class = AdaptedBulkListSerializer
model = AdminUser
fields = [
'id', 'name', 'username', 'password', 'private_key', 'public_key',
@ -33,6 +30,11 @@ class AdminUserSerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer):
}
class AdminUserDetailSerializer(AdminUserSerializer):
class Meta(AdminUserSerializer.Meta):
fields = AdminUserSerializer.Meta.fields + ['ssh_key_fingerprint']
class AdminUserAuthSerializer(AuthSerializer):
class Meta:

View File

@ -111,7 +111,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
@classmethod
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
queryset = queryset.select_related('admin_user', 'domain', 'platform')
queryset = queryset.prefetch_related('admin_user', 'domain', 'platform')
queryset = queryset.prefetch_related('nodes', 'labels')
return queryset
@ -166,16 +166,9 @@ class AssetDisplaySerializer(AssetSerializer):
'connectivity',
]
@classmethod
def setup_eager_loading(cls, queryset):
queryset = super().setup_eager_loading(queryset)
queryset = queryset\
.annotate(admin_user_username=F('admin_user__username'))
return queryset
class PlatformSerializer(serializers.ModelSerializer):
meta = serializers.DictField(required=False, allow_null=True)
meta = serializers.DictField(required=False, allow_null=True, label=_('Meta'))
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@ -41,10 +41,6 @@ class AuthSerializerMixin:
def validate_private_key(self, private_key):
if not private_key:
return
if 'OPENSSH' in private_key:
msg = _("Not support openssh format key, using "
"ssh-keygen -t rsa -m pem to generate")
raise serializers.ValidationError(msg)
password = self.initial_data.get("password")
valid = validate_ssh_private_key(private_key, password)
if not valid:

View File

@ -77,8 +77,6 @@ class GatewayWithAuthSerializer(GatewaySerializer):
return fields
class DomainWithGatewaySerializer(BulkOrgResourceModelSerializer):
gateways = GatewayWithAuthSerializer(many=True, read_only=True)

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
#
from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
from common.drf.serializers import AdaptedBulkListSerializer
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
@ -9,16 +10,17 @@ from ..models import Label
class LabelSerializer(BulkOrgResourceModelSerializer):
asset_count = serializers.SerializerMethodField()
asset_count = serializers.SerializerMethodField(label=_("Assets amount"))
category_display = serializers.ReadOnlyField(source='get_category_display', label=_('Category display'))
class Meta:
model = Label
fields = [
'id', 'name', 'value', 'category', 'is_active', 'comment',
'date_created', 'asset_count', 'assets', 'get_category_display'
'date_created', 'asset_count', 'assets', 'category_display'
]
read_only_fields = (
'category', 'date_created', 'asset_count', 'get_category_display'
'category', 'date_created', 'asset_count',
)
extra_kwargs = {
'assets': {'required': False}

View File

@ -33,7 +33,7 @@ class SystemUserSerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer):
'priority', 'username_same_with_user',
'auto_push', 'cmd_filters', 'sudo', 'shell', 'comment',
'auto_generate_key', 'sftp_root', 'token',
'assets_amount', 'date_created', 'created_by',
'assets_amount', 'date_created', 'date_updated', 'created_by',
'home', 'system_groups', 'ad_domain'
]
extra_kwargs = {
@ -155,7 +155,8 @@ class SystemUserListSerializer(SystemUserSerializer):
'auto_push', 'sudo', 'shell', 'comment',
"assets_amount", 'home', 'system_groups',
'auto_generate_key', 'ad_domain',
'sftp_root',
'sftp_root', 'created_by', 'date_created',
'date_updated',
]
extra_kwargs = {
'password': {"write_only": True},

View File

@ -0,0 +1,3 @@
from .common import *
from .node_assets_amount import *
from .node_assets_mapping import *

View File

@ -1,21 +1,17 @@
# -*- coding: utf-8 -*-
#
from operator import add, sub
from assets.utils import is_asset_exists_in_node
from django.db.models.signals import (
post_save, m2m_changed, pre_delete, post_delete, pre_save
)
from django.db.models import Q, F
from django.dispatch import receiver
from common.exceptions import M2MReverseNotAllowed
from common.const.signals import PRE_ADD, POST_ADD, POST_REMOVE, PRE_CLEAR, PRE_REMOVE
from common.const.signals import POST_ADD, POST_REMOVE, PRE_REMOVE
from common.utils import get_logger
from common.decorator import on_transaction_commit
from .models import Asset, SystemUser, Node, compute_parent_key
from assets.models import Asset, SystemUser, Node
from users.models import User
from .tasks import (
from assets.tasks import (
update_assets_hardware_info_util,
test_asset_connectivity_util,
push_system_user_to_assets_manual,
@ -23,7 +19,6 @@ from .tasks import (
add_nodes_assets_to_system_users
)
logger = get_logger(__file__)
@ -87,13 +82,13 @@ def on_system_user_assets_change(instance, action, model, pk_set, **kwargs):
return
logger.debug("System user assets change signal recv: {}".format(instance))
if model == Asset:
system_users_id = [instance.id]
assets_id = pk_set
system_user_ids = [instance.id]
asset_ids = pk_set
else:
system_users_id = pk_set
assets_id = [instance.id]
for system_user_id in system_users_id:
push_system_user_to_assets.delay(system_user_id, assets_id)
system_user_ids = pk_set
asset_ids = [instance.id]
for system_user_id in system_user_ids:
push_system_user_to_assets.delay(system_user_id, asset_ids)
@receiver(m2m_changed, sender=SystemUser.users.through)
@ -202,134 +197,6 @@ def on_asset_nodes_add(instance, action, reverse, pk_set, **kwargs):
m2m_model.objects.bulk_create(to_create)
def _update_node_assets_amount(node: Node, asset_pk_set: set, operator=add):
"""
一个节点与多个资产关系变化时更新计数
:param node: 节点实例
:param asset_pk_set: 资产的`id`集合, 内部不会修改该值
:param operator: 操作
* -> Node
# -> Asset
* [3]
/ \
* * [2]
/ \
* * [1]
/ / \
* [a] # # [b]
"""
# 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key`
ancestor_keys = node.get_ancestor_keys(with_self=True)
ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key')
to_update = []
for ancestor in ancestors:
# 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3]
# 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的
# 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量
asset_pk_set -= set(Asset.objects.filter(
id__in=asset_pk_set
).filter(
Q(nodes__key__istartswith=f'{ancestor.key}:') |
Q(nodes__key=ancestor.key)
).distinct().values_list('id', flat=True))
if not asset_pk_set:
# 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量
# 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用
# 处理了
break
ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set))
to_update.append(ancestor)
Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key'))
def _remove_ancestor_keys(ancestor_key, tree_set):
# 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while ancestor_key and ancestor_key in tree_set:
tree_set.remove(ancestor_key)
ancestor_key = compute_parent_key(ancestor_key)
def _update_nodes_asset_amount(node_keys, asset_pk, operator):
"""
一个资产与多个节点关系变化时更新计数
:param node_keys: 节点 id 的集合
:param asset_pk: 资产 id
:param operator: 操作
"""
# 所有相关节点的祖先节点,组成一棵局部树
ancestor_keys = set()
for key in node_keys:
ancestor_keys.update(Node.get_node_ancestor_keys(key))
# 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉
node_keys -= ancestor_keys
to_update_keys = []
for key in node_keys:
# 遍历相关节点,处理它及其祖先节点
# 查询该节点是否包含待处理资产
exists = is_asset_exists_in_node(asset_pk, key)
parent_key = compute_parent_key(key)
if exists:
# 如果资产在该节点,那么他及其祖先节点都不用处理
_remove_ancestor_keys(parent_key, ancestor_keys)
continue
else:
# 不存在,要更新本节点
to_update_keys.append(key)
# 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while parent_key and parent_key in ancestor_keys:
exists = is_asset_exists_in_node(asset_pk, parent_key)
if exists:
_remove_ancestor_keys(parent_key, ancestor_keys)
break
else:
to_update_keys.append(parent_key)
ancestor_keys.remove(parent_key)
parent_key = compute_parent_key(parent_key)
Node.objects.filter(key__in=to_update_keys).update(
assets_amount=operator(F('assets_amount'), 1)
)
@receiver(m2m_changed, sender=Asset.nodes.through)
def update_nodes_assets_amount(action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
refused = (PRE_CLEAR,)
if action in refused:
raise ValueError
mapper = {
PRE_ADD: add,
POST_REMOVE: sub
}
if action not in mapper:
return
operator = mapper[action]
if reverse:
node: Node = instance
asset_pk_set = set(pk_set)
_update_node_assets_amount(node, asset_pk_set, operator)
else:
asset_pk = instance.id
# 与资产直接关联的节点
node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True))
_update_nodes_asset_amount(node_keys, asset_pk, operator)
RELATED_NODE_IDS = '_related_node_ids'

View File

@ -0,0 +1,159 @@
# -*- coding: utf-8 -*-
#
from operator import add, sub
from django.db.models import Q, F
from django.dispatch import receiver
from django.db.models.signals import (
m2m_changed
)
from orgs.utils import ensure_in_real_or_default_org
from common.const.signals import PRE_ADD, POST_REMOVE, PRE_CLEAR
from common.utils import get_logger
from assets.models import Asset, Node, compute_parent_key
from assets.locks import NodeTreeUpdateLock
logger = get_logger(__file__)
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
refused = (PRE_CLEAR,)
if action in refused:
raise ValueError
mapper = {
PRE_ADD: add,
POST_REMOVE: sub
}
if action not in mapper:
return
operator = mapper[action]
if reverse:
node: Node = instance
asset_pk_set = set(pk_set)
NodeAssetsAmountUtils.update_node_assets_amount(node, asset_pk_set, operator)
else:
asset_pk = instance.id
# 与资产直接关联的节点
node_keys = set(Node.objects.filter(id__in=pk_set).values_list('key', flat=True))
NodeAssetsAmountUtils.update_nodes_asset_amount(node_keys, asset_pk, operator)
class NodeAssetsAmountUtils:
@classmethod
def _remove_ancestor_keys(cls, ancestor_key, tree_set):
# 这里判断 `ancestor_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while ancestor_key and ancestor_key in tree_set:
tree_set.remove(ancestor_key)
ancestor_key = compute_parent_key(ancestor_key)
@classmethod
def _is_asset_exists_in_node(cls, asset_pk, node_key):
exists = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key)
).filter(id=asset_pk).exists()
return exists
@classmethod
@ensure_in_real_or_default_org
@NodeTreeUpdateLock()
def update_nodes_asset_amount(cls, node_keys, asset_pk, operator):
"""
一个资产与多个节点关系变化时更新计数
:param node_keys: 节点 id 的集合
:param asset_pk: 资产 id
:param operator: 操作
"""
# 所有相关节点的祖先节点,组成一棵局部树
ancestor_keys = set()
for key in node_keys:
ancestor_keys.update(Node.get_node_ancestor_keys(key))
# 相关节点可能是其他相关节点的祖先节点,如果是从相关节点里干掉
node_keys -= ancestor_keys
to_update_keys = []
for key in node_keys:
# 遍历相关节点,处理它及其祖先节点
# 查询该节点是否包含待处理资产
exists = cls._is_asset_exists_in_node(asset_pk, key)
parent_key = compute_parent_key(key)
if exists:
# 如果资产在该节点,那么他及其祖先节点都不用处理
cls._remove_ancestor_keys(parent_key, ancestor_keys)
continue
else:
# 不存在,要更新本节点
to_update_keys.append(key)
# 这里判断 `parent_key` 不能是空,防止数据错误导致的死循环
# 判断是否在集合里,来区分是否已被处理过
while parent_key and parent_key in ancestor_keys:
exists = cls._is_asset_exists_in_node(asset_pk, parent_key)
if exists:
cls._remove_ancestor_keys(parent_key, ancestor_keys)
break
else:
to_update_keys.append(parent_key)
ancestor_keys.remove(parent_key)
parent_key = compute_parent_key(parent_key)
Node.objects.filter(key__in=to_update_keys).update(
assets_amount=operator(F('assets_amount'), 1)
)
@classmethod
@ensure_in_real_or_default_org
@NodeTreeUpdateLock()
def update_node_assets_amount(cls, node: Node, asset_pk_set: set, operator=add):
"""
一个节点与多个资产关系变化时更新计数
:param node: 节点实例
:param asset_pk_set: 资产的`id`集合, 内部不会修改该值
:param operator: 操作
* -> Node
# -> Asset
* [3]
/ \
* * [2]
/ \
* * [1]
/ / \
* [a] # # [b]
"""
# 获取节点[1]祖先节点的 `key` 含自己,也就是[1, 2, 3]节点的`key`
ancestor_keys = node.get_ancestor_keys(with_self=True)
ancestors = Node.objects.filter(key__in=ancestor_keys).order_by('-key')
to_update = []
for ancestor in ancestors:
# 迭代祖先节点的`key`,顺序是 [1] -> [2] -> [3]
# 查询该节点及其后代节点是否包含要操作的资产,将包含的从要操作的
# 资产集合中去掉,他们是重复节点,无论增加或删除都不会影响节点的资产数量
asset_pk_set -= set(Asset.objects.filter(
id__in=asset_pk_set
).filter(
Q(nodes__key__istartswith=f'{ancestor.key}:') |
Q(nodes__key=ancestor.key)
).distinct().values_list('id', flat=True))
if not asset_pk_set:
# 要操作的资产集合为空,说明都是重复资产,不用改变节点资产数量
# 而且既然它包含了,它的祖先节点肯定也包含了,所以祖先节点都不用
# 处理了
break
ancestor.assets_amount = operator(F('assets_amount'), len(asset_pk_set))
to_update.append(ancestor)
Node.objects.bulk_update(to_update, fields=('assets_amount', 'parent_key'))

View File

@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
#
import os
import threading
from django.db.models.signals import (
m2m_changed, post_save, post_delete
)
from django.dispatch import receiver
from django.utils.functional import LazyObject
from common.signals import django_ready
from common.utils.connection import RedisPubSub
from common.utils import get_logger
from assets.models import Asset, Node
logger = get_logger(__file__)
# clear node assets mapping for memory
# ------------------------------------
def get_node_assets_mapping_for_memory_pub_sub():
return RedisPubSub('fm.node_all_asset_ids_memory_mapping')
class NodeAssetsMappingForMemoryPubSub(LazyObject):
def _setup(self):
self._wrapped = get_node_assets_mapping_for_memory_pub_sub()
node_assets_mapping_for_memory_pub_sub = NodeAssetsMappingForMemoryPubSub()
def expire_node_assets_mapping_for_memory(org_id):
# 所有进程清除(自己的 memory 数据)
org_id = str(org_id)
node_assets_mapping_for_memory_pub_sub.publish(org_id)
# 当前进程清除(cache 数据)
logger.debug(
"Expire node assets id mapping from cache of org={}, pid={}"
"".format(org_id, os.getpid())
)
Node.expire_node_all_asset_ids_mapping_from_cache(org_id)
@receiver(post_save, sender=Node)
def on_node_post_create(sender, instance, created, update_fields, **kwargs):
if created:
need_expire = True
elif update_fields and 'key' in update_fields:
need_expire = True
else:
need_expire = False
if need_expire:
expire_node_assets_mapping_for_memory(instance.org_id)
@receiver(post_delete, sender=Node)
def on_node_post_delete(sender, instance, **kwargs):
expire_node_assets_mapping_for_memory(instance.org_id)
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, instance, **kwargs):
expire_node_assets_mapping_for_memory(instance.org_id)
@receiver(django_ready)
def subscribe_node_assets_mapping_expire(sender, **kwargs):
logger.debug("Start subscribe for expire node assets id mapping from memory")
def keep_subscribe():
subscribe = node_assets_mapping_for_memory_pub_sub.subscribe()
for message in subscribe.listen():
if message["type"] != "message":
continue
org_id = message['data'].decode()
Node.expire_node_all_asset_ids_mapping_from_memory(org_id)
logger.debug(
"Expire node assets id mapping from memory of org={}, pid={}"
"".format(str(org_id), os.getpid())
)
t = threading.Thread(target=keep_subscribe)
t.daemon = True
t.start()

View File

@ -12,6 +12,7 @@ __all__ = ['add_nodes_assets_to_system_users']
@tmp_to_root_org()
def add_nodes_assets_to_system_users(nodes_keys, system_users):
from ..models import Node
assets = Node.get_nodes_all_assets(nodes_keys).values_list('id', flat=True)
nodes = Node.objects.filter(key__in=nodes_keys)
assets = Node.get_nodes_all_assets(*nodes)
for system_user in system_users:
system_user.assets.add(*tuple(assets))

View File

@ -141,7 +141,8 @@ def gather_asset_users(assets, task_name=None):
@shared_task(queue="ansible")
def gather_nodes_asset_users(nodes_key):
assets = Node.get_nodes_all_assets(nodes_key)
nodes = Node.objects.filter(key__in=nodes_key)
assets = Node.get_nodes_all_assets(*nodes)
assets_groups_by_100 = [assets[i:i+100] for i in range(0, len(assets), 100)]
for _assets in assets_groups_by_100:
gather_asset_users(_assets)

View File

@ -12,16 +12,24 @@ from common.utils import get_logger
logger = get_logger(__file__)
@shared_task(queue='celery_heavy_tasks')
def check_node_assets_amount_task(org_id=Organization.ROOT_ID):
try:
with tmp_to_org(Organization.get_instance(org_id)):
check_node_assets_amount()
except AcquireFailed:
logger.error(_('The task of self-checking is already running and cannot be started repeatedly'))
@shared_task
def check_node_assets_amount_task(org_id=None):
if org_id is None:
orgs = Organization.objects.all()
else:
orgs = [Organization.get_instance(org_id)]
for org in orgs:
try:
with tmp_to_org(org):
check_node_assets_amount()
except AcquireFailed:
error = _('The task of self-checking is already running '
'and cannot be started repeatedly')
logger.error(error)
@register_as_period_task(crontab='0 2 * * *')
@shared_task(queue='celery_heavy_tasks')
@shared_task
def check_node_assets_amount_period_task():
check_node_assets_amount_task()

View File

@ -32,11 +32,19 @@ def _dump_args(args: dict):
def get_push_unixlike_system_user_tasks(system_user, username=None):
comment = system_user.name
if username is None:
username = system_user.username
if system_user.username_same_with_user:
from users.models import User
user = User.objects.filter(username=username).only('name', 'username').first()
if user:
comment = f'{system_user.name}[{str(user)}]'
password = system_user.password
public_key = system_user.public_key
comment = system_user.name
groups = _split_by_comma(system_user.system_groups)
@ -225,18 +233,18 @@ def push_system_user_util(system_user, assets, task_name, username=None):
print(_("Hosts count: {}").format(len(_assets)))
id_asset_map = {_asset.id: _asset for _asset in _assets}
assets_id = id_asset_map.keys()
asset_ids = id_asset_map.keys()
no_special_auth = []
special_auth_set = set()
auth_books = AuthBook.objects.filter(username__in=usernames, asset_id__in=assets_id)
auth_books = AuthBook.objects.filter(username__in=usernames, asset_id__in=asset_ids)
for auth_book in auth_books:
special_auth_set.add((auth_book.username, auth_book.asset_id))
for _username in usernames:
no_special_assets = []
for asset_id in assets_id:
for asset_id in asset_ids:
if (_username, asset_id) not in special_auth_set:
no_special_assets.append(id_asset_map[asset_id])
if no_special_assets:
@ -281,12 +289,12 @@ def push_system_user_a_asset_manual(system_user, asset, username=None):
@shared_task(queue="ansible")
@tmp_to_root_org()
def push_system_user_to_assets(system_user_id, assets_id, username=None):
def push_system_user_to_assets(system_user_id, asset_ids, username=None):
"""
推送系统用户到指定的若干资产上
"""
system_user = SystemUser.objects.get(id=system_user_id)
assets = get_objects(Asset, assets_id)
assets = get_objects(Asset, asset_ids)
task_name = _("Push system users to assets: {}").format(system_user.name)
return push_system_user_util(system_user, assets, task_name, username=username)

33
apps/assets/tests/tree.py Normal file
View File

@ -0,0 +1,33 @@
from assets.tree import Tree
def test():
from orgs.models import Organization
from assets.models import Node, Asset
import time
Organization.objects.get(id='1863cf22-f666-474e-94aa-935fe175203c').change_to()
t1 = time.time()
nodes = list(Node.objects.exclude(key__startswith='-').only('id', 'key', 'parent_key'))
node_asset_id_pairs = Asset.nodes.through.objects.all().values_list('node_id', 'asset_id')
t2 = time.time()
node_asset_id_pairs = list(node_asset_id_pairs)
tree = Tree(nodes, node_asset_id_pairs)
tree.build_tree()
tree.nodes = None
tree.node_asset_id_pairs = None
import pickle
d = pickle.dumps(tree)
print('------------', len(d))
return tree
tree.compute_tree_node_assets_amount()
print(f'校对算法准确性 ......')
for node in nodes:
tree_node = tree.key_tree_node_mapper[node.key]
if tree_node.assets_amount != node.assets_amount:
print(f'ERROR: {tree_node.assets_amount} {node.assets_amount}')
# print(f'OK {tree_node.asset_amount} {node.assets_amount}')
print(f'数据库时间: {t2 - t1}')
return tree

View File

@ -2,7 +2,6 @@
from django.urls import path, re_path
from rest_framework_nested import routers
from rest_framework_bulk.routes import BulkRouter
from django.db.transaction import non_atomic_requests
from common import api as capi
@ -57,9 +56,9 @@ urlpatterns = [
path('nodes/children/', api.NodeChildrenApi.as_view(), name='node-children-2'),
path('nodes/<uuid:pk>/children/add/', api.NodeAddChildrenApi.as_view(), name='node-add-children'),
path('nodes/<uuid:pk>/assets/', api.NodeAssetsApi.as_view(), name='node-assets'),
path('nodes/<uuid:pk>/assets/add/', non_atomic_requests(api.NodeAddAssetsApi.as_view()), name='node-add-assets'),
path('nodes/<uuid:pk>/assets/replace/', non_atomic_requests(api.MoveAssetsToNodeApi.as_view()), name='node-replace-assets'),
path('nodes/<uuid:pk>/assets/remove/', non_atomic_requests(api.NodeRemoveAssetsApi.as_view()), name='node-remove-assets'),
path('nodes/<uuid:pk>/assets/add/', api.NodeAddAssetsApi.as_view(), name='node-add-assets'),
path('nodes/<uuid:pk>/assets/replace/', api.MoveAssetsToNodeApi.as_view(), name='node-replace-assets'),
path('nodes/<uuid:pk>/assets/remove/', api.NodeRemoveAssetsApi.as_view(), name='node-remove-assets'),
path('nodes/<uuid:pk>/tasks/', api.NodeTaskCreateApi.as_view(), name='node-task-create'),
path('gateways/<uuid:pk>/test-connective/', api.GatewayTestConnectionApi.as_view(), name='test-gateway-connective'),

View File

@ -1,41 +1,47 @@
# ~*~ coding: utf-8 ~*~
#
import time
from django.db.models import Q
from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none
from common.utils.lock import DistributedLock
from collections import defaultdict
from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none, timeit
from common.http import is_true
from .models import Asset, Node
from common.struct import Stack
from common.db.models import output_as_string
from orgs.utils import ensure_in_real_or_default_org, current_org
from .locks import NodeTreeUpdateLock
from .models import Node, Asset
logger = get_logger(__file__)
@DistributedLock(name="assets.node.check_node_assets_amount", blocking=False)
@NodeTreeUpdateLock()
@ensure_in_real_or_default_org
def check_node_assets_amount():
for node in Node.objects.all():
logger.info(f'Check node assets amount: {node}')
assets_amount = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes=node)
).distinct().count()
logger.info(f'Check node assets amount {current_org}')
nodes = list(Node.objects.all().only('id', 'key', 'assets_amount'))
nodeid_assetid_pairs = list(Asset.nodes.through.objects.all().values_list('node_id', 'asset_id'))
nodekey_assetids_mapper = defaultdict(set)
nodeid_nodekey_mapper = {}
for node in nodes:
nodeid_nodekey_mapper[node.id] = node.key
for nodeid, assetid in nodeid_assetid_pairs:
if nodeid not in nodeid_nodekey_mapper:
continue
nodekey = nodeid_nodekey_mapper[nodeid]
nodekey_assetids_mapper[nodekey].add(assetid)
util = NodeAssetsUtil(nodes, nodekey_assetids_mapper)
util.generate()
to_updates = []
for node in nodes:
assets_amount = util.get_assets_amount(node.key)
if node.assets_amount != assets_amount:
logger.warn(f'Node wrong assets amount <Node:{node.key}> '
f'{node.assets_amount} right is {assets_amount}')
logger.error(f'Node[{node.key}] assets amount error {node.assets_amount} != {assets_amount}')
node.assets_amount = assets_amount
node.save()
# 防止自检程序给数据库的压力太大
time.sleep(0.1)
def is_asset_exists_in_node(asset_pk, node_key):
return Asset.objects.filter(
id=asset_pk
).filter(
Q(nodes__key__istartswith=f'{node_key}:') | Q(nodes__key=node_key)
).exists()
to_updates.append(node)
Node.objects.bulk_update(to_updates, fields=('assets_amount',))
def is_query_node_all_assets(request):
@ -57,3 +63,77 @@ def get_node(request):
else:
node = get_object_or_none(Node, key=node_id)
return node
class NodeAssetsInfo:
__slots__ = ('key', 'assets_amount', 'assets')
def __init__(self, key, assets_amount, assets):
self.key = key
self.assets_amount = assets_amount
self.assets = assets
def __str__(self):
return self.key
class NodeAssetsUtil:
def __init__(self, nodes, nodekey_assetsid_mapper):
"""
:param nodes: 节点
:param nodekey_assetsid_mapper: 节点直接资产id的映射 {"key1": set(), "key2": set()}
"""
self.nodes = nodes
# node_id --> set(asset_id1, asset_id2)
self.nodekey_assetsid_mapper = nodekey_assetsid_mapper
self.nodekey_assetsinfo_mapper = {}
@timeit
def generate(self):
# 准备排序好的资产信息数据
infos = []
for node in self.nodes:
assets = self.nodekey_assetsid_mapper.get(node.key, set())
info = NodeAssetsInfo(key=node.key, assets_amount=0, assets=assets)
infos.append(info)
infos = sorted(infos, key=lambda i: [int(i) for i in i.key.split(':')])
# 这个守卫需要添加一下,避免最后一个无法出栈
guarder = NodeAssetsInfo(key='', assets_amount=0, assets=set())
infos.append(guarder)
stack = Stack()
for info in infos:
# 如果栈顶的不是这个节点的父祖节点,那么可以出栈了,可以计算资产数量了
while stack.top and not info.key.startswith(f'{stack.top.key}:'):
pop_info = stack.pop()
pop_info.assets_amount = len(pop_info.assets)
self.nodekey_assetsinfo_mapper[pop_info.key] = pop_info
if not stack.top:
continue
stack.top.assets.update(pop_info.assets)
stack.push(info)
def get_assets_by_key(self, key):
info = self.nodekey_assetsinfo_mapper[key]
return info['assets']
def get_assets_amount(self, key):
info = self.nodekey_assetsinfo_mapper[key]
return info.assets_amount
@classmethod
def test_it(cls):
from assets.models import Node, Asset
nodes = list(Node.objects.all())
nodes_assets = Asset.nodes.through.objects.all()\
.annotate(aid=output_as_string('asset_id'))\
.values_list('node__key', 'aid')
mapping = defaultdict(set)
for key, asset_id in nodes_assets:
mapping[key].add(asset_id)
util = cls(nodes, mapping)
util.generate()
return util

View File

@ -1,10 +1,11 @@
# -*- coding: utf-8 -*-
#
from django.db.models.signals import post_save, post_delete
from django.dispatch import receiver
from django.conf import settings
from django.db import transaction
from django.utils import timezone
from django.utils.functional import LazyObject
from django.contrib.auth import BACKEND_SESSION_KEY
from django.utils.translation import ugettext_lazy as _
from rest_framework.renderers import JSONRenderer
@ -34,17 +35,22 @@ MODELS_NEED_RECORD = (
)
LOGIN_BACKEND = {
'PublicKeyAuthBackend': _('SSH Key'),
'RadiusBackend': User.Source.radius.label,
'RadiusRealmBackend': User.Source.radius.label,
'LDAPAuthorizationBackend': User.Source.ldap.label,
'ModelBackend': _('Password'),
'SSOAuthentication': _('SSO'),
'CASBackend': User.Source.cas.label,
'OIDCAuthCodeBackend': User.Source.openid.label,
'OIDCAuthPasswordBackend': User.Source.openid.label,
}
class AuthBackendLabelMapping(LazyObject):
@staticmethod
def get_login_backends():
backend_label_mapping = {}
for source, backends in User.SOURCE_BACKEND_MAPPING.items():
for backend in backends:
backend_label_mapping[backend] = source.label
backend_label_mapping[settings.AUTH_BACKEND_PUBKEY] = _('SSH Key')
backend_label_mapping[settings.AUTH_BACKEND_MODEL] = _('Password')
return backend_label_mapping
def _setup(self):
self._wrapped = self.get_login_backends()
AUTH_BACKEND_LABEL_MAPPING = AuthBackendLabelMapping()
def create_operate_log(action, sender, resource):
@ -70,6 +76,7 @@ def create_operate_log(action, sender, resource):
@receiver(post_save)
def on_object_created_or_update(sender, instance=None, created=False, update_fields=None, **kwargs):
# last_login 改变是最后登录日期, 每次登录都会改变
if instance._meta.object_name == 'User' and \
update_fields and 'last_login' in update_fields:
return
@ -125,14 +132,13 @@ def on_audits_log_create(sender, instance=None, **kwargs):
def get_login_backend(request):
backend = request.session.get('auth_backend', '') or request.session.get(BACKEND_SESSION_KEY, '')
backend = request.session.get('auth_backend', '') or \
request.session.get(BACKEND_SESSION_KEY, '')
backend = backend.rsplit('.', maxsplit=1)[-1]
if backend in LOGIN_BACKEND:
return LOGIN_BACKEND[backend]
else:
logger.warn(f'LOGIN_BACKEND_NOT_FOUND: {backend}')
return ''
backend_label = AUTH_BACKEND_LABEL_MAPPING.get(backend, None)
if backend_label is None:
backend_label = ''
return backend_label
def generate_data(username, request):

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
#
from .auth import *
from .connection_token import *
from .token import *
from .mfa import *
from .access_key import *

View File

@ -1,55 +0,0 @@
# -*- coding: utf-8 -*-
#
import uuid
from django.core.cache import cache
from django.shortcuts import get_object_or_404
from rest_framework.response import Response
from rest_framework.views import APIView
from common.utils import get_logger
from common.permissions import IsOrgAdminOrAppUser
from orgs.mixins.api import RootOrgViewMixin
from users.models import User
from assets.models import Asset, SystemUser
logger = get_logger(__name__)
__all__ = [
'UserConnectionTokenApi',
]
class UserConnectionTokenApi(RootOrgViewMixin, APIView):
permission_classes = (IsOrgAdminOrAppUser,)
def post(self, request):
user_id = request.data.get('user', '')
asset_id = request.data.get('asset', '')
system_user_id = request.data.get('system_user', '')
token = str(uuid.uuid4())
user = get_object_or_404(User, id=user_id)
asset = get_object_or_404(Asset, id=asset_id)
system_user = get_object_or_404(SystemUser, id=system_user_id)
value = {
'user': user_id,
'username': user.username,
'asset': asset_id,
'hostname': asset.hostname,
'system_user': system_user_id,
'system_user_name': system_user.name
}
cache.set(token, value, timeout=20)
return Response({"token": token}, status=201)
def get(self, request):
token = request.query_params.get('token')
user_only = request.query_params.get('user-only', None)
value = cache.get(token, None)
if not value:
return Response('', status=404)
if not user_only:
return Response(value)
else:
return Response({'user': value['user']})

View File

@ -0,0 +1,242 @@
# -*- coding: utf-8 -*-
#
from django.conf import settings
from django.core.cache import cache
from django.shortcuts import get_object_or_404
from django.http import HttpResponse
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet
from rest_framework.decorators import action
from rest_framework.exceptions import PermissionDenied
from common.utils import get_logger, random_string
from common.drf.api import SerializerMixin2
from common.permissions import IsSuperUserOrAppUser, IsValidUser, IsSuperUser
from orgs.mixins.api import RootOrgViewMixin
from ..serializers import (
ConnectionTokenSerializer, ConnectionTokenSecretSerializer,
RDPFileSerializer
)
logger = get_logger(__name__)
__all__ = ['UserConnectionTokenViewSet']
class UserConnectionTokenViewSet(RootOrgViewMixin, SerializerMixin2, GenericViewSet):
permission_classes = (IsSuperUserOrAppUser,)
serializer_classes = {
'default': ConnectionTokenSerializer,
'get_secret_detail': ConnectionTokenSecretSerializer,
'get_rdp_file': RDPFileSerializer
}
CACHE_KEY_PREFIX = 'CONNECTION_TOKEN_{}'
@staticmethod
def check_resource_permission(user, asset, application, system_user):
from perms.utils.asset import has_asset_system_permission
from perms.utils.application import has_application_system_permission
if asset and not has_asset_system_permission(user, asset, system_user):
error = f'User not has this asset and system user permission: ' \
f'user={user.id} system_user={system_user.id} asset={asset.id}'
raise PermissionDenied(error)
if application and not has_application_system_permission(user, application, system_user):
error = f'User not has this application and system user permission: ' \
f'user={user.id} system_user={system_user.id} application={application.id}'
raise PermissionDenied(error)
return True
def create_token(self, user, asset, application, system_user):
if not settings.CONNECTION_TOKEN_ENABLED:
raise PermissionDenied('Connection token disabled')
if not user:
user = self.request.user
if not self.request.user.is_superuser and user != self.request.user:
raise PermissionDenied('Only super user can create user token')
self.check_resource_permission(user, asset, application, system_user)
token = random_string(36)
value = {
'user': str(user.id),
'username': user.username,
'system_user': str(system_user.id),
'system_user_name': system_user.name
}
if asset:
value.update({
'type': 'asset',
'asset': str(asset.id),
'hostname': asset.hostname,
})
elif application:
value.update({
'type': 'application',
'application': application.id,
'application_name': str(application)
})
key = self.CACHE_KEY_PREFIX.format(token)
cache.set(key, value, timeout=20)
return token
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
asset = serializer.validated_data.get('asset')
application = serializer.validated_data.get('application')
system_user = serializer.validated_data['system_user']
user = serializer.validated_data.get('user')
token = self.create_token(user, asset, application, system_user)
return Response({"token": token}, status=201)
@action(methods=['POST', 'GET'], detail=False, url_path='rdp/file')
def get_rdp_file(self, request, *args, **kwargs):
options = {
'full address:s': '',
'username:s': '',
'screen mode id:i': '0',
'desktopwidth:i': '1280',
'desktopheight:i': '800',
'use multimon:i': '1',
'session bpp:i': '24',
'audiomode:i': '0',
'disable wallpaper:i': '0',
'disable full window drag:i': '0',
'disable menu anims:i': '0',
'disable themes:i': '0',
'alternate shell:s': '',
'shell working directory:s': '',
'authentication level:i': '2',
'connect to console:i': '0',
'disable cursor setting:i': '0',
'allow font smoothing:i': '1',
'allow desktop composition:i': '1',
'redirectprinters:i': '0',
'prompt for credentials on client:i': '0',
'autoreconnection enabled:i': '1',
'bookmarktype:i': '3',
'use redirection server name:i': '0',
# 'alternate shell:s:': '||MySQLWorkbench',
# 'remoteapplicationname:s': 'Firefox',
# 'remoteapplicationcmdline:s': '',
}
if self.request.method == 'GET':
data = self.request.query_params
else:
data = request.data
serializer = self.get_serializer(data=data)
serializer.is_valid(raise_exception=True)
asset = serializer.validated_data.get('asset')
application = serializer.validated_data.get('application')
system_user = serializer.validated_data['system_user']
user = serializer.validated_data.get('user')
height = serializer.validated_data.get('height')
width = serializer.validated_data.get('width')
token = self.create_token(user, asset, application, system_user)
# Todo: 上线后地址是 JumpServerAddr:3389
address = self.request.query_params.get('address') or '1.1.1.1'
options['full address:s'] = address
options['username:s'] = '{}@{}'.format(user.username, token)
options['desktopwidth:i'] = width
options['desktopheight:i'] = height
data = ''
for k, v in options.items():
data += f'{k}:{v}\n'
response = HttpResponse(data, content_type='text/plain')
filename = "{}-{}-jumpserver.rdp".format(user.username, asset.hostname)
response['Content-Disposition'] = 'attachment; filename={}'.format(filename)
return response
@staticmethod
def _get_application_secret_detail(value):
from applications.models import Application
from perms.models import Action
application = get_object_or_404(Application, id=value.get('application'))
gateway = None
if not application.category_remote_app:
actions = Action.NONE
remote_app = {}
asset = None
domain = application.domain
else:
remote_app = application.get_rdp_remote_app_setting()
actions = Action.CONNECT
asset = application.get_remote_app_asset()
domain = asset.domain
if domain and domain.has_gateway():
gateway = domain.random_gateway()
return {
'asset': asset,
'application': application,
'gateway': gateway,
'remote_app': remote_app,
'actions': actions
}
@staticmethod
def _get_asset_secret_detail(value, user, system_user):
from assets.models import Asset
from perms.utils.asset import get_asset_system_user_ids_with_actions_by_user
asset = get_object_or_404(Asset, id=value.get('asset'))
systemuserid_actions_mapper = get_asset_system_user_ids_with_actions_by_user(user, asset)
actions = systemuserid_actions_mapper.get(system_user.id, [])
gateway = None
if asset and asset.domain and asset.domain.has_gateway():
gateway = asset.domain.random_gateway()
return {
'asset': asset,
'application': None,
'gateway': gateway,
'remote_app': None,
'actions': actions,
}
@action(methods=['POST'], detail=False, permission_classes=[IsSuperUserOrAppUser], url_path='secret-info/detail')
def get_secret_detail(self, request, *args, **kwargs):
from users.models import User
from assets.models import SystemUser
token = request.data.get('token', '')
key = self.CACHE_KEY_PREFIX.format(token)
value = cache.get(key, None)
if not value:
return Response(status=404)
user = get_object_or_404(User, id=value.get('user'))
system_user = get_object_or_404(SystemUser, id=value.get('system_user'))
data = dict(user=user, system_user=system_user)
if value.get('type') == 'asset':
asset_detail = self._get_asset_secret_detail(value, user=user, system_user=system_user)
data['type'] = 'asset'
data.update(asset_detail)
else:
app_detail = self._get_application_secret_detail(value)
data['type'] = 'application'
data.update(app_detail)
serializer = self.get_serializer(data)
return Response(data=serializer.data, status=200)
def get_permissions(self):
if self.action in ["create", "get_rdp_file"]:
if self.request.data.get('user', None):
self.permission_classes = (IsSuperUser,)
else:
self.permission_classes = (IsValidUser,)
return super().get_permissions()
def get(self, request):
token = request.query_params.get('token')
key = self.CACHE_KEY_PREFIX.format(token)
value = cache.get(key, None)
if not value:
return Response('', status=404)
return Response(value)

View File

@ -3,10 +3,10 @@
from rest_framework.generics import UpdateAPIView
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.permissions import AllowAny
from django.shortcuts import get_object_or_404
from django.utils.translation import ugettext as _
from common.utils import get_logger, get_object_or_none
from common.utils import get_logger
from common.permissions import IsOrgAdmin
from ..models import LoginConfirmSetting
from ..serializers import LoginConfirmSettingSerializer
@ -32,7 +32,7 @@ class LoginConfirmSettingUpdateApi(UpdateAPIView):
class TicketStatusApi(mixins.AuthMixin, APIView):
permission_classes = ()
permission_classes = (AllowAny,)
def get(self, request, *args, **kwargs):
try:

View File

@ -7,6 +7,7 @@ from django.http.response import HttpResponseRedirect
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.request import Request
from rest_framework.permissions import AllowAny
from common.utils.timezone import utcnow
from common.const.http import POST, GET
@ -31,6 +32,7 @@ class SSOViewSet(AuthMixin, JmsGenericViewSet):
'login_url': SSOTokenSerializer,
'login': EmptySerializer
}
permission_classes = (IsSuperUser,)
@action(methods=[POST], detail=False, permission_classes=[IsSuperUser], url_path='login-url')
def login_url(self, request, *args, **kwargs):
@ -54,7 +56,7 @@ class SSOViewSet(AuthMixin, JmsGenericViewSet):
login_url = '%s?%s' % (reverse('api-auth:sso-login', external=True), urlencode(query))
return Response(data={'login_url': login_url})
@action(methods=[GET], detail=False, filter_backends=[AuthKeyQueryDeclaration], permission_classes=[])
@action(methods=[GET], detail=False, filter_backends=[AuthKeyQueryDeclaration], permission_classes=[AllowAny])
def login(self, request: Request, *args, **kwargs):
"""
此接口违反了 `Restful` 的规范

View File

@ -2,7 +2,7 @@
#
import traceback
from django.contrib.auth import get_user_model, authenticate
from django.contrib.auth import get_user_model
from radiusauth.backends import RADIUSBackend, RADIUSRealmBackend
from django.conf import settings

View File

@ -18,6 +18,8 @@ reason_user_not_exist = 'user_not_exist'
reason_password_expired = 'password_expired'
reason_user_invalid = 'user_invalid'
reason_user_inactive = 'user_inactive'
reason_backend_not_match = 'backend_not_match'
reason_acl_not_allow = 'acl_not_allow'
reason_choices = {
reason_password_failed: _('Username/password check failed'),
@ -27,7 +29,9 @@ reason_choices = {
reason_user_not_exist: _("Username does not exist"),
reason_password_expired: _("Password expired"),
reason_user_invalid: _('Disabled or expired'),
reason_user_inactive: _("This account is inactive.")
reason_user_inactive: _("This account is inactive."),
reason_backend_not_match: _("Auth backend not match"),
reason_acl_not_allow: _("ACL is not allowed")
}
old_reason_choices = {
'0': '-',

View File

@ -8,11 +8,26 @@ from captcha.fields import CaptchaField, CaptchaTextInput
class UserLoginForm(forms.Form):
username = forms.CharField(label=_('Username'), max_length=100)
days_auto_login = int(settings.SESSION_COOKIE_AGE / 3600 / 24)
disable_days_auto_login = settings.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE or days_auto_login < 1
username = forms.CharField(
label=_('Username'), max_length=100,
widget=forms.TextInput(attrs={
'placeholder': _("Username"),
'autofocus': 'autofocus'
})
)
password = forms.CharField(
label=_('Password'), widget=forms.PasswordInput,
max_length=1024, strip=False
)
auto_login = forms.BooleanField(
label=_("{} days auto login").format(days_auto_login or 1),
required=False, initial=False, widget=forms.CheckboxInput(
attrs={'disabled': disable_days_auto_login}
)
)
def confirm_login_allowed(self, user):
if not user.is_staff:
@ -35,8 +50,13 @@ class CaptchaMixin(forms.Form):
class ChallengeMixin(forms.Form):
challenge = forms.CharField(label=_('MFA code'), max_length=6,
required=False)
challenge = forms.CharField(
label=_('MFA code'), max_length=6, required=False,
widget=forms.TextInput(attrs={
'placeholder': _("MFA code"),
'style': 'width: 50%'
})
)
def get_user_login_form_cls(*, captcha=False):

View File

@ -9,7 +9,7 @@ from django.contrib.auth import authenticate
from django.shortcuts import reverse
from django.contrib.auth import BACKEND_SESSION_KEY
from common.utils import get_object_or_none, get_request_ip, get_logger
from common.utils import get_object_or_none, get_request_ip, get_logger, bulk_get
from users.models import User
from users.utils import (
is_block_login, clean_failed_count
@ -24,6 +24,7 @@ logger = get_logger(__name__)
class AuthMixin:
request = None
partial_credential_error = None
def get_user_from_session(self):
if self.request.session.is_empty():
@ -75,49 +76,84 @@ class AuthMixin:
return rsa_decrypt(raw_passwd, rsa_private_key)
except Exception as e:
logger.error(e, exc_info=True)
logger.error(f'Decrypt password faild: password[{raw_passwd}] rsa_private_key[{rsa_private_key}]')
logger.error(f'Decrypt password failed: password[{raw_passwd}] '
f'rsa_private_key[{rsa_private_key}]')
return None
return raw_passwd
def check_user_auth(self, decrypt_passwd=False):
self.check_is_block()
def raise_credential_error(self, error):
raise self.partial_credential_error(error=error)
def get_auth_data(self, decrypt_passwd=False):
request = self.request
if hasattr(request, 'data'):
data = request.data
else:
data = request.POST
username = data.get('username', '')
password = data.get('password', '')
challenge = data.get('challenge', '')
public_key = data.get('public_key', '')
ip = self.get_request_ip()
CredentialError = partial(errors.CredentialError, username=username, ip=ip, request=request)
items = ['username', 'password', 'challenge', 'public_key', 'auto_login']
username, password, challenge, public_key, auto_login = bulk_get(data, *items, default='')
password = password + challenge.strip()
ip = self.get_request_ip()
self.partial_credential_error = partial(errors.CredentialError, username=username, ip=ip, request=request)
if decrypt_passwd:
password = self.decrypt_passwd(password)
if not password:
raise CredentialError(error=errors.reason_password_decrypt_failed)
self.raise_credential_error(errors.reason_password_decrypt_failed)
return username, password, public_key, ip, auto_login
user = authenticate(request,
username=username,
password=password + challenge.strip(),
public_key=public_key)
def _check_only_allow_exists_user_auth(self, username):
# 仅允许预先存在的用户认证
if settings.ONLY_ALLOW_EXIST_USER_AUTH:
exist = User.objects.filter(username=username).exists()
if not exist:
logger.error(f"Only allow exist user auth, login failed: {username}")
self.raise_credential_error(errors.reason_user_not_exist)
def _check_auth_user_is_valid(self, username, password, public_key):
user = authenticate(self.request, username=username, password=password, public_key=public_key)
if not user:
raise CredentialError(error=errors.reason_password_failed)
self.raise_credential_error(errors.reason_password_failed)
elif user.is_expired:
raise CredentialError(error=errors.reason_user_inactive)
self.raise_credential_error(errors.reason_user_inactive)
elif not user.is_active:
raise CredentialError(error=errors.reason_user_inactive)
self.raise_credential_error(errors.reason_user_inactive)
return user
def _check_auth_source_is_valid(self, user, auth_backend):
# 限制只能从认证来源登录
if settings.ONLY_ALLOW_AUTH_FROM_SOURCE:
auth_backends_allowed = user.SOURCE_BACKEND_MAPPING.get(user.source)
if auth_backend not in auth_backends_allowed:
self.raise_credential_error(error=errors.reason_backend_not_match)
def _check_login_acl(self, user, ip):
# ACL 限制用户登录
from acls.models import LoginACL
is_allowed = LoginACL.allow_user_to_login(user, ip)
if not is_allowed:
raise self.raise_credential_error(error=errors.reason_acl_not_allow)
def check_user_auth(self, decrypt_passwd=False):
self.check_is_block()
request = self.request
username, password, public_key, ip, auto_login = self.get_auth_data(decrypt_passwd=decrypt_passwd)
self._check_only_allow_exists_user_auth(username)
user = self._check_auth_user_is_valid(username, password, public_key)
# 校验login-acl规则
self._check_login_acl(user, ip)
# 限制只能从认证来源登录
auth_backend = getattr(user, 'backend', 'django.contrib.auth.backends.ModelBackend')
self._check_auth_source_is_valid(user, auth_backend)
self._check_password_require_reset_or_not(user)
self._check_passwd_is_too_simple(user, password)
clean_failed_count(username, ip)
request.session['auth_password'] = 1
request.session['user_id'] = str(user.id)
auth_backend = getattr(user, 'backend', 'django.contrib.auth.backends.ModelBackend')
request.session['auto_login'] = auto_login
request.session['auth_backend'] = auth_backend
return user

View File

@ -4,13 +4,17 @@ from rest_framework import serializers
from common.utils import get_object_or_none
from users.models import User
from assets.models import Asset, SystemUser, Gateway
from applications.models import Application
from users.serializers import UserProfileSerializer
from perms.serializers.asset.permission import ActionsField
from .models import AccessKey, LoginConfirmSetting, SSOToken
__all__ = [
'AccessKeySerializer', 'OtpVerifySerializer', 'BearerTokenSerializer',
'MFAChallengeSerializer', 'LoginConfirmSettingSerializer', 'SSOTokenSerializer',
'ConnectionTokenSerializer', 'ConnectionTokenSecretSerializer', 'RDPFileSerializer'
]
@ -82,3 +86,103 @@ class SSOTokenSerializer(serializers.Serializer):
username = serializers.CharField(write_only=True)
login_url = serializers.CharField(read_only=True)
next = serializers.CharField(write_only=True, allow_blank=True, required=False, allow_null=True)
class ConnectionTokenSerializer(serializers.Serializer):
user = serializers.CharField(max_length=128, required=False, allow_blank=True)
system_user = serializers.CharField(max_length=128, required=True)
asset = serializers.CharField(max_length=128, required=False)
application = serializers.CharField(max_length=128, required=False)
@staticmethod
def validate_user(user_id):
from users.models import User
user = User.objects.filter(id=user_id).first()
if user is None:
raise serializers.ValidationError('user id not exist')
return user
@staticmethod
def validate_system_user(system_user_id):
from assets.models import SystemUser
system_user = SystemUser.objects.filter(id=system_user_id).first()
if system_user is None:
raise serializers.ValidationError('system_user id not exist')
return system_user
@staticmethod
def validate_asset(asset_id):
from assets.models import Asset
asset = Asset.objects.filter(id=asset_id).first()
if asset is None:
raise serializers.ValidationError('asset id not exist')
return asset
@staticmethod
def validate_application(app_id):
from applications.models import Application
app = Application.objects.filter(id=app_id).first()
if app is None:
raise serializers.ValidationError('app id not exist')
return app
def validate(self, attrs):
asset = attrs.get('asset')
application = attrs.get('application')
if not asset and not application:
raise serializers.ValidationError('asset or application required')
if asset and application:
raise serializers.ValidationError('asset and application should only one')
return super().validate(attrs)
class ConnectionTokenUserSerializer(serializers.ModelSerializer):
class Meta:
model = User
fields = ['id', 'name', 'username', 'email']
class ConnectionTokenAssetSerializer(serializers.ModelSerializer):
class Meta:
model = Asset
fields = ['id', 'hostname', 'ip', 'port', 'org_id']
class ConnectionTokenSystemUserSerializer(serializers.ModelSerializer):
class Meta:
model = SystemUser
fields = ['id', 'name', 'username', 'password', 'private_key']
class ConnectionTokenGatewaySerializer(serializers.ModelSerializer):
class Meta:
model = Gateway
fields = ['id', 'ip', 'port', 'username', 'password', 'private_key']
class ConnectionTokenRemoteAppSerializer(serializers.Serializer):
program = serializers.CharField()
working_directory = serializers.CharField()
parameters = serializers.CharField()
class ConnectionTokenApplicationSerializer(serializers.ModelSerializer):
class Meta:
model = Application
fields = ['id', 'name', 'category', 'type']
class ConnectionTokenSecretSerializer(serializers.Serializer):
type = serializers.ChoiceField(choices=[('application', 'Application'), ('asset', 'Asset')])
user = ConnectionTokenUserSerializer(read_only=True)
asset = ConnectionTokenAssetSerializer(read_only=True)
remote_app = ConnectionTokenRemoteAppSerializer(read_only=True)
application = ConnectionTokenApplicationSerializer(read_only=True)
system_user = ConnectionTokenSystemUserSerializer(read_only=True)
gateway = ConnectionTokenGatewaySerializer(read_only=True)
actions = ActionsField()
class RDPFileSerializer(ConnectionTokenSerializer):
width = serializers.IntegerField(default=1280)
height = serializers.IntegerField(default=800)

View File

@ -24,7 +24,7 @@ def on_user_auth_login_success(sender, user, request, **kwargs):
@receiver(openid_user_login_success)
def on_oidc_user_login_success(sender, request, user, **kwargs):
def on_oidc_user_login_success(sender, request, user, create=False, **kwargs):
request.session[BACKEND_SESSION_KEY] = 'OIDCAuthCodeBackend'
post_auth_success.send(sender, user=user, request=request)

View File

@ -1,12 +1,11 @@
{% load static %}
{% load i18n %}
{% load bootstrap3 %}
{% load static %}
<!DOCTYPE html>
<html>
<!--/*@thymesVar id="LoginConstants" type="com.fit2cloud.support.common.constants.LoginConstants"*/-->
<!--/*@thymesVar id="message" type="java.lang.String"*/-->
<head>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
<link rel="shortcut icon" href="{{ FAVICON_URL }}" type="image/x-icon">
<link rel="shortcut icon" href="{{ FAVICON_URL }}" type="image/x-icon">
<title>
{{ JMS_TITLE }}
</title>
@ -16,6 +15,8 @@
<link href="{% static 'css/font-awesome.min.css' %}" rel="stylesheet">
<link href="{% static 'css/bootstrap-style.css' %}" rel="stylesheet">
<link href="{% static 'css/login-style.css' %}" rel="stylesheet">
<link href="{% static 'css/style.css' %}" rel="stylesheet">
<link href="{% static 'css/jumpserver.css' %}" rel="stylesheet">
<!-- scripts -->
<script src="{% static 'js/jquery-3.1.1.min.js' %}"></script>
@ -24,26 +25,54 @@
<script src="{% static 'js/plugins/datatables/datatables.min.js' %}"></script>
<style>
.login-content {
box-shadow: 0 5px 5px -3px rgb(0 0 0 / 20%), 0 8px 10px 1px rgb(0 0 0 / 14%), 0 3px 14px 2px rgb(0 0 0 / 12%);
}
.box-1{
.help-block {
margin: 0;
text-align: left;
}
form label {
color: #737373;
font-size: 13px;
font-weight: normal;
}
.hr-line-dashed {
border-top: 1px dashed #e7eaec;
color: #ffffff;
background-color: #ffffff;
height: 1px;
margin: 20px 0;
}
.login-content {
height: 472px;
width: 984px;
margin-right: auto;
margin-left: auto;
margin-top: calc((100vh - 470px)/2);
margin-top: calc((100vh - 470px) / 3);
}
.box-2{
body {
background-color: #f2f2f2;
height: calc(100vh - (100vh - 470px) / 3);
}
.right-image-box {
height: 100%;
width: 50%;
float: right;
}
.box-3{
.left-form-box {
text-align: center;
background-color: white;
height: 100%;
width: 50%;
}
.captcha {
float: right;
}
@ -56,136 +85,144 @@
text-align: left;
}
.form-group.has-error {
margin-bottom: 0;
}
.captch-field .has-error .help-block {
margin-top: -8px !important;
}
.no-captcha-challenge .form-group {
margin-bottom: 20px;
}
.jms-title {
padding: 40px 10px 10px;
}
.no-captcha-challenge .jms-title {
padding: 60px 10px 10px;
}
.no-captcha-challenge .welcome-message {
padding-top: 10px;
}
.radio, .checkbox {
margin: 0;
}
#github_star {
float: right;
margin: 10px 10px 0 0;
}
</style>
</head>
<body style="height: 100%;font-size: 13px">
<div>
<div class="box-1">
<div class="box-2">
<img src="{{ LOGIN_IMAGE_URL }}" style="height: 100%; width: 100%"/>
<body>
<div class="login-content ">
<div class="right-image-box">
<a href="{% if not XPACK_ENABLED %}https://github.com/jumpserver/jumpserver{% endif %}">
<img src="{{ LOGIN_IMAGE_URL }}" style="height: 100%; width: 100%"/>
</a>
</div>
<div class="left-form-box {% if not form.challenge and not form.captcha %} no-captcha-challenge {% endif %}">
<div style="background-color: white">
<div class="jms-title">
<span style="font-size: 21px;font-weight:400;color: #151515;letter-spacing: 0;">{{ JMS_TITLE }}</span>
</div>
<div class="box-3">
<div style="background-color: white">
{% if form.challenge %}
<div style="margin-top: 20px;padding-top: 30px;padding-left: 20px;padding-right: 20px;height: 60px">
<div class="contact-form col-md-10 col-md-offset-1">
<form id="login-form" action="" method="post" role="form" novalidate="novalidate">
{% csrf_token %}
<div style="line-height: 17px;margin-bottom: 20px;color: #999999;">
{% if form.errors %}
<p class="red-fonts" style="color: red">
{% if form.non_field_errors %}
{{ form.non_field_errors.as_text }}
{% endif %}
</p>
{% else %}
<div style="margin-top: 20px;padding-top: 40px;padding-left: 20px;padding-right: 20px;height: 80px">
<p class="welcome-message">
{% trans 'Welcome back, please enter username and password to login' %}
</p>
{% endif %}
<span style="font-size: 21px;font-weight:400;color: #151515;letter-spacing: 0;">{{ JMS_TITLE }}</span>
</div>
<div style="font-size: 12px;color: #999999;letter-spacing: 0;line-height: 18px;margin-top: 18px">
{% trans 'Welcome back, please enter username and password to login' %}
</div>
<div style="margin-bottom: 0px">
<div>
<div class="col-md-1"></div>
<div class="contact-form col-md-10" style="margin-top: 0px;height: 35px">
<form id="contact-form" action="" method="post" role="form" novalidate="novalidate">
{% csrf_token %}
{% if form.non_field_errors %}
{% if form.challenge %}
<div style="height: 50px;color: red;line-height: 17px;">
{% else %}
<div style="height: 70px;color: red;line-height: 17px;">
{% endif %}
<p class="red-fonts">{{ form.non_field_errors.as_text }}</p>
</div>
{% elif form.errors.captcha %}
<p class="red-fonts">{% trans 'Captcha invalid' %}</p>
{% else %}
<div style="height: 50px"></div>
{% endif %}
<div class="form-group">
<input type="text" class="form-control" name="{{ form.username.html_name }}" placeholder="{% trans 'Username' %}" required="" value="{% if form.username.value %}{{ form.username.value }}{% endif %}" style="height: 35px">
{% if form.errors.username %}
<div class="help-block field-error">
<p class="red-fonts">{{ form.errors.username.as_text }}</p>
</div>
{% endif %}
</div>
<div class="form-group">
<input type="password" class="form-control" id="password" placeholder="{% trans 'Password' %}" required="">
<input id="password-hidden" type="text" style="display:none" name="{{ form.password.html_name }}">
{% if form.errors.password %}
<div class="help-block field-error">
<p class="red-fonts">{{ form.errors.password.as_text }}</p>
</div>
{% endif %}
</div>
{% if form.challenge %}
<div class="form-group">
<input type="challenge" class="form-control" id="challenge" name="{{ form.challenge.html_name }}" placeholder="{% trans 'MFA code' %}" >
{% if form.errors.challenge %}
<div class="help-block field-error">
<p class="red-fonts">{{ form.errors.challenge.as_text }}</p>
</div>
{% endif %}
</div>
{% endif %}
{% if form.captcha %}
<div class="form-group" style="height: 50px;margin-bottom: 0;font-size: 13px">
{{ form.captcha }}
</div>
{% else %}
<div class="form-group" style="height: 25px;margin-bottom: 0;font-size: 13px"></div>
{% endif %}
<div class="form-group" style="margin-top: 10px">
<button type="submit" class="btn btn-transparent" onclick="doLogin();return false;">{% trans 'Login' %}</button>
</div>
<div>
{% if AUTH_OPENID or AUTH_CAS %}
<div class="hr-line-dashed"></div>
<div style="display: inline-block; float: left">
<b class="text-muted text-left" style="margin-right: 10px">{% trans "More login options" %}</b>
{% if AUTH_OPENID %}
<a href="{% url 'authentication:openid:login' %}">
<i class="fa fa-openid"></i> {% trans 'OpenID' %}
</a>
{% endif %}
{% if AUTH_CAS %}
<a href="{% url 'authentication:cas:cas-login' %}">
<i class="fa"><img src="{{ LOGIN_CAS_LOGO_URL }}" height="13" width="13"></i> {% trans 'CAS' %}
</a>
{% endif %}
</div>
<div class="text-center" style="display: inline-block; float: right">
{% else %}
<div class="text-center" style="display: inline-block;">
{% endif %}
<a id="forgot_password" href="{% url 'authentication:forgot-password' %}">
<small>{% trans 'Forgot password' %}?</small>
</a>
</div>
</div>
</form>
{% bootstrap_field form.username show_label=False %}
<div class="form-group">
<input type="password" class="form-control" id="password" placeholder="{% trans 'Password' %}" required="">
<input id="password-hidden" type="text" style="display:none" name="{{ form.password.html_name }}">
</div>
{% if form.challenge %}
{% bootstrap_field form.challenge show_label=False %}
{% elif form.captcha %}
<div class="captch-field">
{% bootstrap_field form.captcha show_label=False %}
</div>
{% endif %}
<div class="form-group" style="padding-top: 5px; margin-bottom: 10px">
<div class="row">
<div class="col-md-6" style="text-align: left">
{% if form.auto_login %}
{% bootstrap_field form.auto_login form_group_class='' %}
{% endif %}
</div>
<div class="col-md-6">
<a id="forgot_password" href="{{ forgot_password_url }}" style="float: right">
<small>{% trans 'Forgot password' %}?</small>
</a>
</div>
<div class="col-md-1"></div>
</div>
</div>
</div>
<div class="form-group" style="">
<button type="submit" class="btn btn-transparent" onclick="doLogin();return false;">{% trans 'Login' %}</button>
</div>
<div>
{% if AUTH_OPENID or AUTH_CAS %}
<div class="hr-line-dashed"></div>
<div style="display: inline-block; float: left">
<b class="text-muted text-left" style="margin-right: 10px">{% trans "More login options" %}</b>
{% if AUTH_OPENID %}
<a href="{% url 'authentication:openid:login' %}" class="more-login-item">
<i class="fa fa-openid"></i> {% trans 'OpenID' %}
</a>
{% endif %}
{% if AUTH_CAS %}
<a href="{% url 'authentication:cas:cas-login' %}" class="more-login-item">
<i class="fa"><img src="{{ LOGIN_CAS_LOGO_URL }}" height="13" width="13"></i> {% trans 'CAS' %}
</a>
{% endif %}
</div>
{% else %}
<div class="text-center" style="display: inline-block;">
{% endif %}
</div>
</div>
</form>
</div>
</div>
</div>
</div>
</div>
</body>
<script type="text/javascript" src="/static/js/plugins/jsencrypt/jsencrypt.min.js"></script>
<script>
function encryptLoginPassword(password, rsaPublicKey){
function encryptLoginPassword(password, rsaPublicKey) {
var jsencrypt = new JSEncrypt(); //加密对象
jsencrypt.setPublicKey(rsaPublicKey); // 设置密钥
return jsencrypt.encrypt(password); //加密
}
function doLogin() {
//公钥加密
var rsaPublicKey = "{{ rsa_public_key }}"
var password =$('#password').val(); //明文密码
var password = $('#password').val(); //明文密码
var passwordEncrypted = encryptLoginPassword(password, rsaPublicKey)
$('#password-hidden').val(passwordEncrypted); //返回给密码输入input
$('#contact-form').submit();//post提交
$('#login-form').submit();//post提交
}
$(document).ready(function () {

View File

@ -9,6 +9,7 @@ app_name = 'authentication'
router = DefaultRouter()
router.register('access-keys', api.AccessKeyViewSet, 'access-key')
router.register('sso', api.SSOViewSet, 'sso')
router.register('connection-token', api.UserConnectionTokenViewSet, 'connection-token')
urlpatterns = [
@ -16,8 +17,6 @@ urlpatterns = [
path('auth/', api.TokenCreateApi.as_view(), name='user-auth'),
path('tokens/', api.TokenCreateApi.as_view(), name='auth-token'),
path('mfa/challenge/', api.MFAChallengeApi.as_view(), name='mfa-challenge'),
path('connection-token/',
api.UserConnectionTokenApi.as_view(), name='connection-token'),
path('otp/verify/', api.UserOtpVerifyApi.as_view(), name='user-otp-verify'),
path('login-confirm-ticket/status/', api.TicketStatusApi.as_view(), name='login-confirm-ticket-status'),
path('login-confirm-settings/<uuid:user_id>/', api.LoginConfirmSettingUpdateApi.as_view(), name='login-confirm-setting-update')

View File

@ -1,9 +1,9 @@
# -*- coding: utf-8 -*-
#
import base64
from Crypto.PublicKey import RSA
from Crypto.Cipher import PKCS1_v1_5
from Crypto import Random
from Cryptodome.PublicKey import RSA
from Cryptodome.Cipher import PKCS1_v1_5
from Cryptodome import Random
from common.utils import get_logger

View File

@ -45,9 +45,10 @@ class UserLoginView(mixins.AuthMixin, FormView):
def get(self, request, *args, **kwargs):
if request.user.is_staff:
return redirect(redirect_user_first_login_or_index(
request, self.redirect_field_name)
first_login_url = redirect_user_first_login_or_index(
request, self.redirect_field_name
)
return redirect(first_login_url)
request.session.set_test_cookie()
return super().get(request, *args, **kwargs)
@ -99,11 +100,17 @@ class UserLoginView(mixins.AuthMixin, FormView):
self.request.session[RSA_PRIVATE_KEY] = rsa_private_key
self.request.session[RSA_PUBLIC_KEY] = rsa_public_key
forgot_password_url = reverse('authentication:forgot-password')
has_other_auth_backend = settings.AUTHENTICATION_BACKENDS[0] != settings.AUTH_BACKEND_MODEL
if has_other_auth_backend and settings.FORGOT_PASSWORD_URL:
forgot_password_url = settings.FORGOT_PASSWORD_URL
context = {
'demo_mode': os.environ.get("DEMO_MODE"),
'AUTH_OPENID': settings.AUTH_OPENID,
'AUTH_CAS': settings.AUTH_CAS,
'rsa_public_key': rsa_public_key,
'forgot_password_url': forgot_password_url
}
kwargs.update(context)
return super().get_context_data(**kwargs)
@ -121,6 +128,13 @@ class UserLoginGuardView(mixins.AuthMixin, RedirectView):
url = "%s?%s" % (url, args)
return url
def login_it(self, user):
auth_login(self.request, user)
# 如果设置了自动登录,那需要设置 session_id cookie 的有效期
if self.request.session.get('auto_login'):
age = self.request.session.get_expiry_age()
self.request.session.set_expiry(age)
def get_redirect_url(self, *args, **kwargs):
try:
user = self.check_user_auth_if_need()
@ -137,7 +151,7 @@ class UserLoginGuardView(mixins.AuthMixin, RedirectView):
except errors.PasswdTooSimple as e:
return e.url
else:
auth_login(self.request, user)
self.login_it(user)
self.send_auth_signal(success=True, user=user)
self.clear_auth_mark()
url = redirect_user_first_login_or_index(

View File

@ -13,7 +13,7 @@ from rest_framework.viewsets import GenericViewSet
from common.permissions import IsValidUser
from .http import HttpResponseTemporaryRedirect
from .const import KEY_CACHE_RESOURCES_ID
from .const import KEY_CACHE_RESOURCE_IDS
from .utils import get_logger
from .mixins import CommonApiMixin
@ -93,7 +93,7 @@ class ResourcesIDCacheApi(APIView):
spm = str(uuid.uuid4())
resources = request.data.get('resources')
if resources is not None:
cache_key = KEY_CACHE_RESOURCES_ID.format(spm)
cache_key = KEY_CACHE_RESOURCE_IDS.format(spm)
cache.set(cache_key, resources, 300)
return Response({'spm': spm})

View File

@ -1,13 +1,24 @@
import json
from django.core.cache import cache
import time
from redis import Redis
from common.utils.lock import DistributedLock
from common.utils import lazyproperty
from common.utils import get_logger
from jumpserver.const import CONFIG
logger = get_logger(__file__)
class ComputeLock(DistributedLock):
"""
需要重建缓存的时候加上该锁避免重复计算
"""
def __init__(self, key):
name = f'compute:{key}'
super().__init__(name=name)
class CacheFieldBase:
field_type = str
@ -25,7 +36,7 @@ class IntegerField(CacheFieldBase):
field_type = int
class CacheBase(type):
class CacheType(type):
def __new__(cls, name, bases, attrs: dict):
to_update = {}
field_desc_mapper = {}
@ -41,12 +52,31 @@ class CacheBase(type):
return type.__new__(cls, name, bases, attrs)
class Cache(metaclass=CacheBase):
class Cache(metaclass=CacheType):
field_desc_mapper: dict
timeout = None
def __init__(self):
self._data = None
self.redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
def __getitem__(self, item):
return self.field_desc_mapper[item]
def __contains__(self, item):
return item in self.field_desc_mapper
def get_field(self, name):
return self.field_desc_mapper[name]
@property
def fields(self):
return self.field_desc_mapper.values()
@property
def field_names(self):
names = self.field_desc_mapper.keys()
return names
@lazyproperty
def key_suffix(self):
@ -64,81 +94,75 @@ class Cache(metaclass=CacheBase):
@property
def data(self):
if self._data is None:
data = self.get_data()
if data is None:
# 缓存中没有数据时,去数据库获取
self.compute_and_set_all_data()
data = self.load_data_from_db()
if not data:
with ComputeLock(self.key):
data = self.load_data_from_db()
if not data:
# 缓存中没有数据时,去数据库获取
self.init_all_values()
return self._data
def get_data(self) -> dict:
data = cache.get(self.key)
logger.debug(f'CACHE: get {self.key} = {data}')
if data is not None:
data = json.loads(data)
def to_internal_value(self, data: dict):
internal_data = {}
for k, v in data.items():
field = k.decode()
if field in self:
value = self[field].to_internal_value(v.decode())
internal_data[field] = value
else:
logger.warn(f'Cache got invalid field: '
f'key={self.key} '
f'invalid_field={field} '
f'valid_fields={self.field_names}')
return internal_data
def load_data_from_db(self) -> dict:
data = self.redis.hgetall(self.key)
logger.debug(f'Get data from cache: key={self.key} data={data}')
if data:
data = self.to_internal_value(data)
self._data = data
return data
def set_data(self, data):
self._data = data
to_json = json.dumps(data)
logger.info(f'CACHE: set {self.key} = {to_json}, timeout={self.timeout}')
cache.set(self.key, to_json, timeout=self.timeout)
def save_data_to_db(self, data):
logger.info(f'Set data to cache: key={self.key} data={data}')
self.redis.hset(self.key, mapping=data)
self.load_data_from_db()
def compute_values(self, *fields):
field_objs = []
for field in fields:
field_objs.append(self[field])
def compute_data(self, *fields):
field_descs = []
if not fields:
field_descs = self.field_desc_mapper.values()
else:
for field in fields:
assert field in self.field_desc_mapper, f'{field} is not a valid field'
field_descs.append(self.field_desc_mapper[field])
data = {
field_desc.field_name: field_desc.compute_value(self)
for field_desc in field_descs
field_obj.field_name: field_obj.compute_value(self)
for field_obj in field_objs
}
return data
def compute_and_set_all_data(self, computed_data: dict = None):
"""
TODO 怎样防止并发更新全部数据浪费数据库资源
"""
uncomputed_keys = ()
if computed_data:
computed_keys = computed_data.keys()
all_keys = self.field_desc_mapper.keys()
uncomputed_keys = all_keys - computed_keys
else:
computed_data = {}
data = self.compute_data(*uncomputed_keys)
data.update(computed_data)
self.set_data(data)
def init_all_values(self):
t_start = time.time()
logger.info(f'Start init cache: key={self.key}')
data = self.compute_values(*self.field_names)
self.save_data_to_db(data)
logger.info(f'End init cache: cost={time.time()-t_start} key={self.key}')
return data
def refresh_part_data_with_lock(self, refresh_data):
with DistributedLock(name=f'{self.key}.refresh'):
data = self.get_data()
if data is not None:
data.update(refresh_data)
self.set_data(data)
return data
def refresh(self, *fields):
if not fields:
# 没有指定 field 要刷新所有的值
self.compute_and_set_all_data()
self.init_all_values()
return
data = self.get_data()
if data is None:
data = self.load_data_from_db()
if not data:
# 缓存中没有数据,设置所有的值
self.compute_and_set_all_data()
self.init_all_values()
return
refresh_data = self.compute_data(*fields)
if not self.refresh_part_data_with_lock(refresh_data):
# 刷新部分失败,缓存中没有数据,更新所有的值
self.compute_and_set_all_data(refresh_data)
return
refresh_values = self.compute_values(*fields)
self.save_data_to_db(refresh_values)
def get_key_suffix(self):
raise NotImplementedError
@ -146,10 +170,14 @@ class Cache(metaclass=CacheBase):
def reload(self):
self._data = None
def delete(self):
def expire(self, *fields):
self._data = None
logger.info(f'CACHE: delete {self.key}')
cache.delete(self.key)
if not fields:
logger.info(f'Delete cached key: key={self.key}')
self.redis.delete(self.key)
else:
self.redis.hdel(self.key, *fields)
logger.info(f'Expire cached fields: key={self.key} fields={fields}')
class CacheValueDesc:
@ -167,10 +195,13 @@ class CacheValueDesc:
return self
if self.field_name not in instance.data:
instance.refresh(self.field_name)
value = instance.data[self.field_name]
# 防止边界情况没有值,报错
value = instance.data.get(self.field_name)
return value
def compute_value(self, instance: Cache):
t_start = time.time()
logger.info(f'Start compute cache field: field={self.field_name} key={instance.key}')
if self.field_type.queryset is not None:
new_value = self.field_type.queryset.count()
else:
@ -183,5 +214,8 @@ class CacheValueDesc:
new_value = compute_func()
new_value = self.field_type.field_type(new_value)
logger.info(f'CACHE: compute {instance.key}.{self.field_name} = {new_value}')
logger.info(f'End compute cache field: cost={time.time()-t_start} field={self.field_name} value={new_value} key={instance.key}')
return new_value
def to_internal_value(self, value):
return self.field_type.field_type(value)

View File

@ -7,7 +7,7 @@ create_success_msg = _("%(name)s was created successfully")
update_success_msg = _("%(name)s was updated successfully")
FILE_END_GUARD = ">>> Content End <<<"
celery_task_pre_key = "CELERY_"
KEY_CACHE_RESOURCES_ID = "RESOURCES_ID_{}"
KEY_CACHE_RESOURCE_IDS = "RESOURCE_IDS_{}"
# AD User AccountDisable
# https://blog.csdn.net/bytxl/article/details/17763975

View File

@ -1,2 +0,0 @@
UPDATE_NODE_TREE_LOCK_KEY = 'org_level_transaction_lock_{org_id}_assets_update_node_tree'
UPDATE_MAPPING_NODE_TASK_LOCK_KEY = 'org_level_transaction_lock_{user_id}_update_mapping_node_task'

View File

@ -10,8 +10,11 @@
"""
import uuid
from functools import reduce, partial
import inspect
from django.db.models import *
from django.db.models import QuerySet
from django.db.models.functions import Concat
from django.utils.translation import ugettext_lazy as _
@ -82,3 +85,88 @@ class JMSModel(JMSBaseModel):
def concated_display(name1, name2):
return Concat(F(name1), Value('('), F(name2), Value(')'))
def output_as_string(field_name):
return ExpressionWrapper(F(field_name), output_field=CharField())
class UnionQuerySet(QuerySet):
after_union = ['order_by']
not_return_qs = [
'query', 'get', 'create', 'get_or_create',
'update_or_create', 'bulk_create', 'count',
'latest', 'earliest', 'first', 'last', 'aggregate',
'exists', 'update', 'delete', 'as_manager', 'explain',
]
def __init__(self, *queryset_list):
self.queryset_list = queryset_list
self.after_union_items = []
self.before_union_items = []
def __execute(self):
queryset_list = []
for qs in self.queryset_list:
for attr, args, kwargs in self.before_union_items:
qs = getattr(qs, attr)(*args, **kwargs)
queryset_list.append(qs)
union_qs = reduce(lambda x, y: x.union(y), queryset_list)
for attr, args, kwargs in self.after_union_items:
union_qs = getattr(union_qs, attr)(*args, **kwargs)
return union_qs
def __before_union_perform(self, item, *args, **kwargs):
self.before_union_items.append((item, args, kwargs))
return self.__clone(*self.queryset_list)
def __after_union_perform(self, item, *args, **kwargs):
self.after_union_items.append((item, args, kwargs))
return self.__clone(*self.queryset_list)
def __clone(self, *queryset_list):
uqs = UnionQuerySet(*queryset_list)
uqs.after_union_items = self.after_union_items
uqs.before_union_items = self.before_union_items
return uqs
def __getattribute__(self, item):
if item.startswith('__') or item in UnionQuerySet.__dict__ or item in [
'queryset_list', 'after_union_items', 'before_union_items'
]:
return object.__getattribute__(self, item)
if item in UnionQuerySet.not_return_qs:
return getattr(self.__execute(), item)
origin_item = object.__getattribute__(self, 'queryset_list')[0]
origin_attr = getattr(origin_item, item, None)
if not inspect.ismethod(origin_attr):
return getattr(self.__execute(), item)
if item in UnionQuerySet.after_union:
attr = partial(self.__after_union_perform, item)
else:
attr = partial(self.__before_union_perform, item)
return attr
def __getitem__(self, item):
return self.__execute()[item]
def __iter__(self):
return iter(self.__execute())
def __str__(self):
return str(self.__execute())
def __repr__(self):
return repr(self.__execute())
@classmethod
def test_it(cls):
from assets.models import Asset
assets1 = Asset.objects.filter(hostname__startswith='a')
assets2 = Asset.objects.filter(hostname__startswith='b')
qs = cls(assets1, assets2)
return qs

View File

@ -3,39 +3,35 @@ from rest_framework_bulk import BulkModelViewSet
from ..mixins.api import (
SerializerMixin2, QuerySetMixin, ExtraFilterFieldsMixin, PaginatedResponseMixin,
RelationMixin, AllowBulkDestoryMixin
RelationMixin, AllowBulkDestoryMixin, RenderToJsonMixin,
)
class JmsGenericViewSet(SerializerMixin2,
QuerySetMixin,
ExtraFilterFieldsMixin,
PaginatedResponseMixin,
class CommonMixin(SerializerMixin2,
QuerySetMixin,
ExtraFilterFieldsMixin,
PaginatedResponseMixin,
RenderToJsonMixin):
pass
class JmsGenericViewSet(CommonMixin,
GenericViewSet):
pass
class JMSModelViewSet(SerializerMixin2,
QuerySetMixin,
ExtraFilterFieldsMixin,
PaginatedResponseMixin,
class JMSModelViewSet(CommonMixin,
ModelViewSet):
pass
class JMSBulkModelViewSet(SerializerMixin2,
QuerySetMixin,
ExtraFilterFieldsMixin,
PaginatedResponseMixin,
class JMSBulkModelViewSet(CommonMixin,
AllowBulkDestoryMixin,
BulkModelViewSet):
pass
class JMSBulkRelationModelViewSet(SerializerMixin2,
QuerySetMixin,
ExtraFilterFieldsMixin,
PaginatedResponseMixin,
class JMSBulkRelationModelViewSet(CommonMixin,
RelationMixin,
AllowBulkDestoryMixin,
BulkModelViewSet):

View File

@ -19,7 +19,10 @@ def extract_object_name(exc, index=0):
`No User matches the given query.`
提取 `User``index=1`
"""
(msg, *_) = exc.args
if exc.args:
(msg, *others) = exc.args
else:
return gettext('Object')
return gettext(msg.split(sep=' ', maxsplit=index + 1)[index])

View File

@ -6,11 +6,25 @@ from rest_framework.serializers import ValidationError
from rest_framework.compat import coreapi, coreschema
from django.core.cache import cache
from django.core.exceptions import ImproperlyConfigured
from django_filters import rest_framework as drf_filters
import logging
from common import const
__all__ = ["DatetimeRangeFilter", "IDSpmFilter", 'IDInFilter', "CustomFilter"]
__all__ = [
"DatetimeRangeFilter", "IDSpmFilter", 'IDInFilter', "CustomFilter",
"BaseFilterSet"
]
class BaseFilterSet(drf_filters.FilterSet):
def do_nothing(self, queryset, name, value):
return queryset
def get_query_param(self, k, default=None):
if k in self.form.data:
return self.form.cleaned_data[k]
return default
class DatetimeRangeFilter(filters.BaseFilterBackend):
@ -94,11 +108,11 @@ class IDSpmFilter(filters.BaseFilterBackend):
spm = request.query_params.get('spm')
if not spm:
return queryset
cache_key = const.KEY_CACHE_RESOURCES_ID.format(spm)
resources_id = cache.get(cache_key)
if resources_id is None or not isinstance(resources_id, list):
cache_key = const.KEY_CACHE_RESOURCE_IDS.format(spm)
resource_ids = cache.get(cache_key)
if resource_ids is None or not isinstance(resource_ids, list):
return queryset
queryset = queryset.filter(id__in=resources_id)
queryset = queryset.filter(id__in=resource_ids)
return queryset

View File

@ -3,6 +3,7 @@
from __future__ import unicode_literals
from collections import OrderedDict
import datetime
from django.core.exceptions import PermissionDenied
from django.http import Http404
@ -21,7 +22,7 @@ class SimpleMetadataWithFilters(SimpleMetadata):
attrs = [
'read_only', 'label', 'help_text',
'min_length', 'max_length',
'min_value', 'max_value', "write_only"
'min_value', 'max_value', "write_only",
]
def determine_actions(self, request, view):
@ -59,9 +60,10 @@ class SimpleMetadataWithFilters(SimpleMetadata):
field_info['type'] = self.label_lookup[field]
field_info['required'] = getattr(field, 'required', False)
default = getattr(field, 'default', False)
if default and isinstance(default, (str, int)):
field_info['default'] = default
default = getattr(field, 'default', None)
if default is not None and default != empty:
if isinstance(default, (str, int, bool, datetime.datetime, list)):
field_info['default'] = default
for attr in self.attrs:
value = getattr(field, attr, None)
@ -95,6 +97,8 @@ class SimpleMetadataWithFilters(SimpleMetadata):
fields = view.filterset_fields
elif hasattr(view, 'get_filterset_fields'):
fields = view.get_filterset_fields(request)
elif hasattr(view, 'filterset_class'):
fields = view.filterset_class.Meta.fields
if isinstance(fields, dict):
fields = list(fields.keys())

View File

@ -22,6 +22,7 @@ class BaseFileParser(BaseParser):
FILE_CONTENT_MAX_LENGTH = 1024 * 1024 * 10
serializer_cls = None
serializer_fields = None
def check_content_length(self, meta):
content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)))
@ -45,7 +46,7 @@ class BaseFileParser(BaseParser):
def convert_to_field_names(self, column_titles):
fields_map = {}
fields = self.serializer_cls().fields
fields = self.serializer_fields
fields_map.update({v.label: k for k, v in fields.items()})
fields_map.update({k: k for k, _ in fields.items()})
field_names = [
@ -89,7 +90,7 @@ class BaseFileParser(BaseParser):
构建json数据后的行数据处理
"""
new_row_data = {}
serializer_fields = self.serializer_cls().fields
serializer_fields = self.serializer_fields
for k, v in row_data.items():
if isinstance(v, list) or isinstance(v, dict) or isinstance(v, str) and k.strip() and v.strip():
# 解决类似disk_info为字符串的'{}'的问题
@ -111,12 +112,15 @@ class BaseFileParser(BaseParser):
return data
def parse(self, stream, media_type=None, parser_context=None):
parser_context = parser_context or {}
assert parser_context is not None, '`parser_context` should not be `None`'
view = parser_context['view']
request = view.request
try:
view = parser_context['view']
meta = view.request.META
meta = request.META
self.serializer_cls = view.get_serializer_class()
self.serializer_fields = self.serializer_cls().fields
except Exception as e:
logger.debug(e, exc_info=True)
raise ParseError('The resource does not support imports!')
@ -128,6 +132,13 @@ class BaseFileParser(BaseParser):
rows = self.generate_rows(stream_data)
column_titles = self.get_column_titles(rows)
field_names = self.convert_to_field_names(column_titles)
# 给 `common.mixins.api.RenderToJsonMixin` 提供,暂时只能耦合
column_title_field_pairs = list(zip(column_titles, field_names))
if not hasattr(request, 'jms_context'):
request.jms_context = {}
request.jms_context['column_title_field_pairs'] = column_title_field_pairs
data = self.generate_data(field_names, rows)
return data
except Exception as e:

View File

@ -1,40 +1,7 @@
# -*- coding: utf-8 -*-
#
from jumpserver.const import DYNAMIC
from werkzeug.local import Local, LocalProxy
from werkzeug.local import Local
thread_local = Local()
def _find(attr):
return getattr(thread_local, attr, None)
class _Settings:
pass
def get_dynamic_cfg_from_thread_local():
KEY = 'dynamic_config'
try:
cfg = getattr(thread_local, KEY)
except AttributeError:
cfg = _Settings()
setattr(thread_local, KEY, cfg)
return cfg
class DynamicDefaultLocalProxy(LocalProxy):
def __getattr__(self, item):
try:
value = super().__getattr__(item)
except AttributeError:
value = getattr(DYNAMIC, item)()
setattr(self, item, value)
return value
LOCAL_DYNAMIC_SETTINGS = DynamicDefaultLocalProxy(get_dynamic_cfg_from_thread_local)

View File

@ -11,13 +11,16 @@ from django.core.cache import cache
from django.http import JsonResponse
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.decorators import action
from rest_framework.request import Request
from common.const.http import POST
from common.drf.filters import IDSpmFilter, CustomFilter, IDInFilter
from ..utils import lazyproperty
__all__ = [
'JSONResponseMixin', 'CommonApiMixin', 'AsyncApiMixin', 'RelationMixin',
'SerializerMixin2', 'QuerySetMixin', 'ExtraFilterFieldsMixin'
'SerializerMixin2', 'QuerySetMixin', 'ExtraFilterFieldsMixin', 'RenderToJsonMixin',
]
@ -32,6 +35,21 @@ class JSONResponseMixin(object):
# ----------------------
class RenderToJsonMixin:
@action(methods=[POST], detail=False, url_path='render-to-json')
def render_to_json(self, request: Request):
data = {
'title': (),
'data': request.data,
}
jms_context = getattr(request, 'jms_context', {})
column_title_field_pairs = jms_context.get('column_title_field_pairs', ())
data['title'] = column_title_field_pairs
return Response(data=data)
class SerializerMixin:
""" 根据用户请求动作的不同,获取不同的 `serializer_class `"""
@ -98,7 +116,7 @@ class PaginatedResponseMixin:
return Response(serializer.data)
class CommonApiMixin(SerializerMixin, ExtraFilterFieldsMixin):
class CommonApiMixin(SerializerMixin, ExtraFilterFieldsMixin, RenderToJsonMixin):
pass

View File

@ -2,7 +2,7 @@
#
from collections import Iterable
from django.db.models import Prefetch, F
from django.db.models import Prefetch, F, NOT_PROVIDED
from django.core.exceptions import ObjectDoesNotExist
from rest_framework.utils import html
from rest_framework.settings import api_settings
@ -71,7 +71,7 @@ class BulkListSerializerMixin(object):
"""
List of dicts of native values <- List of dicts of primitive datatypes.
"""
if not self.instance:
if self.instance is None:
return super().to_internal_value(data)
if html.is_html_input(data):
@ -106,7 +106,7 @@ class BulkListSerializerMixin(object):
pk = item["pk"]
else:
raise ValidationError("id or pk not in data")
child = self.instance.get(id=pk) if self.instance else None
child = self.instance.get(id=pk)
self.child.instance = child
self.child.initial_data = item
# raw
@ -228,7 +228,43 @@ class SizedModelFieldsMixin(BaseDynamicFieldsPlugin):
return fields_to_drop
class DefaultValueFieldsMixin:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.set_fields_default_value()
def set_fields_default_value(self):
if not hasattr(self, 'Meta'):
return
if not hasattr(self.Meta, 'model'):
return
model = self.Meta.model
for name, serializer_field in self.fields.items():
if serializer_field.default != empty or serializer_field.required:
continue
model_field = getattr(model, name, None)
if model_field is None:
continue
if not hasattr(model_field, 'field') \
or not hasattr(model_field.field, 'default') \
or model_field.field.default == NOT_PROVIDED:
continue
if name == 'id':
continue
default = model_field.field.default
if callable(default):
default = default()
if default == '':
continue
# print(f"Set default value: {name}: {default}")
serializer_field.default = default
class DynamicFieldsMixin:
"""
可以控制显示不同的字段mini 最少small 不包含关系
"""
dynamic_fields_plugins = [QueryFieldsMixin, SizedModelFieldsMixin]
def __init__(self, *args, **kwargs):
@ -256,7 +292,7 @@ class EagerLoadQuerySetFields:
return queryset
class CommonSerializerMixin(DynamicFieldsMixin):
class CommonSerializerMixin(DynamicFieldsMixin, DefaultValueFieldsMixin):
pass

View File

@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
#
import time
from rest_framework import permissions
from django.contrib.auth.mixins import UserPassesTestMixin
from django.conf import settings
@ -97,7 +96,7 @@ class WithBootstrapToken(permissions.BasePermission):
class PermissionsMixin(UserPassesTestMixin):
permission_classes = []
permission_classes = [permissions.IsAuthenticated]
def get_permissions(self):
return self.permission_classes
@ -110,12 +109,17 @@ class PermissionsMixin(UserPassesTestMixin):
return True
class UserCanUpdatePassword:
class UserCanUseCurrentOrg(permissions.BasePermission):
def has_permission(self, request, view):
return current_org.can_use_by(request.user)
class UserCanUpdatePassword(permissions.BasePermission):
def has_permission(self, request, view):
return request.user.can_update_password()
class UserCanUpdateSSHKey:
class UserCanUpdateSSHKey(permissions.BasePermission):
def has_permission(self, request, view):
return request.user.can_update_ssh_key()
@ -188,3 +192,12 @@ class IsObjectOwner(IsValidUser):
def has_object_permission(self, request, view, obj):
return (super().has_object_permission(request, view, obj) and
request.user == getattr(obj, 'user', None))
class HasQueryParamsUserAndIsCurrentOrgMember(permissions.BasePermission):
def has_permission(self, request, view):
query_user_id = request.query_params.get('user')
if not query_user_id:
return False
query_user = current_org.get_members().filter(id=query_user_id).first()
return bool(query_user)

View File

@ -5,16 +5,12 @@ import os
import logging
from collections import defaultdict
from django.conf import settings
from django.dispatch import receiver
from django.core.signals import request_finished
from django.db import connection
from django.conf import LazySettings
from django.db.utils import ProgrammingError, OperationalError
from jumpserver.utils import get_current_request
from .local import thread_local
from .signals import django_ready
pattern = re.compile(r'FROM `(\w+)`')
logger = logging.getLogger("jumpserver.common")
@ -74,17 +70,3 @@ if settings.DEBUG and DEBUG_DB:
request_finished.connect(on_request_finished_logging_db_query)
else:
request_finished.connect(on_request_finished_release_local)
@receiver(django_ready)
def monkey_patch_settings(sender, **kwargs):
def monkey_patch_getattr(self, name):
val = getattr(self._wrapped, name)
if callable(val):
val = val()
return val
try:
LazySettings.__getattr__ = monkey_patch_getattr
except (ProgrammingError, OperationalError):
pass

View File

@ -7,3 +7,4 @@ from .encode import *
from .http import *
from .ipip import *
from .crypto import *
from .random import *

View File

@ -1,18 +1,17 @@
# -*- coding: utf-8 -*-
#
import re
import data_tree
from collections import OrderedDict
from itertools import chain
import logging
import datetime
import uuid
from functools import wraps
import string
import random
import time
import ipaddress
import psutil
from django.utils.translation import ugettext_lazy as _
from ..exceptions import JMSException
UUID_PATTERN = re.compile(r'\w{8}(-\w{4}){3}-\w{12}')
@ -143,7 +142,7 @@ def is_uuid(seq):
elif isinstance(seq, str) and UUID_PATTERN.match(seq):
return True
elif isinstance(seq, (list, tuple)):
all([is_uuid(x) for x in seq])
return all([is_uuid(x) for x in seq])
return False
@ -194,23 +193,17 @@ def with_cache(func):
return wrapper
def random_string(length):
import string
import random
charset = string.ascii_letters + string.digits
s = [random.choice(charset) for i in range(length)]
return ''.join(s)
logger = get_logger(__name__)
def timeit(func):
def wrapper(*args, **kwargs):
if hasattr(func, '__name__'):
name = func.__name__
else:
name = func
name = func
for attr in ('__qualname__', '__name__'):
if hasattr(func, attr):
name = getattr(func, attr)
break
logger.debug("Start call: {}".format(name))
now = time.time()
result = func(*args, **kwargs)
@ -254,3 +247,29 @@ def get_disk_usage():
mount_points = [p.mountpoint for p in partitions]
usages = {p: psutil.disk_usage(p) for p in mount_points}
return usages
class Time:
def __init__(self):
self._timestamps = []
self._msgs = []
def begin(self):
self._timestamps.append(time.time())
def time(self, msg):
self._timestamps.append(time.time())
self._msgs.append(msg)
def print(self):
last, *timestamps = self._timestamps
for timestamp, msg in zip(timestamps, self._msgs):
logger.debug(f'TIME_IT: {msg} {timestamp-last}')
last = timestamp
def bulk_get(d, *keys, default=None):
values = []
for key in keys:
values.append(d.get(key, default))
return values

View File

@ -0,0 +1,27 @@
import redis
from django.conf import settings
def get_redis_client(db):
rc = redis.StrictRedis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
password=settings.REDIS_PASSWORD,
db=db
)
return rc
class RedisPubSub:
def __init__(self, ch, db=10):
self.ch = ch
self.redis = get_redis_client(db)
def subscribe(self):
ps = self.redis.pubsub()
ps.subscribe(self.ch)
return ps
def publish(self, data):
self.redis.publish(self.ch, data)
return True

View File

@ -1,7 +1,7 @@
import base64
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from Crypto.Random import get_random_bytes
from Cryptodome.Cipher import AES
from Cryptodome.Util.Padding import pad
from Cryptodome.Random import get_random_bytes
from gmssl.sm4 import CryptSM4, SM4_ENCRYPT, SM4_DECRYPT
from django.conf import settings

View File

@ -1,12 +1,14 @@
from functools import wraps
import threading
from redis_lock import Lock as RedisLock
from redis_lock import Lock as RedisLock, NotAcquired
from redis import Redis
from django.db import transaction
from common.utils import get_logger
from common.utils.inspect import copy_function_args
from apps.jumpserver.const import CONFIG
from jumpserver.const import CONFIG
from common.local import thread_local
logger = get_logger(__file__)
@ -15,37 +17,49 @@ class AcquireFailed(RuntimeError):
pass
class LockHasTimeOut(RuntimeError):
pass
class DistributedLock(RedisLock):
def __init__(self, name, blocking=True, expire=60*2, auto_renewal=True):
def __init__(self, name, *, expire=None, release_on_transaction_commit=False,
reentrant=False, release_raise_exc=False, auto_renewal_seconds=60):
"""
使用 redis 构造的分布式锁
:param name:
锁的名字要全局唯一
:param blocking:
该参数只在锁作为装饰器或者 `with` 时有效
:param expire:
锁的过期时间注意不一定是锁到这个时间就释放了分两种情况
`auto_renewal=False` 锁会释放
`auto_renewal=True` 如果过期之前程序还没释放锁我们会延长锁的存活时间
这里的作用是防止程序意外终止没有释放锁导致死锁
锁的过期时间
:param release_on_transaction_commit:
是否在当前事务结束后再释放锁
:param release_raise_exc:
释放锁时如果没有持有锁是否抛异常或静默
:param auto_renewal_seconds:
当持有一个无限期锁的时候刷新锁的时间具体参考 `redis_lock.Lock#auto_renewal`
:param reentrant:
是否可重入
"""
self.kwargs_copy = copy_function_args(self.__init__, locals())
redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
if expire is None:
expire = auto_renewal_seconds
auto_renewal = True
else:
auto_renewal = False
super().__init__(redis_client=redis, name=name, expire=expire, auto_renewal=auto_renewal)
self._blocking = blocking
self._release_on_transaction_commit = release_on_transaction_commit
self._release_raise_exc = release_raise_exc
self._reentrant = reentrant
self._acquired_reentrant_lock = False
self._thread_id = threading.current_thread().ident
def __enter__(self):
thread_id = threading.current_thread().ident
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> attempt to acquire <lock:{self._name}> ...')
acquired = self.acquire(blocking=self._blocking)
if self._blocking and not acquired:
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> was not acquired <lock:{self._name}>, but blocking=True')
raise EnvironmentError("Lock wasn't acquired, but blocking=True")
acquired = self.acquire(blocking=True)
if not acquired:
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> acquire <lock:{self._name}> failed')
raise AcquireFailed
logger.debug(f'DISTRIBUTED_LOCK: <thread_id:{thread_id}> acquire <lock:{self._name}> ok')
return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
@ -57,5 +71,114 @@ class DistributedLock(RedisLock):
# 要创建一个新的锁对象
with self.__class__(**self.kwargs_copy):
return func(*args, **kwds)
return inner
def locked_by_me(self):
if self.locked():
if self.get_owner_id() == self.id:
return True
return False
def locked_by_current_thread(self):
if self.locked():
owner_id = self.get_owner_id()
local_owner_id = getattr(thread_local, self.name, None)
if local_owner_id and owner_id == local_owner_id:
return True
return False
def acquire(self, blocking=True, timeout=None):
if self._reentrant:
if self.locked_by_current_thread():
self._acquired_reentrant_lock = True
logger.debug(
f'Reentry lock ok: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name} thread={self._thread_id}')
return True
logger.debug(f'Attempt acquire reentrant-lock: lock_id={self.id} lock={self.name} thread={self._thread_id}')
acquired = super().acquire(blocking=blocking, timeout=timeout)
if acquired:
logger.debug(f'Acquired reentrant-lock ok: lock_id={self.id} lock={self.name} thread={self._thread_id}')
setattr(thread_local, self.name, self.id)
else:
logger.debug(f'Acquired reentrant-lock failed: lock_id={self.id} lock={self.name} thread={self._thread_id}')
return acquired
else:
logger.debug(f'Attempt acquire lock: lock_id={self.id} lock={self.name} thread={self._thread_id}')
acquired = super().acquire(blocking=blocking, timeout=timeout)
logger.debug(f'Acquired lock: ok={acquired} lock_id={self.id} lock={self.name} thread={self._thread_id}')
return acquired
@property
def name(self):
return self._name
def _raise_exc_with_log(self, msg, *, exc_cls=NotAcquired):
e = exc_cls(msg)
logger.error(msg)
self._raise_exc(e)
def _raise_exc(self, e):
if self._release_raise_exc:
raise e
def _release_on_reentrant_locked_by_brother(self):
if self._acquired_reentrant_lock:
self._acquired_reentrant_lock = False
logger.debug(f'Released reentrant-lock: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name} thread={self._thread_id}')
return
else:
self._raise_exc_with_log(f'Reentrant-lock is not acquired: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name} thread={self._thread_id}')
def _release_on_reentrant_locked_by_me(self):
logger.debug(f'Release reentrant-lock locked by me: lock_id={self.id} lock={self.name} thread={self._thread_id}')
id = getattr(thread_local, self.name, None)
if id != self.id:
raise PermissionError(f'Reentrant-lock is not locked by me: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name} thread={self._thread_id}')
try:
# 这里要保证先删除 thread_local 的标记,
delattr(thread_local, self.name)
except AttributeError:
pass
finally:
try:
# 这里处理的是边界情况,
# 判断锁是我的 -> 锁超时 -> 释放锁报错
# 此时的报错应该被静默
self._release_redis_lock()
except NotAcquired:
pass
def _release_redis_lock(self):
# 最底层 api
super().release()
def _release(self):
try:
self._release_redis_lock()
logger.debug(f'Released lock: lock_id={self.id} lock={self.name} thread={self._thread_id}')
except NotAcquired as e:
logger.error(f'Release lock failed: lock_id={self.id} lock={self.name} thread={self._thread_id} error: {e}')
self._raise_exc(e)
def release(self):
_release = self._release
# 处理可重入锁
if self._reentrant:
if self.locked_by_current_thread():
if self.locked_by_me():
_release = self._release_on_reentrant_locked_by_me
else:
_release = self._release_on_reentrant_locked_by_brother
else:
self._raise_exc_with_log(f'Reentrant-lock is not acquired: lock_id={self.id} lock={self.name} thread={self._thread_id}')
# 处理是否在事务提交时才释放锁
if self._release_on_transaction_commit:
logger.debug(f'Release lock on transaction commit ... :lock_id={self.id} lock={self.name} thread={self._thread_id}')
transaction.on_commit(_release)
else:
_release()

View File

@ -1,8 +1,13 @@
# -*- coding: utf-8 -*-
#
import socket
import struct
import random
import socket
import string
import secrets
string_punctuation = '!#$%&()*+,-.:;<=>?@[]^_{}~'
def random_datetime(date_start, date_end):
@ -14,6 +19,29 @@ def random_ip():
return socket.inet_ntoa(struct.pack('>I', random.randint(1, 0xffffffff)))
def random_string(length, lower=True, upper=True, digit=True, special_char=False):
chars = string.ascii_letters
if digit:
chars += string.digits
while True:
password = list(random.choice(chars) for i in range(length))
if upper and not any(c.upper() for c in password):
continue
if lower and not any(c.lower() for c in password):
continue
if digit and not any(c.isdigit() for c in password):
continue
break
if special_char:
spc = random.choice(string_punctuation)
i = random.choice(range(len(password)))
password[i] = spc
password = ''.join(password)
return password
# def strTimeProp(start, end, prop, fmt):
# time_start = time.mktime(time.strptime(start, fmt))

View File

@ -4,6 +4,7 @@ from django.utils.timesince import timesince
from django.db.models import Count, Max
from django.http.response import JsonResponse, HttpResponse
from rest_framework.views import APIView
from rest_framework.permissions import AllowAny
from collections import Counter
from users.models import User
@ -307,7 +308,7 @@ class IndexApi(TotalCountMixin, DatesLoginMetricMixin, APIView):
class PrometheusMetricsApi(APIView):
permission_classes = ()
permission_classes = (AllowAny,)
def get(self, request, *args, **kwargs):
util = ComponentsPrometheusMetricsUtil()

View File

@ -280,7 +280,14 @@ class Config(dict):
'SESSION_COOKIE_SECURE': False,
'CSRF_COOKIE_SECURE': False,
'REFERER_CHECK_ENABLED': False,
'SERVER_REPLAY_STORAGE': {}
'SERVER_REPLAY_STORAGE': {},
'CONNECTION_TOKEN_ENABLED': False,
'ONLY_ALLOW_EXIST_USER_AUTH': False,
'ONLY_ALLOW_AUTH_FROM_SOURCE': True,
'DISK_CHECK_ENABLED': True,
'SESSION_SAVE_EVERY_REQUEST': True,
'SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE': False,
'FORGOT_PASSWORD_URL': '',
}
def compatible_auth_openid_of_key(self):
@ -426,98 +433,6 @@ class Config(dict):
return self.get(item)
class DynamicConfig:
def __init__(self, static_config):
self.static_config = static_config
self.db_setting = None
def __getitem__(self, item):
return self.dynamic(item)
def __getattr__(self, item):
return self.dynamic(item)
def dynamic(self, item):
return lambda: self.get(item)
def LOGIN_URL(self):
return self.get('LOGIN_URL')
def AUTHENTICATION_BACKENDS(self):
backends = [
'authentication.backends.pubkey.PublicKeyAuthBackend',
'django.contrib.auth.backends.ModelBackend',
]
if self.get('AUTH_LDAP'):
backends.insert(0, 'authentication.backends.ldap.LDAPAuthorizationBackend')
if self.static_config.get('AUTH_CAS'):
backends.insert(0, 'authentication.backends.cas.CASBackend')
if self.static_config.get('AUTH_OPENID'):
backends.insert(0, 'jms_oidc_rp.backends.OIDCAuthPasswordBackend')
backends.insert(0, 'jms_oidc_rp.backends.OIDCAuthCodeBackend')
if self.static_config.get('AUTH_RADIUS'):
backends.insert(0, 'authentication.backends.radius.RadiusBackend')
if self.static_config.get('AUTH_SSO'):
backends.insert(0, 'authentication.backends.api.SSOAuthentication')
return backends
def XPACK_LICENSE_IS_VALID(self):
if not HAS_XPACK:
return False
try:
from xpack.plugins.license.models import License
return License.has_valid_license()
except:
return False
def XPACK_INTERFACE_LOGIN_TITLE(self):
default_title = _('Welcome to the JumpServer open source fortress')
if not HAS_XPACK:
return default_title
try:
from xpack.plugins.interface.models import Interface
return Interface.get_login_title()
except:
return default_title
def LOGO_URLS(self):
logo_urls = {'logo_logout': static('img/logo.png'),
'logo_index': static('img/logo_text.png'),
'login_image': static('img/login_image.png'),
'favicon': static('img/facio.ico')}
if not HAS_XPACK:
return logo_urls
try:
from xpack.plugins.interface.models import Interface
obj = Interface.interface()
if obj:
if obj.logo_logout:
logo_urls.update({'logo_logout': obj.logo_logout.url})
if obj.logo_index:
logo_urls.update({'logo_index': obj.logo_index.url})
if obj.login_image:
logo_urls.update({'login_image': obj.login_image.url})
if obj.favicon:
logo_urls.update({'favicon': obj.favicon.url})
except:
pass
return logo_urls
def get_from_db(self, item):
if self.db_setting is not None:
value = self.db_setting.get(item)
if value is not None:
return value
return None
def get(self, item):
# 先从数据库中获取
value = self.get_from_db(item)
if value is not None:
return value
return self.static_config.get(item)
class ConfigManager:
config_class = Config
@ -694,7 +609,3 @@ class ConfigManager:
# 对config进行兼容处理
config.compatible()
return config
@classmethod
def get_dynamic_config(cls, config):
return DynamicConfig(config)

View File

@ -4,12 +4,11 @@ import os
from .conf import ConfigManager
__all__ = ['BASE_DIR', 'PROJECT_DIR', 'VERSION', 'CONFIG', 'DYNAMIC']
__all__ = ['BASE_DIR', 'PROJECT_DIR', 'VERSION', 'CONFIG']
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
PROJECT_DIR = os.path.dirname(BASE_DIR)
VERSION = '2.0.0'
CONFIG = ConfigManager.load_user_config()
DYNAMIC = ConfigManager.get_dynamic_config(CONFIG)

Some files were not shown because too many files have changed in this diff Show More