Merge pull request #5213 from jumpserver/dev

chore(merge): 合并dev到master
pull/5308/head
老广 2020-12-10 23:03:35 +08:00 committed by GitHub
commit 2fc6e6cd54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
109 changed files with 2512 additions and 1229 deletions

View File

@ -1,5 +1,6 @@
FROM registry.fit2cloud.com/public/python:v3 as stage-build
MAINTAINER Jumpserver Team <ibuler@qq.com>
# 编译代码
FROM python:3.8.6-slim as stage-build
MAINTAINER JumpServer Team <ibuler@qq.com>
ARG VERSION
ENV VERSION=$VERSION
@ -8,33 +9,38 @@ ADD . .
RUN cd utils && bash -ixeu build.sh
FROM registry.fit2cloud.com/public/python:v3
# 构建运行时环境
FROM python:3.8.6-slim
ARG PIP_MIRROR=https://pypi.douban.com/simple
ENV PIP_MIRROR=$PIP_MIRROR
ARG MYSQL_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/mysql/yum/mysql57-community-el6/
ENV MYSQL_MIRROR=$MYSQL_MIRROR
ARG PIP_JMS_MIRROR=https://pypi.douban.com/simple
ENV PIP_JMS_MIRROR=$PIP_JMS_MIRROR
WORKDIR /opt/jumpserver
COPY ./requirements ./requirements
RUN useradd jumpserver
RUN yum -y install epel-release && \
echo -e "[mysql]\nname=mysql\nbaseurl=${MYSQL_MIRROR}\ngpgcheck=0\nenabled=1" > /etc/yum.repos.d/mysql.repo
RUN yum -y install $(cat requirements/rpm_requirements.txt)
RUN pip install --upgrade pip setuptools==49.6.0 wheel -i ${PIP_MIRROR} && \
pip config set global.index-url ${PIP_MIRROR}
RUN pip install $(grep 'jms' requirements/requirements.txt) -i https://pypi.org/simple
RUN pip install -r requirements/requirements.txt
COPY ./requirements/deb_buster_requirements.txt ./requirements/deb_buster_requirements.txt
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list \
&& sed -i 's/security.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list \
&& apt update \
&& grep -v '^#' ./requirements/deb_buster_requirements.txt | xargs apt -y install \
&& localedef -c -f UTF-8 -i zh_CN zh_CN.UTF-8 \
&& cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime
COPY ./requirements/requirements.txt ./requirements/requirements.txt
RUN pip install --upgrade pip==20.2.4 setuptools==49.6.0 wheel==0.34.2 -i ${PIP_MIRROR} \
&& pip config set global.index-url ${PIP_MIRROR} \
&& pip install --no-cache-dir $(grep 'jms' requirements/requirements.txt) -i ${PIP_JMS_MIRROR} \
&& pip install --no-cache-dir -r requirements/requirements.txt
COPY --from=stage-build /opt/jumpserver/release/jumpserver /opt/jumpserver
RUN mkdir -p /root/.ssh/ && echo -e "Host *\n\tStrictHostKeyChecking no\n\tUserKnownHostsFile /dev/null" > /root/.ssh/config
RUN mkdir -p /root/.ssh/ \
&& echo -e "Host *\n\tStrictHostKeyChecking no\n\tUserKnownHostsFile /dev/null" > /root/.ssh/config
RUN echo > config.yml
VOLUME /opt/jumpserver/data
VOLUME /opt/jumpserver/logs
ENV LANG=zh_CN.UTF-8
ENV LC_ALL=zh_CN.UTF-8
EXPOSE 8070
EXPOSE 8080

View File

@ -1,4 +1,5 @@
from common.exceptions import JMSException
from orgs.models import Organization
from .. import models
@ -85,11 +86,46 @@ class SerializeApplicationToTreeNodeMixin:
'meta': {'type': 'k8s_app'}
}
def _serialize(self, application):
def _serialize_application(self, application):
method_name = f'_serialize_{application.category}'
data = getattr(self, method_name)(application)
data.update({
'pId': application.org.id,
'org_name': application.org_name
})
return data
def serialize_applications(self, applications):
data = [self._serialize(application) for application in applications]
data = [self._serialize_application(application) for application in applications]
return data
@staticmethod
def _serialize_organization(org):
return {
'id': org.id,
'name': org.name,
'title': org.name,
'pId': '',
'open': True,
'isParent': True,
'meta': {
'type': 'node'
}
}
def serialize_organizations(self, organizations):
data = [self._serialize_organization(org) for org in organizations]
return data
@staticmethod
def filter_organizations(applications):
organizations_id = set(applications.values_list('org_id', flat=True))
organizations = [Organization.get_instance(org_id) for org_id in organizations_id]
return organizations
def serialize_applications_with_org(self, applications):
organizations = self.filter_organizations(applications)
data_organizations = self.serialize_organizations(organizations)
data_applications = self.serialize_applications(applications)
data = data_organizations + data_applications
return data

View File

@ -0,0 +1,18 @@
# Generated by Django 3.1 on 2020-11-19 03:10
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('applications', '0006_application'),
]
operations = [
migrations.AlterField(
model_name='application',
name='attrs',
field=models.JSONField(),
),
]

View File

@ -2,7 +2,6 @@ from itertools import chain
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django_mysql.models import JSONField, QuerySet
from orgs.mixins.models import OrgModelMixin
from common.mixins import CommonModelMixin
@ -123,7 +122,7 @@ class Application(CommonModelMixin, OrgModelMixin):
domain = models.ForeignKey('assets.Domain', null=True, blank=True, related_name='applications', verbose_name=_("Domain"), on_delete=models.SET_NULL)
category = models.CharField(max_length=16, choices=Category.choices, verbose_name=_('Category'))
type = models.CharField(max_length=16, choices=Category.get_all_type_choices(), verbose_name=_('Type'))
attrs = JSONField()
attrs = models.JSONField()
comment = models.TextField(
max_length=128, default='', blank=True, verbose_name=_('Comment')
)

View File

@ -27,10 +27,8 @@ class ApplicationSerializer(BulkOrgResourceModelSerializer):
]
def create(self, validated_data):
attrs = validated_data.pop('attrs', {})
validated_data['attrs'] = validated_data.pop('attrs', {})
instance = super().create(validated_data)
instance.attrs = attrs
instance.save()
return instance
def update(self, instance, validated_data):

View File

@ -12,9 +12,8 @@ from .. import models
class DBAttrsSerializer(serializers.Serializer):
host = serializers.CharField(max_length=128, label=_('Host'))
port = serializers.IntegerField(label=_('Port'))
database = serializers.CharField(
max_length=128, required=False, allow_blank=True, allow_null=True, label=_('Database')
)
# 添加allow_null=True兼容之前数据库中database字段为None的情况
database = serializers.CharField(max_length=128, required=True, allow_null=True, label=_('Database'))
class MySQLAttrsSerializer(DBAttrsSerializer):

View File

@ -69,6 +69,7 @@ class SerializeToTreeNodeMixin:
'ip': asset.ip,
'protocols': asset.protocols_as_list,
'platform': asset.platform_base,
'org_name': asset.org_name
},
}
}

View File

@ -5,11 +5,13 @@ from collections import namedtuple, defaultdict
from rest_framework import status
from rest_framework.serializers import ValidationError
from rest_framework.response import Response
from rest_framework.decorators import action
from django.utils.translation import ugettext_lazy as _
from django.shortcuts import get_object_or_404, Http404
from django.utils.decorators import method_decorator
from django.db.models.signals import m2m_changed
from common.const.http import POST
from common.exceptions import SomeoneIsDoingThis
from common.const.signals import PRE_REMOVE, POST_REMOVE
from assets.models import Asset
@ -19,6 +21,8 @@ from common.const.distributed_lock_key import UPDATE_NODE_TREE_LOCK_KEY
from orgs.mixins.api import OrgModelViewSet
from orgs.mixins import generics
from orgs.lock import org_level_transaction_lock
from orgs.utils import current_org
from assets.tasks import check_node_assets_amount_task
from ..hands import IsOrgAdmin
from ..models import Node
from ..tasks import (
@ -46,6 +50,11 @@ class NodeViewSet(OrgModelViewSet):
permission_classes = (IsOrgAdmin,)
serializer_class = serializers.NodeSerializer
@action(methods=[POST], detail=False, url_name='launch-check-assets-amount-task')
def launch_check_assets_amount_task(self, request):
task = check_node_assets_amount_task.delay(current_org.id)
return Response(data={'task': task.id})
# 仅支持根节点指直接创建子节点下的节点需要通过children接口创建
def perform_create(self, serializer):
child_key = Node.org_root().get_next_child_key()
@ -61,6 +70,9 @@ class NodeViewSet(OrgModelViewSet):
def destroy(self, request, *args, **kwargs):
node = self.get_object()
if node.is_org_root():
error = _("You can't delete the root node ({})".format(node.value))
return Response(data={'error': error}, status=status.HTTP_403_FORBIDDEN)
if node.has_children_or_has_assets():
error = _("Deletion failed and the node contains children or assets")
return Response(data={'error': error}, status=status.HTTP_403_FORBIDDEN)
@ -173,7 +185,7 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
return []
assets = self.instance.get_assets().only(
"id", "hostname", "ip", "os",
"org_id", "protocols",
"org_id", "protocols", "is_active"
)
return self.serialize_assets(assets, self.instance.key)
@ -201,10 +213,8 @@ class NodeAddChildrenApi(generics.UpdateAPIView):
def put(self, request, *args, **kwargs):
instance = self.get_object()
nodes_id = request.data.get("nodes")
children = [get_object_or_none(Node, id=pk) for pk in nodes_id]
children = Node.objects.filter(id__in=nodes_id)
for node in children:
if not node:
continue
node.parent = instance
return Response("OK")

View File

@ -3,7 +3,8 @@ from django.shortcuts import get_object_or_404
from rest_framework.response import Response
from common.utils import get_logger
from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser, IsAppUser
from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser
from common.drf.filters import CustomFilter
from orgs.mixins.api import OrgBulkModelViewSet
from orgs.mixins import generics
from orgs.utils import tmp_to_org
@ -12,7 +13,7 @@ from .. import serializers
from ..serializers import SystemUserWithAuthInfoSerializer
from ..tasks import (
push_system_user_to_assets_manual, test_system_user_connectivity_manual,
push_system_user_a_asset_manual,
push_system_user_to_assets
)
@ -82,18 +83,18 @@ class SystemUserTaskApi(generics.CreateAPIView):
permission_classes = (IsOrgAdmin,)
serializer_class = serializers.SystemUserTaskSerializer
def do_push(self, system_user, asset=None):
if asset is None:
def do_push(self, system_user, assets_id=None):
if assets_id is None:
task = push_system_user_to_assets_manual.delay(system_user)
else:
username = self.request.query_params.get('username')
task = push_system_user_a_asset_manual.delay(
system_user, asset, username=username
task = push_system_user_to_assets.delay(
system_user.id, assets_id, username=username
)
return task
@staticmethod
def do_test(system_user, asset=None):
def do_test(system_user):
task = test_system_user_connectivity_manual.delay(system_user)
return task
@ -104,11 +105,16 @@ class SystemUserTaskApi(generics.CreateAPIView):
def perform_create(self, serializer):
action = serializer.validated_data["action"]
asset = serializer.validated_data.get('asset')
assets = serializer.validated_data.get('assets') or []
system_user = self.get_object()
if action == 'push':
task = self.do_push(system_user, asset)
assets = [asset] if asset else assets
assets_id = [asset.id for asset in assets]
assets_id = assets_id if assets_id else None
task = self.do_push(system_user, assets_id)
else:
task = self.do_test(system_user, asset)
task = self.do_test(system_user)
data = getattr(serializer, '_data', {})
data["task"] = task.id
setattr(serializer, '_data', data)

View File

@ -0,0 +1,72 @@
# Generated by Jiangjie.Bai on 2020-12-01 10:47
from django.db import migrations
from django.db.models import Q
default_node_value = 'Default' # Always
old_default_node_key = '0' # Version <= 1.4.3
new_default_node_key = '1' # Version >= 1.4.4
def compute_parent_key(key):
try:
return key[:key.rindex(':')]
except ValueError:
return ''
def migrate_default_node_key(apps, schema_editor):
""" 将已经存在的Default节点的key从0修改为1 """
# 1.4.3版本中Default节点的key为0
print('')
Node = apps.get_model('assets', 'Node')
Asset = apps.get_model('assets', 'Asset')
# key为0的节点
old_default_node = Node.objects.filter(key=old_default_node_key, value=default_node_value).first()
if not old_default_node:
print(f'Check old default node `key={old_default_node_key} value={default_node_value}` not exists')
return
print(f'Check old default node `key={old_default_node_key} value={default_node_value}` exists')
# key为1的节点
new_default_node = Node.objects.filter(key=new_default_node_key, value=default_node_value).first()
if new_default_node:
print(f'Check new default node `key={new_default_node_key} value={default_node_value}` exists')
all_assets = Asset.objects.filter(
Q(nodes__key__startswith=f'{new_default_node_key}:') | Q(nodes__key=new_default_node_key)
).distinct()
if all_assets:
print(f'Check new default node has assets (count: {len(all_assets)})')
return
all_children = Node.objects.filter(key__startswith=f'{new_default_node_key}:')
if all_children:
print(f'Check new default node has children nodes (count: {len(all_children)})')
return
print(f'Check new default node not has assets and children nodes, delete it.')
new_default_node.delete()
# 执行修改
print(f'Modify old default node `key` from `{old_default_node_key}` to `{new_default_node_key}`')
nodes = Node.objects.filter(
Q(key__istartswith=f'{old_default_node_key}:') | Q(key=old_default_node_key)
)
for node in nodes:
old_key = node.key
key_list = old_key.split(':', maxsplit=1)
key_list[0] = new_default_node_key
new_key = ':'.join(key_list)
node.key = new_key
node.parent_key = compute_parent_key(node.key)
# 批量更新
print(f'Bulk update nodes `key` and `parent_key`, (count: {len(nodes)})')
Node.objects.bulk_update(nodes, ['key', 'parent_key'])
class Migration(migrations.Migration):
dependencies = [
('assets', '0062_auto_20201117_1938'),
]
operations = [
migrations.RunPython(migrate_default_node_key)
]

View File

@ -0,0 +1,17 @@
# Generated by Django 3.1 on 2020-12-03 03:00
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('assets', '0063_migrate_default_node_key'),
]
operations = [
migrations.AlterModelOptions(
name='node',
options={'ordering': ['parent_key', 'value'], 'verbose_name': 'Node'},
),
]

View File

@ -103,7 +103,7 @@ class FamilyMixin:
if value is None:
value = child_key
child = self.__class__.objects.create(
id=_id, key=child_key, value=value, parent_key=self.key,
id=_id, key=child_key, value=value
)
return child
@ -354,7 +354,8 @@ class SomeNodesMixin:
def org_root(cls):
root = cls.objects.filter(parent_key='')\
.filter(key__regex=r'^[0-9]+$')\
.exclude(key__startswith='-')
.exclude(key__startswith='-')\
.order_by('key')
if root:
return root[0]
else:
@ -411,7 +412,7 @@ class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
class Meta:
verbose_name = _("Node")
ordering = ['value']
ordering = ['parent_key', 'value']
def __str__(self):
return self.full_value

View File

@ -98,9 +98,6 @@ class AssetSerializer(BulkOrgResourceModelSerializer):
fields_as = list(annotates_fields.keys())
fields = fields_small + fields_fk + fields_m2m + fields_as
read_only_fields = [
'vendor', 'model', 'sn', 'cpu_model', 'cpu_count',
'cpu_cores', 'cpu_vcpus', 'memory', 'disk_total', 'disk_info',
'os', 'os_version', 'os_arch', 'hostname_raw',
'created_by', 'date_created',
] + fields_as

View File

@ -257,4 +257,8 @@ class SystemUserTaskSerializer(serializers.Serializer):
asset = serializers.PrimaryKeyRelatedField(
queryset=Asset.objects, allow_null=True, required=False, write_only=True
)
assets = serializers.PrimaryKeyRelatedField(
queryset=Asset.objects, allow_null=True, required=False, write_only=True,
many=True
)
task = serializers.CharField(read_only=True)

View File

@ -4,7 +4,7 @@ from operator import add, sub
from assets.utils import is_asset_exists_in_node
from django.db.models.signals import (
post_save, m2m_changed, pre_delete, post_delete
post_save, m2m_changed, pre_delete, post_delete, pre_save
)
from django.db.models import Q, F
from django.dispatch import receiver
@ -37,6 +37,11 @@ def test_asset_conn_on_created(asset):
test_asset_connectivity_util.delay([asset])
@receiver(pre_save, sender=Node)
def on_node_pre_save(sender, instance: Node, **kwargs):
instance.parent_key = instance.compute_parent_key()
@receiver(post_save, sender=Asset)
@on_transaction_commit
def on_asset_created_or_update(sender, instance=None, created=False, **kwargs):
@ -73,6 +78,7 @@ def on_system_user_update(instance: SystemUser, created, **kwargs):
@receiver(m2m_changed, sender=SystemUser.assets.through)
@on_transaction_commit
def on_system_user_assets_change(instance, action, model, pk_set, **kwargs):
"""
当系统用户和资产关系发生变化时应该重新推送系统用户到新添加的资产中
@ -91,25 +97,29 @@ def on_system_user_assets_change(instance, action, model, pk_set, **kwargs):
@receiver(m2m_changed, sender=SystemUser.users.through)
def on_system_user_users_change(sender, instance=None, action='', model=None, pk_set=None, **kwargs):
@on_transaction_commit
def on_system_user_users_change(sender, instance: SystemUser, action, model, pk_set, reverse, **kwargs):
"""
当系统用户和用户关系发生变化时应该重新推送系统用户资产中
"""
if action != POST_ADD:
return
if reverse:
raise M2MReverseNotAllowed
if not instance.username_same_with_user:
return
logger.debug("System user users change signal recv: {}".format(instance))
queryset = model.objects.filter(pk__in=pk_set)
if model == SystemUser:
system_users = queryset
else:
system_users = [instance]
for s in system_users:
push_system_user_to_assets_manual.delay(s)
usernames = model.objects.filter(pk__in=pk_set).values_list('username', flat=True)
for username in usernames:
push_system_user_to_assets_manual.delay(instance, username)
@receiver(m2m_changed, sender=SystemUser.nodes.through)
@on_transaction_commit
def on_system_user_nodes_change(sender, instance=None, action=None, model=None, pk_set=None, **kwargs):
"""
当系统用户和节点关系发生变化时应该将节点下资产关联到新的系统用户上

View File

@ -1,14 +1,27 @@
from celery import shared_task
from django.utils.translation import gettext_lazy as _
from orgs.models import Organization
from orgs.utils import tmp_to_org
from ops.celery.decorator import register_as_period_task
from assets.utils import check_node_assets_amount
from common.utils.lock import AcquireFailed
from common.utils import get_logger
from common.utils.timezone import now
logger = get_logger(__file__)
@shared_task()
def check_node_assets_amount_celery_task():
logger.info(f'>>> {now()} begin check_node_assets_amount_celery_task ...')
check_node_assets_amount()
logger.info(f'>>> {now()} end check_node_assets_amount_celery_task ...')
@shared_task(queue='celery_heavy_tasks')
def check_node_assets_amount_task(org_id=Organization.ROOT_ID):
try:
with tmp_to_org(Organization.get_instance(org_id)):
check_node_assets_amount()
except AcquireFailed:
logger.error(_('The task of self-checking is already running and cannot be started repeatedly'))
@register_as_period_task(crontab='0 2 * * *')
@shared_task(queue='celery_heavy_tasks')
def check_node_assets_amount_period_task():
check_node_assets_amount_task()

View File

@ -2,13 +2,13 @@
from itertools import groupby
from celery import shared_task
from common.db.utils import get_object_if_need, get_objects_if_need, get_objects
from common.db.utils import get_object_if_need, get_objects
from django.utils.translation import ugettext as _
from django.db.models import Empty
from common.utils import encrypt_password, get_logger
from assets.models import SystemUser, Asset
from orgs.utils import org_aware_func
from assets.models import SystemUser, Asset, AuthBook
from orgs.utils import org_aware_func, tmp_to_root_org
from . import const
from .utils import clean_ansible_task_hosts, group_asset_by_platform
@ -190,15 +190,12 @@ def get_push_system_user_tasks(system_user, platform="unixlike", username=None):
@org_aware_func("system_user")
def push_system_user_util(system_user, assets, task_name, username=None):
from ops.utils import update_or_create_ansible_task
hosts = clean_ansible_task_hosts(assets, system_user=system_user)
if not hosts:
assets = clean_ansible_task_hosts(assets, system_user=system_user)
if not assets:
return {}
platform_hosts_map = {}
hosts_sorted = sorted(hosts, key=group_asset_by_platform)
platform_hosts = groupby(hosts_sorted, key=group_asset_by_platform)
for i in platform_hosts:
platform_hosts_map[i[0]] = list(i[1])
assets_sorted = sorted(assets, key=group_asset_by_platform)
platform_hosts = groupby(assets_sorted, key=group_asset_by_platform)
def run_task(_tasks, _hosts):
if not _tasks:
@ -209,27 +206,59 @@ def push_system_user_util(system_user, assets, task_name, username=None):
)
task.run()
for platform, _hosts in platform_hosts_map.items():
if not _hosts:
if system_user.username_same_with_user:
if username is None:
# 动态系统用户,但是没有指定 username
usernames = list(system_user.users.all().values_list('username', flat=True).distinct())
else:
usernames = [username]
else:
# 非动态系统用户指定 username 无效
assert username is None, 'Only Dynamic user can assign `username`'
usernames = [system_user.username]
for platform, _assets in platform_hosts:
_assets = list(_assets)
if not _assets:
continue
print(_("Start push system user for platform: [{}]").format(platform))
print(_("Hosts count: {}").format(len(_hosts)))
print(_("Hosts count: {}").format(len(_assets)))
# 如果没有特殊密码设置,就不需要单独推送某台机器了
if not system_user.has_special_auth(username=username):
logger.debug("System user not has special auth")
tasks = get_push_system_user_tasks(system_user, platform, username=username)
run_task(tasks, _hosts)
continue
id_asset_map = {_asset.id: _asset for _asset in _assets}
assets_id = id_asset_map.keys()
no_special_auth = []
special_auth_set = set()
for _host in _hosts:
system_user.load_asset_special_auth(_host, username=username)
tasks = get_push_system_user_tasks(system_user, platform, username=username)
run_task(tasks, [_host])
auth_books = AuthBook.objects.filter(username__in=usernames, asset_id__in=assets_id)
for auth_book in auth_books:
special_auth_set.add((auth_book.username, auth_book.asset_id))
for _username in usernames:
no_special_assets = []
for asset_id in assets_id:
if (_username, asset_id) not in special_auth_set:
no_special_assets.append(id_asset_map[asset_id])
if no_special_assets:
no_special_auth.append((_username, no_special_assets))
for _username, no_special_assets in no_special_auth:
tasks = get_push_system_user_tasks(system_user, platform, username=_username)
run_task(tasks, no_special_assets)
for auth_book in auth_books:
system_user._merge_auth(auth_book)
tasks = get_push_system_user_tasks(system_user, platform, username=auth_book.username)
asset = id_asset_map[auth_book.asset_id]
run_task(tasks, [asset])
@shared_task(queue="ansible")
@tmp_to_root_org()
def push_system_user_to_assets_manual(system_user, username=None):
"""
将系统用户推送到与它关联的所有资产上
"""
system_user = get_object_if_need(SystemUser, system_user)
assets = system_user.get_related_assets()
task_name = _("Push system users to assets: {}").format(system_user.name)
@ -237,7 +266,11 @@ def push_system_user_to_assets_manual(system_user, username=None):
@shared_task(queue="ansible")
@tmp_to_root_org()
def push_system_user_a_asset_manual(system_user, asset, username=None):
"""
将系统用户推送到一个资产上
"""
if username is None:
username = system_user.username
task_name = _("Push system users to asset: {}({}) => {}").format(
@ -247,10 +280,15 @@ def push_system_user_a_asset_manual(system_user, asset, username=None):
@shared_task(queue="ansible")
@tmp_to_root_org()
def push_system_user_to_assets(system_user_id, assets_id, username=None):
"""
推送系统用户到指定的若干资产上
"""
system_user = SystemUser.objects.get(id=system_user_id)
assets = get_objects(Asset, assets_id)
task_name = _("Push system users to assets: {}").format(system_user.name)
return push_system_user_util(system_user, assets, task_name, username=username)
# @shared_task

View File

@ -1,8 +1,11 @@
# ~*~ coding: utf-8 ~*~
#
import time
from django.db.models import Q
from common.utils import get_logger, dict_get_any, is_uuid, get_object_or_none
from common.utils.lock import DistributedLock
from common.http import is_true
from .models import Asset, Node
@ -10,17 +13,21 @@ from .models import Asset, Node
logger = get_logger(__file__)
@DistributedLock(name="assets.node.check_node_assets_amount", blocking=False)
def check_node_assets_amount():
for node in Node.objects.all():
logger.info(f'Check node assets amount: {node}')
assets_amount = Asset.objects.filter(
Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes=node)
).distinct().count()
if node.assets_amount != assets_amount:
print(f'>>> <Node:{node.key}> wrong assets amount '
f'{node.assets_amount} right is {assets_amount}')
logger.warn(f'Node wrong assets amount <Node:{node.key}> '
f'{node.assets_amount} right is {assets_amount}')
node.assets_amount = assets_amount
node.save()
# 防止自检程序给数据库的压力太大
time.sleep(0.1)
def is_asset_exists_in_node(asset_pk, node_key):

View File

@ -0,0 +1,18 @@
# Generated by Django 3.1 on 2020-12-09 03:03
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('audits', '0010_auto_20200811_1122'),
]
operations = [
migrations.AddField(
model_name='userloginlog',
name='backend',
field=models.CharField(default='', max_length=32, verbose_name='Login backend'),
),
]

View File

@ -105,6 +105,7 @@ class UserLoginLog(models.Model):
reason = models.CharField(default='', max_length=128, blank=True, verbose_name=_('Reason'))
status = models.BooleanField(max_length=2, default=True, choices=STATUS_CHOICE, verbose_name=_('Status'))
datetime = models.DateTimeField(default=timezone.now, verbose_name=_('Date login'))
backend = models.CharField(max_length=32, default='', verbose_name=_('Login backend'))
@classmethod
def get_login_logs(cls, date_from=None, date_to=None, user=None, keyword=None):

View File

@ -31,7 +31,8 @@ class UserLoginLogSerializer(serializers.ModelSerializer):
model = models.UserLoginLog
fields = (
'id', 'username', 'type', 'type_display', 'ip', 'city', 'user_agent',
'mfa', 'reason', 'status', 'status_display', 'datetime', 'mfa_display'
'mfa', 'reason', 'status', 'status_display', 'datetime', 'mfa_display',
'backend'
)
extra_kwargs = {
"user_agent": {'label': _('User agent')}

View File

@ -5,6 +5,8 @@ from django.db.models.signals import post_save, post_delete
from django.dispatch import receiver
from django.db import transaction
from django.utils import timezone
from django.contrib.auth import BACKEND_SESSION_KEY
from django.utils.translation import ugettext_lazy as _
from rest_framework.renderers import JSONRenderer
from rest_framework.request import Request
@ -32,6 +34,19 @@ MODELS_NEED_RECORD = (
)
LOGIN_BACKEND = {
'PublicKeyAuthBackend': _('SSH Key'),
'RadiusBackend': User.Source.radius.label,
'RadiusRealmBackend': User.Source.radius.label,
'LDAPAuthorizationBackend': User.Source.ldap.label,
'ModelBackend': _('Password'),
'SSOAuthentication': _('SSO'),
'CASBackend': User.Source.cas.label,
'OIDCAuthCodeBackend': User.Source.openid.label,
'OIDCAuthPasswordBackend': User.Source.openid.label,
}
def create_operate_log(action, sender, resource):
user = current_request.user if current_request else None
if not user or not user.is_authenticated:
@ -109,6 +124,12 @@ def on_audits_log_create(sender, instance=None, **kwargs):
sys_logger.info(msg)
def get_login_backend(request):
backend = request.session.get(BACKEND_SESSION_KEY, '')
backend = backend.rsplit('.', maxsplit=1)[-1]
return LOGIN_BACKEND.get(backend, '')
def generate_data(username, request):
user_agent = request.META.get('HTTP_USER_AGENT', '')
login_ip = get_request_ip(request) or '0.0.0.0'
@ -122,7 +143,8 @@ def generate_data(username, request):
'ip': login_ip,
'type': login_type,
'user_agent': user_agent,
'datetime': timezone.now()
'datetime': timezone.now(),
'backend': get_login_backend(request)
}
return data

View File

@ -6,7 +6,7 @@ import time
from django.core.cache import cache
from django.utils.translation import ugettext as _
from django.utils.six import text_type
from six import text_type
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
from rest_framework import HTTP_HEADER_ENCODING

View File

@ -23,7 +23,7 @@ class CreateUserMixin:
email_suffix = settings.EMAIL_SUFFIX
email = '{}@{}'.format(username, email_suffix)
user = User(username=username, name=username, email=email)
user.source = user.SOURCE_RADIUS
user.source = user.Source.radius.value
user.save()
return user

View File

@ -218,5 +218,14 @@ class PasswdTooSimple(JMSException):
default_detail = _('Your password is too simple, please change it for security')
def __init__(self, url, *args, **kwargs):
super(PasswdTooSimple, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.url = url
class PasswordRequireResetError(JMSException):
default_code = 'passwd_has_expired'
default_detail = _('Your password has expired, please reset before logging in')
def __init__(self, url, *args, **kwargs):
super().__init__(*args, **kwargs)
self.url = url

View File

@ -4,7 +4,7 @@
from django import forms
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from captcha.fields import CaptchaField
from captcha.fields import CaptchaField, CaptchaTextInput
class UserLoginForm(forms.Form):
@ -26,8 +26,12 @@ class UserCheckOtpCodeForm(forms.Form):
otp_code = forms.CharField(label=_('MFA code'), max_length=6)
class CustomCaptchaTextInput(CaptchaTextInput):
template_name = 'authentication/_captcha_field.html'
class CaptchaMixin(forms.Form):
captcha = CaptchaField()
captcha = CaptchaField(widget=CustomCaptchaTextInput)
class ChallengeMixin(forms.Form):

View File

@ -110,9 +110,8 @@ class AuthMixin:
raise CredentialError(error=errors.reason_user_inactive)
elif not user.is_active:
raise CredentialError(error=errors.reason_user_inactive)
elif user.password_has_expired:
raise CredentialError(error=errors.reason_password_expired)
self._check_password_require_reset_or_not(user)
self._check_passwd_is_too_simple(user, password)
clean_failed_count(username, ip)
@ -123,20 +122,34 @@ class AuthMixin:
return user
@classmethod
def _check_passwd_is_too_simple(cls, user, password):
def generate_reset_password_url_with_flash_msg(cls, user: User, flash_view_name):
reset_passwd_url = reverse('authentication:reset-password')
query_str = urlencode({
'token': user.generate_reset_token()
})
reset_passwd_url = f'{reset_passwd_url}?{query_str}'
flash_page_url = reverse(flash_view_name)
query_str = urlencode({
'redirect_url': reset_passwd_url
})
return f'{flash_page_url}?{query_str}'
@classmethod
def _check_passwd_is_too_simple(cls, user: User, password):
if user.is_superuser and password == 'admin':
reset_passwd_url = reverse('authentication:reset-password')
query_str = urlencode({
'token': user.generate_reset_token()
})
reset_passwd_url = f'{reset_passwd_url}?{query_str}'
url = cls.generate_reset_password_url_with_flash_msg(
user, 'authentication:passwd-too-simple-flash-msg'
)
raise errors.PasswdTooSimple(url)
flash_page_url = reverse('authentication:passwd-too-simple-flash-msg')
query_str = urlencode({
'redirect_url': reset_passwd_url
})
raise errors.PasswdTooSimple(f'{flash_page_url}?{query_str}')
@classmethod
def _check_password_require_reset_or_not(cls, user: User):
if user.password_has_expired:
url = cls.generate_reset_password_url_with_flash_msg(
user, 'authentication:passwd-has-expired-flash-msg'
)
raise errors.PasswordRequireResetError(url)
def check_user_auth_if_need(self, decrypt_passwd=False):
request = self.request

View File

@ -1,11 +1,9 @@
import uuid
from functools import partial
from django.utils import timezone
from django.utils.translation import ugettext_lazy as _, ugettext as __
from rest_framework.authtoken.models import Token
from django.conf import settings
from django.utils.crypto import get_random_string
from common.db import models
from common.mixins.models import CommonModelMixin

View File

@ -0,0 +1,29 @@
{% load i18n %}
{% spaceless %}
<img src="{{ image }}" alt="captcha" class="captcha" />
<div class="row" style="padding-bottom: 10px">
<div class="col-sm-6">
<div class="input-group-prepend">
{% if audio %}
<a title="{% trans "Play CAPTCHA as audio file" %}" href="{{ audio }}">
{% endif %}
</div>
{% include "django/forms/widgets/multiwidget.html" %}
</div>
</div>
<script>
var placeholder = '{% trans "Captcha" %}'
function refresh_captcha() {
$.getJSON("{% url "captcha-refresh" %}",
function (result) {
$('.captcha').attr('src', result['image_url']);
$('#id_captcha_0').val(result['key'])
})
}
$(document).ready(function () {
$('.captcha').click(refresh_captcha)
$('#id_captcha_1').addClass('form-control').attr('placeholder', placeholder)
})
</script>
{% endspaceless %}

View File

@ -22,6 +22,7 @@ urlpatterns = [
name='forgot-password-sendmail-success'),
path('password/reset/', users_view.UserResetPasswordView.as_view(), name='reset-password'),
path('password/too-simple-flash-msg/', views.FlashPasswdTooSimpleMsgView.as_view(), name='passwd-too-simple-flash-msg'),
path('password/has-expired-msg/', views.FlashPasswdHasExpiredMsgView.as_view(), name='passwd-has-expired-flash-msg'),
path('password/reset/success/', users_view.UserResetPasswordSuccessView.as_view(), name='reset-password-success'),
path('password/verify/', users_view.UserVerifyPasswordView.as_view(), name='user-verify-password'),

View File

@ -32,7 +32,7 @@ from ..forms import get_user_login_form_cls
__all__ = [
'UserLoginView', 'UserLogoutView',
'UserLoginGuardView', 'UserLoginWaitConfirmView',
'FlashPasswdTooSimpleMsgView',
'FlashPasswdTooSimpleMsgView', 'FlashPasswdHasExpiredMsgView'
]
@ -96,7 +96,7 @@ class UserLoginView(mixins.AuthMixin, FormView):
new_form._errors = form.errors
context = self.get_context_data(form=new_form)
return self.render_to_response(context)
except errors.PasswdTooSimple as e:
except (errors.PasswdTooSimple, errors.PasswordRequireResetError) as e:
return redirect(e.url)
self.clear_rsa_key()
return self.redirect_to_guard_view()
@ -250,3 +250,18 @@ class FlashPasswdTooSimpleMsgView(TemplateView):
'auto_redirect': True,
}
return self.render_to_response(context)
@method_decorator(never_cache, name='dispatch')
class FlashPasswdHasExpiredMsgView(TemplateView):
template_name = 'flash_message_standalone.html'
def get(self, request, *args, **kwargs):
context = {
'title': _('Please change your password'),
'messages': _('Your password has expired, please reset before logging in'),
'interval': 5,
'redirect_url': request.GET.get('redirect_url'),
'auto_redirect': True,
}
return self.render_to_response(context)

View File

@ -1 +1,2 @@
from .csv import *
from .csv import *
from .excel import *

View File

@ -0,0 +1,132 @@
import abc
import json
import codecs
from django.utils.translation import ugettext_lazy as _
from rest_framework.parsers import BaseParser
from rest_framework import status
from rest_framework.exceptions import ParseError, APIException
from common.utils import get_logger
logger = get_logger(__file__)
class FileContentOverflowedError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_code = 'file_content_overflowed'
default_detail = _('The file content overflowed (The maximum length `{}` bytes)')
class BaseFileParser(BaseParser):
FILE_CONTENT_MAX_LENGTH = 1024 * 1024 * 10
serializer_cls = None
def check_content_length(self, meta):
content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)))
if content_length > self.FILE_CONTENT_MAX_LENGTH:
msg = FileContentOverflowedError.default_detail.format(self.FILE_CONTENT_MAX_LENGTH)
logger.error(msg)
raise FileContentOverflowedError(msg)
@staticmethod
def get_stream_data(stream):
stream_data = stream.read()
stream_data = stream_data.strip(codecs.BOM_UTF8)
return stream_data
@abc.abstractmethod
def generate_rows(self, stream_data):
raise NotImplemented
def get_column_titles(self, rows):
return next(rows)
def convert_to_field_names(self, column_titles):
fields_map = {}
fields = self.serializer_cls().fields
fields_map.update({v.label: k for k, v in fields.items()})
fields_map.update({k: k for k, _ in fields.items()})
field_names = [
fields_map.get(column_title.strip('*'), '')
for column_title in column_titles
]
return field_names
@staticmethod
def _replace_chinese_quote(s):
trans_table = str.maketrans({
'': '"',
'': '"',
'': '"',
'': '"',
'\'': '"'
})
return s.translate(trans_table)
@classmethod
def process_row(cls, row):
"""
构建json数据前的行处理
"""
new_row = []
for col in row:
# 转换中文引号
col = cls._replace_chinese_quote(col)
# 列表/字典转换
if isinstance(col, str) and (
(col.startswith('[') and col.endswith(']'))
or
(col.startswith("{") and col.endswith("}"))
):
col = json.loads(col)
new_row.append(col)
return new_row
@staticmethod
def process_row_data(row_data):
"""
构建json数据后的行数据处理
"""
new_row_data = {}
for k, v in row_data.items():
if isinstance(v, list) or isinstance(v, dict) or isinstance(v, str) and k.strip() and v.strip():
new_row_data[k] = v
return new_row_data
def generate_data(self, fields_name, rows):
data = []
for row in rows:
# 空行不处理
if not any(row):
continue
row = self.process_row(row)
row_data = dict(zip(fields_name, row))
row_data = self.process_row_data(row_data)
data.append(row_data)
return data
def parse(self, stream, media_type=None, parser_context=None):
parser_context = parser_context or {}
try:
view = parser_context['view']
meta = view.request.META
self.serializer_cls = view.get_serializer_class()
except Exception as e:
logger.debug(e, exc_info=True)
raise ParseError('The resource does not support imports!')
self.check_content_length(meta)
try:
stream_data = self.get_stream_data(stream)
rows = self.generate_rows(stream_data)
column_titles = self.get_column_titles(rows)
field_names = self.convert_to_field_names(column_titles)
data = self.generate_data(field_names, rows)
return data
except Exception as e:
logger.error(e, exc_info=True)
raise ParseError('Parse error! ({})'.format(self.media_type))

View File

@ -1,32 +1,13 @@
# ~*~ coding: utf-8 ~*~
#
import json
import chardet
import codecs
import unicodecsv
from django.utils.translation import ugettext as _
from rest_framework.parsers import BaseParser
from rest_framework.exceptions import ParseError, APIException
from rest_framework import status
from common.utils import get_logger
logger = get_logger(__file__)
from .base import BaseFileParser
class CsvDataTooBig(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_code = 'csv_data_too_big'
default_detail = _('The max size of CSV is %d bytes')
class JMSCSVParser(BaseParser):
"""
Parses CSV file to serializer data
"""
CSV_UPLOAD_MAX_SIZE = 1024 * 1024 * 10
class CSVFileParser(BaseFileParser):
media_type = 'text/csv'
@ -38,99 +19,10 @@ class JMSCSVParser(BaseParser):
for line in stream.splitlines():
yield line
@staticmethod
def _gen_rows(csv_data, charset='utf-8', **kwargs):
csv_reader = unicodecsv.reader(csv_data, encoding=charset, **kwargs)
def generate_rows(self, stream_data):
detect_result = chardet.detect(stream_data)
encoding = detect_result.get("encoding", "utf-8")
lines = self._universal_newlines(stream_data)
csv_reader = unicodecsv.reader(lines, encoding=encoding)
for row in csv_reader:
if not any(row): # 空行
continue
yield row
@staticmethod
def _get_fields_map(serializer_cls):
fields_map = {}
fields = serializer_cls().fields
fields_map.update({v.label: k for k, v in fields.items()})
fields_map.update({k: k for k, _ in fields.items()})
return fields_map
@staticmethod
def _replace_chinese_quot(str_):
trans_table = str.maketrans({
'': '"',
'': '"',
'': '"',
'': '"',
'\'': '"'
})
return str_.translate(trans_table)
@classmethod
def _process_row(cls, row):
"""
构建json数据前的行处理
"""
_row = []
for col in row:
# 列表转换
if isinstance(col, str) and col.startswith('[') and col.endswith(']'):
col = cls._replace_chinese_quot(col)
col = json.loads(col)
# 字典转换
if isinstance(col, str) and col.startswith("{") and col.endswith("}"):
col = cls._replace_chinese_quot(col)
col = json.loads(col)
_row.append(col)
return _row
@staticmethod
def _process_row_data(row_data):
"""
构建json数据后的行数据处理
"""
_row_data = {}
for k, v in row_data.items():
if isinstance(v, list) or isinstance(v, dict)\
or isinstance(v, str) and k.strip() and v.strip():
_row_data[k] = v
return _row_data
def parse(self, stream, media_type=None, parser_context=None):
parser_context = parser_context or {}
try:
view = parser_context['view']
meta = view.request.META
serializer_cls = view.get_serializer_class()
except Exception as e:
logger.debug(e, exc_info=True)
raise ParseError('The resource does not support imports!')
content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)))
if content_length > self.CSV_UPLOAD_MAX_SIZE:
msg = CsvDataTooBig.default_detail % self.CSV_UPLOAD_MAX_SIZE
logger.error(msg)
raise CsvDataTooBig(msg)
try:
stream_data = stream.read()
stream_data = stream_data.strip(codecs.BOM_UTF8)
detect_result = chardet.detect(stream_data)
encoding = detect_result.get("encoding", "utf-8")
binary = self._universal_newlines(stream_data)
rows = self._gen_rows(binary, charset=encoding)
header = next(rows)
fields_map = self._get_fields_map(serializer_cls)
header = [fields_map.get(name.strip('*'), '') for name in header]
data = []
for row in rows:
row = self._process_row(row)
row_data = dict(zip(header, row))
row_data = self._process_row_data(row_data)
data.append(row_data)
return data
except Exception as e:
logger.error(e, exc_info=True)
raise ParseError('CSV parse error!')

View File

@ -0,0 +1,14 @@
import pyexcel
from .base import BaseFileParser
class ExcelFileParser(BaseFileParser):
media_type = 'text/xlsx'
def generate_rows(self, stream_data):
workbook = pyexcel.get_book(file_type='xlsx', file_content=stream_data)
# 默认获取第一个工作表sheet
sheet = workbook.sheet_by_index(0)
rows = sheet.rows()
return rows

View File

@ -1,6 +1,7 @@
from rest_framework import renderers
from .csv import *
from .excel import *
class PassthroughRenderer(renderers.BaseRenderer):

View File

@ -0,0 +1,132 @@
import abc
from datetime import datetime
from rest_framework.renderers import BaseRenderer
from rest_framework.utils import encoders, json
from common.utils import get_logger
logger = get_logger(__file__)
class BaseFileRenderer(BaseRenderer):
# 渲染模版标识, 导入、导出、更新模版: ['import', 'update', 'export']
template = 'export'
serializer = None
@staticmethod
def _check_validation_data(data):
detail_key = "detail"
if detail_key in data:
return False
return True
@staticmethod
def _json_format_response(response_data):
return json.dumps(response_data)
def set_response_disposition(self, response):
serializer = self.serializer
if response and hasattr(serializer, 'Meta') and hasattr(serializer.Meta, "model"):
model_name = serializer.Meta.model.__name__.lower()
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = "{}_{}.{}".format(model_name, now, self.format)
disposition = 'attachment; filename="{}"'.format(filename)
response['Content-Disposition'] = disposition
def get_rendered_fields(self):
fields = self.serializer.fields
if self.template == 'import':
return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id']
elif self.template == 'update':
return [v for k, v in fields.items() if not v.read_only and k != "org_id"]
else:
return [v for k, v in fields.items() if not v.write_only and k != "org_id"]
@staticmethod
def get_column_titles(render_fields):
return [
'*{}'.format(field.label) if field.required else str(field.label)
for field in render_fields
]
def process_data(self, data):
results = data['results'] if 'results' in data else data
if isinstance(results, dict):
results = [results]
if self.template == 'import':
results = [results[0]] if results else results
else:
# 限制数据数量
results = results[:10000]
# 会将一些 UUID 字段转化为 string
results = json.loads(json.dumps(results, cls=encoders.JSONEncoder))
return results
@staticmethod
def generate_rows(data, render_fields):
for item in data:
row = []
for field in render_fields:
value = item.get(field.field_name)
value = str(value) if value else ''
row.append(value)
yield row
@abc.abstractmethod
def initial_writer(self):
raise NotImplementedError
def write_column_titles(self, column_titles):
self.write_row(column_titles)
def write_rows(self, rows):
for row in rows:
self.write_row(row)
@abc.abstractmethod
def write_row(self, row):
raise NotImplementedError
@abc.abstractmethod
def get_rendered_value(self):
raise NotImplementedError
def render(self, data, accepted_media_type=None, renderer_context=None):
if data is None:
return bytes()
if not self._check_validation_data(data):
return self._json_format_response(data)
try:
renderer_context = renderer_context or {}
request = renderer_context['request']
response = renderer_context['response']
view = renderer_context['view']
self.template = request.query_params.get('template', 'export')
self.serializer = view.get_serializer()
self.set_response_disposition(response)
except Exception as e:
logger.debug(e, exc_info=True)
value = 'The resource not support export!'.encode('utf-8')
return value
try:
rendered_fields = self.get_rendered_fields()
column_titles = self.get_column_titles(rendered_fields)
data = self.process_data(data)
rows = self.generate_rows(data, rendered_fields)
self.initial_writer()
self.write_column_titles(column_titles)
self.write_rows(rows)
value = self.get_rendered_value()
except Exception as e:
logger.debug(e, exc_info=True)
value = 'Render error! ({})'.format(self.media_type).encode('utf-8')
return value
return value

View File

@ -1,83 +1,30 @@
# ~*~ coding: utf-8 ~*~
#
import unicodecsv
import codecs
from datetime import datetime
import unicodecsv
from six import BytesIO
from rest_framework.renderers import BaseRenderer
from rest_framework.utils import encoders, json
from common.utils import get_logger
logger = get_logger(__file__)
from .base import BaseFileRenderer
class JMSCSVRender(BaseRenderer):
class CSVFileRenderer(BaseFileRenderer):
media_type = 'text/csv'
format = 'csv'
@staticmethod
def _get_show_fields(fields, template):
if template == 'import':
return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id']
elif template == 'update':
return [v for k, v in fields.items() if not v.read_only and k != "org_id"]
else:
return [v for k, v in fields.items() if not v.write_only and k != "org_id"]
writer = None
buffer = None
@staticmethod
def _gen_table(data, fields):
data = data[:10000]
yield ['*{}'.format(f.label) if f.required else f.label for f in fields]
def initial_writer(self):
csv_buffer = BytesIO()
csv_buffer.write(codecs.BOM_UTF8)
csv_writer = unicodecsv.writer(csv_buffer, encoding='utf-8')
self.buffer = csv_buffer
self.writer = csv_writer
for item in data:
row = [item.get(f.field_name) for f in fields]
yield row
def set_response_disposition(self, serializer, context):
response = context.get('response')
if response and hasattr(serializer, 'Meta') and \
hasattr(serializer.Meta, "model"):
model_name = serializer.Meta.model.__name__.lower()
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = "{}_{}.csv".format(model_name, now)
disposition = 'attachment; filename="{}"'.format(filename)
response['Content-Disposition'] = disposition
def render(self, data, media_type=None, renderer_context=None):
renderer_context = renderer_context or {}
request = renderer_context['request']
template = request.query_params.get('template', 'export')
view = renderer_context['view']
if isinstance(data, dict):
data = data.get("results", [])
if template == 'import':
data = [data[0]] if data else data
data = json.loads(json.dumps(data, cls=encoders.JSONEncoder))
try:
serializer = view.get_serializer()
self.set_response_disposition(serializer, renderer_context)
except Exception as e:
logger.debug(e, exc_info=True)
value = 'The resource not support export!'.encode('utf-8')
else:
fields = serializer.fields
show_fields = self._get_show_fields(fields, template)
table = self._gen_table(data, show_fields)
csv_buffer = BytesIO()
csv_buffer.write(codecs.BOM_UTF8)
csv_writer = unicodecsv.writer(csv_buffer, encoding='utf-8')
for row in table:
csv_writer.writerow(row)
value = csv_buffer.getvalue()
def write_row(self, row):
self.writer.writerow(row)
def get_rendered_value(self):
value = self.buffer.getvalue()
return value

View File

@ -0,0 +1,28 @@
from openpyxl import Workbook
from openpyxl.writer.excel import save_virtual_workbook
from .base import BaseFileRenderer
class ExcelFileRenderer(BaseFileRenderer):
media_type = "application/xlsx"
format = "xlsx"
wb = None
ws = None
row_count = 0
def initial_writer(self):
self.wb = Workbook()
self.ws = self.wb.active
def write_row(self, row):
self.row_count += 1
column_count = 0
for cell_value in row:
column_count += 1
self.ws.cell(row=self.row_count, column=column_count, value=cell_value)
def get_rendered_value(self):
value = save_virtual_workbook(self.wb)
return value

View File

@ -3,7 +3,7 @@
import json
from django import forms
from django.utils import six
import six
from django.core.exceptions import ValidationError
from django.utils.translation import ugettext as _
from ..utils import signer

View File

@ -31,7 +31,7 @@ class JsonMixin:
def json_encode(data):
return json.dumps(data)
def from_db_value(self, value, expression, connection, context):
def from_db_value(self, value, expression, connection, context=None):
if value is None:
return value
return self.json_decode(value)
@ -54,7 +54,7 @@ class JsonMixin:
class JsonTypeMixin(JsonMixin):
tp = dict
def from_db_value(self, value, expression, connection, context):
def from_db_value(self, value, expression, connection, context=None):
value = super().from_db_value(value, expression, connection, context)
if not isinstance(value, self.tp):
value = self.tp()
@ -116,7 +116,7 @@ class EncryptMixin:
def decrypt_from_signer(self, value):
return signer.unsign(value) or ''
def from_db_value(self, value, expression, connection, context):
def from_db_value(self, value, expression, connection, context=None):
if value is None:
return value
value = force_text(value)

View File

@ -2,7 +2,7 @@
#
from rest_framework import serializers
from django.utils import six
import six
__all__ = [

View File

@ -41,7 +41,7 @@ def timesince(dt, since='', default="just now"):
3 days, 5 hours.
"""
if since is '':
if not since:
since = datetime.datetime.utcnow()
if since is None:

View File

@ -0,0 +1,10 @@
import inspect
def copy_function_args(func, locals_dict: dict):
signature = inspect.signature(func)
keys = signature.parameters.keys()
kwargs = {}
for k in keys:
kwargs[k] = locals_dict.get(k)
return kwargs

55
apps/common/utils/lock.py Normal file
View File

@ -0,0 +1,55 @@
from functools import wraps
from redis_lock import Lock as RedisLock
from redis import Redis
from common.utils import get_logger
from common.utils.inspect import copy_function_args
from apps.jumpserver.const import CONFIG
logger = get_logger(__file__)
class AcquireFailed(RuntimeError):
pass
class DistributedLock(RedisLock):
def __init__(self, name, blocking=True, expire=60*2, auto_renewal=True):
"""
使用 redis 构造的分布式锁
:param name:
锁的名字要全局唯一
:param blocking:
该参数只在锁作为装饰器或者 `with` 时有效
:param expire:
锁的过期时间注意不一定是锁到这个时间就释放了分两种情况
`auto_renewal=False` 锁会释放
`auto_renewal=True` 如果过期之前程序还没释放锁我们会延长锁的存活时间
这里的作用是防止程序意外终止没有释放锁导致死锁
"""
self.kwargs_copy = copy_function_args(self.__init__, locals())
redis = Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
super().__init__(redis_client=redis, name=name, expire=expire, auto_renewal=auto_renewal)
self._blocking = blocking
def __enter__(self):
acquired = self.acquire(blocking=self._blocking)
if self._blocking and not acquired:
raise EnvironmentError("Lock wasn't acquired, but blocking=True")
if not acquired:
raise AcquireFailed
return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
self.release()
def __call__(self, func):
@wraps(func)
def inner(*args, **kwds):
# 要创建一个新的锁对象
with self.__class__(**self.kwargs_copy):
return func(*args, **kwds)
return inner

View File

@ -14,9 +14,9 @@ alphanumeric = RegexValidator(r'^[0-9a-zA-Z_@\-\.]*$', _('Special char not allow
class ProjectUniqueValidator(UniqueTogetherValidator):
def __call__(self, attrs):
def __call__(self, attrs, serializer):
try:
super().__call__(attrs)
super().__call__(attrs, serializer)
except ValidationError as e:
errors = {}
for field in self.fields:

View File

@ -2,13 +2,14 @@ from django.core.cache import cache
from django.utils import timezone
from django.utils.timesince import timesince
from django.db.models import Count, Max
from django.http.response import JsonResponse
from django.http.response import JsonResponse, HttpResponse
from rest_framework.views import APIView
from collections import Counter
from users.models import User
from assets.models import Asset
from terminal.models import Session
from terminal.utils import ComponentsPrometheusMetricsUtil
from orgs.utils import current_org
from common.permissions import IsOrgAdmin, IsOrgAuditor
from common.utils import lazyproperty
@ -305,3 +306,11 @@ class IndexApi(TotalCountMixin, DatesLoginMetricMixin, APIView):
return JsonResponse(data, status=200)
class PrometheusMetricsApi(APIView):
permission_classes = ()
def get(self, request, *args, **kwargs):
util = ComponentsPrometheusMetricsUtil()
metrics_text = util.get_prometheus_metrics_text()
return HttpResponse(metrics_text, content_type='text/plain; version=0.0.4; charset=utf-8')

View File

@ -16,7 +16,7 @@ import json
import yaml
from importlib import import_module
from django.urls import reverse_lazy
from django.contrib.staticfiles.templatetags.staticfiles import static
from django.templatetags.static import static
from urllib.parse import urljoin, urlparse
from django.utils.translation import ugettext_lazy as _

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
#
from django.contrib.staticfiles.templatetags.staticfiles import static
from django.templatetags.static import static
from django.conf import settings
from django.utils.translation import gettext_lazy as _

View File

@ -64,6 +64,7 @@ INSTALLED_APPS = [
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'django.forms',
]

View File

@ -11,14 +11,17 @@ REST_FRAMEWORK = {
),
'DEFAULT_RENDERER_CLASSES': (
'rest_framework.renderers.JSONRenderer',
'rest_framework.renderers.BrowsableAPIRenderer',
'common.drf.renders.JMSCSVRender',
# 'rest_framework.renderers.BrowsableAPIRenderer',
'common.drf.renders.CSVFileRenderer',
'common.drf.renders.ExcelFileRenderer',
),
'DEFAULT_PARSER_CLASSES': (
'rest_framework.parsers.JSONParser',
'rest_framework.parsers.FormParser',
'rest_framework.parsers.MultiPartParser',
'common.drf.parsers.JMSCSVParser',
'common.drf.parsers.CSVFileParser',
'common.drf.parsers.ExcelFileParser',
'rest_framework.parsers.FileUploadParser',
),
'DEFAULT_AUTHENTICATION_CLASSES': (
@ -61,10 +64,10 @@ SWAGGER_SETTINGS = {
# Captcha settings, more see https://django-simple-captcha.readthedocs.io/en/latest/advanced.html
CAPTCHA_IMAGE_SIZE = (80, 33)
CAPTCHA_IMAGE_SIZE = (140, 34)
CAPTCHA_FOREGROUND_COLOR = '#001100'
CAPTCHA_NOISE_FUNCTIONS = ('captcha.helpers.noise_dots',)
CAPTCHA_TEST_MODE = CONFIG.CAPTCHA_TEST_MODE
CAPTCHA_CHALLENGE_FUNCT = 'captcha.helpers.math_challenge'
# Django bootstrap3 setting, more see http://django-bootstrap3.readthedocs.io/en/latest/settings.html
BOOTSTRAP3 = {

View File

@ -23,6 +23,7 @@ api_v1 = [
path('common/', include('common.urls.api_urls', namespace='api-common')),
path('applications/', include('applications.urls.api_urls', namespace='api-applications')),
path('tickets/', include('tickets.urls.api_urls', namespace='api-tickets')),
path('prometheus/metrics/', api.PrometheusMetricsApi.as_view())
]
api_v2 = [
@ -30,7 +31,6 @@ api_v2 = [
path('users/', include('users.urls.api_urls_v2', namespace='api-users-v2')),
]
app_view_patterns = [
path('auth/', include('authentication.urls.view_urls'), name='auth'),
path('ops/', include('ops.urls.view_urls'), name='ops'),
@ -63,7 +63,7 @@ urlpatterns = [
# External apps url
path('core/auth/captcha/', include('captcha.urls')),
path('core/', include(app_view_patterns)),
path('ui/', views.UIView.as_view())
path('ui/', views.UIView.as_view()),
]
urlpatterns += static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) \

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@ -29,16 +29,3 @@ configs["CELERY_ROUTES"] = {
app.namespace = 'CELERY'
app.conf.update(configs)
app.autodiscover_tasks(lambda: [app_config.split('.')[0] for app_config in settings.INSTALLED_APPS])
app.conf.beat_schedule = {
'check-asset-permission-expired': {
'task': 'perms.tasks.check_asset_permission_expired',
'schedule': settings.PERM_EXPIRED_CHECK_PERIODIC,
'args': ()
},
'check-node-assets-amount': {
'task': 'assets.tasks.nodes_amount.check_node_assets_amount_celery_task',
'schedule': crontab(minute=0, hour=0),
'args': ()
},
}

View File

@ -3,6 +3,8 @@
import json
import os
import redis_lock
import redis
from django.conf import settings
from django.utils.timezone import get_current_timezone
from django.db.utils import ProgrammingError, OperationalError
@ -105,3 +107,27 @@ def get_celery_task_log_path(task_id):
path = os.path.join(settings.CELERY_LOG_DIR, rel_path)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
def get_celery_status():
from . import app
i = app.control.inspect()
ping_data = i.ping() or {}
active_nodes = [k for k, v in ping_data.items() if v.get('ok') == 'pong']
active_queue_worker = set([n.split('@')[0] for n in active_nodes if n])
if len(active_queue_worker) < 5:
print("Not all celery worker worked")
return False
else:
return True
def get_beat_status():
CONFIG = settings.CONFIG
r = redis.Redis(host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, password=CONFIG.REDIS_PASSWORD)
lock = redis_lock.Lock(r, name="beat-distribute-start-lock")
try:
locked = lock.locked()
return locked
except redis.ConnectionError:
return False

View File

@ -0,0 +1,20 @@
from django.core.management.base import BaseCommand, CommandError
class Command(BaseCommand):
help = 'Ops manage commands'
def add_arguments(self, parser):
parser.add_argument('check_celery', nargs='?', help='Check celery health')
def handle(self, *args, **options):
from ops.celery.utils import get_celery_status, get_beat_status
ok = get_celery_status()
if not ok:
raise CommandError('Celery worker unhealthy')
ok = get_beat_status()
if not ok:
raise CommandError('Beat unhealthy')

View File

@ -92,6 +92,7 @@ class OrgMemberAdminRelationBulkViewSet(JMSBulkRelationModelViewSet):
serializer_class = OrgMemberAdminSerializer
filterset_class = OrgMemberRelationFilterSet
search_fields = ('user__name', 'user__username', 'org__name')
lookup_field = 'user_id'
def get_queryset(self):
queryset = super().get_queryset()
@ -116,6 +117,7 @@ class OrgMemberUserRelationBulkViewSet(JMSBulkRelationModelViewSet):
serializer_class = OrgMemberUserSerializer
filterset_class = OrgMemberRelationFilterSet
search_fields = ('user__name', 'user__username', 'org__name')
lookup_field = 'user_id'
def get_queryset(self):
queryset = super().get_queryset()

View File

@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
#
from .models import Organization
from .utils import get_org_from_request, set_current_org

View File

@ -8,7 +8,7 @@ from django.core.exceptions import ValidationError
from common.utils import get_logger
from ..utils import (
set_current_org, get_current_org, current_org,
filter_org_queryset
filter_org_queryset, get_org_by_id, get_org_name_by_id
)
from ..models import Organization
@ -70,13 +70,11 @@ class OrgModelMixin(models.Model):
@property
def org(self):
from orgs.models import Organization
org = Organization.get_instance(self.org_id)
return org
return get_org_by_id(self.org_id)
@property
def org_name(self):
return self.org.name
return get_org_name_by_id(self.org_id)
@property
def fullname(self, attr=None):

View File

@ -74,26 +74,29 @@ class OrgMemberSerializer(BulkModelSerializer):
).distinct()
class OrgMemberAdminSerializer(BulkModelSerializer):
class OrgMemberOldBaseSerializer(BulkModelSerializer):
organization = serializers.PrimaryKeyRelatedField(
label=_('Organization'), queryset=Organization.objects.all(), required=True, source='org'
)
def to_internal_value(self, data):
view = self.context['view']
org_id = view.kwargs.get('org_id')
if org_id:
data['organization'] = org_id
return super().to_internal_value(data)
class Meta:
model = OrganizationMember
fields = ('id', 'organization', 'user', 'role')
class OrgMemberAdminSerializer(OrgMemberOldBaseSerializer):
role = serializers.HiddenField(default=ROLE.ADMIN)
organization = serializers.PrimaryKeyRelatedField(
label=_('Organization'), queryset=Organization.objects.all(), required=True, source='org'
)
class Meta:
model = OrganizationMember
fields = ('id', 'organization', 'user', 'role')
class OrgMemberUserSerializer(BulkModelSerializer):
class OrgMemberUserSerializer(OrgMemberOldBaseSerializer):
role = serializers.HiddenField(default=ROLE.USER)
organization = serializers.PrimaryKeyRelatedField(
label=_('Organization'), queryset=Organization.objects.all(), required=True, source='org'
)
class Meta:
model = OrganizationMember
fields = ('id', 'organization', 'user', 'role')
class OrgRetrieveSerializer(OrgReadSerializer):

View File

@ -65,6 +65,47 @@ def get_current_org_id():
return org_id
def construct_org_mapper():
orgs = Organization.objects.all()
org_mapper = {str(org.id): org for org in orgs}
default_org = Organization.default()
org_mapper.update({
'': default_org,
Organization.DEFAULT_ID: default_org,
Organization.ROOT_ID: Organization.root(),
Organization.SYSTEM_ID: Organization.system()
})
return org_mapper
def set_org_mapper(org_mapper):
setattr(thread_local, 'org_mapper', org_mapper)
def get_org_mapper():
org_mapper = _find('org_mapper')
if org_mapper is None:
org_mapper = construct_org_mapper()
set_org_mapper(org_mapper)
return org_mapper
def get_org_by_id(org_id):
org_id = str(org_id)
org_mapper = get_org_mapper()
org = org_mapper.get(org_id)
return org
def get_org_name_by_id(org_id):
org = get_org_by_id(org_id)
if org:
org_name = org.name
else:
org_name = 'Not Found'
return org_name
def get_current_org_id_for_serializer():
org_id = get_current_org_id()
if org_id == Organization.DEFAULT_ID:

View File

@ -48,7 +48,7 @@ class ApplicationsAsTreeMixin(SerializeApplicationToTreeNodeMixin):
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
data = self.serialize_applications(queryset)
data = self.serialize_applications_with_org(queryset)
return Response(data=data)

View File

@ -32,9 +32,6 @@ class UserGroupMixin:
class UserGroupGrantedAssetsApi(ListAPIView):
"""
获取用户组直接授权的资产
"""
permission_classes = (IsOrgAdminOrAppUser,)
serializer_class = serializers.AssetGrantedSerializer
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
@ -44,11 +41,27 @@ class UserGroupGrantedAssetsApi(ListAPIView):
def get_queryset(self):
user_group_id = self.kwargs.get('pk', '')
return Asset.objects.filter(
Q(granted_by_permissions__user_groups__id=user_group_id)
asset_perms_id = list(AssetPermission.objects.valid().filter(
user_groups__id=user_group_id
).distinct().values_list('id', flat=True))
granted_node_keys = Node.objects.filter(
granted_by_permissions__id__in=asset_perms_id,
).distinct().values_list('key', flat=True)
granted_q = Q()
for _key in granted_node_keys:
granted_q |= Q(nodes__key__startswith=f'{_key}:')
granted_q |= Q(nodes__key=_key)
granted_q |= Q(granted_by_permissions__id__in=asset_perms_id)
assets = Asset.objects.filter(
granted_q
).distinct().only(
*self.only_fields
)
return assets
class UserGroupGrantedNodeAssetsApi(ListAPIView):
@ -66,7 +79,7 @@ class UserGroupGrantedNodeAssetsApi(ListAPIView):
granted = AssetPermission.objects.filter(
user_groups__id=user_group_id,
nodes__id=node_id
).exists()
).valid().exists()
if granted:
assets = Asset.objects.filter(
Q(nodes__key__startswith=f'{node.key}:') |
@ -74,8 +87,12 @@ class UserGroupGrantedNodeAssetsApi(ListAPIView):
)
return assets
else:
asset_perms_id = list(AssetPermission.objects.valid().filter(
user_groups__id=user_group_id
).distinct().values_list('id', flat=True))
granted_node_keys = Node.objects.filter(
granted_by_permissions__user_groups__id=user_group_id,
granted_by_permissions__id__in=asset_perms_id,
key__startswith=f'{node.key}:'
).distinct().values_list('key', flat=True)
@ -85,7 +102,7 @@ class UserGroupGrantedNodeAssetsApi(ListAPIView):
granted_node_q |= Q(nodes__key=_key)
granted_asset_q = (
Q(granted_by_permissions__user_groups__id=user_group_id) &
Q(granted_by_permissions__id__in=asset_perms_id) &
(
Q(nodes__key__startswith=f'{node.key}:') |
Q(nodes__key=node.key)
@ -129,12 +146,16 @@ class UserGroupGrantedNodeChildrenAsTreeApi(SerializeToTreeNodeMixin, ListAPIVie
group_id = self.kwargs.get('pk')
node_key = self.request.query_params.get('key', None)
asset_perms_id = list(AssetPermission.objects.valid().filter(
user_groups__id=group_id
).distinct().values_list('id', flat=True))
granted_keys = Node.objects.filter(
granted_by_permissions__user_groups__id=group_id
granted_by_permissions__id__in=asset_perms_id
).values_list('key', flat=True)
asset_granted_keys = Node.objects.filter(
assets__granted_by_permissions__user_groups__id=group_id
assets__granted_by_permissions__id__in=asset_perms_id
).values_list('key', flat=True)
if node_key is None:

View File

@ -3,6 +3,7 @@
from perms.api.asset.user_permission.mixin import UserNodeGrantStatusDispatchMixin
from rest_framework.generics import ListAPIView
from rest_framework.response import Response
from rest_framework.request import Request
from django.conf import settings
from assets.api.mixin import SerializeToTreeNodeMixin
@ -55,8 +56,12 @@ class AssetsAsTreeMixin(SerializeToTreeNodeMixin):
"""
资产 序列化成树的结构返回
"""
def list(self, request, *args, **kwargs):
def list(self, request: Request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
if request.query_params.get('search'):
# 如果用户搜索的条件不精准,会导致返回大量的无意义数据。
# 这里限制一下返回数据的最大条数
queryset = queryset[:999]
data = self.serialize_assets(queryset, None)
return Response(data=data)

View File

@ -139,11 +139,13 @@ class MyGrantedNodesWithAssetsAsTreeApi(SerializeToTreeNodeMixin, ListAPIView):
return Response(data=data)
class UserGrantedNodeChildrenWithAssetsAsTreeForAdminApi(ForAdminMixin, UserNodeGrantStatusDispatchMixin,
SerializeToTreeNodeMixin, ListAPIView):
class GrantedNodeChildrenWithAssetsAsTreeApiMixin(UserNodeGrantStatusDispatchMixin,
SerializeToTreeNodeMixin,
ListAPIView):
"""
带资产的授权树
"""
user: None
def get_data_on_node_direct_granted(self, key):
nodes = Node.objects.filter(parent_key=key)
@ -203,5 +205,9 @@ class UserGrantedNodeChildrenWithAssetsAsTreeForAdminApi(ForAdminMixin, UserNode
return Response(data=[*tree_nodes, *tree_assets])
class MyGrantedNodeChildrenWithAssetsAsTreeApi(ForUserMixin, UserGrantedNodeChildrenWithAssetsAsTreeForAdminApi):
class UserGrantedNodeChildrenWithAssetsAsTreeApi(ForAdminMixin, GrantedNodeChildrenWithAssetsAsTreeApiMixin):
pass
class MyGrantedNodeChildrenWithAssetsAsTreeApi(ForUserMixin, GrantedNodeChildrenWithAssetsAsTreeApiMixin):
pass

View File

@ -6,7 +6,7 @@ from django.dispatch import receiver
from perms.tasks import create_rebuild_user_tree_task, \
create_rebuild_user_tree_task_by_related_nodes_or_assets
from users.models import User, UserGroup
from assets.models import Asset
from assets.models import Asset, SystemUser
from common.utils import get_logger
from common.exceptions import M2MReverseNotAllowed
from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR
@ -16,6 +16,42 @@ from .models import AssetPermission, RemoteAppPermission
logger = get_logger(__file__)
def handle_rebuild_user_tree(instance, action, reverse, pk_set, **kwargs):
if action.startswith('post'):
if reverse:
create_rebuild_user_tree_task(pk_set)
else:
create_rebuild_user_tree_task([instance.id])
def handle_bind_groups_systemuser(instance, action, reverse, pk_set, **kwargs):
"""
UserGroup 增加 User 增加的 User 需要与 UserGroup 关联的动态系统用户相关联
"""
user: User
if action != POST_ADD:
return
if not reverse:
# 一个用户添加了多个用户组
users_id = [instance.id]
system_users = SystemUser.objects.filter(groups__id__in=pk_set).distinct()
else:
# 一个用户组添加了多个用户
users_id = pk_set
system_users = SystemUser.objects.filter(groups__id=instance.pk).distinct()
for system_user in system_users:
system_user.users.add(*users_id)
@receiver(m2m_changed, sender=User.groups.through)
def on_user_groups_change(**kwargs):
handle_rebuild_user_tree(**kwargs)
handle_bind_groups_systemuser(**kwargs)
@receiver([pre_save], sender=AssetPermission)
def on_asset_perm_deactive(instance: AssetPermission, **kwargs):
try:

View File

@ -5,10 +5,12 @@ from datetime import timedelta
from django.db import transaction
from django.db.models import Q
from django.db.transaction import atomic
from django.conf import settings
from celery import shared_task
from common.utils import get_logger
from common.utils.timezone import now, dt_formater, dt_parser
from users.models import User
from ops.celery.decorator import register_as_period_task
from assets.models import Node
from perms.models import RebuildUserTreeTask, AssetPermission
from perms.utils.asset.user_permission import rebuild_user_mapping_nodes_if_need_with_lock, lock
@ -33,7 +35,8 @@ def dispatch_mapping_node_tasks():
rebuild_user_mapping_nodes_celery_task.delay(id)
@shared_task(queue='check_asset_perm_expired')
@register_as_period_task(interval=settings.PERM_EXPIRED_CHECK_PERIODIC)
@shared_task(queue='celery_check_asset_perm_expired')
@atomic()
def check_asset_permission_expired():
"""

View File

@ -21,7 +21,7 @@ user_permission_urlpatterns = [
# ---------------------------------------------------------
# 以 serializer 格式返回
path('<uuid:pk>/assets/', api.UserAllGrantedAssetsApi.as_view(), name='user-assets'),
path('assets/', api.MyAllAssetsAsTreeApi.as_view(), name='my-assets'),
path('assets/', api.MyAllGrantedAssetsApi.as_view(), name='my-assets'),
# Tree Node 的数据格式返回
path('<uuid:pk>/assets/tree/', api.UserDirectGrantedAssetsAsTreeForAdminApi.as_view(), name='user-assets-as-tree'),
@ -56,7 +56,7 @@ user_permission_urlpatterns = [
path('nodes-with-assets/tree/', api.MyGrantedNodesWithAssetsAsTreeApi.as_view(), name='my-nodes-with-assets-as-tree'),
# 主要用于 luna 页面,带资产的节点树
path('<uuid:pk>/nodes/children-with-assets/tree/', api.UserGrantedNodeChildrenWithAssetsAsTreeForAdminApi.as_view(), name='user-nodes-children-with-assets-as-tree'),
path('<uuid:pk>/nodes/children-with-assets/tree/', api.UserGrantedNodeChildrenWithAssetsAsTreeApi.as_view(), name='user-nodes-children-with-assets-as-tree'),
path('nodes/children-with-assets/tree/', api.MyGrantedNodeChildrenWithAssetsAsTreeApi.as_view(), name='my-nodes-children-with-assets-as-tree'),
# 查询授权树上某个节点的所有资产

View File

@ -34,27 +34,6 @@ TMP_ASSET_GRANTED_FIELD = '_asset_granted'
TMP_GRANTED_ASSETS_AMOUNT_FIELD = '_granted_assets_amount'
# 使用场景
# Asset.objects.filter(get_user_resources_q_granted_by_permissions(user))
def get_user_resources_q_granted_by_permissions(user: User):
"""
获取用户关联的 asset permission 或者 用户组关联的 asset permission 获取规则,
前提 AssetPermission 对象中的 related_name granted_by_permissions
:param user:
:return:
"""
_now = now()
return reduce(and_, (
Q(granted_by_permissions__date_start__lt=_now),
Q(granted_by_permissions__date_expired__gt=_now),
Q(granted_by_permissions__is_active=True),
(
Q(granted_by_permissions__users=user) |
Q(granted_by_permissions__user_groups__users=user)
)
))
# 使用场景
# `Node.objects.annotate(**node_annotate_mapping_node)`
node_annotate_mapping_node = {
@ -215,7 +194,7 @@ def compute_tmp_mapping_node_from_perm(user: User, asset_perms_id=None):
return [*leaf_nodes, *ancestors]
def create_mapping_nodes(user, nodes, clear=True):
def create_mapping_nodes(user, nodes):
to_create = []
for node in nodes:
_granted = getattr(node, TMP_GRANTED_FIELD, False)
@ -231,8 +210,6 @@ def create_mapping_nodes(user, nodes, clear=True):
assets_amount=_granted_assets_amount,
))
if clear:
UserGrantedMappingNode.objects.filter(user=user).delete()
UserGrantedMappingNode.objects.bulk_create(to_create)
@ -254,6 +231,9 @@ def set_node_granted_assets_amount(user, node, asset_perms_id=None):
@tmp_to_root_org()
def rebuild_user_mapping_nodes(user):
logger.info(f'>>> {dt_formater(now())} start rebuild {user} mapping nodes')
# 先删除旧的授权树🌲
UserGrantedMappingNode.objects.filter(user=user).delete()
asset_perms_id = get_user_all_assetpermissions_id(user)
if not asset_perms_id:
# 没有授权直接返回
@ -384,7 +364,8 @@ def get_node_all_granted_assets(user: User, key):
if only_asset_granted_nodes_qs:
only_asset_granted_nodes_q = reduce(or_, only_asset_granted_nodes_qs)
only_asset_granted_nodes_q &= get_user_resources_q_granted_by_permissions(user)
asset_perms_id = get_user_all_assetpermissions_id(user)
only_asset_granted_nodes_q &= Q(granted_by_permissions__id__in=list(asset_perms_id))
q.append(only_asset_granted_nodes_q)
if q:
@ -484,6 +465,9 @@ def get_user_all_assetpermissions_id(user: User):
asset_perms_id = AssetPermission.objects.valid().filter(
Q(users=user) | Q(user_groups__users=user)
).distinct().values_list('id', flat=True)
# !!! 这个很重要,必须转换成 list避免 Django 生成嵌套子查询
asset_perms_id = list(asset_perms_id)
return asset_perms_id

View File

@ -333,7 +333,7 @@ class LDAPImportUtil(object):
def update_or_create(self, user):
user['email'] = self.get_user_email(user)
if user['username'] not in ['admin']:
user['source'] = User.SOURCE_LDAP
user['source'] = User.Source.ldap.value
obj, created = User.objects.update_or_create(
username=user['username'], defaults=user
)

View File

@ -1,12 +0,0 @@
{{image}}{{hidden_field}}{{text_field}}
<script>
function refresh_captcha() {
$.getJSON("{% url "captcha-refresh" %}",
function (result) {
$('.captcha').attr('src', result['image_url']);
$('#id_captcha_0').val(result['key'])
})
}
$('.captcha').click(refresh_captcha)
</script>

View File

@ -1 +0,0 @@
<input id="{{id}}_0" name="{{name}}_0" type="hidden" value="{{key}}" />

View File

@ -1,4 +0,0 @@
{% load i18n %}
{% spaceless %}
{% if audio %}<a title="{% trans "Play CAPTCHA as audio file" %}" href="{{audio}}">{% endif %}<img src="{{image}}" alt="captcha" class="captcha" />{% if audio %}</a>{% endif %}
{% endspaceless %}

View File

@ -1,7 +0,0 @@
{% load i18n %}
<div class="row">
<div class="col-sm-6">
<input autocomplete="off" id="{{id}}_1" class="form-control" name="{{name}}_1" placeholder="{% trans 'Captcha' %}" type="text" />
</div>
</div>
</br>

View File

@ -5,3 +5,4 @@ from .session import *
from .command import *
from .task import *
from .storage import *
from .component import *

View File

@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
#
import logging
from rest_framework import generics, status
from rest_framework.views import Response
from .. import serializers
from ..utils import ComponentsMetricsUtil
from common.permissions import IsAppUser, IsSuperUser
logger = logging.getLogger(__file__)
__all__ = [
'ComponentsStateAPIView', 'ComponentsMetricsAPIView',
]
class ComponentsStateAPIView(generics.CreateAPIView):
""" koko, guacamole, omnidb 上报状态 """
permission_classes = (IsAppUser,)
serializer_class = serializers.ComponentsStateSerializer
class ComponentsMetricsAPIView(generics.GenericAPIView):
""" 返回汇总组件指标数据 """
permission_classes = (IsSuperUser,)
def get(self, request, *args, **kwargs):
component_type = request.query_params.get('type')
util = ComponentsMetricsUtil(component_type)
metrics = util.get_metrics()
return Response(metrics, status=status.HTTP_200_OK)

View File

@ -27,7 +27,7 @@ class TerminalViewSet(JMSBulkModelViewSet):
queryset = Terminal.objects.filter(is_deleted=False)
serializer_class = serializers.TerminalSerializer
permission_classes = (IsSuperUser,)
filter_fields = ['name', 'remote_addr']
filter_fields = ['name', 'remote_addr', 'type']
def create(self, request, *args, **kwargs):
if isinstance(request.data, list):
@ -60,6 +60,15 @@ class TerminalViewSet(JMSBulkModelViewSet):
logger.error("Register terminal error: {}".format(data))
return Response(data, status=400)
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
status = self.request.query_params.get('status')
if not status:
return queryset
filtered_queryset_id = [str(q.id) for q in queryset if q.status == status]
queryset = queryset.filter(id__in=filtered_queryset_id)
return queryset
def get_permissions(self):
if self.action == "create":
self.permission_classes = (AllowAny,)
@ -104,15 +113,11 @@ class StatusViewSet(viewsets.ModelViewSet):
task_serializer_class = serializers.TaskSerializer
def create(self, request, *args, **kwargs):
self.handle_status(request)
self.handle_sessions()
tasks = self.request.user.terminal.task_set.filter(is_finished=False)
serializer = self.task_serializer_class(tasks, many=True)
return Response(serializer.data, status=201)
def handle_status(self, request):
request.user.terminal.is_alive = True
def handle_sessions(self):
sessions_id = self.request.data.get('sessions', [])
# guacamole 上报的 session 是字符串

View File

@ -108,3 +108,27 @@ COMMAND_STORAGE_TYPE_CHOICES_EXTENDS = [
COMMAND_STORAGE_TYPE_CHOICES = COMMAND_STORAGE_TYPE_CHOICES_DEFAULT + \
COMMAND_STORAGE_TYPE_CHOICES_EXTENDS
from django.db.models import TextChoices
from django.utils.translation import ugettext_lazy as _
class ComponentStatusChoices(TextChoices):
critical = 'critical', _('Critical')
high = 'high', _('High')
normal = 'normal', _('Normal')
@classmethod
def status(cls):
return set(dict(cls.choices).keys())
class TerminalTypeChoices(TextChoices):
koko = 'koko', 'KoKo'
guacamole = 'guacamole', 'Guacamole'
omnidb = 'omnidb', 'OmniDB'
@classmethod
def types(cls):
return set(dict(cls.choices).keys())

View File

@ -0,0 +1,42 @@
# Generated by Django 3.1 on 2020-12-10 07:05
from django.db import migrations, models
TERMINAL_TYPE_KOKO = 'koko'
TERMINAL_TYPE_GUACAMOLE = 'guacamole'
TERMINAL_TYPE_OMNIDB = 'omnidb'
def migrate_terminal_type(apps, schema_editor):
terminal_model = apps.get_model("terminal", "Terminal")
db_alias = schema_editor.connection.alias
terminals = terminal_model.objects.using(db_alias).all()
for terminal in terminals:
name = terminal.name.lower()
if 'koko' in name:
_type = TERMINAL_TYPE_KOKO
elif 'gua' in name:
_type = TERMINAL_TYPE_GUACAMOLE
elif 'omnidb' in name:
_type = TERMINAL_TYPE_OMNIDB
else:
_type = TERMINAL_TYPE_KOKO
terminal.type = _type
terminal_model.objects.bulk_update(terminals, ['type'])
class Migration(migrations.Migration):
dependencies = [
('terminal', '0029_auto_20201116_1757'),
]
operations = [
migrations.AddField(
model_name='terminal',
name='type',
field=models.CharField(choices=[('koko', 'KoKo'), ('guacamole', 'Guacamole'), ('omnidb', 'OmniDB')], default='koko', max_length=64, verbose_name='type'),
preserve_default=False,
),
migrations.RunPython(migrate_terminal_type)
]

View File

@ -1,486 +0,0 @@
from __future__ import unicode_literals
import os
import uuid
import jms_storage
from django.db import models
from django.db.models.signals import post_save
from django.utils.translation import ugettext_lazy as _
from django.utils import timezone
from django.conf import settings
from django.core.files.storage import default_storage
from django.core.cache import cache
from assets.models import Asset
from users.models import User
from orgs.mixins.models import OrgModelMixin
from common.mixins import CommonModelMixin
from common.fields.model import EncryptJsonDictTextField
from common.db.models import ChoiceSet
from .backends import get_multi_command_storage
from .backends.command.models import AbstractSessionCommand
from . import const
class Terminal(models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=128, verbose_name=_('Name'))
remote_addr = models.CharField(max_length=128, blank=True, verbose_name=_('Remote Address'))
ssh_port = models.IntegerField(verbose_name=_('SSH Port'), default=2222)
http_port = models.IntegerField(verbose_name=_('HTTP Port'), default=5000)
command_storage = models.CharField(max_length=128, verbose_name=_("Command storage"), default='default')
replay_storage = models.CharField(max_length=128, verbose_name=_("Replay storage"), default='default')
user = models.OneToOneField(User, related_name='terminal', verbose_name='Application User', null=True, on_delete=models.CASCADE)
is_accepted = models.BooleanField(default=False, verbose_name='Is Accepted')
is_deleted = models.BooleanField(default=False)
date_created = models.DateTimeField(auto_now_add=True)
comment = models.TextField(blank=True, verbose_name=_('Comment'))
STATUS_KEY_PREFIX = 'terminal_status_'
@property
def is_alive(self):
key = self.STATUS_KEY_PREFIX + str(self.id)
return bool(cache.get(key))
@is_alive.setter
def is_alive(self, value):
key = self.STATUS_KEY_PREFIX + str(self.id)
cache.set(key, value, 60)
@property
def is_active(self):
if self.user and self.user.is_active:
return True
return False
@is_active.setter
def is_active(self, active):
if self.user:
self.user.is_active = active
self.user.save()
def get_command_storage(self):
storage = CommandStorage.objects.filter(name=self.command_storage).first()
return storage
def get_command_storage_config(self):
s = self.get_command_storage()
if s:
config = s.config
else:
config = settings.DEFAULT_TERMINAL_COMMAND_STORAGE
return config
def get_command_storage_setting(self):
config = self.get_command_storage_config()
return {"TERMINAL_COMMAND_STORAGE": config}
def get_replay_storage(self):
storage = ReplayStorage.objects.filter(name=self.replay_storage).first()
return storage
def get_replay_storage_config(self):
s = self.get_replay_storage()
if s:
config = s.config
else:
config = settings.DEFAULT_TERMINAL_REPLAY_STORAGE
return config
def get_replay_storage_setting(self):
config = self.get_replay_storage_config()
return {"TERMINAL_REPLAY_STORAGE": config}
@staticmethod
def get_login_title_setting():
login_title = None
if settings.XPACK_ENABLED:
from xpack.plugins.interface.models import Interface
login_title = Interface.get_login_title()
return {'TERMINAL_HEADER_TITLE': login_title}
@property
def config(self):
configs = {}
for k in dir(settings):
if not k.startswith('TERMINAL'):
continue
configs[k] = getattr(settings, k)
configs.update(self.get_command_storage_setting())
configs.update(self.get_replay_storage_setting())
configs.update(self.get_login_title_setting())
configs.update({
'SECURITY_MAX_IDLE_TIME': settings.SECURITY_MAX_IDLE_TIME
})
return configs
@property
def service_account(self):
return self.user
def create_app_user(self):
random = uuid.uuid4().hex[:6]
user, access_key = User.create_app_user(
name="{}-{}".format(self.name, random), comment=self.comment
)
self.user = user
self.save()
return user, access_key
def delete(self, using=None, keep_parents=False):
if self.user:
self.user.delete()
self.user = None
self.is_deleted = True
self.save()
return
def __str__(self):
status = "Active"
if not self.is_accepted:
status = "NotAccept"
elif self.is_deleted:
status = "Deleted"
elif not self.is_active:
status = "Disable"
return '%s: %s' % (self.name, status)
class Meta:
ordering = ('is_accepted',)
db_table = "terminal"
class Status(models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
session_online = models.IntegerField(verbose_name=_("Session Online"), default=0)
cpu_used = models.FloatField(verbose_name=_("CPU Usage"))
memory_used = models.FloatField(verbose_name=_("Memory Used"))
connections = models.IntegerField(verbose_name=_("Connections"))
threads = models.IntegerField(verbose_name=_("Threads"))
boot_time = models.FloatField(verbose_name=_("Boot Time"))
terminal = models.ForeignKey(Terminal, null=True, on_delete=models.CASCADE)
date_created = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = 'terminal_status'
get_latest_by = 'date_created'
def __str__(self):
return self.date_created.strftime("%Y-%m-%d %H:%M:%S")
class Session(OrgModelMixin):
class LOGIN_FROM(ChoiceSet):
ST = 'ST', 'SSH Terminal'
WT = 'WT', 'Web Terminal'
class PROTOCOL(ChoiceSet):
SSH = 'ssh', 'ssh'
RDP = 'rdp', 'rdp'
VNC = 'vnc', 'vnc'
TELNET = 'telnet', 'telnet'
MYSQL = 'mysql', 'mysql'
ORACLE = 'oracle', 'oracle'
MARIADB = 'mariadb', 'mariadb'
POSTGRESQL = 'postgresql', 'postgresql'
K8S = 'k8s', 'kubernetes'
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
user = models.CharField(max_length=128, verbose_name=_("User"), db_index=True)
user_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
asset = models.CharField(max_length=128, verbose_name=_("Asset"), db_index=True)
asset_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
system_user = models.CharField(max_length=128, verbose_name=_("System user"), db_index=True)
system_user_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
login_from = models.CharField(max_length=2, choices=LOGIN_FROM.choices, default="ST", verbose_name=_("Login from"))
remote_addr = models.CharField(max_length=128, verbose_name=_("Remote addr"), blank=True, null=True)
is_success = models.BooleanField(default=True, db_index=True)
is_finished = models.BooleanField(default=False, db_index=True)
has_replay = models.BooleanField(default=False, verbose_name=_("Replay"))
has_command = models.BooleanField(default=False, verbose_name=_("Command"))
terminal = models.ForeignKey(Terminal, null=True, on_delete=models.DO_NOTHING, db_constraint=False)
protocol = models.CharField(choices=PROTOCOL.choices, default='ssh', max_length=16, db_index=True)
date_start = models.DateTimeField(verbose_name=_("Date start"), db_index=True, default=timezone.now)
date_end = models.DateTimeField(verbose_name=_("Date end"), null=True)
upload_to = 'replay'
ACTIVE_CACHE_KEY_PREFIX = 'SESSION_ACTIVE_{}'
_DATE_START_FIRST_HAS_REPLAY_RDP_SESSION = None
def get_rel_replay_path(self, version=2):
"""
获取session日志的文件路径
:param version: 原来后缀是 .gz为了统一新版本改为 .replay.gz
:return:
"""
suffix = '.replay.gz'
if version == 1:
suffix = '.gz'
date = self.date_start.strftime('%Y-%m-%d')
return os.path.join(date, str(self.id) + suffix)
def get_local_path(self, version=2):
rel_path = self.get_rel_replay_path(version=version)
if version == 2:
local_path = os.path.join(self.upload_to, rel_path)
else:
local_path = rel_path
return local_path
@property
def asset_obj(self):
return Asset.objects.get(id=self.asset_id)
@property
def _date_start_first_has_replay_rdp_session(self):
if self.__class__._DATE_START_FIRST_HAS_REPLAY_RDP_SESSION is None:
instance = self.__class__.objects.filter(
protocol='rdp', has_replay=True
).order_by('date_start').first()
if not instance:
date_start = timezone.now() - timezone.timedelta(days=365)
else:
date_start = instance.date_start
self.__class__._DATE_START_FIRST_HAS_REPLAY_RDP_SESSION = date_start
return self.__class__._DATE_START_FIRST_HAS_REPLAY_RDP_SESSION
def can_replay(self):
if self.has_replay:
return True
if self.date_start < self._date_start_first_has_replay_rdp_session:
return True
return False
@property
def can_join(self):
_PROTOCOL = self.PROTOCOL
if self.is_finished:
return False
if self.protocol in [_PROTOCOL.SSH, _PROTOCOL.TELNET, _PROTOCOL.K8S]:
return True
else:
return False
@property
def db_protocols(self):
_PROTOCOL = self.PROTOCOL
return [_PROTOCOL.MYSQL, _PROTOCOL.MARIADB, _PROTOCOL.ORACLE, _PROTOCOL.POSTGRESQL]
@property
def can_terminate(self):
_PROTOCOL = self.PROTOCOL
if self.is_finished:
return False
if self.protocol in self.db_protocols:
return False
else:
return True
def save_replay_to_storage(self, f):
local_path = self.get_local_path()
try:
name = default_storage.save(local_path, f)
except OSError as e:
return None, e
if settings.SERVER_REPLAY_STORAGE:
from .tasks import upload_session_replay_to_external_storage
upload_session_replay_to_external_storage.delay(str(self.id))
return name, None
@classmethod
def set_sessions_active(cls, sessions_id):
data = {cls.ACTIVE_CACHE_KEY_PREFIX.format(i): i for i in sessions_id}
cache.set_many(data, timeout=5*60)
@classmethod
def get_active_sessions(cls):
return cls.objects.filter(is_finished=False)
def is_active(self):
if self.protocol in ['ssh', 'telnet', 'rdp', 'mysql']:
key = self.ACTIVE_CACHE_KEY_PREFIX.format(self.id)
return bool(cache.get(key))
return True
@property
def command_amount(self):
command_store = get_multi_command_storage()
return command_store.count(session=str(self.id))
@property
def login_from_display(self):
return self.get_login_from_display()
@classmethod
def generate_fake(cls, count=100, is_finished=True):
import random
from orgs.models import Organization
from users.models import User
from assets.models import Asset, SystemUser
from orgs.utils import get_current_org
from common.utils.random import random_datetime, random_ip
org = get_current_org()
if not org or not org.is_real():
Organization.default().change_to()
i = 0
users = User.objects.all()[:100]
assets = Asset.objects.all()[:100]
system_users = SystemUser.objects.all()[:100]
while i < count:
user_random = random.choices(users, k=10)
assets_random = random.choices(assets, k=10)
system_users = random.choices(system_users, k=10)
ziped = zip(user_random, assets_random, system_users)
sessions = []
now = timezone.now()
month_ago = now - timezone.timedelta(days=30)
for user, asset, system_user in ziped:
ip = random_ip()
date_start = random_datetime(month_ago, now)
date_end = random_datetime(date_start, date_start+timezone.timedelta(hours=2))
data = dict(
user=str(user), user_id=user.id,
asset=str(asset), asset_id=asset.id,
system_user=str(system_user), system_user_id=system_user.id,
remote_addr=ip,
date_start=date_start,
date_end=date_end,
is_finished=is_finished,
)
sessions.append(Session(**data))
cls.objects.bulk_create(sessions)
i += 10
class Meta:
db_table = "terminal_session"
ordering = ["-date_start"]
def __str__(self):
return "{0.id} of {0.user} to {0.asset}".format(self)
class Task(models.Model):
NAME_CHOICES = (
("kill_session", "Kill Session"),
)
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=128, choices=NAME_CHOICES, verbose_name=_("Name"))
args = models.CharField(max_length=1024, verbose_name=_("Args"))
terminal = models.ForeignKey(Terminal, null=True, on_delete=models.SET_NULL)
is_finished = models.BooleanField(default=False)
date_created = models.DateTimeField(auto_now_add=True)
date_finished = models.DateTimeField(null=True)
class Meta:
db_table = "terminal_task"
class CommandManager(models.Manager):
def bulk_create(self, objs, **kwargs):
resp = super().bulk_create(objs, **kwargs)
for i in objs:
post_save.send(i.__class__, instance=i, created=True)
return resp
class Command(AbstractSessionCommand):
objects = CommandManager()
class Meta:
db_table = "terminal_command"
ordering = ('-timestamp',)
class CommandStorage(CommonModelMixin):
TYPE_CHOICES = const.COMMAND_STORAGE_TYPE_CHOICES
TYPE_DEFAULTS = dict(const.REPLAY_STORAGE_TYPE_CHOICES_DEFAULT).keys()
TYPE_SERVER = const.COMMAND_STORAGE_TYPE_SERVER
name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True)
type = models.CharField(
max_length=16, choices=TYPE_CHOICES, verbose_name=_('Type'),
default=TYPE_SERVER
)
meta = EncryptJsonDictTextField(default={})
comment = models.TextField(
max_length=128, default='', blank=True, verbose_name=_('Comment')
)
def __str__(self):
return self.name
@property
def config(self):
config = self.meta
config.update({'TYPE': self.type})
return config
def in_defaults(self):
return self.type in self.TYPE_DEFAULTS
def is_valid(self):
if self.in_defaults():
return True
storage = jms_storage.get_log_storage(self.config)
return storage.ping()
def is_using(self):
return Terminal.objects.filter(command_storage=self.name).exists()
class ReplayStorage(CommonModelMixin):
TYPE_CHOICES = const.REPLAY_STORAGE_TYPE_CHOICES
TYPE_SERVER = const.REPLAY_STORAGE_TYPE_SERVER
TYPE_DEFAULTS = dict(const.REPLAY_STORAGE_TYPE_CHOICES_DEFAULT).keys()
name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True)
type = models.CharField(
max_length=16, choices=TYPE_CHOICES, verbose_name=_('Type'),
default=TYPE_SERVER
)
meta = EncryptJsonDictTextField(default={})
comment = models.TextField(
max_length=128, default='', blank=True, verbose_name=_('Comment')
)
def __str__(self):
return self.name
def convert_type(self):
s3_type_list = [const.REPLAY_STORAGE_TYPE_CEPH]
tp = self.type
if tp in s3_type_list:
tp = const.REPLAY_STORAGE_TYPE_S3
return tp
def get_extra_config(self):
extra_config = {'TYPE': self.convert_type()}
if self.type == const.REPLAY_STORAGE_TYPE_SWIFT:
extra_config.update({'signer': 'S3SignerType'})
return extra_config
@property
def config(self):
config = self.meta
extra_config = self.get_extra_config()
config.update(extra_config)
return config
def in_defaults(self):
return self.type in self.TYPE_DEFAULTS
def is_valid(self):
if self.in_defaults():
return True
storage = jms_storage.get_object_storage(self.config)
target = 'tests.py'
src = os.path.join(settings.BASE_DIR, 'common', target)
return storage.is_valid(src, target)
def is_using(self):
return Terminal.objects.filter(replay_storage=self.name).exists()

View File

@ -0,0 +1,6 @@
from .command import *
from .session import *
from .status import *
from .storage import *
from .task import *
from .terminal import *

View File

@ -0,0 +1,21 @@
from __future__ import unicode_literals
from django.db import models
from django.db.models.signals import post_save
from ..backends.command.models import AbstractSessionCommand
class CommandManager(models.Manager):
def bulk_create(self, objs, **kwargs):
resp = super().bulk_create(objs, **kwargs)
for i in objs:
post_save.send(i.__class__, instance=i, created=True)
return resp
class Command(AbstractSessionCommand):
objects = CommandManager()
class Meta:
db_table = "terminal_command"
ordering = ('-timestamp',)

View File

@ -0,0 +1,210 @@
from __future__ import unicode_literals
import os
import uuid
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.utils import timezone
from django.conf import settings
from django.core.files.storage import default_storage
from django.core.cache import cache
from assets.models import Asset
from orgs.mixins.models import OrgModelMixin
from common.db.models import ChoiceSet
from ..backends import get_multi_command_storage
from .terminal import Terminal
class Session(OrgModelMixin):
class LOGIN_FROM(ChoiceSet):
ST = 'ST', 'SSH Terminal'
WT = 'WT', 'Web Terminal'
class PROTOCOL(ChoiceSet):
SSH = 'ssh', 'ssh'
RDP = 'rdp', 'rdp'
VNC = 'vnc', 'vnc'
TELNET = 'telnet', 'telnet'
MYSQL = 'mysql', 'mysql'
ORACLE = 'oracle', 'oracle'
MARIADB = 'mariadb', 'mariadb'
POSTGRESQL = 'postgresql', 'postgresql'
K8S = 'k8s', 'kubernetes'
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
user = models.CharField(max_length=128, verbose_name=_("User"), db_index=True)
user_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
asset = models.CharField(max_length=128, verbose_name=_("Asset"), db_index=True)
asset_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
system_user = models.CharField(max_length=128, verbose_name=_("System user"), db_index=True)
system_user_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
login_from = models.CharField(max_length=2, choices=LOGIN_FROM.choices, default="ST", verbose_name=_("Login from"))
remote_addr = models.CharField(max_length=128, verbose_name=_("Remote addr"), blank=True, null=True)
is_success = models.BooleanField(default=True, db_index=True)
is_finished = models.BooleanField(default=False, db_index=True)
has_replay = models.BooleanField(default=False, verbose_name=_("Replay"))
has_command = models.BooleanField(default=False, verbose_name=_("Command"))
terminal = models.ForeignKey(Terminal, null=True, on_delete=models.DO_NOTHING, db_constraint=False)
protocol = models.CharField(choices=PROTOCOL.choices, default='ssh', max_length=16, db_index=True)
date_start = models.DateTimeField(verbose_name=_("Date start"), db_index=True, default=timezone.now)
date_end = models.DateTimeField(verbose_name=_("Date end"), null=True)
upload_to = 'replay'
ACTIVE_CACHE_KEY_PREFIX = 'SESSION_ACTIVE_{}'
_DATE_START_FIRST_HAS_REPLAY_RDP_SESSION = None
def get_rel_replay_path(self, version=2):
"""
获取session日志的文件路径
:param version: 原来后缀是 .gz为了统一新版本改为 .replay.gz
:return:
"""
suffix = '.replay.gz'
if version == 1:
suffix = '.gz'
date = self.date_start.strftime('%Y-%m-%d')
return os.path.join(date, str(self.id) + suffix)
def get_local_path(self, version=2):
rel_path = self.get_rel_replay_path(version=version)
if version == 2:
local_path = os.path.join(self.upload_to, rel_path)
else:
local_path = rel_path
return local_path
@property
def asset_obj(self):
return Asset.objects.get(id=self.asset_id)
@property
def _date_start_first_has_replay_rdp_session(self):
if self.__class__._DATE_START_FIRST_HAS_REPLAY_RDP_SESSION is None:
instance = self.__class__.objects.filter(
protocol='rdp', has_replay=True
).order_by('date_start').first()
if not instance:
date_start = timezone.now() - timezone.timedelta(days=365)
else:
date_start = instance.date_start
self.__class__._DATE_START_FIRST_HAS_REPLAY_RDP_SESSION = date_start
return self.__class__._DATE_START_FIRST_HAS_REPLAY_RDP_SESSION
def can_replay(self):
if self.has_replay:
return True
if self.date_start < self._date_start_first_has_replay_rdp_session:
return True
return False
@property
def can_join(self):
_PROTOCOL = self.PROTOCOL
if self.is_finished:
return False
if self.protocol in [_PROTOCOL.SSH, _PROTOCOL.TELNET, _PROTOCOL.K8S]:
return True
else:
return False
@property
def db_protocols(self):
_PROTOCOL = self.PROTOCOL
return [_PROTOCOL.MYSQL, _PROTOCOL.MARIADB, _PROTOCOL.ORACLE, _PROTOCOL.POSTGRESQL]
@property
def can_terminate(self):
_PROTOCOL = self.PROTOCOL
if self.is_finished:
return False
if self.protocol in self.db_protocols:
return False
else:
return True
def save_replay_to_storage(self, f):
local_path = self.get_local_path()
try:
name = default_storage.save(local_path, f)
except OSError as e:
return None, e
if settings.SERVER_REPLAY_STORAGE:
from .tasks import upload_session_replay_to_external_storage
upload_session_replay_to_external_storage.delay(str(self.id))
return name, None
@classmethod
def set_sessions_active(cls, sessions_id):
data = {cls.ACTIVE_CACHE_KEY_PREFIX.format(i): i for i in sessions_id}
cache.set_many(data, timeout=5*60)
@classmethod
def get_active_sessions(cls):
return cls.objects.filter(is_finished=False)
def is_active(self):
if self.protocol in ['ssh', 'telnet', 'rdp', 'mysql']:
key = self.ACTIVE_CACHE_KEY_PREFIX.format(self.id)
return bool(cache.get(key))
return True
@property
def command_amount(self):
command_store = get_multi_command_storage()
return command_store.count(session=str(self.id))
@property
def login_from_display(self):
return self.get_login_from_display()
@classmethod
def generate_fake(cls, count=100, is_finished=True):
import random
from orgs.models import Organization
from users.models import User
from assets.models import Asset, SystemUser
from orgs.utils import get_current_org
from common.utils.random import random_datetime, random_ip
org = get_current_org()
if not org or not org.is_real():
Organization.default().change_to()
i = 0
users = User.objects.all()[:100]
assets = Asset.objects.all()[:100]
system_users = SystemUser.objects.all()[:100]
while i < count:
user_random = random.choices(users, k=10)
assets_random = random.choices(assets, k=10)
system_users = random.choices(system_users, k=10)
ziped = zip(user_random, assets_random, system_users)
sessions = []
now = timezone.now()
month_ago = now - timezone.timedelta(days=30)
for user, asset, system_user in ziped:
ip = random_ip()
date_start = random_datetime(month_ago, now)
date_end = random_datetime(date_start, date_start+timezone.timedelta(hours=2))
data = dict(
user=str(user), user_id=user.id,
asset=str(asset), asset_id=asset.id,
system_user=str(system_user), system_user_id=system_user.id,
remote_addr=ip,
date_start=date_start,
date_end=date_end,
is_finished=is_finished,
)
sessions.append(Session(**data))
cls.objects.bulk_create(sessions)
i += 10
class Meta:
db_table = "terminal_session"
ordering = ["-date_start"]
def __str__(self):
return "{0.id} of {0.user} to {0.asset}".format(self)

View File

@ -0,0 +1,28 @@
from __future__ import unicode_literals
import uuid
from django.db import models
from django.utils.translation import ugettext_lazy as _
from .terminal import Terminal
class Status(models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
session_online = models.IntegerField(verbose_name=_("Session Online"), default=0)
cpu_used = models.FloatField(verbose_name=_("CPU Usage"))
memory_used = models.FloatField(verbose_name=_("Memory Used"))
connections = models.IntegerField(verbose_name=_("Connections"))
threads = models.IntegerField(verbose_name=_("Threads"))
boot_time = models.FloatField(verbose_name=_("Boot Time"))
terminal = models.ForeignKey(Terminal, null=True, on_delete=models.CASCADE)
date_created = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = 'terminal_status'
get_latest_by = 'date_created'
def __str__(self):
return self.date_created.strftime("%Y-%m-%d %H:%M:%S")

View File

@ -0,0 +1,103 @@
from __future__ import unicode_literals
import os
import jms_storage
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.conf import settings
from common.mixins import CommonModelMixin
from common.fields.model import EncryptJsonDictTextField
from .. import const
from .terminal import Terminal
class CommandStorage(CommonModelMixin):
TYPE_CHOICES = const.COMMAND_STORAGE_TYPE_CHOICES
TYPE_DEFAULTS = dict(const.REPLAY_STORAGE_TYPE_CHOICES_DEFAULT).keys()
TYPE_SERVER = const.COMMAND_STORAGE_TYPE_SERVER
name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True)
type = models.CharField(
max_length=16, choices=TYPE_CHOICES, verbose_name=_('Type'),
default=TYPE_SERVER
)
meta = EncryptJsonDictTextField(default={})
comment = models.TextField(
max_length=128, default='', blank=True, verbose_name=_('Comment')
)
def __str__(self):
return self.name
@property
def config(self):
config = self.meta
config.update({'TYPE': self.type})
return config
def in_defaults(self):
return self.type in self.TYPE_DEFAULTS
def is_valid(self):
if self.in_defaults():
return True
storage = jms_storage.get_log_storage(self.config)
return storage.ping()
def is_using(self):
return Terminal.objects.filter(command_storage=self.name).exists()
class ReplayStorage(CommonModelMixin):
TYPE_CHOICES = const.REPLAY_STORAGE_TYPE_CHOICES
TYPE_SERVER = const.REPLAY_STORAGE_TYPE_SERVER
TYPE_DEFAULTS = dict(const.REPLAY_STORAGE_TYPE_CHOICES_DEFAULT).keys()
name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True)
type = models.CharField(
max_length=16, choices=TYPE_CHOICES, verbose_name=_('Type'),
default=TYPE_SERVER
)
meta = EncryptJsonDictTextField(default={})
comment = models.TextField(
max_length=128, default='', blank=True, verbose_name=_('Comment')
)
def __str__(self):
return self.name
def convert_type(self):
s3_type_list = [const.REPLAY_STORAGE_TYPE_CEPH]
tp = self.type
if tp in s3_type_list:
tp = const.REPLAY_STORAGE_TYPE_S3
return tp
def get_extra_config(self):
extra_config = {'TYPE': self.convert_type()}
if self.type == const.REPLAY_STORAGE_TYPE_SWIFT:
extra_config.update({'signer': 'S3SignerType'})
return extra_config
@property
def config(self):
config = self.meta
extra_config = self.get_extra_config()
config.update(extra_config)
return config
def in_defaults(self):
return self.type in self.TYPE_DEFAULTS
def is_valid(self):
if self.in_defaults():
return True
storage = jms_storage.get_object_storage(self.config)
target = 'tests.py'
src = os.path.join(settings.BASE_DIR, 'common', target)
return storage.is_valid(src, target)
def is_using(self):
return Terminal.objects.filter(replay_storage=self.name).exists()

View File

@ -0,0 +1,25 @@
from __future__ import unicode_literals
import uuid
from django.db import models
from django.utils.translation import ugettext_lazy as _
from .terminal import Terminal
class Task(models.Model):
NAME_CHOICES = (
("kill_session", "Kill Session"),
)
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=128, choices=NAME_CHOICES, verbose_name=_("Name"))
args = models.CharField(max_length=1024, verbose_name=_("Args"))
terminal = models.ForeignKey(Terminal, null=True, on_delete=models.SET_NULL)
is_finished = models.BooleanField(default=False)
date_created = models.DateTimeField(auto_now_add=True)
date_finished = models.DateTimeField(null=True)
class Meta:
db_table = "terminal_task"

View File

@ -0,0 +1,247 @@
from __future__ import unicode_literals
import uuid
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.conf import settings
from django.core.cache import cache
from users.models import User
from .. import const
class ComputeStatusMixin:
# system status
@staticmethod
def _common_compute_system_status(value, thresholds):
if thresholds[0] <= value <= thresholds[1]:
return const.ComponentStatusChoices.normal.value
elif thresholds[1] < value <= thresholds[2]:
return const.ComponentStatusChoices.high.value
else:
return const.ComponentStatusChoices.critical.value
def _compute_system_cpu_load_1_status(self, value):
thresholds = [0, 5, 20]
return self._common_compute_system_status(value, thresholds)
def _compute_system_memory_used_percent_status(self, value):
thresholds = [0, 85, 95]
return self._common_compute_system_status(value, thresholds)
def _compute_system_disk_used_percent_status(self, value):
thresholds = [0, 80, 99]
return self._common_compute_system_status(value, thresholds)
def _compute_system_status(self, state):
system_status_keys = [
'system_cpu_load_1', 'system_memory_used_percent', 'system_disk_used_percent'
]
system_status = []
for system_status_key in system_status_keys:
state_value = state[system_status_key]
status = getattr(self, f'_compute_{system_status_key}_status')(state_value)
system_status.append(status)
return system_status
def _compute_component_status(self, state):
system_status = self._compute_system_status(state)
if const.ComponentStatusChoices.critical in system_status:
return const.ComponentStatusChoices.critical
elif const.ComponentStatusChoices.high in system_status:
return const.ComponentStatusChoices.high
else:
return const.ComponentStatusChoices.normal
@staticmethod
def _compute_component_status_display(status):
return getattr(const.ComponentStatusChoices, status).label
class TerminalStateMixin(ComputeStatusMixin):
CACHE_KEY_COMPONENT_STATE = 'CACHE_KEY_COMPONENT_STATE_TERMINAL_{}'
CACHE_TIMEOUT = 120
@property
def cache_key(self):
return self.CACHE_KEY_COMPONENT_STATE.format(str(self.id))
# get
def _get_from_cache(self):
return cache.get(self.cache_key)
def _set_to_cache(self, state):
cache.set(self.cache_key, state, self.CACHE_TIMEOUT)
# set
def _add_status(self, state):
status = self._compute_component_status(state)
status_display = self._compute_component_status_display(status)
state.update({
'status': status,
'status_display': status_display
})
@property
def state(self):
state = self._get_from_cache()
return state or {}
@state.setter
def state(self, state):
self._add_status(state)
self._set_to_cache(state)
class TerminalStatusMixin(TerminalStateMixin):
# alive
@property
def is_alive(self):
return bool(self.state)
# status
@property
def status(self):
if self.is_alive:
return self.state['status']
else:
return const.ComponentStatusChoices.critical.value
@property
def status_display(self):
return self._compute_component_status_display(self.status)
@property
def is_normal(self):
return self.status == const.ComponentStatusChoices.normal.value
@property
def is_high(self):
return self.status == const.ComponentStatusChoices.high.value
@property
def is_critical(self):
return self.status == const.ComponentStatusChoices.critical.value
class Terminal(TerminalStatusMixin, models.Model):
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
name = models.CharField(max_length=128, verbose_name=_('Name'))
type = models.CharField(choices=const.TerminalTypeChoices.choices, max_length=64, verbose_name=_('type'))
remote_addr = models.CharField(max_length=128, blank=True, verbose_name=_('Remote Address'))
ssh_port = models.IntegerField(verbose_name=_('SSH Port'), default=2222)
http_port = models.IntegerField(verbose_name=_('HTTP Port'), default=5000)
command_storage = models.CharField(max_length=128, verbose_name=_("Command storage"), default='default')
replay_storage = models.CharField(max_length=128, verbose_name=_("Replay storage"), default='default')
user = models.OneToOneField(User, related_name='terminal', verbose_name='Application User', null=True, on_delete=models.CASCADE)
is_accepted = models.BooleanField(default=False, verbose_name='Is Accepted')
is_deleted = models.BooleanField(default=False)
date_created = models.DateTimeField(auto_now_add=True)
comment = models.TextField(blank=True, verbose_name=_('Comment'))
@property
def is_active(self):
if self.user and self.user.is_active:
return True
return False
@is_active.setter
def is_active(self, active):
if self.user:
self.user.is_active = active
self.user.save()
def get_command_storage(self):
from .storage import CommandStorage
storage = CommandStorage.objects.filter(name=self.command_storage).first()
return storage
def get_command_storage_config(self):
s = self.get_command_storage()
if s:
config = s.config
else:
config = settings.DEFAULT_TERMINAL_COMMAND_STORAGE
return config
def get_command_storage_setting(self):
config = self.get_command_storage_config()
return {"TERMINAL_COMMAND_STORAGE": config}
def get_replay_storage(self):
from .storage import ReplayStorage
storage = ReplayStorage.objects.filter(name=self.replay_storage).first()
return storage
def get_replay_storage_config(self):
s = self.get_replay_storage()
if s:
config = s.config
else:
config = settings.DEFAULT_TERMINAL_REPLAY_STORAGE
return config
def get_replay_storage_setting(self):
config = self.get_replay_storage_config()
return {"TERMINAL_REPLAY_STORAGE": config}
@staticmethod
def get_login_title_setting():
login_title = None
if settings.XPACK_ENABLED:
from xpack.plugins.interface.models import Interface
login_title = Interface.get_login_title()
return {'TERMINAL_HEADER_TITLE': login_title}
@property
def config(self):
configs = {}
for k in dir(settings):
if not k.startswith('TERMINAL'):
continue
configs[k] = getattr(settings, k)
configs.update(self.get_command_storage_setting())
configs.update(self.get_replay_storage_setting())
configs.update(self.get_login_title_setting())
configs.update({
'SECURITY_MAX_IDLE_TIME': settings.SECURITY_MAX_IDLE_TIME
})
return configs
@property
def service_account(self):
return self.user
def create_app_user(self):
random = uuid.uuid4().hex[:6]
user, access_key = User.create_app_user(
name="{}-{}".format(self.name, random), comment=self.comment
)
self.user = user
self.save()
return user, access_key
def delete(self, using=None, keep_parents=False):
if self.user:
self.user.delete()
self.user = None
self.is_deleted = True
self.save()
return
def __str__(self):
status = "Active"
if not self.is_accepted:
status = "NotAccept"
elif self.is_deleted:
status = "Deleted"
elif not self.is_active:
status = "Disable"
return '%s: %s' % (self.name, status)
class Meta:
ordering = ('is_accepted',)
db_table = "terminal"

View File

@ -4,3 +4,4 @@ from .terminal import *
from .session import *
from .storage import *
from .command import *
from .components import *

View File

@ -0,0 +1,25 @@
from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
class ComponentsStateSerializer(serializers.Serializer):
# system
system_cpu_load_1 = serializers.FloatField(
required=False, default=0, label=_("System cpu load (1 minutes)")
)
system_memory_used_percent = serializers.FloatField(
required=False, default=0, label=_('System memory used percent')
)
system_disk_used_percent = serializers.FloatField(
required=False, default=0, label=_('System disk used percent')
)
# sessions
session_active_count = serializers.IntegerField(
required=False, default=0, label=_("Session active count")
)
def save(self, **kwargs):
request = self.context['request']
terminal = request.user.terminal
terminal.state = self.validated_data

View File

@ -6,19 +6,25 @@ from common.utils import is_uuid
from ..models import (
Terminal, Status, Session, Task, CommandStorage, ReplayStorage
)
from .components import ComponentsStateSerializer
class TerminalSerializer(BulkModelSerializer):
session_online = serializers.SerializerMethodField()
is_alive = serializers.BooleanField(read_only=True)
status = serializers.CharField(read_only=True)
status_display = serializers.CharField(read_only=True)
state = ComponentsStateSerializer(read_only=True)
class Meta:
model = Terminal
fields = [
'id', 'name', 'remote_addr', 'http_port', 'ssh_port',
'id', 'name', 'type', 'remote_addr', 'http_port', 'ssh_port',
'comment', 'is_accepted', "is_active", 'session_online',
'is_alive', 'date_created', 'command_storage', 'replay_storage'
'is_alive', 'date_created', 'command_storage', 'replay_storage',
'status', 'status_display', 'state'
]
read_only_fields = ['type', 'date_created']
@staticmethod
def get_kwargs_may_be_uuid(value):

View File

@ -33,7 +33,10 @@ urlpatterns = [
path('commands/export/', api.CommandExportApi.as_view(), name="command-export"),
path('commands/insecure-command/', api.InsecureCommandAlertAPI.as_view(), name="command-alert"),
path('replay-storages/<uuid:pk>/test-connective/', api.ReplayStorageTestConnectiveApi.as_view(), name='replay-storage-test-connective'),
path('command-storages/<uuid:pk>/test-connective/', api.CommandStorageTestConnectiveApi.as_view(), name='command-storage-test-connective')
path('command-storages/<uuid:pk>/test-connective/', api.CommandStorageTestConnectiveApi.as_view(), name='command-storage-test-connective'),
# components
path('components/metrics/', api.ComponentsMetricsAPIView.as_view(), name='components-metrics'),
path('components/state/', api.ComponentsStateAPIView.as_view(), name='components-state'),
# v2: get session's replay
# path('v2/sessions/<uuid:pk>/replay/',
# api.SessionReplayV2ViewSet.as_view({'get': 'retrieve'}),

View File

@ -11,6 +11,7 @@ import jms_storage
from common.tasks import send_mail_async
from common.utils import get_logger, reverse
from settings.models import Setting
from . import const
from .models import ReplayStorage, Session, Command
@ -101,3 +102,104 @@ def send_command_alert_mail(command):
logger.debug(message)
send_mail_async.delay(subject, message, recipient_list, html_message=message)
class ComponentsMetricsUtil(object):
def __init__(self, component_type=None):
self.type = component_type
self.components = []
self.initial_components()
def initial_components(self):
from .models import Terminal
terminals = Terminal.objects.all().order_by('type')
if self.type:
terminals = terminals.filter(type=self.type)
self.components = list(terminals)
def get_metrics(self):
total_count = normal_count = high_count = critical_count = session_active_total = 0
for component in self.components:
total_count += 1
if not component.is_alive:
critical_count += 1
continue
session_active_total += component.state.get('session_active_count', 0)
if component.is_normal:
normal_count += 1
elif component.is_high:
high_count += 1
else:
critical_count += 1
metrics = {
'total': total_count,
'normal': normal_count,
'high': high_count,
'critical': critical_count,
'session_active': session_active_total
}
return metrics
class ComponentsPrometheusMetricsUtil(ComponentsMetricsUtil):
@staticmethod
def get_status_metrics(metrics):
return {
'any': metrics['total'],
'normal': metrics['normal'],
'high': metrics['high'],
'critical': metrics['critical']
}
def get_prometheus_metrics_text(self):
prometheus_metrics = []
prometheus_metrics.append('# JumpServer 各组件状态个数汇总')
base_status_metric_text = 'jumpserver_components_status_total{component_type="%s", status="%s"} %s'
for component in self.components:
component_type = component.type
base_metrics = self.get_metrics()
prometheus_metrics.append(f'## 组件: {component_type}')
status_metrics = self.get_status_metrics(base_metrics)
for status, value in status_metrics.items():
metric_text = base_status_metric_text % (component_type, status, value)
prometheus_metrics.append(metric_text)
prometheus_metrics.append('\n')
prometheus_metrics.append('# JumpServer 各组件在线会话数汇总')
base_session_active_metric_text = 'jumpserver_components_session_active_total{component_type="%s"} %s'
for component in self.components:
component_type = component.type
prometheus_metrics.append(f'## 组件: {component_type}')
base_metrics = self.get_metrics()
metric_text = base_session_active_metric_text % (
component_type,
base_metrics['session_active']
)
prometheus_metrics.append(metric_text)
prometheus_metrics.append('\n')
prometheus_metrics.append('# JumpServer 各组件节点一些指标')
base_system_state_metric_text = 'jumpserver_components_%s{component_type="%s", component="%s"} %s'
system_states_name = [
'system_cpu_load_1', 'system_memory_used_percent',
'system_disk_used_percent', 'session_active_count'
]
for system_state_name in system_states_name:
prometheus_metrics.append(f'## 指标: {system_state_name}')
for component in self.components:
if not component.is_alive:
continue
component_type = component.type
metric_text = base_system_state_metric_text % (
system_state_name,
component_type,
component.name,
component.state.get(system_state_name)
)
prometheus_metrics.append(metric_text)
prometheus_metrics_text = '\n'.join(prometheus_metrics)
return prometheus_metrics_text

View File

@ -6,8 +6,8 @@ from rest_framework.decorators import action
from rest_framework import generics
from rest_framework.response import Response
from rest_framework_bulk import BulkModelViewSet
from django.db.models import Prefetch
from common.db.aggregates import GroupConcat
from common.permissions import (
IsOrgAdmin, IsOrgAdminOrAppUser,
CanUpdateDeleteUser, IsSuperUser
@ -44,9 +44,18 @@ class UserViewSet(CommonApiMixin, UserQuerysetMixin, BulkModelViewSet):
extra_filter_backends = [OrgRoleUserFilterBackend]
def get_queryset(self):
return super().get_queryset().annotate(
gc_m2m_org_members__role=GroupConcat('m2m_org_members__role'),
).prefetch_related('groups')
queryset = super().get_queryset().prefetch_related(
'groups'
)
if current_org.is_real():
# 为在列表中计算用户在真实组织里的角色
queryset = queryset.prefetch_related(
Prefetch(
'm2m_org_members',
queryset=OrganizationMember.objects.filter(org__id=current_org.id)
)
)
return queryset
def send_created_signal(self, users):
if not isinstance(users, list):

View File

@ -28,7 +28,7 @@ class UserCreateUpdateFormMixin(OrgModelForm):
)
source = forms.ChoiceField(
choices=get_source_choices, required=True,
initial=User.SOURCE_LOCAL, label=_("Source")
initial=User.Source.local.value, label=_("Source")
)
public_key = forms.CharField(
label=_('ssh public key'), max_length=5000, required=False,

View File

@ -0,0 +1,18 @@
# Generated by Django 3.1 on 2020-11-18 10:01
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('users', '0030_auto_20200819_2041'),
]
operations = [
migrations.AlterField(
model_name='user',
name='first_name',
field=models.CharField(blank=True, max_length=150, verbose_name='first name'),
),
]

View File

@ -7,10 +7,10 @@ import string
import random
from django.conf import settings
from django.contrib.auth.hashers import make_password
from django.contrib.auth.models import AbstractUser
from django.core.cache import cache
from django.db import models
from django.db.models import TextChoices
from django.utils.translation import ugettext_lazy as _
from django.utils import timezone
@ -170,22 +170,18 @@ class RoleMixin:
from orgs.models import ROLE as ORG_ROLE
if not current_org.is_real():
# 不是真实的组织,取 User 本身的角色
if self.is_superuser:
return [ORG_ROLE.ADMIN]
else:
return [ORG_ROLE.USER]
if hasattr(self, 'gc_m2m_org_members__role'):
names = self.gc_m2m_org_members__role
if isinstance(names, str):
roles = set(self.gc_m2m_org_members__role.split(','))
else:
roles = set()
else:
roles = set(self.m2m_org_members.filter(
org_id=current_org.id
).values_list('role', flat=True))
roles = list(roles)
# 是真实组织,取 OrganizationMember 中的角色
roles = [
org_member.role
for org_member in self.m2m_org_members.all()
if org_member.org_id == current_org.id
]
roles.sort()
return roles
@ -485,18 +481,12 @@ class MFAMixin:
class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
SOURCE_LOCAL = 'local'
SOURCE_LDAP = 'ldap'
SOURCE_OPENID = 'openid'
SOURCE_RADIUS = 'radius'
SOURCE_CAS = 'cas'
SOURCE_CHOICES = (
(SOURCE_LOCAL, _('Local')),
(SOURCE_LDAP, 'LDAP/AD'),
(SOURCE_OPENID, 'OpenID'),
(SOURCE_RADIUS, 'Radius'),
(SOURCE_CAS, 'CAS'),
)
class Source(TextChoices):
local = 'local', _('Local')
ldap = 'ldap', 'LDAP/AD'
openid = 'openid', 'OpenID'
radius = 'radius', 'Radius'
cas = 'cas', 'CAS'
id = models.UUIDField(default=uuid.uuid4, primary_key=True)
username = models.CharField(
@ -546,7 +536,7 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
max_length=30, default='', blank=True, verbose_name=_('Created by')
)
source = models.CharField(
max_length=30, default=SOURCE_LOCAL, choices=SOURCE_CHOICES,
max_length=30, default=Source.local.value, choices=Source.choices,
verbose_name=_('Source')
)
date_password_last_updated = models.DateTimeField(
@ -597,7 +587,7 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
@property
def is_local(self):
return self.source == self.SOURCE_LOCAL
return self.source == self.Source.local.value
def set_unprovide_attr_if_need(self):
if not self.name:
@ -667,6 +657,6 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
user.groups.add(UserGroup.initial())
def can_send_created_mail(self):
if self.email and self.source == self.SOURCE_LOCAL:
if self.email and self.source == self.Source.local.value:
return True
return False

View File

@ -68,6 +68,10 @@ class UserSerializer(CommonBulkSerializerMixin, serializers.ModelSerializer):
'can_update', 'can_delete', 'login_blocked', 'org_roles'
]
read_only_fields = [
'date_joined', 'last_login', 'created_by', 'is_first_login', 'source'
]
extra_kwargs = {
'password': {'write_only': True, 'required': False, 'allow_null': True, 'allow_blank': True},
'public_key': {'write_only': True},

View File

@ -2,14 +2,12 @@
#
from django.dispatch import receiver
from django.db.models.signals import m2m_changed
from django_auth_ldap.backend import populate_user
from django.conf import settings
from django_cas_ng.signals import cas_user_authenticated
from jms_oidc_rp.signals import openid_create_or_update_user
from perms.tasks import create_rebuild_user_tree_task
from common.utils import get_logger
from .signals import post_user_create
from .models import User
@ -27,19 +25,10 @@ def on_user_create(sender, user=None, **kwargs):
send_user_created_mail(user)
@receiver(m2m_changed, sender=User.groups.through)
def on_user_groups_change(instance, action, reverse, pk_set, **kwargs):
if action.startswith('post'):
if reverse:
create_rebuild_user_tree_task(pk_set)
else:
create_rebuild_user_tree_task([instance.id])
@receiver(cas_user_authenticated)
def on_cas_user_authenticated(sender, user, created, **kwargs):
if created:
user.source = user.SOURCE_CAS
user.source = user.Source.cas.value
user.save()
@ -48,7 +37,7 @@ def on_ldap_create_user(sender, user, ldap_user, **kwargs):
if user and user.username not in ['admin']:
exists = User.objects.filter(username=user.username).exists()
if not exists:
user.source = user.SOURCE_LDAP
user.source = user.Source.ldap.value
user.save()
@ -57,9 +46,9 @@ def on_openid_create_or_update_user(sender, request, user, created, name, userna
if created:
logger.debug(
"Receive OpenID user created signal: {}, "
"Set user source is: {}".format(user, User.SOURCE_OPENID)
"Set user source is: {}".format(user, User.Source.openid.value)
)
user.source = User.SOURCE_OPENID
user.source = User.Source.openid.value
user.save()
elif not created and settings.AUTH_OPENID_ALWAYS_UPDATE_USER:
logger.debug(

View File

@ -22,7 +22,7 @@ logger = get_logger(__file__)
@shared_task
def check_password_expired():
users = User.objects.filter(source=User.SOURCE_LOCAL).exclude(role=User.ROLE.APP)
users = User.objects.filter(source=User.Source.local.value).exclude(role=User.ROLE.APP)
for user in users:
if not user.is_valid:
continue

Some files were not shown because too many files have changed in this diff Show More