pull/8873/head
feng626 2022-08-29 19:53:04 +08:00
commit ca3d2271a8
138 changed files with 4665 additions and 1147 deletions

View File

@ -17,6 +17,7 @@ ARG DEPENDENCIES=" \
libxmlsec1-dev \
libxmlsec1-openssl \
libaio-dev \
openssh-client \
sshpass"
ARG TOOLS=" \
@ -29,24 +30,22 @@ ARG TOOLS=" \
redis-tools \
telnet \
vim \
unzip \
unzip \
wget"
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list \
&& sed -i 's/security.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list \
&& apt update && sleep 1 && apt update \
&& apt -y install ${BUILD_DEPENDENCIES} \
&& apt -y install ${DEPENDENCIES} \
&& apt -y install ${TOOLS} \
RUN sed -i 's@http://.*.debian.org@http://mirrors.ustc.edu.cn@g' /etc/apt/sources.list \
&& apt-get update \
&& apt-get -y install --no-install-recommends ${BUILD_DEPENDENCIES} \
&& apt-get -y install --no-install-recommends ${DEPENDENCIES} \
&& apt-get -y install --no-install-recommends ${TOOLS} \
&& localedef -c -f UTF-8 -i zh_CN zh_CN.UTF-8 \
&& cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
&& mkdir -p /root/.ssh/ \
&& echo "Host *\n\tStrictHostKeyChecking no\n\tUserKnownHostsFile /dev/null" > /root/.ssh/config \
&& sed -i "s@# alias l@alias l@g" ~/.bashrc \
&& echo "set mouse-=a" > ~/.vimrc \
&& rm -rf /var/lib/apt/lists/* \
&& mv /bin/sh /bin/sh.bak \
&& ln -s /bin/bash /bin/sh
&& echo "no" | dpkg-reconfigure dash \
&& rm -rf /var/lib/apt/lists/*
ARG TARGETARCH
ARG ORACLE_LIB_MAJOR=19
@ -65,9 +64,9 @@ RUN mkdir -p /opt/oracle/ \
WORKDIR /tmp/build
COPY ./requirements ./requirements
ARG PIP_MIRROR=https://mirrors.aliyun.com/pypi/simple/
ARG PIP_MIRROR=https://pypi.douban.com/simple
ENV PIP_MIRROR=$PIP_MIRROR
ARG PIP_JMS_MIRROR=https://mirrors.aliyun.com/pypi/simple/
ARG PIP_JMS_MIRROR=https://pypi.douban.com/simple
ENV PIP_JMS_MIRROR=$PIP_JMS_MIRROR
# 因为以 jms 或者 jumpserver 开头的 mirror 上可能没有
RUN pip install --upgrade pip==20.2.4 setuptools==49.6.0 wheel==0.34.2 -i ${PIP_MIRROR} \

View File

@ -16,7 +16,7 @@
JumpServer 是全球首款开源的堡垒机,使用 GPLv3 开源协议,是符合 4A 规范的运维安全审计系统。
JumpServer 是广受欢迎的开源堡垒机,是符合 4A 规范的专业运维安全审计系统。
JumpServer 使用 Python 开发,配备了业界领先的 Web Terminal 方案,交互界面美观、用户体验好。
@ -95,11 +95,15 @@ JumpServer 采纳分布式架构,支持多机房跨区域部署,支持横向
### 案例研究
- [JumpServer 堡垒机护航顺丰科技超大规模资产安全运维](https://blog.fit2cloud.com/?p=1147)
- [JumpServer 堡垒机让“大智慧”的混合 IT 运维更智慧](https://blog.fit2cloud.com/?p=882)
- [携程 JumpServer 堡垒机部署与运营实战](https://blog.fit2cloud.com/?p=851)
- [小红书的JumpServer堡垒机大规模资产跨版本迁移之路](https://blog.fit2cloud.com/?p=516)
- [JumpServer堡垒机助力中手游提升多云环境下安全运维能力](https://blog.fit2cloud.com/?p=732)
- [腾讯海外游戏基于JumpServer构建游戏安全运营能力](https://blog.fit2cloud.com/?p=3704)
- [万华化学通过JumpServer管理全球化分布式IT资产并且实现与云管平台的联动](https://blog.fit2cloud.com/?p=3504)
- [雪花啤酒JumpServer堡垒机使用体会](https://blog.fit2cloud.com/?p=3412)
- [顺丰科技JumpServer 堡垒机护航顺丰科技超大规模资产安全运维](https://blog.fit2cloud.com/?p=1147)
- [沐瞳游戏通过JumpServer管控多项目分布式资产](https://blog.fit2cloud.com/?p=3213)
- [携程JumpServer 堡垒机部署与运营实战](https://blog.fit2cloud.com/?p=851)
- [大智慧JumpServer 堡垒机让“大智慧”的混合 IT 运维更智慧](https://blog.fit2cloud.com/?p=882)
- [小红书的JumpServer堡垒机大规模资产跨版本迁移之路](https://blog.fit2cloud.com/?p=516)
- [中手游JumpServer堡垒机助力中手游提升多云环境下安全运维能力](https://blog.fit2cloud.com/?p=732)
- [中通快递JumpServer主机安全运维实践](https://blog.fit2cloud.com/?p=708)
- [东方明珠JumpServer高效管控异构化、分布式云端资产](https://blog.fit2cloud.com/?p=687)
- [江苏农信JumpServer堡垒机助力行业云安全运维](https://blog.fit2cloud.com/?p=666)

View File

@ -44,58 +44,29 @@ class LoginACL(BaseACL):
def __str__(self):
return self.name
@property
def action_reject(self):
return self.action == self.ActionChoices.reject
@property
def action_allow(self):
return self.action == self.ActionChoices.allow
def is_action(self, action):
return self.action == action
@classmethod
def filter_acl(cls, user):
return user.login_acls.all().valid().distinct()
@staticmethod
def allow_user_confirm_if_need(user, ip):
acl = LoginACL.filter_acl(user).filter(
action=LoginACL.ActionChoices.confirm
).first()
acl = acl if acl and acl.reviewers.exists() else None
if not acl:
return False, acl
ip_group = acl.rules.get('ip_group')
time_periods = acl.rules.get('time_period')
is_contain_ip = contains_ip(ip, ip_group)
is_contain_time_period = contains_time_period(time_periods)
return is_contain_ip and is_contain_time_period, acl
def match(user, ip):
acls = LoginACL.filter_acl(user)
if not acls:
return
@staticmethod
def allow_user_to_login(user, ip):
acl = LoginACL.filter_acl(user).exclude(
action=LoginACL.ActionChoices.confirm
).first()
if not acl:
return True, ''
ip_group = acl.rules.get('ip_group')
time_periods = acl.rules.get('time_period')
is_contain_ip = contains_ip(ip, ip_group)
is_contain_time_period = contains_time_period(time_periods)
reject_type = ''
if is_contain_ip and is_contain_time_period:
# 满足条件
allow = acl.action_allow
if not allow:
reject_type = 'ip' if is_contain_ip else 'time'
else:
# 不满足条件
# 如果acl本身允许那就拒绝如果本身拒绝那就允许
allow = not acl.action_allow
if not allow:
reject_type = 'ip' if not is_contain_ip else 'time'
return allow, reject_type
for acl in acls:
if acl.is_action(LoginACL.ActionChoices.confirm) and not acl.reviewers.exists():
continue
ip_group = acl.rules.get('ip_group')
time_periods = acl.rules.get('time_period')
is_contain_ip = contains_ip(ip, ip_group)
is_contain_time_period = contains_time_period(time_periods)
if is_contain_ip and is_contain_time_period:
# 满足条件,则返回
return acl
def create_confirm_ticket(self, request):
from tickets import const

View File

@ -0,0 +1,91 @@
# coding: utf-8
#
from django.db import models
from django.utils.translation import ugettext_lazy as _
class AppCategory(models.TextChoices):
db = 'db', _('Database')
remote_app = 'remote_app', _('Remote app')
cloud = 'cloud', 'Cloud'
@classmethod
def get_label(cls, category):
return dict(cls.choices).get(category, '')
@classmethod
def is_xpack(cls, category):
return category in ['remote_app']
class AppType(models.TextChoices):
# db category
mysql = 'mysql', 'MySQL'
mariadb = 'mariadb', 'MariaDB'
oracle = 'oracle', 'Oracle'
pgsql = 'postgresql', 'PostgreSQL'
sqlserver = 'sqlserver', 'SQLServer'
redis = 'redis', 'Redis'
mongodb = 'mongodb', 'MongoDB'
# remote-app category
chrome = 'chrome', 'Chrome'
mysql_workbench = 'mysql_workbench', 'MySQL Workbench'
vmware_client = 'vmware_client', 'vSphere Client'
custom = 'custom', _('Custom')
# cloud category
k8s = 'k8s', 'Kubernetes'
@classmethod
def category_types_mapper(cls):
return {
AppCategory.db: [
cls.mysql, cls.mariadb, cls.oracle, cls.pgsql,
cls.sqlserver, cls.redis, cls.mongodb
],
AppCategory.remote_app: [
cls.chrome, cls.mysql_workbench,
cls.vmware_client, cls.custom
],
AppCategory.cloud: [cls.k8s]
}
@classmethod
def type_category_mapper(cls):
mapper = {}
for category, tps in cls.category_types_mapper().items():
for tp in tps:
mapper[tp] = category
return mapper
@classmethod
def get_label(cls, tp):
return dict(cls.choices).get(tp, '')
@classmethod
def db_types(cls):
return [tp.value for tp in cls.category_types_mapper()[AppCategory.db]]
@classmethod
def remote_app_types(cls):
return [tp.value for tp in cls.category_types_mapper()[AppCategory.remote_app]]
@classmethod
def cloud_types(cls):
return [tp.value for tp in cls.category_types_mapper()[AppCategory.cloud]]
@classmethod
def is_xpack(cls, tp):
tp_category_mapper = cls.type_category_mapper()
category = tp_category_mapper[tp]
if AppCategory.is_xpack(category):
return True
return tp in ['oracle', 'postgresql', 'sqlserver']
class OracleVersion(models.TextChoices):
version_11g = '11g', '11g'
version_12c = '12c', '12c'
version_other = 'other', _('Other')

View File

@ -0,0 +1,23 @@
# Generated by Django 3.2.12 on 2022-07-14 02:46
from django.db import migrations
def migrate_db_oracle_version_to_attrs(apps, schema_editor):
db_alias = schema_editor.connection.alias
model = apps.get_model("applications", "Application")
oracles = list(model.objects.using(db_alias).filter(type='oracle'))
for o in oracles:
o.attrs['version'] = '12c'
model.objects.using(db_alias).bulk_update(oracles, ['attrs'])
class Migration(migrations.Migration):
dependencies = [
('applications', '0021_auto_20220629_1826'),
]
operations = [
migrations.RunPython(migrate_db_oracle_version_to_attrs)
]

View File

@ -0,0 +1,48 @@
# Generated by Django 3.1.14 on 2022-07-15 07:56
import time
from collections import defaultdict
from django.db import migrations
def migrate_account_dirty_data(apps, schema_editor):
db_alias = schema_editor.connection.alias
account_model = apps.get_model('applications', 'Account')
count = 0
bulk_size = 1000
while True:
accounts = account_model.objects.using(db_alias) \
.filter(org_id='')[count:count + bulk_size]
if not accounts:
break
accounts = list(accounts)
start = time.time()
for i in accounts:
if i.app:
org_id = i.app.org_id
elif i.systemuser:
org_id = i.systemuser.org_id
else:
org_id = ''
if org_id:
i.org_id = org_id
account_model.objects.bulk_update(accounts, ['org_id', ])
print("Update account org is empty: {}-{} using: {:.2f}s".format(
count, count + len(accounts), time.time() - start
))
count += len(accounts)
class Migration(migrations.Migration):
dependencies = [
('applications', '0022_auto_20220714_1046'),
]
operations = [
migrations.RunPython(migrate_account_dirty_data),
]

View File

@ -0,0 +1,320 @@
from collections import defaultdict
from urllib.parse import urlencode, parse_qsl
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.conf import settings
from orgs.mixins.models import OrgModelMixin
from common.mixins import CommonModelMixin
from common.tree import TreeNode
from common.utils import is_uuid
from assets.models import Asset, SystemUser
from ..const import OracleVersion
from ..utils import KubernetesTree
from .. import const
class ApplicationTreeNodeMixin:
id: str
name: str
type: str
category: str
attrs: dict
@staticmethod
def create_tree_id(pid, type, v):
i = dict(parse_qsl(pid))
i[type] = v
tree_id = urlencode(i)
return tree_id
@classmethod
def create_choice_node(cls, c, id_, pid, tp, opened=False, counts=None,
show_empty=True, show_count=True):
count = counts.get(c.value, 0)
if count == 0 and not show_empty:
return None
label = c.label
if count is not None and show_count:
label = '{} ({})'.format(label, count)
data = {
'id': id_,
'name': label,
'title': label,
'pId': pid,
'isParent': bool(count),
'open': opened,
'iconSkin': '',
'meta': {
'type': tp,
'data': {
'name': c.name,
'value': c.value
}
}
}
return TreeNode(**data)
@classmethod
def create_root_tree_node(cls, queryset, show_count=True):
count = queryset.count() if show_count else None
root_id = 'applications'
root_name = _('Applications')
if count is not None and show_count:
root_name = '{} ({})'.format(root_name, count)
node = TreeNode(**{
'id': root_id,
'name': root_name,
'title': root_name,
'pId': '',
'isParent': True,
'open': True,
'iconSkin': '',
'meta': {
'type': 'applications_root',
}
})
return node
@classmethod
def create_category_tree_nodes(cls, pid, counts=None, show_empty=True, show_count=True):
nodes = []
categories = const.AppType.category_types_mapper().keys()
for category in categories:
if not settings.XPACK_ENABLED and const.AppCategory.is_xpack(category):
continue
i = cls.create_tree_id(pid, 'category', category.value)
node = cls.create_choice_node(
category, i, pid=pid, tp='category',
counts=counts, opened=False, show_empty=show_empty,
show_count=show_count
)
if not node:
continue
nodes.append(node)
return nodes
@classmethod
def create_types_tree_nodes(cls, pid, counts, show_empty=True, show_count=True):
nodes = []
temp_pid = pid
type_category_mapper = const.AppType.type_category_mapper()
types = const.AppType.type_category_mapper().keys()
for tp in types:
if not settings.XPACK_ENABLED and const.AppType.is_xpack(tp):
continue
category = type_category_mapper.get(tp)
pid = cls.create_tree_id(pid, 'category', category.value)
i = cls.create_tree_id(pid, 'type', tp.value)
node = cls.create_choice_node(
tp, i, pid, tp='type', counts=counts, opened=False,
show_empty=show_empty, show_count=show_count
)
pid = temp_pid
if not node:
continue
nodes.append(node)
return nodes
@staticmethod
def get_tree_node_counts(queryset):
counts = defaultdict(int)
values = queryset.values_list('type', 'category')
for i in values:
tp = i[0]
category = i[1]
counts[tp] += 1
counts[category] += 1
return counts
@classmethod
def create_category_type_tree_nodes(cls, queryset, pid, show_empty=True, show_count=True):
counts = cls.get_tree_node_counts(queryset)
tree_nodes = []
# 类别的节点
tree_nodes += cls.create_category_tree_nodes(
pid, counts, show_empty=show_empty,
show_count=show_count
)
# 类型的节点
tree_nodes += cls.create_types_tree_nodes(
pid, counts, show_empty=show_empty,
show_count=show_count
)
return tree_nodes
@classmethod
def create_tree_nodes(cls, queryset, root_node=None, show_empty=True, show_count=True):
tree_nodes = []
# 根节点有可能是组织名称
if root_node is None:
root_node = cls.create_root_tree_node(queryset, show_count=show_count)
tree_nodes.append(root_node)
tree_nodes += cls.create_category_type_tree_nodes(
queryset, root_node.id, show_empty=show_empty, show_count=show_count
)
# 应用的节点
for app in queryset:
if not settings.XPACK_ENABLED and const.AppType.is_xpack(app.type):
continue
node = app.as_tree_node(root_node.id)
tree_nodes.append(node)
return tree_nodes
def create_app_tree_pid(self, root_id):
pid = self.create_tree_id(root_id, 'category', self.category)
pid = self.create_tree_id(pid, 'type', self.type)
return pid
def as_tree_node(self, pid, k8s_as_tree=False):
if self.type == const.AppType.k8s and k8s_as_tree:
node = KubernetesTree(pid).as_tree_node(self)
else:
node = self._as_tree_node(pid)
return node
def _attrs_to_tree(self):
if self.category == const.AppCategory.db:
return self.attrs
return {}
def _as_tree_node(self, pid):
icon_skin_category_mapper = {
'remote_app': 'chrome',
'db': 'database',
'cloud': 'cloud'
}
icon_skin = icon_skin_category_mapper.get(self.category, 'file')
pid = self.create_app_tree_pid(pid)
node = TreeNode(**{
'id': str(self.id),
'name': self.name,
'title': self.name,
'pId': pid,
'isParent': False,
'open': False,
'iconSkin': icon_skin,
'meta': {
'type': 'application',
'data': {
'category': self.category,
'type': self.type,
'attrs': self._attrs_to_tree()
}
}
})
return node
class Application(CommonModelMixin, OrgModelMixin, ApplicationTreeNodeMixin):
APP_TYPE = const.AppType
name = models.CharField(max_length=128, verbose_name=_('Name'))
category = models.CharField(
max_length=16, choices=const.AppCategory.choices, verbose_name=_('Category')
)
type = models.CharField(
max_length=16, choices=const.AppType.choices, verbose_name=_('Type')
)
domain = models.ForeignKey(
'assets.Domain', null=True, blank=True, related_name='applications',
on_delete=models.SET_NULL, verbose_name=_("Domain"),
)
attrs = models.JSONField(default=dict, verbose_name=_('Attrs'))
comment = models.TextField(
max_length=128, default='', blank=True, verbose_name=_('Comment')
)
class Meta:
verbose_name = _('Application')
unique_together = [('org_id', 'name')]
ordering = ('name',)
permissions = [
('match_application', _('Can match application')),
]
def __str__(self):
category_display = self.get_category_display()
type_display = self.get_type_display()
return f'{self.name}({type_display})[{category_display}]'
@property
def category_remote_app(self):
return self.category == const.AppCategory.remote_app.value
@property
def category_cloud(self):
return self.category == const.AppCategory.cloud.value
@property
def category_db(self):
return self.category == const.AppCategory.db.value
def is_type(self, tp):
return self.type == tp
def get_rdp_remote_app_setting(self):
from applications.serializers.attrs import get_serializer_class_by_application_type
if not self.category_remote_app:
raise ValueError(f"Not a remote app application: {self.name}")
serializer_class = get_serializer_class_by_application_type(self.type)
fields = serializer_class().get_fields()
parameters = [self.type]
for field_name in list(fields.keys()):
if field_name in ['asset']:
continue
value = self.attrs.get(field_name)
if not value:
continue
if field_name == 'path':
value = '\"%s\"' % value
parameters.append(str(value))
parameters = ' '.join(parameters)
return {
'program': '||jmservisor',
'working_directory': '',
'parameters': parameters
}
def get_remote_app_asset(self, raise_exception=True):
asset_id = self.attrs.get('asset')
if is_uuid(asset_id):
return Asset.objects.filter(id=asset_id).first()
if raise_exception:
raise ValueError("Remote App not has asset attr")
def get_target_ip(self):
target_ip = ''
if self.category_remote_app:
asset = self.get_remote_app_asset()
target_ip = asset.ip if asset else target_ip
elif self.category_cloud:
target_ip = self.attrs.get('cluster')
elif self.category_db:
target_ip = self.attrs.get('host')
return target_ip
def get_target_protocol_for_oracle(self):
""" Oracle 类型需要单独处理,因为要携带版本号 """
if not self.is_type(self.APP_TYPE.oracle):
return
version = self.attrs.get('version', OracleVersion.version_12c)
if version == OracleVersion.version_other:
return
return 'oracle_%s' % version
class ApplicationUser(SystemUser):
class Meta:
proxy = True
verbose_name = _('Application user')

View File

@ -0,0 +1,60 @@
# coding: utf-8
#
from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
from django.core.exceptions import ObjectDoesNotExist
from common.utils import get_logger, is_uuid, get_object_or_none
from assets.models import Asset
logger = get_logger(__file__)
__all__ = ['RemoteAppSerializer']
class ExistAssetPrimaryKeyRelatedField(serializers.PrimaryKeyRelatedField):
def to_internal_value(self, data):
instance = super().to_internal_value(data)
return str(instance.id)
def to_representation(self, _id):
# _id 是 instance.id
if self.pk_field is not None:
return self.pk_field.to_representation(_id)
# 解决删除资产后远程应用更新页面会显示资产ID的问题
asset = get_object_or_none(Asset, id=_id)
if not asset:
return None
return _id
class RemoteAppSerializer(serializers.Serializer):
asset_info = serializers.SerializerMethodField(label=_('Asset Info'))
asset = ExistAssetPrimaryKeyRelatedField(
queryset=Asset.objects, required=True, label=_("Asset"), allow_null=True
)
path = serializers.CharField(
max_length=128, label=_('Application path'), allow_null=True
)
def validate_asset(self, asset):
if not asset:
raise serializers.ValidationError(_('This field is required.'))
return asset
@staticmethod
def get_asset_info(obj):
asset_id = obj.get('asset')
if not asset_id or not is_uuid(asset_id):
return {}
try:
asset = Asset.objects.get(id=str(asset_id))
except ObjectDoesNotExist as e:
logger.error(e)
return {}
if not asset:
return {}
asset_info = {'id': str(asset.id), 'hostname': asset.hostname}
return asset_info

View File

@ -0,0 +1,16 @@
from rest_framework import serializers
from django.utils.translation import ugettext_lazy as _
from ..application_category import DBSerializer
from applications.const import OracleVersion
__all__ = ['OracleSerializer']
class OracleSerializer(DBSerializer):
version = serializers.ChoiceField(
choices=OracleVersion.choices, default=OracleVersion.version_12c,
allow_null=True, label=_('Version'),
help_text=_('Magnus currently supports only 11g and 12c connections')
)
port = serializers.IntegerField(default=1521, label=_('Port'), allow_null=True)

View File

@ -26,6 +26,17 @@ class AccountHistoryViewSet(AccountViewSet):
}
http_method_names = ['get', 'options']
<<<<<<< HEAD
=======
def get_queryset(self):
queryset = self.model.objects.all() \
.annotate(ip=F('asset__ip')) \
.annotate(hostname=F('asset__hostname')) \
.annotate(platform=F('asset__platform__name')) \
.annotate(protocols=F('asset__protocols'))
return queryset
>>>>>>> origin
class AccountHistorySecretsViewSet(RecordViewLogMixin, AccountHistoryViewSet):
serializer_classes = {

View File

@ -0,0 +1,313 @@
<<<<<<< HEAD
=======
# -*- coding: utf-8 -*-
#
from rest_framework.viewsets import ModelViewSet
from rest_framework.generics import RetrieveAPIView, ListAPIView
from django.shortcuts import get_object_or_404
from django.db.models import Q
from common.utils import get_logger, get_object_or_none
from common.mixins.api import SuggestionMixin, RenderToJsonMixin
from users.models import User, UserGroup
from users.serializers import UserSerializer, UserGroupSerializer
from users.filters import UserFilter
from perms.models import AssetPermission
from perms.serializers import AssetPermissionSerializer
from perms.filters import AssetPermissionFilter
from orgs.mixins.api import OrgBulkModelViewSet
from orgs.mixins import generics
from assets.api import FilterAssetByNodeMixin
from ..models import Asset, Node, Platform, Gateway
from .. import serializers
from ..tasks import (
update_assets_hardware_info_manual, test_assets_connectivity_manual,
test_system_users_connectivity_a_asset, push_system_users_a_asset
)
from ..filters import FilterAssetByNodeFilterBackend, LabelFilterBackend, IpInFilterBackend
logger = get_logger(__file__)
__all__ = [
'AssetViewSet', 'AssetPlatformRetrieveApi',
'AssetGatewayListApi', 'AssetPlatformViewSet',
'AssetTaskCreateApi', 'AssetsTaskCreateApi',
'AssetPermUserListApi', 'AssetPermUserPermissionsListApi',
'AssetPermUserGroupListApi', 'AssetPermUserGroupPermissionsListApi',
]
class AssetViewSet(SuggestionMixin, FilterAssetByNodeMixin, OrgBulkModelViewSet):
"""
API endpoint that allows Asset to be viewed or edited.
"""
model = Asset
filterset_fields = {
'hostname': ['exact'],
'ip': ['exact'],
'system_users__id': ['exact'],
'platform__base': ['exact'],
'is_active': ['exact'],
'protocols': ['exact', 'icontains']
}
search_fields = ("hostname", "ip")
ordering_fields = ("hostname", "ip", "port", "cpu_cores")
ordering = ('hostname', )
serializer_classes = {
'default': serializers.AssetSerializer,
'suggestion': serializers.MiniAssetSerializer
}
rbac_perms = {
'match': 'assets.match_asset'
}
extra_filter_backends = [FilterAssetByNodeFilterBackend, LabelFilterBackend, IpInFilterBackend]
def set_assets_node(self, assets):
if not isinstance(assets, list):
assets = [assets]
node_id = self.request.query_params.get('node_id')
if not node_id:
return
node = get_object_or_none(Node, pk=node_id)
if not node:
return
node.assets.add(*assets)
def perform_create(self, serializer):
assets = serializer.save()
self.set_assets_node(assets)
class AssetPlatformRetrieveApi(RetrieveAPIView):
queryset = Platform.objects.all()
serializer_class = serializers.PlatformSerializer
rbac_perms = {
'retrieve': 'assets.view_gateway'
}
def get_object(self):
asset_pk = self.kwargs.get('pk')
asset = get_object_or_404(Asset, pk=asset_pk)
return asset.platform
class AssetPlatformViewSet(ModelViewSet, RenderToJsonMixin):
queryset = Platform.objects.all()
serializer_class = serializers.PlatformSerializer
filterset_fields = ['name', 'base']
search_fields = ['name']
def check_object_permissions(self, request, obj):
if request.method.lower() in ['delete', 'put', 'patch'] and obj.internal:
self.permission_denied(
request, message={"detail": "Internal platform"}
)
return super().check_object_permissions(request, obj)
class AssetsTaskMixin:
def perform_assets_task(self, serializer):
data = serializer.validated_data
action = data['action']
assets = data.get('assets', [])
if action == "refresh":
task = update_assets_hardware_info_manual.delay(assets)
else:
# action == 'test':
task = test_assets_connectivity_manual.delay(assets)
return task
def perform_create(self, serializer):
task = self.perform_assets_task(serializer)
self.set_task_to_serializer_data(serializer, task)
def set_task_to_serializer_data(self, serializer, task):
data = getattr(serializer, '_data', {})
data["task"] = task.id
setattr(serializer, '_data', data)
class AssetTaskCreateApi(AssetsTaskMixin, generics.CreateAPIView):
model = Asset
serializer_class = serializers.AssetTaskSerializer
def create(self, request, *args, **kwargs):
pk = self.kwargs.get('pk')
request.data['asset'] = pk
request.data['assets'] = [pk]
return super().create(request, *args, **kwargs)
def check_permissions(self, request):
action = request.data.get('action')
action_perm_require = {
'refresh': 'assets.refresh_assethardwareinfo',
'push_system_user': 'assets.push_assetsystemuser',
'test': 'assets.test_assetconnectivity',
'test_system_user': 'assets.test_assetconnectivity'
}
perm_required = action_perm_require.get(action)
has = self.request.user.has_perm(perm_required)
if not has:
self.permission_denied(request)
def perform_asset_task(self, serializer):
data = serializer.validated_data
action = data['action']
if action not in ['push_system_user', 'test_system_user']:
return
asset = data['asset']
system_users = data.get('system_users')
if not system_users:
system_users = asset.get_all_system_users()
if action == 'push_system_user':
task = push_system_users_a_asset.delay(system_users, asset=asset)
elif action == 'test_system_user':
task = test_system_users_connectivity_a_asset.delay(system_users, asset=asset)
else:
task = None
return task
def perform_create(self, serializer):
task = self.perform_asset_task(serializer)
if not task:
task = self.perform_assets_task(serializer)
self.set_task_to_serializer_data(serializer, task)
class AssetsTaskCreateApi(AssetsTaskMixin, generics.CreateAPIView):
model = Asset
serializer_class = serializers.AssetsTaskSerializer
def check_permissions(self, request):
action = request.data.get('action')
action_perm_require = {
'refresh': 'assets.refresh_assethardwareinfo',
}
perm_required = action_perm_require.get(action)
has = self.request.user.has_perm(perm_required)
if not has:
self.permission_denied(request)
class AssetGatewayListApi(generics.ListAPIView):
serializer_class = serializers.GatewayWithAuthSerializer
rbac_perms = {
'list': 'assets.view_gateway'
}
def get_queryset(self):
asset_id = self.kwargs.get('pk')
asset = get_object_or_404(Asset, pk=asset_id)
if not asset.domain:
return Gateway.objects.none()
queryset = asset.domain.gateways.filter(protocol='ssh')
return queryset
class BaseAssetPermUserOrUserGroupListApi(ListAPIView):
rbac_perms = {
'GET': 'perms.view_assetpermission'
}
def get_object(self):
asset_id = self.kwargs.get('pk')
asset = get_object_or_404(Asset, pk=asset_id)
return asset
def get_asset_related_perms(self):
asset = self.get_object()
nodes = asset.get_all_nodes(flat=True)
perms = AssetPermission.objects.filter(Q(assets=asset) | Q(nodes__in=nodes))
return perms
class AssetPermUserListApi(BaseAssetPermUserOrUserGroupListApi):
filterset_class = UserFilter
search_fields = ('username', 'email', 'name', 'id', 'source', 'role')
serializer_class = UserSerializer
rbac_perms = {
'GET': 'perms.view_assetpermission'
}
def get_queryset(self):
perms = self.get_asset_related_perms()
users = User.objects.filter(
Q(assetpermissions__in=perms) | Q(groups__assetpermissions__in=perms)
).distinct()
return users
class AssetPermUserGroupListApi(BaseAssetPermUserOrUserGroupListApi):
serializer_class = UserGroupSerializer
def get_queryset(self):
perms = self.get_asset_related_perms()
user_groups = UserGroup.objects.filter(assetpermissions__in=perms).distinct()
return user_groups
class BaseAssetPermUserOrUserGroupPermissionsListApiMixin(generics.ListAPIView):
model = AssetPermission
serializer_class = AssetPermissionSerializer
filterset_class = AssetPermissionFilter
search_fields = ('name',)
rbac_perms = {
'list': 'perms.view_assetpermission'
}
def get_object(self):
asset_id = self.kwargs.get('pk')
asset = get_object_or_404(Asset, pk=asset_id)
return asset
def filter_asset_related(self, queryset):
asset = self.get_object()
nodes = asset.get_all_nodes(flat=True)
perms = queryset.filter(Q(assets=asset) | Q(nodes__in=nodes))
return perms
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
queryset = self.filter_asset_related(queryset)
return queryset
class AssetPermUserPermissionsListApi(BaseAssetPermUserOrUserGroupPermissionsListApiMixin):
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
queryset = self.filter_user_related(queryset)
queryset = queryset.distinct()
return queryset
def filter_user_related(self, queryset):
user = self.get_perm_user()
user_groups = user.groups.all()
perms = queryset.filter(Q(users=user) | Q(user_groups__in=user_groups))
return perms
def get_perm_user(self):
user_id = self.kwargs.get('perm_user_id')
user = get_object_or_404(User, pk=user_id)
return user
class AssetPermUserGroupPermissionsListApi(BaseAssetPermUserOrUserGroupPermissionsListApiMixin):
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
queryset = self.filter_user_group_related(queryset)
queryset = queryset.distinct()
return queryset
def filter_user_group_related(self, queryset):
user_group = self.get_perm_user_group()
perms = queryset.filter(user_groups=user_group)
return perms
def get_perm_user_group(self):
user_group_id = self.kwargs.get('perm_user_group_id')
user_group = get_object_or_404(UserGroup, pk=user_group_id)
return user_group
>>>>>>> origin

View File

@ -24,7 +24,7 @@ class SerializeToTreeNodeMixin:
'title': _name(node),
'pId': node.parent_key,
'isParent': True,
'open': node.is_org_root(),
'open': True,
'meta': {
'data': {
"id": node.id,

View File

@ -44,7 +44,7 @@ __all__ = [
class NodeViewSet(SuggestionMixin, OrgBulkModelViewSet):
model = Node
filterset_fields = ('value', 'key', 'id')
search_fields = ('value',)
search_fields = ('full_value',)
serializer_class = serializers.NodeSerializer
rbac_perms = {
'match': 'assets.match_node',
@ -102,6 +102,8 @@ class NodeListAsTreeApi(generics.ListAPIView):
class NodeChildrenApi(generics.ListCreateAPIView):
serializer_class = serializers.NodeSerializer
search_fields = ('value',)
instance = None
is_initial = False
@ -180,8 +182,15 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
"""
model = Node
def filter_queryset(self, queryset):
if not self.request.GET.get('search'):
return queryset
queryset = super().filter_queryset(queryset)
queryset = self.model.get_ancestor_queryset(queryset)
return queryset
def list(self, request, *args, **kwargs):
nodes = self.get_queryset().order_by('value')
nodes = self.filter_queryset(self.get_queryset()).order_by('value')
nodes = self.serialize_nodes(nodes, with_asset_amount=True)
assets = self.get_assets()
data = [*nodes, *assets]

View File

@ -0,0 +1,138 @@
# -*- coding: utf-8 -*-
#
from django.db import models
from django.db.models import F
from django.utils.translation import ugettext_lazy as _
from simple_history.models import HistoricalRecords
from common.utils import lazyproperty, get_logger
from .base import BaseUser, AbsConnectivity
logger = get_logger(__name__)
__all__ = ['AuthBook']
class AuthBook(BaseUser, AbsConnectivity):
asset = models.ForeignKey('assets.Asset', on_delete=models.CASCADE, verbose_name=_('Asset'))
systemuser = models.ForeignKey('assets.SystemUser', on_delete=models.CASCADE, null=True, verbose_name=_("System user"))
version = models.IntegerField(default=1, verbose_name=_('Version'))
history = HistoricalRecords()
auth_attrs = ['username', 'password', 'private_key', 'public_key']
class Meta:
verbose_name = _('AuthBook')
unique_together = [('username', 'asset', 'systemuser')]
permissions = [
('test_authbook', _('Can test asset account connectivity')),
('view_assetaccountsecret', _('Can view asset account secret')),
('change_assetaccountsecret', _('Can change asset account secret')),
('view_assethistoryaccount', _('Can view asset history account')),
('view_assethistoryaccountsecret', _('Can view asset history account secret')),
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.auth_snapshot = {}
def get_or_systemuser_attr(self, attr):
val = getattr(self, attr, None)
if val:
return val
if self.systemuser:
return getattr(self.systemuser, attr, '')
return ''
def load_auth(self):
for attr in self.auth_attrs:
value = self.get_or_systemuser_attr(attr)
self.auth_snapshot[attr] = [getattr(self, attr), value]
setattr(self, attr, value)
def unload_auth(self):
if not self.systemuser:
return
for attr, values in self.auth_snapshot.items():
origin_value, loaded_value = values
current_value = getattr(self, attr, '')
if current_value == loaded_value:
setattr(self, attr, origin_value)
def save(self, *args, **kwargs):
self.unload_auth()
instance = super().save(*args, **kwargs)
self.load_auth()
return instance
@property
def username_display(self):
return self.get_or_systemuser_attr('username') or '*'
@lazyproperty
def systemuser_display(self):
if not self.systemuser:
return ''
return str(self.systemuser)
@property
def smart_name(self):
username = self.username_display
if self.asset:
asset = str(self.asset)
else:
asset = '*'
return '{}@{}'.format(username, asset)
def sync_to_system_user_account(self):
if self.systemuser:
return
matched = AuthBook.objects.filter(
asset=self.asset, systemuser__username=self.username
)
if not matched:
return
for i in matched:
i.password = self.password
i.private_key = self.private_key
i.public_key = self.public_key
i.comment = 'Update triggered by account {}'.format(self.id)
# 不触发post_save信号
self.__class__.objects.bulk_update(matched, fields=['password', 'private_key', 'public_key'])
def remove_asset_admin_user_if_need(self):
if not self.asset or not self.systemuser:
return
if not self.systemuser.is_admin_user or self.asset.admin_user != self.systemuser:
return
self.asset.admin_user = None
self.asset.save()
logger.debug('Remove asset admin user: {} {}'.format(self.asset, self.systemuser))
def update_asset_admin_user_if_need(self):
if not self.asset or not self.systemuser:
return
if not self.systemuser.is_admin_user or self.asset.admin_user == self.systemuser:
return
self.asset.admin_user = self.systemuser
self.asset.save()
logger.debug('Update asset admin user: {} {}'.format(self.asset, self.systemuser))
@classmethod
def get_queryset(cls):
queryset = cls.objects.all() \
.annotate(ip=F('asset__ip')) \
.annotate(hostname=F('asset__hostname')) \
.annotate(platform=F('asset__platform__name')) \
.annotate(protocols=F('asset__protocols'))
return queryset
def __str__(self):
return self.smart_name

View File

@ -25,7 +25,6 @@ from orgs.mixins.models import OrgModelMixin, OrgManager
from orgs.utils import get_current_org, tmp_to_org, tmp_to_root_org
from orgs.models import Organization
__all__ = ['Node', 'FamilyMixin', 'compute_parent_key', 'NodeQuerySet']
logger = get_logger(__name__)
@ -98,6 +97,14 @@ class FamilyMixin:
q |= Q(key=self.key)
return Node.objects.filter(q)
@classmethod
def get_ancestor_queryset(cls, queryset, with_self=True):
parent_keys = set()
for i in queryset:
parent_keys.update(set(i.get_ancestor_keys(with_self=with_self)))
queryset = queryset.model.objects.filter(key__in=list(parent_keys)).distinct()
return queryset
@property
def children(self):
return self.get_children(with_self=False)
@ -396,7 +403,7 @@ class NodeAllAssetsMappingMixin:
mapping[ancestor_key].update(asset_ids)
t3 = time.time()
logger.info('t1-t2(DB Query): {} s, t3-t2(Generate mapping): {} s'.format(t2-t1, t3-t2))
logger.info('t1-t2(DB Query): {} s, t3-t2(Generate mapping): {} s'.format(t2 - t1, t3 - t2))
return mapping

View File

@ -0,0 +1,226 @@
# -*- coding: utf-8 -*-
#
from rest_framework import serializers
from django.core.validators import RegexValidator
from django.utils.translation import ugettext_lazy as _
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from ..models import Asset, Node, Platform, SystemUser
__all__ = [
'AssetSerializer', 'AssetSimpleSerializer', 'MiniAssetSerializer',
'ProtocolsField', 'PlatformSerializer',
'AssetTaskSerializer', 'AssetsTaskSerializer', 'ProtocolsField',
]
class ProtocolField(serializers.RegexField):
protocols = '|'.join(dict(Asset.Protocol.choices).keys())
default_error_messages = {
'invalid': _('Protocol format should {}/{}').format(protocols, '1-65535')
}
regex = r'^(%s)/(\d{1,5})$' % protocols
def __init__(self, *args, **kwargs):
super().__init__(self.regex, **kwargs)
def validate_duplicate_protocols(values):
errors = []
names = []
for value in values:
if not value or '/' not in value:
continue
name = value.split('/')[0]
if name in names:
errors.append(_("Protocol duplicate: {}").format(name))
names.append(name)
errors.append('')
if any(errors):
raise serializers.ValidationError(errors)
class ProtocolsField(serializers.ListField):
default_validators = [validate_duplicate_protocols]
def __init__(self, *args, **kwargs):
kwargs['child'] = ProtocolField()
kwargs['allow_null'] = True
kwargs['allow_empty'] = True
kwargs['min_length'] = 1
kwargs['max_length'] = 4
super().__init__(*args, **kwargs)
def to_representation(self, value):
if not value:
return []
return value.split(' ')
class AssetSerializer(BulkOrgResourceModelSerializer):
platform = serializers.SlugRelatedField(
slug_field='name', queryset=Platform.objects.all(), label=_("Platform")
)
protocols = ProtocolsField(label=_('Protocols'), required=False, default=['ssh/22'])
domain_display = serializers.ReadOnlyField(source='domain.name', label=_('Domain name'))
nodes_display = serializers.ListField(
child=serializers.CharField(), label=_('Nodes name'), required=False
)
labels_display = serializers.ListField(
child=serializers.CharField(), label=_('Labels name'), required=False, read_only=True
)
"""
资产的数据结构
"""
class Meta:
model = Asset
fields_mini = ['id', 'hostname', 'ip', 'platform', 'protocols']
fields_small = fields_mini + [
'protocol', 'port', 'protocols', 'is_active',
'public_ip', 'number', 'comment',
]
fields_hardware = [
'vendor', 'model', 'sn', 'cpu_model', 'cpu_count',
'cpu_cores', 'cpu_vcpus', 'memory', 'disk_total', 'disk_info',
'os', 'os_version', 'os_arch', 'hostname_raw',
'cpu_info', 'hardware_info',
]
fields_fk = [
'domain', 'domain_display', 'platform', 'admin_user', 'admin_user_display'
]
fields_m2m = [
'nodes', 'nodes_display', 'labels', 'labels_display',
]
read_only_fields = [
'connectivity', 'date_verified', 'cpu_info', 'hardware_info',
'created_by', 'date_created',
]
fields = fields_small + fields_hardware + fields_fk + fields_m2m + read_only_fields
extra_kwargs = {
'protocol': {'write_only': True},
'port': {'write_only': True},
'hardware_info': {'label': _('Hardware info'), 'read_only': True},
'admin_user_display': {'label': _('Admin user display'), 'read_only': True},
'cpu_info': {'label': _('CPU info')},
}
def get_fields(self):
fields = super().get_fields()
admin_user_field = fields.get('admin_user')
# 因为 mixin 中对 fields 有处理,可能不需要返回 admin_user
if admin_user_field:
admin_user_field.queryset = SystemUser.objects.filter(type=SystemUser.Type.admin)
return fields
@classmethod
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('domain', 'platform', 'admin_user')
queryset = queryset.prefetch_related('nodes', 'labels')
return queryset
def compatible_with_old_protocol(self, validated_data):
protocols_data = validated_data.pop("protocols", [])
# 兼容老的api
name = validated_data.get("protocol")
port = validated_data.get("port")
if not protocols_data and name and port:
protocols_data.insert(0, '/'.join([name, str(port)]))
elif not name and not port and protocols_data:
protocol = protocols_data[0].split('/')
validated_data["protocol"] = protocol[0]
validated_data["port"] = int(protocol[1])
if protocols_data:
validated_data["protocols"] = ' '.join(protocols_data)
def perform_nodes_display_create(self, instance, nodes_display):
if not nodes_display:
return
nodes_to_set = []
for full_value in nodes_display:
node = Node.objects.filter(full_value=full_value).first()
if node:
nodes_to_set.append(node)
else:
node = Node.create_node_by_full_value(full_value)
nodes_to_set.append(node)
instance.nodes.set(nodes_to_set)
def create(self, validated_data):
self.compatible_with_old_protocol(validated_data)
nodes_display = validated_data.pop('nodes_display', '')
instance = super().create(validated_data)
self.perform_nodes_display_create(instance, nodes_display)
return instance
def update(self, instance, validated_data):
nodes_display = validated_data.pop('nodes_display', '')
self.compatible_with_old_protocol(validated_data)
instance = super().update(instance, validated_data)
self.perform_nodes_display_create(instance, nodes_display)
return instance
class MiniAssetSerializer(serializers.ModelSerializer):
class Meta:
model = Asset
fields = AssetSerializer.Meta.fields_mini
class PlatformSerializer(serializers.ModelSerializer):
meta = serializers.DictField(required=False, allow_null=True, label=_('Meta'))
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO 修复 drf SlugField RegexValidator bug之后记得删除
validators = self.fields['name'].validators
if isinstance(validators[-1], RegexValidator):
validators.pop()
class Meta:
model = Platform
fields = [
'id', 'name', 'base', 'charset',
'internal', 'meta', 'comment'
]
extra_kwargs = {
'internal': {'read_only': True},
}
class AssetSimpleSerializer(serializers.ModelSerializer):
class Meta:
model = Asset
fields = ['id', 'hostname', 'ip', 'port', 'connectivity', 'date_verified']
class AssetsTaskSerializer(serializers.Serializer):
ACTION_CHOICES = (
('refresh', 'refresh'),
('test', 'test'),
)
task = serializers.CharField(read_only=True)
action = serializers.ChoiceField(choices=ACTION_CHOICES, write_only=True)
assets = serializers.PrimaryKeyRelatedField(
queryset=Asset.objects, required=False, allow_empty=True, many=True
)
class AssetTaskSerializer(AssetsTaskSerializer):
ACTION_CHOICES = tuple(list(AssetsTaskSerializer.ACTION_CHOICES) + [
('push_system_user', 'push_system_user'),
('test_system_user', 'test_system_user')
])
action = serializers.ChoiceField(choices=ACTION_CHOICES, write_only=True)
asset = serializers.PrimaryKeyRelatedField(
queryset=Asset.objects, required=False, allow_empty=True, many=False
)
system_users = serializers.PrimaryKeyRelatedField(
queryset=SystemUser.objects, required=False, allow_empty=True, many=True
)

View File

@ -12,6 +12,7 @@ from common.api import CommonGenericViewSet
from orgs.mixins.api import OrgGenericViewSet, OrgBulkModelViewSet, OrgRelationMixin
from orgs.utils import current_org
from ops.models import CommandExecution
from . import filters
from .models import FTPLog, UserLoginLog, OperateLog, PasswordChangeLog
from .serializers import FTPLogSerializer, UserLoginLogSerializer, CommandExecutionSerializer
from .serializers import OperateLogSerializer, PasswordChangeLogSerializer, CommandExecutionHostsRelationSerializer
@ -128,10 +129,15 @@ class CommandExecutionViewSet(ListModelMixin, OrgGenericViewSet):
class CommandExecutionHostRelationViewSet(OrgRelationMixin, OrgBulkModelViewSet):
serializer_class = CommandExecutionHostsRelationSerializer
m2m_field = CommandExecution.hosts.field
<<<<<<< HEAD
filterset_fields = [
'id', 'asset', 'commandexecution'
]
search_fields = ('asset__name', )
=======
filterset_class = filters.CommandExecutionFilter
search_fields = ('asset__hostname', )
>>>>>>> origin
http_method_names = ['options', 'get']
rbac_perms = {
'GET': 'ops.view_commandexecution',

View File

@ -1,10 +1,14 @@
from django.db.models import F, Value
from django.db.models.functions import Concat
from django_filters.rest_framework import CharFilter
from rest_framework import filters
from rest_framework.compat import coreapi, coreschema
from orgs.utils import current_org
from ops.models import CommandExecution
from common.drf.filters import BaseFilterSet
__all__ = ['CurrentOrgMembersFilter']
__all__ = ['CurrentOrgMembersFilter', 'CommandExecutionFilter']
class CurrentOrgMembersFilter(filters.BaseFilterBackend):
@ -30,3 +34,22 @@ class CurrentOrgMembersFilter(filters.BaseFilterBackend):
else:
queryset = queryset.filter(user__in=self._get_user_list())
return queryset
class CommandExecutionFilter(BaseFilterSet):
hostname_ip = CharFilter(method='filter_hostname_ip')
class Meta:
model = CommandExecution.hosts.through
fields = (
'id', 'asset', 'commandexecution', 'hostname_ip'
)
def filter_hostname_ip(self, queryset, name, value):
queryset = queryset.annotate(
hostname_ip=Concat(
F('asset__hostname'), Value('('),
F('asset__ip'), Value(')')
)
).filter(hostname_ip__icontains=value)
return queryset

View File

@ -29,7 +29,7 @@ def clean_ftp_log_period():
now = timezone.now()
days = get_log_keep_day('FTP_LOG_KEEP_DAYS')
expired_day = now - datetime.timedelta(days=days)
FTPLog.objects.filter(datetime__lt=expired_day).delete()
FTPLog.objects.filter(date_start__lt=expired_day).delete()
@register_as_period_task(interval=3600*24)

View File

@ -1,3 +1,4 @@
import abc
import os
import json
import base64
@ -16,12 +17,11 @@ from orgs.mixins.api import RootOrgViewMixin
from perms.models import Action
from terminal.models import EndpointRule
from ..serializers import (
ConnectionTokenSerializer, ConnectionTokenSecretSerializer, SuperConnectionTokenSerializer,
ConnectionTokenDisplaySerializer,
ConnectionTokenSerializer, ConnectionTokenSecretSerializer,
SuperConnectionTokenSerializer, ConnectionTokenDisplaySerializer,
)
from ..models import ConnectionToken
__all__ = ['ConnectionTokenViewSet', 'SuperConnectionTokenViewSet']
@ -34,9 +34,12 @@ class ConnectionTokenMixin:
if not is_valid:
raise PermissionDenied(error)
@staticmethod
def get_request_resources(serializer):
user = serializer.validated_data.get('user')
@abc.abstractmethod
def get_request_resource_user(self, serializer):
raise NotImplementedError
def get_request_resources(self, serializer):
user = self.get_request_resource_user(serializer)
asset = serializer.validated_data.get('asset')
application = serializer.validated_data.get('application')
system_user = serializer.validated_data.get('system_user')
@ -164,9 +167,8 @@ class ConnectionTokenMixin:
rdp_options['remoteapplicationname:s'] = name
else:
name = '*'
filename = "{}-{}-jumpserver".format(token.user.username, name)
filename = urllib.parse.quote(filename)
prefix_name = f'{token.user.username}-{name}'
filename = self.get_connect_filename(prefix_name)
content = ''
for k, v in rdp_options.items():
@ -174,6 +176,15 @@ class ConnectionTokenMixin:
return filename, content
@staticmethod
def get_connect_filename(prefix_name):
prefix_name = prefix_name.replace('/', '_')
prefix_name = prefix_name.replace('\\', '_')
prefix_name = prefix_name.replace('.', '_')
filename = f'{prefix_name}-jumpserver'
filename = urllib.parse.quote(filename)
return filename
def get_ssh_token(self, token: ConnectionToken):
if token.asset:
name = token.asset.name
@ -181,7 +192,8 @@ class ConnectionTokenMixin:
name = token.application.name
else:
name = '*'
filename = f'{token.user.username}-{name}-jumpserver'
prefix_name = f'{token.user.username}-{name}'
filename = self.get_connect_filename(prefix_name)
endpoint = self.get_smart_endpoint(
protocol='ssh', asset=token.asset, application=token.application
@ -198,7 +210,12 @@ class ConnectionTokenMixin:
class ConnectionTokenViewSet(ConnectionTokenMixin, RootOrgViewMixin, JMSModelViewSet):
filterset_fields = (
<<<<<<< HEAD
'type', 'user_display', 'asset_display'
=======
'type', 'user_display', 'system_user_display',
'application_display', 'asset_display'
>>>>>>> origin
)
search_fields = filterset_fields
serializer_classes = {
@ -215,7 +232,20 @@ class ConnectionTokenViewSet(ConnectionTokenMixin, RootOrgViewMixin, JMSModelVie
'get_rdp_file': 'authentication.add_connectiontoken',
'get_client_protocol_url': 'authentication.add_connectiontoken',
}
queryset = ConnectionToken.objects.all()
def get_queryset(self):
return ConnectionToken.objects.filter(user=self.request.user)
def get_request_resource_user(self, serializer):
return self.request.user
def get_object(self):
if self.request.user.is_service_account:
# TODO: 组件获取 token 详情,将来放在 Super-connection-token API 中
obj = get_object_or_404(ConnectionToken, pk=self.kwargs.get('pk'))
else:
obj = super(ConnectionTokenViewSet, self).get_object()
return obj
def create_connection_token(self):
data = self.request.query_params if self.request.method == 'GET' else self.request.data
@ -284,6 +314,9 @@ class SuperConnectionTokenViewSet(ConnectionTokenViewSet):
'renewal': 'authentication.add_superconnectiontoken'
}
def get_request_resource_user(self, serializer):
return serializer.validated_data.get('user')
@action(methods=['PATCH'], detail=False)
def renewal(self, request, *args, **kwargs):
from common.utils.timezone import as_current_tz
@ -299,4 +332,3 @@ class SuperConnectionTokenViewSet(ConnectionTokenViewSet):
'msg': f'Token is renewed, date expired: {date_expired}'
}
return Response(data=data, status=status.HTTP_200_OK)

View File

@ -6,6 +6,8 @@ from rest_framework.permissions import AllowAny
from common.utils import get_logger
from .. import errors, mixins
from django.contrib.auth import logout as auth_logout
__all__ = ['TicketStatusApi']
logger = get_logger(__name__)
@ -17,7 +19,15 @@ class TicketStatusApi(mixins.AuthMixin, APIView):
def get(self, request, *args, **kwargs):
try:
self.check_user_login_confirm()
self.request.session['auth_third_party_done'] = 1
return Response({"msg": "ok"})
except errors.LoginConfirmOtherError as e:
reason = e.msg
username = e.username
self.send_auth_signal(success=False, username=username, reason=reason)
# 若为三方登录,此时应退出登录
auth_logout(request)
return Response(e.as_data(), status=200)
except errors.NeedMoreInfoError as e:
return Response(e.as_data(), status=200)

View File

@ -49,7 +49,7 @@ class JMSBaseAuthBackend:
if not allow:
info = 'User {} skip authentication backend {}, because it not in {}'
info = info.format(username, backend_name, ','.join(allowed_backend_names))
logger.debug(info)
logger.info(info)
return allow

View File

@ -3,9 +3,10 @@
from django.urls import path
import django_cas_ng.views
from .views import CASLoginView
urlpatterns = [
path('login/', django_cas_ng.views.LoginView.as_view(), name='cas-login'),
path('login/', CASLoginView.as_view(), name='cas-login'),
path('logout/', django_cas_ng.views.LogoutView.as_view(), name='cas-logout'),
path('callback/', django_cas_ng.views.CallbackView.as_view(), name='cas-proxy-callback'),
]

View File

@ -0,0 +1,15 @@
from django_cas_ng.views import LoginView
from django.core.exceptions import PermissionDenied
from django.http import HttpResponseRedirect
__all__ = ['LoginView']
class CASLoginView(LoginView):
def get(self, request):
try:
return super().get(request)
except PermissionDenied:
return HttpResponseRedirect('/')

View File

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
#
from .backends import *

View File

@ -0,0 +1,157 @@
# -*- coding: utf-8 -*-
#
import requests
from django.contrib.auth import get_user_model
from django.utils.http import urlencode
from django.conf import settings
from django.urls import reverse
from common.utils import get_logger
from users.utils import construct_user_email
from authentication.utils import build_absolute_uri
from common.exceptions import JMSException
from .signals import (
oauth2_create_or_update_user, oauth2_user_login_failed,
oauth2_user_login_success
)
from ..base import JMSModelBackend
__all__ = ['OAuth2Backend']
logger = get_logger(__name__)
class OAuth2Backend(JMSModelBackend):
@staticmethod
def is_enabled():
return settings.AUTH_OAUTH2
def get_or_create_user_from_userinfo(self, request, userinfo):
log_prompt = "Get or Create user [OAuth2Backend]: {}"
logger.debug(log_prompt.format('start'))
# Construct user attrs value
user_attrs = {}
for field, attr in settings.AUTH_OAUTH2_USER_ATTR_MAP.items():
user_attrs[field] = userinfo.get(attr, '')
username = user_attrs.get('username')
if not username:
error_msg = 'username is missing'
logger.error(log_prompt.format(error_msg))
raise JMSException(error_msg)
email = user_attrs.get('email', '')
email = construct_user_email(user_attrs.get('username'), email)
user_attrs.update({'email': email})
logger.debug(log_prompt.format(user_attrs))
user, created = get_user_model().objects.get_or_create(
username=username, defaults=user_attrs
)
logger.debug(log_prompt.format("user: {}|created: {}".format(user, created)))
logger.debug(log_prompt.format("Send signal => oauth2 create or update user"))
oauth2_create_or_update_user.send(
sender=self.__class__, request=request, user=user, created=created,
attrs=user_attrs
)
return user, created
@staticmethod
def get_response_data(response_data):
if response_data.get('data') is not None:
response_data = response_data['data']
return response_data
@staticmethod
def get_query_dict(response_data, query_dict):
query_dict.update({
'uid': response_data.get('uid', ''),
'access_token': response_data.get('access_token', '')
})
return query_dict
def authenticate(self, request, code=None, **kwargs):
log_prompt = "Process authenticate [OAuth2Backend]: {}"
logger.debug(log_prompt.format('Start'))
if code is None:
logger.error(log_prompt.format('code is missing'))
return None
query_dict = {
'client_id': settings.AUTH_OAUTH2_CLIENT_ID,
'client_secret': settings.AUTH_OAUTH2_CLIENT_SECRET,
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': build_absolute_uri(
request, path=reverse(settings.AUTH_OAUTH2_AUTH_LOGIN_CALLBACK_URL_NAME)
)
}
access_token_url = '{url}?{query}'.format(
url=settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT, query=urlencode(query_dict)
)
token_method = settings.AUTH_OAUTH2_ACCESS_TOKEN_METHOD.lower()
requests_func = getattr(requests, token_method, requests.get)
logger.debug(log_prompt.format('Call the access token endpoint[method: %s]' % token_method))
headers = {
'Accept': 'application/json'
}
access_token_response = requests_func(access_token_url, headers=headers)
try:
access_token_response.raise_for_status()
access_token_response_data = access_token_response.json()
response_data = self.get_response_data(access_token_response_data)
except Exception as e:
error = "Json access token response error, access token response " \
"content is: {}, error is: {}".format(access_token_response.content, str(e))
logger.error(log_prompt.format(error))
return None
query_dict = self.get_query_dict(response_data, query_dict)
headers = {
'Accept': 'application/json',
'Authorization': 'token {}'.format(response_data.get('access_token', ''))
}
logger.debug(log_prompt.format('Get userinfo endpoint'))
userinfo_url = '{url}?{query}'.format(
url=settings.AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT,
query=urlencode(query_dict)
)
userinfo_response = requests.get(userinfo_url, headers=headers)
try:
userinfo_response.raise_for_status()
userinfo_response_data = userinfo_response.json()
if 'data' in userinfo_response_data:
userinfo = userinfo_response_data['data']
else:
userinfo = userinfo_response_data
except Exception as e:
error = "Json userinfo response error, userinfo response " \
"content is: {}, error is: {}".format(userinfo_response.content, str(e))
logger.error(log_prompt.format(error))
return None
try:
logger.debug(log_prompt.format('Update or create oauth2 user'))
user, created = self.get_or_create_user_from_userinfo(request, userinfo)
except JMSException:
return None
if self.user_can_authenticate(user):
logger.debug(log_prompt.format('OAuth2 user login success'))
logger.debug(log_prompt.format('Send signal => oauth2 user login success'))
oauth2_user_login_success.send(sender=self.__class__, request=request, user=user)
return user
else:
logger.debug(log_prompt.format('OAuth2 user login failed'))
logger.debug(log_prompt.format('Send signal => oauth2 user login failed'))
oauth2_user_login_failed.send(
sender=self.__class__, request=request, username=user.username,
reason=_('User invalid, disabled or expired')
)
return None

View File

@ -0,0 +1,9 @@
from django.dispatch import Signal
oauth2_create_or_update_user = Signal(
providing_args=['request', 'user', 'created', 'name', 'username', 'email']
)
oauth2_user_login_success = Signal(providing_args=['request', 'user'])
oauth2_user_login_failed = Signal(providing_args=['request', 'username', 'reason'])

View File

@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
#
from django.urls import path
from . import views
urlpatterns = [
path('login/', views.OAuth2AuthRequestView.as_view(), name='login'),
path('callback/', views.OAuth2AuthCallbackView.as_view(), name='login-callback')
]

View File

@ -0,0 +1,58 @@
from django.views import View
from django.conf import settings
from django.contrib.auth import login
from django.http import HttpResponseRedirect
from django.urls import reverse
from django.utils.http import urlencode
from authentication.utils import build_absolute_uri
from common.utils import get_logger
from authentication.mixins import authenticate
logger = get_logger(__file__)
class OAuth2AuthRequestView(View):
def get(self, request):
log_prompt = "Process OAuth2 GET requests: {}"
logger.debug(log_prompt.format('Start'))
query_dict = {
'client_id': settings.AUTH_OAUTH2_CLIENT_ID, 'response_type': 'code',
'scope': settings.AUTH_OAUTH2_SCOPE,
'redirect_uri': build_absolute_uri(
request, path=reverse(settings.AUTH_OAUTH2_AUTH_LOGIN_CALLBACK_URL_NAME)
)
}
redirect_url = '{url}?{query}'.format(
url=settings.AUTH_OAUTH2_PROVIDER_AUTHORIZATION_ENDPOINT,
query=urlencode(query_dict)
)
logger.debug(log_prompt.format('Redirect login url'))
return HttpResponseRedirect(redirect_url)
class OAuth2AuthCallbackView(View):
http_method_names = ['get', ]
def get(self, request):
""" Processes GET requests. """
log_prompt = "Process GET requests [OAuth2AuthCallbackView]: {}"
logger.debug(log_prompt.format('Start'))
callback_params = request.GET
if 'code' in callback_params:
logger.debug(log_prompt.format('Process authenticate'))
user = authenticate(code=callback_params['code'], request=request)
if user and user.is_valid:
logger.debug(log_prompt.format('Login: {}'.format(user)))
login(self.request, user)
logger.debug(log_prompt.format('Redirect'))
return HttpResponseRedirect(
settings.AUTH_OAUTH2_AUTHENTICATION_REDIRECT_URI
)
logger.debug(log_prompt.format('Redirect'))
return HttpResponseRedirect(settings.AUTH_OAUTH2_AUTHENTICATION_FAILURE_REDIRECT_URI)

View File

@ -9,6 +9,7 @@
import base64
import requests
from rest_framework.exceptions import ParseError
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
@ -18,10 +19,11 @@ from django.urls import reverse
from django.conf import settings
from common.utils import get_logger
from authentication.utils import build_absolute_uri_for_oidc
from users.utils import construct_user_email
from ..base import JMSBaseAuthBackend
from .utils import validate_and_return_id_token, build_absolute_uri
from .utils import validate_and_return_id_token
from .decorator import ssl_verification
from .signals import (
openid_create_or_update_user, openid_user_login_failed, openid_user_login_success
@ -127,7 +129,7 @@ class OIDCAuthCodeBackend(OIDCBaseBackend):
token_payload = {
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': build_absolute_uri(
'redirect_uri': build_absolute_uri_for_oidc(
request, path=reverse(settings.AUTH_OPENID_AUTH_LOGIN_CALLBACK_URL_NAME)
)
}

View File

@ -8,7 +8,7 @@
import datetime as dt
from calendar import timegm
from urllib.parse import urlparse, urljoin
from urllib.parse import urlparse
from django.core.exceptions import SuspiciousOperation
from django.utils.encoding import force_bytes, smart_bytes
@ -110,17 +110,3 @@ def _validate_claims(id_token, nonce=None, validate_nonce=True):
raise SuspiciousOperation('Incorrect id_token: nonce')
logger.debug(log_prompt.format('End'))
def build_absolute_uri(request, path=None):
"""
Build absolute redirect uri
"""
if path is None:
path = '/'
if settings.BASE_SITE_URL:
redirect_uri = urljoin(settings.BASE_SITE_URL, path)
else:
redirect_uri = request.build_absolute_uri(path)
return redirect_uri

View File

@ -20,7 +20,8 @@ from django.utils.crypto import get_random_string
from django.utils.http import is_safe_url, urlencode
from django.views.generic import View
from .utils import get_logger, build_absolute_uri
from authentication.utils import build_absolute_uri_for_oidc
from .utils import get_logger
logger = get_logger(__file__)
@ -50,7 +51,7 @@ class OIDCAuthRequestView(View):
'scope': settings.AUTH_OPENID_SCOPES,
'response_type': 'code',
'client_id': settings.AUTH_OPENID_CLIENT_ID,
'redirect_uri': build_absolute_uri(
'redirect_uri': build_absolute_uri_for_oidc(
request, path=reverse(settings.AUTH_OPENID_AUTH_LOGIN_CALLBACK_URL_NAME)
)
})
@ -216,7 +217,7 @@ class OIDCEndSessionView(View):
""" Returns the end-session URL. """
q = QueryDict(mutable=True)
q[settings.AUTH_OPENID_PROVIDER_END_SESSION_REDIRECT_URI_PARAMETER] = \
build_absolute_uri(self.request, path=settings.LOGOUT_REDIRECT_URL or '/')
build_absolute_uri_for_oidc(self.request, path=settings.LOGOUT_REDIRECT_URL or '/')
q[settings.AUTH_OPENID_PROVIDER_END_SESSION_ID_TOKEN_PARAMETER] = \
self.request.session['oidc_auth_id_token']
return '{}?{}'.format(settings.AUTH_OPENID_PROVIDER_END_SESSION_ENDPOINT, q.urlencode())

View File

@ -39,7 +39,7 @@ class SAML2Backend(JMSModelBackend):
return user, created
def authenticate(self, request, saml_user_data=None, **kwargs):
log_prompt = "Process authenticate [SAML2AuthCodeBackend]: {}"
log_prompt = "Process authenticate [SAML2Backend]: {}"
logger.debug(log_prompt.format('Start'))
if saml_user_data is None:
logger.error(log_prompt.format('saml_user_data is missing'))
@ -48,7 +48,7 @@ class SAML2Backend(JMSModelBackend):
logger.debug(log_prompt.format('saml data, {}'.format(saml_user_data)))
username = saml_user_data.get('username')
if not username:
logger.debug(log_prompt.format('username is missing'))
logger.warning(log_prompt.format('username is missing'))
return None
user, created = self.get_or_create_from_saml_data(request, **saml_user_data)

View File

@ -12,12 +12,13 @@ class AuthFailedNeedLogMixin:
username = ''
request = None
error = ''
msg = ''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
post_auth_failed.send(
sender=self.__class__, username=self.username,
request=self.request, reason=self.error
request=self.request, reason=self.msg
)
@ -55,7 +56,8 @@ class BlockGlobalIpLoginError(AuthFailedError):
error = 'block_global_ip_login'
def __init__(self, username, ip, **kwargs):
self.msg = const.block_ip_login_msg.format(settings.SECURITY_LOGIN_IP_LIMIT_TIME)
if not self.msg:
self.msg = const.block_ip_login_msg.format(settings.SECURITY_LOGIN_IP_LIMIT_TIME)
LoginIpBlockUtil(ip).set_block_if_need()
super().__init__(username=username, ip=ip, **kwargs)
@ -65,22 +67,21 @@ class CredentialError(
BlockGlobalIpLoginError, AuthFailedError
):
def __init__(self, error, username, ip, request):
super().__init__(error=error, username=username, ip=ip, request=request)
util = LoginBlockUtil(username, ip)
times_remainder = util.get_remainder_times()
block_time = settings.SECURITY_LOGIN_LIMIT_TIME
if times_remainder < 1:
self.msg = const.block_user_login_msg.format(settings.SECURITY_LOGIN_LIMIT_TIME)
return
default_msg = const.invalid_login_msg.format(
times_try=times_remainder, block_time=block_time
)
if error == const.reason_password_failed:
self.msg = default_msg
else:
self.msg = const.reason_choices.get(error, default_msg)
default_msg = const.invalid_login_msg.format(
times_try=times_remainder, block_time=block_time
)
if error == const.reason_password_failed:
self.msg = default_msg
else:
self.msg = const.reason_choices.get(error, default_msg)
# 先处理 msg 在 super记录日志时原因才准确
super().__init__(error=error, username=username, ip=ip, request=request)
class MFAFailedError(AuthFailedNeedLogMixin, AuthFailedError):
@ -138,18 +139,11 @@ class ACLError(AuthFailedNeedLogMixin, AuthFailedError):
}
class LoginIPNotAllowed(ACLError):
class LoginACLIPAndTimePeriodNotAllowed(ACLError):
def __init__(self, username, request, **kwargs):
self.username = username
self.request = request
super().__init__(_("IP is not allowed"), **kwargs)
class TimePeriodNotAllowed(ACLError):
def __init__(self, username, request, **kwargs):
self.username = username
self.request = request
super().__init__(_("Time Period is not allowed"), **kwargs)
super().__init__(_("Current IP and Time period is not allowed"), **kwargs)
class MFACodeRequiredError(AuthFailedError):

View File

@ -14,23 +14,23 @@ class WeComCodeInvalid(JMSException):
class WeComBindAlready(JMSException):
default_code = 'wecom_bind_already'
default_detail = 'WeCom already binded'
default_code = 'wecom_not_bound'
default_detail = _('WeCom is already bound')
class WeComNotBound(JMSException):
default_code = 'wecom_not_bound'
default_detail = 'WeCom is not bound'
default_detail = _('WeCom is not bound')
class DingTalkNotBound(JMSException):
default_code = 'dingtalk_not_bound'
default_detail = 'DingTalk is not bound'
default_detail = _('DingTalk is not bound')
class FeiShuNotBound(JMSException):
default_code = 'feishu_not_bound'
default_detail = 'FeiShu is not bound'
default_detail = _('FeiShu is not bound')
class PasswordInvalid(JMSException):

View File

@ -69,10 +69,16 @@ class LoginConfirmWaitError(LoginConfirmBaseError):
class LoginConfirmOtherError(LoginConfirmBaseError):
error = 'login_confirm_error'
def __init__(self, ticket_id, status):
def __init__(self, ticket_id, status, username):
self.username = username
msg = const.login_confirm_error_msg.format(status)
super().__init__(ticket_id=ticket_id, msg=msg)
def as_data(self):
ret = super().as_data()
ret['data']['username'] = self.username
return ret
class PasswordTooSimple(NeedRedirectError):
default_code = 'passwd_too_simple'

View File

@ -1,11 +1,16 @@
import base64
from django.shortcuts import redirect, reverse
from django.shortcuts import redirect, reverse, render
from django.utils.deprecation import MiddlewareMixin
from django.http import HttpResponse
from django.conf import settings
from django.utils.translation import ugettext as _
from django.contrib.auth import logout as auth_logout
from apps.authentication import mixins
from common.utils import gen_key_pair
from common.utils import get_request_ip
from .signals import post_auth_failed
class MFAMiddleware:
@ -13,6 +18,7 @@ class MFAMiddleware:
这个 中间件 是用来全局拦截开启了 MFA 却没有认证的 OIDC, CAS使用第三方库做的登录直接 login
所以只能在 Middleware 中控制
"""
def __init__(self, get_response):
self.get_response = get_response
@ -42,6 +48,50 @@ class MFAMiddleware:
return redirect(url)
class ThirdPartyLoginMiddleware(mixins.AuthMixin):
"""OpenID、CAS、SAML2登录规则设置验证"""
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
response = self.get_response(request)
# 没有认证过,证明不是从 第三方 来的
if request.user.is_anonymous:
return response
if not request.session.get('auth_third_party_required'):
return response
ip = get_request_ip(request)
try:
self.request = request
self._check_login_acl(request.user, ip)
except Exception as e:
post_auth_failed.send(
sender=self.__class__, username=request.user.username,
request=self.request, reason=e.msg
)
auth_logout(request)
context = {
'title': _('Authentication failed'),
'message': _('Authentication failed (before login check failed): {}').format(e),
'interval': 10,
'redirect_url': reverse('authentication:login'),
'auto_redirect': True,
}
response = render(request, 'authentication/auth_fail_flash_message_standalone.html', context)
else:
if not self.request.session['auth_confirm_required']:
return response
guard_url = reverse('authentication:login-guard')
args = request.META.get('QUERY_STRING', '')
if args:
guard_url = "%s?%s" % (guard_url, args)
response = redirect(guard_url)
finally:
request.session.pop('auth_third_party_required', '')
return response
class SessionCookieMiddleware(MiddlewareMixin):
@staticmethod

View File

@ -328,13 +328,59 @@ class AuthACLMixin:
def _check_login_acl(self, user, ip):
# ACL 限制用户登录
is_allowed, limit_type = LoginACL.allow_user_to_login(user, ip)
if is_allowed:
acl = LoginACL.match(user, ip)
if not acl:
return
if limit_type == 'ip':
raise errors.LoginIPNotAllowed(username=user.username, request=self.request)
elif limit_type == 'time':
raise errors.TimePeriodNotAllowed(username=user.username, request=self.request)
acl: LoginACL
if acl.is_action(acl.ActionChoices.allow):
return
if acl.is_action(acl.ActionChoices.reject):
raise errors.LoginACLIPAndTimePeriodNotAllowed(user.username, request=self.request)
if acl.is_action(acl.ActionChoices.confirm):
self.request.session['auth_confirm_required'] = '1'
self.request.session['auth_acl_id'] = str(acl.id)
return
def check_user_login_confirm_if_need(self, user):
if not self.request.session.get("auth_confirm_required"):
return
acl_id = self.request.session.get('auth_acl_id')
logger.debug('Login confirm acl id: {}'.format(acl_id))
if not acl_id:
return
acl = LoginACL.filter_acl(user).filter(id=acl_id).first()
if not acl:
return
if not acl.is_action(acl.ActionChoices.confirm):
return
self.get_ticket_or_create(acl)
self.check_user_login_confirm()
def get_ticket_or_create(self, acl):
ticket = self.get_ticket()
if not ticket or ticket.is_state(ticket.State.closed):
ticket = acl.create_confirm_ticket(self.request)
self.request.session['auth_ticket_id'] = str(ticket.id)
return ticket
def check_user_login_confirm(self):
ticket = self.get_ticket()
if not ticket:
raise errors.LoginConfirmOtherError('', "Not found")
elif ticket.is_state(ticket.State.approved):
self.request.session["auth_confirm_required"] = ''
return
elif ticket.is_status(ticket.Status.open):
raise errors.LoginConfirmWaitError(ticket.id)
else:
# rejected, closed
ticket_id = ticket.id
status = ticket.get_state_display()
username = ticket.applicant.username
raise errors.LoginConfirmOtherError(ticket_id, status, username)
def get_ticket(self):
from tickets.models import ApplyLoginTicket
@ -346,44 +392,6 @@ class AuthACLMixin:
ticket = ApplyLoginTicket.all().filter(id=ticket_id).first()
return ticket
def get_ticket_or_create(self, confirm_setting):
ticket = self.get_ticket()
if not ticket or ticket.is_status(ticket.Status.closed):
ticket = confirm_setting.create_confirm_ticket(self.request)
self.request.session['auth_ticket_id'] = str(ticket.id)
return ticket
def check_user_login_confirm(self):
ticket = self.get_ticket()
if not ticket:
raise errors.LoginConfirmOtherError('', "Not found")
if ticket.is_status(ticket.Status.open):
raise errors.LoginConfirmWaitError(ticket.id)
elif ticket.is_state(ticket.State.approved):
self.request.session["auth_confirm"] = "1"
return
elif ticket.is_state(ticket.State.rejected):
raise errors.LoginConfirmOtherError(
ticket.id, ticket.get_state_display()
)
elif ticket.is_state(ticket.State.closed):
raise errors.LoginConfirmOtherError(
ticket.id, ticket.get_state_display()
)
else:
raise errors.LoginConfirmOtherError(
ticket.id, ticket.get_status_display()
)
def check_user_login_confirm_if_need(self, user):
ip = self.get_request_ip()
is_allowed, confirm_setting = LoginACL.allow_user_confirm_if_need(user, ip)
if self.request.session.get('auth_confirm') or not is_allowed:
return
self.get_ticket_or_create(confirm_setting)
self.check_user_login_confirm()
class AuthMixin(CommonMixin, AuthPreCheckMixin, AuthACLMixin, MFAMixin, AuthPostCheckMixin):
request = None
@ -482,7 +490,9 @@ class AuthMixin(CommonMixin, AuthPreCheckMixin, AuthACLMixin, MFAMixin, AuthPost
return self.check_user_auth(valid_data)
def clear_auth_mark(self):
keys = ['auth_password', 'user_id', 'auth_confirm', 'auth_ticket_id']
keys = [
'auth_password', 'user_id', 'auth_confirm_required', 'auth_ticket_id', 'auth_acl_id'
]
for k in keys:
self.request.session.pop(k, '')

View File

@ -216,6 +216,13 @@ class ConnectionToken(OrgModelMixin, JMSBaseModel):
return {}
return self.application.get_rdp_remote_app_setting()
@lazyproperty
def asset_or_remote_app_asset(self):
if self.asset:
return self.asset
if self.application and self.application.category_remote_app:
return self.application.get_remote_app_asset()
@lazyproperty
def cmd_filter_rules(self):
from assets.models import CommandFilterRule

View File

@ -25,9 +25,8 @@ class ConnectionTokenSerializer(OrgResourceModelSerializerMixin):
model = ConnectionToken
fields_mini = ['id', 'type']
fields_small = fields_mini + [
'secret', 'date_expired',
'date_created', 'date_updated', 'created_by', 'updated_by',
'org_id', 'org_name',
'secret', 'date_expired', 'date_created', 'date_updated',
'created_by', 'updated_by', 'org_id', 'org_name',
]
fields_fk = [
'user', 'system_user', 'asset', 'application',
@ -35,8 +34,8 @@ class ConnectionTokenSerializer(OrgResourceModelSerializerMixin):
read_only_fields = [
# 普通 Token 不支持指定 user
'user', 'is_valid', 'expire_time',
'type_display', 'user_display', 'system_user_display', 'asset_display',
'application_display',
'type_display', 'user_display', 'system_user_display',
'asset_display', 'application_display',
]
fields = fields_small + fields_fk + read_only_fields
@ -59,7 +58,7 @@ class ConnectionTokenSerializer(OrgResourceModelSerializerMixin):
system_user = attrs.get('system_user') or ''
asset = attrs.get('asset') or ''
application = attrs.get('application') or ''
secret = attrs.get('secret') or random_string(64)
secret = attrs.get('secret') or random_string(16)
date_expired = attrs.get('date_expired') or ConnectionToken.get_default_date_expired()
if isinstance(asset, Asset):
@ -97,8 +96,8 @@ class SuperConnectionTokenSerializer(ConnectionTokenSerializer):
class Meta(ConnectionTokenSerializer.Meta):
read_only_fields = [
'validity',
'user_display', 'system_user_display', 'asset_display', 'application_display',
'validity', 'user_display', 'system_user_display',
'asset_display', 'application_display',
]
def get_user(self, attrs):
@ -154,7 +153,12 @@ class ConnectionTokenCmdFilterRuleSerializer(serializers.ModelSerializer):
class ConnectionTokenSecretSerializer(OrgResourceModelSerializerMixin):
user = ConnectionTokenUserSerializer(read_only=True)
<<<<<<< HEAD
asset = ConnectionTokenAssetSerializer(read_only=True)
=======
asset = ConnectionTokenAssetSerializer(read_only=True, source='asset_or_remote_app_asset')
application = ConnectionTokenApplicationSerializer(read_only=True)
>>>>>>> origin
remote_app = ConnectionTokenRemoteAppSerializer(read_only=True)
account = serializers.CharField(read_only=True)
gateway = ConnectionTokenGatewaySerializer(read_only=True)

View File

@ -6,12 +6,16 @@ from django.core.cache import cache
from django.dispatch import receiver
from django_cas_ng.signals import cas_user_authenticated
from apps.jumpserver.settings.auth import AUTHENTICATION_BACKENDS_THIRD_PARTY
from authentication.backends.oidc.signals import (
openid_user_login_failed, openid_user_login_success
)
from authentication.backends.saml2.signals import (
saml2_user_authenticated, saml2_user_authentication_failed
)
from authentication.backends.oauth2.signals import (
oauth2_user_login_failed, oauth2_user_login_success
)
from .signals import post_auth_success, post_auth_failed
@ -25,7 +29,8 @@ def on_user_auth_login_success(sender, user, request, **kwargs):
and user.mfa_enabled \
and not request.session.get('auth_mfa'):
request.session['auth_mfa_required'] = 1
if not request.session.get("auth_third_party_done") and request.session.get('auth_backend') in AUTHENTICATION_BACKENDS_THIRD_PARTY:
request.session['auth_third_party_required'] = 1
# 单点登录,超过了自动退出
if settings.USER_LOGIN_SINGLE_MACHINE_ENABLED:
lock_key = 'single_machine_login_' + str(user.id)
@ -67,3 +72,15 @@ def on_saml2_user_login_success(sender, request, user, **kwargs):
def on_saml2_user_login_failed(sender, request, username, reason, **kwargs):
request.session['auth_backend'] = settings.AUTH_BACKEND_SAML2
post_auth_failed.send(sender, username=username, request=request, reason=reason)
@receiver(oauth2_user_login_success)
def on_oauth2_user_login_success(sender, request, user, **kwargs):
request.session['auth_backend'] = settings.AUTH_BACKEND_OAUTH2
post_auth_success.send(sender, user=user, request=request)
@receiver(oauth2_user_login_failed)
def on_oauth2_user_login_failed(sender, username, request, reason, **kwargs):
request.session['auth_backend'] = settings.AUTH_BACKEND_OAUTH2
post_auth_failed.send(sender, username=username, request=request, reason=reason)

View File

@ -0,0 +1,70 @@
{% extends '_base_only_content.html' %}
{% load static %}
{% load i18n %}
{% block html_title %} {{ title }} {% endblock %}
{% block title %} {{ title }}{% endblock %}
{% block content %}
<style>
.alert.alert-msg {
background: #F5F5F7;
}
</style>
<div>
<p>
<div class="alert alert-msg" id="messages">
{% if error %}
{{ error }}
{% else %}
{{ message|safe }}
{% endif %}
</div>
</p>
<div class="row">
{% if has_cancel %}
<div class="col-sm-3">
<a href="{{ cancel_url }}" class="btn btn-default block full-width m-b">
{% trans 'Cancel' %}
</a>
</div>
{% endif %}
<div class="col-sm-3">
<a href="{{ redirect_url }}" class="btn btn-primary block full-width m-b">
{% if confirm_button %}
{{ confirm_button }}
{% else %}
{% trans 'Confirm' %}
{% endif %}
</a>
</div>
</div>
</div>
{% endblock %}
{% block custom_foot_js %}
<script>
var message = ''
var time = '{{ interval }}'
{% if error %}
message = '{{ error }}'
{% else %}
message = '{{ message|safe }}'
{% endif %}
function redirect_page() {
if (time >= 0) {
var msg = message + ' <b>' + time + '</b> ...';
$('#messages').html(msg);
time--;
setTimeout(redirect_page, 1000);
} else {
window.location.href = "{{ redirect_url }}";
}
}
{% if auto_redirect %}
window.onload = redirect_page;
{% endif %}
</script>
{% endblock %}

View File

@ -79,6 +79,9 @@ function doRequestAuth() {
requestApi({
url: url,
method: "GET",
headers: {
"X-JMS-LOGIN-TYPE": "W"
},
success: function (data) {
if (!data.error && data.msg === 'ok') {
window.onbeforeunload = function(){};
@ -98,7 +101,7 @@ function doRequestAuth() {
},
error: function (text, data) {
},
flash_message: false
flash_message: false, // 是否显示flash消息
})
}
function initClipboard() {

View File

@ -56,9 +56,11 @@ urlpatterns = [
path('profile/otp/disable/', users_view.UserOtpDisableView.as_view(),
name='user-otp-disable'),
# openid
# other authentication protocol
path('cas/', include(('authentication.backends.cas.urls', 'authentication'), namespace='cas')),
path('openid/', include(('authentication.backends.oidc.urls', 'authentication'), namespace='openid')),
path('saml2/', include(('authentication.backends.saml2.urls', 'authentication'), namespace='saml2')),
path('oauth2/', include(('authentication.backends.oauth2.urls', 'authentication'), namespace='oauth2')),
path('captcha/', include('captcha.urls')),
]

View File

@ -1,7 +1,10 @@
# -*- coding: utf-8 -*-
#
import ipaddress
from urllib.parse import urljoin, urlparse
from django.conf import settings
from django.utils.translation import ugettext_lazy as _
from common.utils import validate_ip, get_ip_city, get_request_ip
from common.utils import get_logger
@ -22,10 +25,34 @@ def check_different_city_login_if_need(user, request):
else:
city = get_ip_city(ip) or DEFAULT_CITY
city_white = ['LAN', ]
if city not in city_white:
city_white = [_('LAN'), 'LAN']
is_private = ipaddress.ip_address(ip).is_private
if not is_private:
last_user_login = UserLoginLog.objects.exclude(city__in=city_white) \
.filter(username=user.username, status=True).first()
if last_user_login and last_user_login.city != city:
DifferentCityLoginMessage(user, ip, city).publish_async()
def build_absolute_uri(request, path=None):
""" Build absolute redirect """
if path is None:
path = '/'
site_url = urlparse(settings.SITE_URL)
scheme = site_url.scheme or request.scheme
host = request.get_host()
url = f'{scheme}://{host}'
redirect_uri = urljoin(url, path)
return redirect_uri
def build_absolute_uri_for_oidc(request, path=None):
""" Build absolute redirect uri for OIDC """
if path is None:
path = '/'
if settings.BASE_SITE_URL:
# OIDC 专用配置项
redirect_uri = urljoin(settings.BASE_SITE_URL, path)
return redirect_uri
return build_absolute_uri(request, path=path)

View File

@ -21,7 +21,7 @@ from django.conf import settings
from django.urls import reverse_lazy
from django.contrib.auth import BACKEND_SESSION_KEY
from common.utils import FlashMessageUtil
from common.utils import FlashMessageUtil, static_or_direct
from users.utils import (
redirect_user_first_login_or_index
)
@ -39,8 +39,7 @@ class UserLoginContextMixin:
get_user_mfa_context: Callable
request: HttpRequest
@staticmethod
def get_support_auth_methods():
def get_support_auth_methods(self):
auth_methods = [
{
'name': 'OpenID',
@ -63,6 +62,13 @@ class UserLoginContextMixin:
'logo': static('img/login_saml2_logo.png'),
'auto_redirect': True
},
{
'name': settings.AUTH_OAUTH2_PROVIDER,
'enabled': settings.AUTH_OAUTH2,
'url': reverse('authentication:oauth2:login'),
'logo': static_or_direct(settings.AUTH_OAUTH2_LOGO_PATH),
'auto_redirect': True
},
{
'name': _('WeCom'),
'enabled': settings.AUTH_WECOM,

View File

@ -68,7 +68,7 @@ class SimpleMetadataWithFilters(SimpleMetadata):
default = getattr(field, 'default', None)
if default is not None and default != empty:
if isinstance(default, (str, int, bool, datetime.datetime, list)):
if isinstance(default, (str, int, bool, float, datetime.datetime, list)):
field_info['default'] = default
for attr in self.attrs:

View File

@ -0,0 +1 @@
from .sm3 import PBKDF2SM3PasswordHasher

View File

@ -0,0 +1,23 @@
from gmssl import sm3, func
from django.contrib.auth.hashers import PBKDF2PasswordHasher
class Hasher:
name = 'sm3'
def __init__(self, key):
self.key = key
def hexdigest(self):
return sm3.sm3_hash(func.bytes_to_list(self.key))
@staticmethod
def hash(msg):
return Hasher(msg)
class PBKDF2SM3PasswordHasher(PBKDF2PasswordHasher):
algorithm = "pbkdf2_sm3"
digest = Hasher.hash

View File

@ -7,6 +7,9 @@ from rest_framework import permissions
from authentication.const import ConfirmType
from common.exceptions import UserConfirmRequired
from orgs.utils import tmp_to_root_org
from authentication.models import ConnectionToken
from common.utils import get_object_or_none
class IsValidUser(permissions.IsAuthenticated, permissions.BasePermission):
@ -17,6 +20,22 @@ class IsValidUser(permissions.IsAuthenticated, permissions.BasePermission):
and request.user.is_valid
class IsValidUserOrConnectionToken(IsValidUser):
def has_permission(self, request, view):
return super(IsValidUserOrConnectionToken, self).has_permission(request, view) \
or self.is_valid_connection_token(request)
@staticmethod
def is_valid_connection_token(request):
token_id = request.query_params.get('token')
if not token_id:
return False
with tmp_to_root_org():
token = get_object_or_none(ConnectionToken, id=token_id)
return token and token.is_valid
class OnlySuperUser(IsValidUser):
def has_permission(self, request, view):
return super().has_permission(request, view) \
@ -38,6 +57,9 @@ class UserConfirmation(permissions.BasePermission):
confirm_type = ConfirmType.ReLogin
def has_permission(self, request, view):
if not settings.SECURITY_VIEW_AUTH_NEED_MFA:
return True
confirm_level = request.session.get('CONFIRM_LEVEL')
confirm_time = request.session.get('CONFIRM_TIME')

View File

@ -17,4 +17,8 @@ class BaseSMSClient:
def send_sms(self, phone_numbers: list, sign_name: str, template_code: str, template_param: dict, **kwargs):
raise NotImplementedError
@staticmethod
def need_pre_check():
return True

View File

@ -0,0 +1,329 @@
import hashlib
import socket
import struct
import time
from django.conf import settings
from django.utils.translation import ugettext_lazy as _
from common.utils import get_logger
from common.exceptions import JMSException
from .base import BaseSMSClient
logger = get_logger(__file__)
CMPP_CONNECT = 0x00000001 # 请求连接
CMPP_CONNECT_RESP = 0x80000001 # 请求连接应答
CMPP_TERMINATE = 0x00000002 # 终止连接
CMPP_TERMINATE_RESP = 0x80000002 # 终止连接应答
CMPP_SUBMIT = 0x00000004 # 提交短信
CMPP_SUBMIT_RESP = 0x80000004 # 提交短信应答
CMPP_DELIVER = 0x00000005 # 短信下发
CMPP_DELIVER_RESP = 0x80000005 # 下发短信应答
class CMPPBaseRequestInstance(object):
def __init__(self):
self.command_id = ''
self.body = b''
self.length = 0
def get_header(self, sequence_id):
length = struct.pack('!L', 12 + self.length)
command_id = struct.pack('!L', self.command_id)
sequence_id = struct.pack('!L', sequence_id)
return length + command_id + sequence_id
def get_message(self, sequence_id):
return self.get_header(sequence_id) + self.body
class CMPPConnectRequestInstance(CMPPBaseRequestInstance):
def __init__(self, sp_id, sp_secret):
if len(sp_id) != 6:
raise ValueError(_("sp_id is 6 bits"))
super().__init__()
source_addr = sp_id.encode('utf-8')
sp_secret = sp_secret.encode('utf-8')
version = struct.pack('!B', 0x02)
timestamp = struct.pack('!L', int(self.get_now()))
authenticator_source = source_addr + 9 * b'\x00' + sp_secret + self.get_now().encode('utf-8')
auth_source_md5 = hashlib.md5(authenticator_source).digest()
self.body = source_addr + auth_source_md5 + version + timestamp
self.length = len(self.body)
self.command_id = CMPP_CONNECT
@staticmethod
def get_now():
return time.strftime('%m%d%H%M%S', time.localtime(time.time()))
class CMPPSubmitRequestInstance(CMPPBaseRequestInstance):
def __init__(self, msg_src, dest_terminal_id, msg_content, src_id,
service_id='', dest_usr_tl=1):
if len(msg_content) >= 70:
raise JMSException('The message length should be within 70 characters')
if len(dest_terminal_id) > 100:
raise JMSException('The number of users receiving information should be less than 100')
super().__init__()
msg_id = 8 * b'\x00'
pk_total = struct.pack('!B', 1)
pk_number = struct.pack('!B', 1)
registered_delivery = struct.pack('!B', 0)
msg_level = struct.pack('!B', 0)
service_id = ((10 - len(service_id)) * '\x00' + service_id).encode('utf-8')
fee_user_type = struct.pack('!B', 2)
fee_terminal_id = ('0' * 21).encode('utf-8')
tp_pid = struct.pack('!B', 0)
tp_udhi = struct.pack('!B', 0)
msg_fmt = struct.pack('!B', 8)
fee_type = '01'.encode('utf-8')
fee_code = '000000'.encode('utf-8')
valid_time = ('\x00' * 17).encode('utf-8')
at_time = ('\x00' * 17).encode('utf-8')
src_id = ((21 - len(src_id)) * '\x00' + src_id).encode('utf-8')
reserve = b'\x00' * 8
_msg_length = struct.pack('!B', len(msg_content) * 2)
_msg_src = msg_src.encode('utf-8')
_dest_usr_tl = struct.pack('!B', dest_usr_tl)
_msg_content = msg_content.encode('utf-16-be')
_dest_terminal_id = b''.join([
(i + (21 - len(i)) * '\x00').encode('utf-8') for i in dest_terminal_id
])
self.length = 126 + 21 * dest_usr_tl + len(_msg_content)
self.command_id = CMPP_SUBMIT
self.body = msg_id + pk_total + pk_number + registered_delivery \
+ msg_level + service_id + fee_user_type + fee_terminal_id \
+ tp_pid + tp_udhi + msg_fmt + _msg_src + fee_type + fee_code \
+ valid_time + at_time + src_id + _dest_usr_tl + _dest_terminal_id \
+ _msg_length + _msg_content + reserve
class CMPPTerminateRequestInstance(CMPPBaseRequestInstance):
def __init__(self):
super().__init__()
self.body = b''
self.command_id = CMPP_TERMINATE
class CMPPDeliverRespRequestInstance(CMPPBaseRequestInstance):
def __init__(self, msg_id, result=0):
super().__init__()
msg_id = struct.pack('!Q', msg_id)
result = struct.pack('!B', result)
self.length = len(self.body)
self.body = msg_id + result
class CMPPResponseInstance(object):
def __init__(self):
self.command_id = None
self.length = None
self.response_handler_map = {
CMPP_CONNECT_RESP: self.connect_response_parse,
CMPP_SUBMIT_RESP: self.submit_response_parse,
CMPP_DELIVER: self.deliver_request_parse,
}
@staticmethod
def connect_response_parse(body):
status, = struct.unpack('!B', body[0:1])
authenticator_ISMG = body[1:17]
version, = struct.unpack('!B', body[17:18])
return {
'Status': status,
'AuthenticatorISMG': authenticator_ISMG,
'Version': version
}
@staticmethod
def submit_response_parse(body):
msg_id = body[:8]
result = struct.unpack('!B', body[8:9])
return {
'Msg_Id': msg_id, 'Result': result[0]
}
@staticmethod
def deliver_request_parse(body):
msg_id, = struct.unpack('!Q', body[0:8])
dest_id = body[8:29]
service_id = body[29:39]
tp_pid = struct.unpack('!B', body[39:40])
tp_udhi = struct.unpack('!B', body[40:41])
msg_fmt = struct.unpack('!B', body[41:42])
src_terminal_id = body[42:63]
registered_delivery = struct.unpack('!B', body[63:64])
msg_length = struct.unpack('!B', body[64:65])
msg_content = body[65:msg_length[0]+65]
return {
'Msg_Id': msg_id, 'Dest_Id': dest_id, 'Service_Id': service_id,
'TP_pid': tp_pid, 'TP_udhi': tp_udhi, 'Msg_Fmt': msg_fmt,
'Src_terminal_Id': src_terminal_id, 'Registered_Delivery': registered_delivery,
'Msg_Length': msg_length, 'Msg_content': msg_content
}
def parse_header(self, data):
self.command_id, = struct.unpack('!L', data[4:8])
sequence_id, = struct.unpack('!L', data[8:12])
return {
'length': self.length,
'command_id': hex(self.command_id),
'sequence_id': sequence_id
}
def parse_body(self, body):
response_body_func = self.response_handler_map.get(self.command_id)
if response_body_func is None:
raise JMSException('Unable to parse the returned result: %s' % body)
return response_body_func(body)
def parse(self, data):
self.length, = struct.unpack('!L', data[0:4])
header = self.parse_header(data)
body = self.parse_body(data[12:self.length])
return header, body
class CMPPClient(object):
def __init__(self, host, port, sp_id, sp_secret, src_id, service_id):
self.ip = host
self.port = port
self.sp_id = sp_id
self.sp_secret = sp_secret
self.src_id = src_id
self.service_id = service_id
self._sequence_id = 0
self._is_connect = False
self._times = 3
self.__socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._connect()
@property
def sequence_id(self):
s = self._sequence_id
self._sequence_id += 1
return s
def _connect(self):
self.__socket.settimeout(5)
error_msg = _('Failed to connect to the CMPP gateway server, err: {}')
for i in range(self._times):
try:
self.__socket.connect((self.ip, self.port))
except Exception as err:
error_msg = error_msg.format(str(err))
logger.warning(error_msg)
time.sleep(1)
else:
self._is_connect = True
break
else:
raise JMSException(error_msg)
def send(self, instance):
if isinstance(instance, CMPPBaseRequestInstance):
message = instance.get_message(sequence_id=self.sequence_id)
else:
message = instance
self.__socket.send(message)
def recv(self):
raw_length = self.__socket.recv(4)
length, = struct.unpack('!L', raw_length)
header, body = CMPPResponseInstance().parse(
raw_length + self.__socket.recv(length - 4)
)
return header, body
def close(self):
if self._is_connect:
terminate_request = CMPPTerminateRequestInstance()
self.send(terminate_request)
self.__socket.close()
def _cmpp_connect(self):
connect_request = CMPPConnectRequestInstance(self.sp_id, self.sp_secret)
self.send(connect_request)
header, body = self.recv()
if body['Status'] != 0:
raise JMSException('CMPPv2.0 authentication failed: %s' % body)
def _cmpp_send_sms(self, dest, sign_name, template_code, template_param):
"""
优先发送template_param中message的信息
若该内容不存在则根据template_code构建验证码发送
"""
message = template_param.get('message')
if message is None:
code = template_param.get('code')
message = template_code.replace('{code}', code)
msg = '%s%s' % (sign_name, message)
submit_request = CMPPSubmitRequestInstance(
msg_src=self.sp_id, src_id=self.src_id, msg_content=msg,
dest_usr_tl=len(dest), dest_terminal_id=dest,
service_id=self.service_id
)
self.send(submit_request)
header, body = self.recv()
command_id = header.get('command_id')
if command_id == CMPP_DELIVER:
deliver_request = CMPPDeliverRespRequestInstance(
msg_id=body['Msg_Id'], result=body['Result']
)
self.send(deliver_request)
def send_sms(self, dest, sign_name, template_code, template_param):
try:
self._cmpp_connect()
self._cmpp_send_sms(dest, sign_name, template_code, template_param)
except Exception as e:
logger.error('CMPPv2.0 Error: %s', e)
self.close()
raise JMSException(e)
class CMPP2SMS(BaseSMSClient):
SIGN_AND_TMPL_SETTING_FIELD_PREFIX = 'CMPP2'
@classmethod
def new_from_settings(cls):
return cls(
host=settings.CMPP2_HOST, port=settings.CMPP2_PORT,
sp_id=settings.CMPP2_SP_ID, sp_secret=settings.CMPP2_SP_SECRET,
service_id=settings.CMPP2_SERVICE_ID, src_id=getattr(settings, 'CMPP2_SRC_ID', ''),
)
def __init__(self, host: str, port: int, sp_id: str, sp_secret: str, service_id: str, src_id=''):
try:
self.client = CMPPClient(
host=host, port=port, sp_id=sp_id, sp_secret=sp_secret, src_id=src_id, service_id=service_id
)
except Exception as err:
self.client = None
logger.warning(err)
raise JMSException(err)
@staticmethod
def need_pre_check():
return False
def send_sms(self, phone_numbers: list, sign_name: str, template_code: str, template_param: dict, **kwargs):
try:
logger.info(f'CMPPv2.0 sms send: '
f'phone_numbers={phone_numbers} '
f'sign_name={sign_name} '
f'template_code={template_code} '
f'template_param={template_param}')
self.client.send_sms(phone_numbers, sign_name, template_code, template_param)
except Exception as e:
raise JMSException(e)
client = CMPP2SMS

View File

@ -15,6 +15,7 @@ logger = get_logger(__name__)
class BACKENDS(TextChoices):
ALIBABA = 'alibaba', _('Alibaba cloud')
TENCENT = 'tencent', _('Tencent cloud')
CMPP2 = 'cmpp2', _('CMPP v2.0')
class SMS:
@ -43,7 +44,7 @@ class SMS:
sign_name = getattr(settings, f'{self.client.SIGN_AND_TMPL_SETTING_FIELD_PREFIX}_VERIFY_SIGN_NAME')
template_code = getattr(settings, f'{self.client.SIGN_AND_TMPL_SETTING_FIELD_PREFIX}_VERIFY_TEMPLATE_CODE')
if not (sign_name and template_code):
if self.client.need_pre_check() and not (sign_name and template_code):
raise JMSException(
code='verify_code_sign_tmpl_invalid',
detail=_('SMS verification code signature or template invalid')

View File

@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
#
import re
import socket
from django.templatetags.static import static
from collections import OrderedDict
from itertools import chain
import logging
@ -381,3 +383,13 @@ def test_ip_connectivity(host, port, timeout=0.5):
else:
connectivity = False
return connectivity
<<<<<<< HEAD
=======
def static_or_direct(logo_path):
if logo_path.startswith('img/'):
return static(logo_path)
else:
return logo_path
>>>>>>> origin

View File

@ -1,9 +1,10 @@
import base64
import logging
import re
from Cryptodome.Cipher import AES, PKCS1_v1_5
from Cryptodome.Util.Padding import pad
from Cryptodome.Random import get_random_bytes
from Cryptodome.PublicKey import RSA
from Cryptodome.Util.Padding import pad
from Cryptodome import Random
from gmssl.sm4 import CryptSM4, SM4_ENCRYPT, SM4_DECRYPT
@ -11,21 +12,25 @@ from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
def process_key(key):
secret_pattern = re.compile(r'password|secret|key|token', re.IGNORECASE)
def padding_key(key, max_length=32):
"""
返回32 bytes 的key
"""
if not isinstance(key, bytes):
key = bytes(key, encoding='utf-8')
if len(key) >= 32:
return key[:32]
if len(key) >= max_length:
return key[:max_length]
return pad(key, 32)
while len(key) % 16 != 0:
key += b'\0'
return key
class BaseCrypto:
def encrypt(self, text):
return base64.urlsafe_b64encode(
self._encrypt(bytes(text, encoding='utf8'))
@ -45,7 +50,7 @@ class BaseCrypto:
class GMSM4EcbCrypto(BaseCrypto):
def __init__(self, key):
self.key = process_key(key)
self.key = padding_key(key, 16)
self.sm4_encryptor = CryptSM4()
self.sm4_encryptor.set_key(self.key, SM4_ENCRYPT)
@ -70,9 +75,8 @@ class AESCrypto:
"""
def __init__(self, key):
if len(key) > 32:
key = key[:32]
self.key = self.to_16(key)
self.key = padding_key(key, 32)
self.aes = AES.new(self.key, AES.MODE_ECB)
@staticmethod
def to_16(key):
@ -87,17 +91,15 @@ class AESCrypto:
return key # 返回bytes
def aes(self):
return AES.new(self.key, AES.MODE_ECB) # 初始化加密器
return AES.new(self.key, AES.MODE_ECB)
def encrypt(self, text):
aes = self.aes()
cipher = base64.encodebytes(aes.encrypt(self.to_16(text)))
cipher = base64.encodebytes(self.aes.encrypt(self.to_16(text)))
return str(cipher, encoding='utf8').replace('\n', '') # 加密
def decrypt(self, text):
aes = self.aes()
text_decoded = base64.decodebytes(bytes(text, encoding='utf8'))
return str(aes.decrypt(text_decoded).rstrip(b'\0').decode("utf8"))
return str(self.aes.decrypt(text_decoded).rstrip(b'\0').decode("utf8"))
class AESCryptoGCM:
@ -106,7 +108,15 @@ class AESCryptoGCM:
"""
def __init__(self, key):
self.key = process_key(key)
self.key = self.process_key(key)
@staticmethod
def process_key(key):
if not isinstance(key, bytes):
key = bytes(key, encoding='utf-8')
if len(key) >= 32:
return key[:32]
return pad(key, 32)
def encrypt(self, text):
"""
@ -133,7 +143,6 @@ class AESCryptoGCM:
nonce = base64.b64decode(metadata[24:48])
tag = base64.b64decode(metadata[48:])
ciphertext = base64.b64decode(text[72:])
cipher = AES.new(self.key, AES.MODE_GCM, nonce=nonce)
cipher.update(header)
@ -144,11 +153,10 @@ class AESCryptoGCM:
def get_aes_crypto(key=None, mode='GCM'):
if key is None:
key = settings.SECRET_KEY
if mode == 'ECB':
a = AESCrypto(key)
elif mode == 'GCM':
a = AESCryptoGCM(key)
return a
if mode == 'GCM':
return AESCryptoGCM(key)
else:
return AESCrypto(key)
def get_gm_sm4_ecb_crypto(key=None):
@ -162,34 +170,42 @@ gm_sm4_ecb_crypto = get_gm_sm4_ecb_crypto()
class Crypto:
cryptoes = {
cryptor_map = {
'aes_ecb': aes_ecb_crypto,
'aes_gcm': aes_crypto,
'aes': aes_crypto,
'gm_sm4_ecb': gm_sm4_ecb_crypto,
'gm': gm_sm4_ecb_crypto,
}
cryptos = []
def __init__(self):
cryptoes = self.__class__.cryptoes.copy()
crypto = cryptoes.pop(settings.SECURITY_DATA_CRYPTO_ALGO, None)
if crypto is None:
crypt_algo = settings.SECURITY_DATA_CRYPTO_ALGO
if not crypt_algo:
if settings.GMSSL_ENABLED:
crypt_algo = 'gm'
else:
crypt_algo = 'aes'
cryptor = self.cryptor_map.get(crypt_algo, None)
if cryptor is None:
raise ImproperlyConfigured(
f'Crypto method not supported {settings.SECURITY_DATA_CRYPTO_ALGO}'
)
self.cryptoes = [crypto, *cryptoes.values()]
others = set(self.cryptor_map.values()) - {cryptor}
self.cryptos = [cryptor, *others]
@property
def encryptor(self):
return self.cryptoes[0]
return self.cryptos[0]
def encrypt(self, text):
return self.encryptor.encrypt(text)
def decrypt(self, text):
for decryptor in self.cryptoes:
for cryptor in self.cryptos:
try:
origin_text = decryptor.decrypt(text)
origin_text = cryptor.decrypt(text)
if origin_text:
# 有时不同算法解密不报错,但是返回空字符串
return origin_text
@ -255,11 +271,13 @@ def decrypt_password(value):
if len(cipher) != 2:
return value
key_cipher, password_cipher = cipher
if not all([key_cipher, password_cipher]):
return value
aes_key = rsa_decrypt_by_session_pkey(key_cipher)
aes = get_aes_crypto(aes_key, 'ECB')
try:
password = aes.decrypt(password_cipher)
except UnicodeDecodeError as e:
except Exception as e:
logging.error("Decrypt password error: {}, {}".format(password_cipher, e))
return value
return password

View File

@ -8,12 +8,14 @@ from django.utils import timezone
from django.db import models
from django.db.models.signals import post_save, pre_save
UUID_PATTERN = re.compile(r'[0-9a-zA-Z\-]{36}')
def reverse(view_name, urlconf=None, args=None, kwargs=None,
current_app=None, external=False, api_to_ui=False):
def reverse(
view_name, urlconf=None, args=None, kwargs=None,
current_app=None, external=False, api_to_ui=False,
is_console=False, is_audit=False, is_workbench=False
):
url = dj_reverse(view_name, urlconf=urlconf, args=args,
kwargs=kwargs, current_app=current_app)
@ -21,7 +23,15 @@ def reverse(view_name, urlconf=None, args=None, kwargs=None,
site_url = settings.SITE_URL
url = site_url.strip('/') + url
if api_to_ui:
url = url.replace('api/v1', 'ui/#').rstrip('/')
replace_str = 'ui/#'
if is_console:
replace_str += '/console'
elif is_audit:
replace_str += '/audit'
elif is_workbench:
replace_str += '/workbench'
url = url.replace('api/v1', replace_str).rstrip('/')
return url
@ -38,7 +48,7 @@ def date_expired_default():
years = int(settings.DEFAULT_EXPIRED_YEARS)
except TypeError:
years = 70
return timezone.now() + timezone.timedelta(days=365*years)
return timezone.now() + timezone.timedelta(days=365 * years)
def union_queryset(*args, base_queryset=None):

View File

@ -196,7 +196,8 @@ def encrypt_password(password, salt=None, algorithm='sha512'):
return des_crypt.hash(password, salt=salt[:2])
support_algorithm = {
'sha512': sha512, 'des': des
'sha512': sha512,
'des': des
}
if isinstance(algorithm, str):
@ -222,9 +223,6 @@ def ensure_last_char_is_ascii(data):
remain = ''
secret_pattern = re.compile(r'password|secret|key', re.IGNORECASE)
def data_to_json(data, sort_keys=True, indent=2, cls=None):
if cls is None:
cls = DjangoJSONEncoder

View File

@ -15,18 +15,23 @@ import errno
import json
import yaml
import copy
import base64
import logging
from importlib import import_module
from urllib.parse import urljoin, urlparse
from gmssl.sm4 import CryptSM4, SM4_ENCRYPT, SM4_DECRYPT
from django.urls import reverse_lazy
from django.conf import settings
from django.utils.translation import ugettext_lazy as _
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
PROJECT_DIR = os.path.dirname(BASE_DIR)
XPACK_DIR = os.path.join(BASE_DIR, 'xpack')
HAS_XPACK = os.path.isdir(XPACK_DIR)
logger = logging.getLogger('jumpserver.conf')
def import_string(dotted_path):
try:
@ -39,9 +44,9 @@ def import_string(dotted_path):
try:
return getattr(module, class_name)
except AttributeError as err:
raise ImportError('Module "%s" does not define a "%s" attribute/class' % (
module_path, class_name)
) from err
raise ImportError(
'Module "%s" does not define a "%s" attribute/class' %
(module_path, class_name)) from err
def is_absolute_uri(uri):
@ -80,6 +85,59 @@ class DoesNotExist(Exception):
pass
class ConfigCrypto:
secret_keys = [
'SECRET_KEY', 'DB_PASSWORD', 'REDIS_PASSWORD',
]
def __init__(self, key):
self.safe_key = self.process_key(key)
self.sm4_encryptor = CryptSM4()
self.sm4_encryptor.set_key(self.safe_key, SM4_ENCRYPT)
self.sm4_decryptor = CryptSM4()
self.sm4_decryptor.set_key(self.safe_key, SM4_DECRYPT)
@staticmethod
def process_key(secret_encrypt_key):
key = secret_encrypt_key.encode()
if len(key) >= 16:
key = key[:16]
else:
key += b'\0' * (16 - len(key))
return key
def encrypt(self, data):
data = bytes(data, encoding='utf8')
return base64.b64encode(self.sm4_encryptor.crypt_ecb(data)).decode('utf8')
def decrypt(self, data):
data = base64.urlsafe_b64decode(bytes(data, encoding='utf8'))
return self.sm4_decryptor.crypt_ecb(data).decode('utf8')
def decrypt_if_need(self, value, item):
if item not in self.secret_keys:
return value
try:
plaintext = self.decrypt(value)
if plaintext:
value = plaintext
except Exception as e:
logger.error('decrypt %s error: %s', item, e)
return value
@classmethod
def get_secret_encryptor(cls):
# 使用 SM4 加密配置文件敏感信息
# https://the-x.cn/cryptography/Sm4.aspx
secret_encrypt_key = os.environ.get('SECRET_ENCRYPT_KEY', '')
if not secret_encrypt_key:
return None
print('Info: Using SM4 to encrypt config secret value')
return cls(secret_encrypt_key)
class Config(dict):
"""Works exactly like a dict but provides ways to fill it from files
or special dictionaries. There are two common patterns to populate the
@ -160,7 +218,7 @@ class Config(dict):
'SESSION_COOKIE_DOMAIN': None,
'CSRF_COOKIE_DOMAIN': None,
'SESSION_COOKIE_NAME_PREFIX': None,
'SESSION_COOKIE_AGE': 3600,
'SESSION_COOKIE_AGE': 3600 * 24,
'SESSION_EXPIRE_AT_BROWSER_CLOSE': False,
'LOGIN_URL': reverse_lazy('authentication:login'),
'CONNECTION_TOKEN_EXPIRATION': 5 * 60,
@ -265,6 +323,22 @@ class Config(dict):
'AUTH_SAML2_PROVIDER_AUTHORIZATION_ENDPOINT': '/',
'AUTH_SAML2_AUTHENTICATION_FAILURE_REDIRECT_URI': '/',
# OAuth2 认证
'AUTH_OAUTH2': False,
'AUTH_OAUTH2_LOGO_PATH': 'img/login_oauth2_logo.png',
'AUTH_OAUTH2_PROVIDER': 'OAuth2',
'AUTH_OAUTH2_ALWAYS_UPDATE_USER': True,
'AUTH_OAUTH2_CLIENT_ID': 'client-id',
'AUTH_OAUTH2_SCOPE': '',
'AUTH_OAUTH2_CLIENT_SECRET': '',
'AUTH_OAUTH2_PROVIDER_AUTHORIZATION_ENDPOINT': 'https://oauth2.example.com/authorize',
'AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT': 'https://oauth2.example.com/userinfo',
'AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT': 'https://oauth2.example.com/access_token',
'AUTH_OAUTH2_ACCESS_TOKEN_METHOD': 'GET',
'AUTH_OAUTH2_USER_ATTR_MAP': {
'name': 'name', 'username': 'username', 'email': 'email'
},
'AUTH_TEMP_TOKEN': False,
# 企业微信
@ -302,6 +376,15 @@ class Config(dict):
'TENCENT_VERIFY_SIGN_NAME': '',
'TENCENT_VERIFY_TEMPLATE_CODE': '',
'CMPP2_HOST': '',
'CMPP2_PORT': 7890,
'CMPP2_SP_ID': '',
'CMPP2_SP_SECRET': '',
'CMPP2_SRC_ID': '',
'CMPP2_SERVICE_ID': '',
'CMPP2_VERIFY_SIGN_NAME': '',
'CMPP2_VERIFY_TEMPLATE_CODE': '{code}',
# Email
'EMAIL_CUSTOM_USER_CREATED_SUBJECT': _('Create account successfully'),
'EMAIL_CUSTOM_USER_CREATED_HONORIFIC': _('Hello'),
@ -387,7 +470,8 @@ class Config(dict):
'SESSION_SAVE_EVERY_REQUEST': True,
'SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE': False,
'SERVER_REPLAY_STORAGE': {},
'SECURITY_DATA_CRYPTO_ALGO': 'aes',
'SECURITY_DATA_CRYPTO_ALGO': None,
'GMSSL_ENABLED': False,
# 记录清理清理
'LOGIN_LOG_KEEP_DAYS': 200,
@ -405,6 +489,7 @@ class Config(dict):
'CONNECTION_TOKEN_ENABLED': False,
'PERM_SINGLE_ASSET_TO_UNGROUP_NODE': False,
'TICKET_AUTHORIZE_DEFAULT_TIME': 7,
'WINDOWS_SSH_DEFAULT_SHELL': 'cmd',
'PERIOD_TASK_ENABLED': True,
@ -416,6 +501,10 @@ class Config(dict):
'HEALTH_CHECK_TOKEN': '',
}
def __init__(self, *args):
super().__init__(*args)
self.secret_encryptor = ConfigCrypto.get_secret_encryptor()
@staticmethod
def convert_keycloak_to_openid(keycloak_config):
"""
@ -427,7 +516,6 @@ class Config(dict):
"""
openid_config = copy.deepcopy(keycloak_config)
auth_openid = openid_config.get('AUTH_OPENID')
auth_openid_realm_name = openid_config.get('AUTH_OPENID_REALM_NAME')
auth_openid_server_url = openid_config.get('AUTH_OPENID_SERVER_URL')
@ -556,13 +644,12 @@ class Config(dict):
def get(self, item):
# 再从配置文件中获取
value = self.get_from_config(item)
if value is not None:
return value
# 其次从环境变量来
value = self.get_from_env(item)
if value is not None:
return value
value = self.defaults.get(item)
if value is None:
value = self.get_from_env(item)
if value is None:
value = self.defaults.get(item)
if self.secret_encryptor:
value = self.secret_encryptor.decrypt_if_need(value, item)
return value
def __getitem__(self, item):

View File

@ -11,7 +11,7 @@ default_interface = dict((
('favicon', static('img/facio.ico')),
('login_title', _('JumpServer Open Source Bastion Host')),
('theme', 'classic_green'),
('theme_info', None),
('theme_info', {}),
))
default_context = {

View File

@ -4,10 +4,10 @@ from django.core.asgi import get_asgi_application
from ops.urls.ws_urls import urlpatterns as ops_urlpatterns
from notifications.urls.ws_urls import urlpatterns as notifications_urlpatterns
from settings.urls.ws_urls import urlpatterns as setting_urlpatterns
urlpatterns = []
urlpatterns += ops_urlpatterns \
+ notifications_urlpatterns
urlpatterns += ops_urlpatterns + notifications_urlpatterns + setting_urlpatterns
application = ProtocolTypeRouter({
'websocket': AuthMiddlewareStack(

View File

@ -24,9 +24,15 @@ AUTH_LDAP_GLOBAL_OPTIONS = {
ldap.OPT_X_TLS_REQUIRE_CERT: ldap.OPT_X_TLS_NEVER,
ldap.OPT_REFERRALS: CONFIG.AUTH_LDAP_OPTIONS_OPT_REFERRALS
}
LDAP_CERT_FILE = os.path.join(PROJECT_DIR, "data", "certs", "ldap_ca.pem")
LDAP_CACERT_FILE = os.path.join(PROJECT_DIR, "data", "certs", "ldap_ca.pem")
if os.path.isfile(LDAP_CACERT_FILE):
AUTH_LDAP_GLOBAL_OPTIONS[ldap.OPT_X_TLS_CACERTFILE] = LDAP_CACERT_FILE
LDAP_CERT_FILE = os.path.join(PROJECT_DIR, "data", "certs", "ldap_cert.pem")
if os.path.isfile(LDAP_CERT_FILE):
AUTH_LDAP_GLOBAL_OPTIONS[ldap.OPT_X_TLS_CACERTFILE] = LDAP_CERT_FILE
AUTH_LDAP_GLOBAL_OPTIONS[ldap.OPT_X_TLS_CERTFILE] = LDAP_CERT_FILE
LDAP_KEY_FILE = os.path.join(PROJECT_DIR, "data", "certs", "ldap_cert.key")
if os.path.isfile(LDAP_KEY_FILE):
AUTH_LDAP_GLOBAL_OPTIONS[ldap.OPT_X_TLS_KEYFILE] = LDAP_KEY_FILE
# AUTH_LDAP_GROUP_SEARCH_OU = CONFIG.AUTH_LDAP_GROUP_SEARCH_OU
# AUTH_LDAP_GROUP_SEARCH_FILTER = CONFIG.AUTH_LDAP_GROUP_SEARCH_FILTER
# AUTH_LDAP_GROUP_SEARCH = LDAPSearch(
@ -143,6 +149,23 @@ SAML2_SP_ADVANCED_SETTINGS = CONFIG.SAML2_SP_ADVANCED_SETTINGS
SAML2_LOGIN_URL_NAME = "authentication:saml2:saml2-login"
SAML2_LOGOUT_URL_NAME = "authentication:saml2:saml2-logout"
# OAuth2 auth
AUTH_OAUTH2 = CONFIG.AUTH_OAUTH2
AUTH_OAUTH2_LOGO_PATH = CONFIG.AUTH_OAUTH2_LOGO_PATH
AUTH_OAUTH2_PROVIDER = CONFIG.AUTH_OAUTH2_PROVIDER
AUTH_OAUTH2_ALWAYS_UPDATE_USER = CONFIG.AUTH_OAUTH2_ALWAYS_UPDATE_USER
AUTH_OAUTH2_PROVIDER_AUTHORIZATION_ENDPOINT = CONFIG.AUTH_OAUTH2_PROVIDER_AUTHORIZATION_ENDPOINT
AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT = CONFIG.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT
AUTH_OAUTH2_ACCESS_TOKEN_METHOD = CONFIG.AUTH_OAUTH2_ACCESS_TOKEN_METHOD
AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT = CONFIG.AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT
AUTH_OAUTH2_CLIENT_SECRET = CONFIG.AUTH_OAUTH2_CLIENT_SECRET
AUTH_OAUTH2_CLIENT_ID = CONFIG.AUTH_OAUTH2_CLIENT_ID
AUTH_OAUTH2_SCOPE = CONFIG.AUTH_OAUTH2_SCOPE
AUTH_OAUTH2_USER_ATTR_MAP = CONFIG.AUTH_OAUTH2_USER_ATTR_MAP
AUTH_OAUTH2_AUTH_LOGIN_CALLBACK_URL_NAME = 'authentication:oauth2:login-callback'
AUTH_OAUTH2_AUTHENTICATION_REDIRECT_URI = '/'
AUTH_OAUTH2_AUTHENTICATION_FAILURE_REDIRECT_URI = '/'
# 临时 token
AUTH_TEMP_TOKEN = CONFIG.AUTH_TEMP_TOKEN
@ -170,6 +193,7 @@ AUTH_BACKEND_DINGTALK = 'authentication.backends.sso.DingTalkAuthentication'
AUTH_BACKEND_FEISHU = 'authentication.backends.sso.FeiShuAuthentication'
AUTH_BACKEND_AUTH_TOKEN = 'authentication.backends.sso.AuthorizationTokenAuthentication'
AUTH_BACKEND_SAML2 = 'authentication.backends.saml2.SAML2Backend'
AUTH_BACKEND_OAUTH2 = 'authentication.backends.oauth2.OAuth2Backend'
AUTH_BACKEND_TEMP_TOKEN = 'authentication.backends.token.TempTokenAuthBackend'
@ -180,12 +204,14 @@ AUTHENTICATION_BACKENDS = [
AUTH_BACKEND_MODEL, AUTH_BACKEND_PUBKEY, AUTH_BACKEND_LDAP, AUTH_BACKEND_RADIUS,
# 跳转形式
AUTH_BACKEND_CAS, AUTH_BACKEND_OIDC_PASSWORD, AUTH_BACKEND_OIDC_CODE, AUTH_BACKEND_SAML2,
AUTH_BACKEND_OAUTH2,
# 扫码模式
AUTH_BACKEND_WECOM, AUTH_BACKEND_DINGTALK, AUTH_BACKEND_FEISHU,
# Token模式
AUTH_BACKEND_AUTH_TOKEN, AUTH_BACKEND_SSO, AUTH_BACKEND_TEMP_TOKEN
]
AUTHENTICATION_BACKENDS_THIRD_PARTY = [AUTH_BACKEND_OIDC_CODE, AUTH_BACKEND_CAS, AUTH_BACKEND_SAML2, AUTH_BACKEND_OAUTH2]
ONLY_ALLOW_EXIST_USER_AUTH = CONFIG.ONLY_ALLOW_EXIST_USER_AUTH
ONLY_ALLOW_AUTH_FROM_SOURCE = CONFIG.ONLY_ALLOW_AUTH_FROM_SOURCE

View File

@ -43,6 +43,9 @@ DEBUG_DEV = CONFIG.DEBUG_DEV
# Absolute url for some case, for example email link
SITE_URL = CONFIG.SITE_URL
# https://docs.djangoproject.com/en/4.1/ref/settings/
SECURE_PROXY_SSL_HEADER = ('HTTP_X_FORWARDED_PROTO', 'https')
# LOG LEVEL
LOG_LEVEL = CONFIG.LOG_LEVEL
@ -106,6 +109,7 @@ MIDDLEWARE = [
'authentication.backends.oidc.middleware.OIDCRefreshIDTokenMiddleware',
'authentication.backends.cas.middleware.CASMiddleware',
'authentication.middleware.MFAMiddleware',
'authentication.middleware.ThirdPartyLoginMiddleware',
'authentication.middleware.SessionCookieMiddleware',
'simple_history.middleware.HistoryRequestMiddleware',
]
@ -307,6 +311,21 @@ CSRF_COOKIE_SECURE = CONFIG.CSRF_COOKIE_SECURE
DEFAULT_AUTO_FIELD = 'django.db.models.AutoField'
PASSWORD_HASHERS = [
'django.contrib.auth.hashers.PBKDF2PasswordHasher',
'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher',
'django.contrib.auth.hashers.Argon2PasswordHasher',
'django.contrib.auth.hashers.BCryptSHA256PasswordHasher',
]
GMSSL_ENABLED = CONFIG.GMSSL_ENABLED
GM_HASHER = 'common.hashers.PBKDF2SM3PasswordHasher'
if GMSSL_ENABLED:
PASSWORD_HASHERS.insert(0, GM_HASHER)
else:
PASSWORD_HASHERS.append(GM_HASHER)
# For Debug toolbar
INTERNAL_IPS = ["127.0.0.1"]
if os.environ.get('DEBUG_TOOLBAR', False):
@ -315,3 +334,4 @@ if os.environ.get('DEBUG_TOOLBAR', False):
DEBUG_TOOLBAR_PANELS = [
'debug_toolbar.panels.profiling.ProfilingPanel',
]

View File

@ -84,6 +84,7 @@ TERMINAL_TELNET_REGEX = CONFIG.TERMINAL_TELNET_REGEX
BACKEND_ASSET_USER_AUTH_VAULT = False
PERM_SINGLE_ASSET_TO_UNGROUP_NODE = CONFIG.PERM_SINGLE_ASSET_TO_UNGROUP_NODE
TICKET_AUTHORIZE_DEFAULT_TIME = CONFIG.TICKET_AUTHORIZE_DEFAULT_TIME
PERM_EXPIRED_CHECK_PERIODIC = CONFIG.PERM_EXPIRED_CHECK_PERIODIC
WINDOWS_SSH_DEFAULT_SHELL = CONFIG.WINDOWS_SSH_DEFAULT_SHELL
FLOWER_URL = CONFIG.FLOWER_URL

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7b79695fe8cb323097c12171db8f6ae58b8e016b317f08562183b677f537e8b3
size 129597
oid sha256:261eee68117787809a9bc6b2034846ee7b222677224f97055f7d7398d427b1d7
size 255

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:54c9c54a2e5ae5d27eb79f8ce0d19e7f362c016efb8c6011cace7bd2cb7eec1c
size 108123
oid sha256:c6f584a0c74107ceddce6b403ff8755b59aabb093a0e6cc0c5f9b47eb6ae49f4
size 255

File diff suppressed because it is too large Load Diff

View File

@ -15,8 +15,6 @@ logger = get_logger(__file__)
class JMSBaseInventory(BaseInventory):
windows_ssh_default_shell = settings.WINDOWS_SSH_DEFAULT_SHELL
def convert_to_ansible(self, asset, run_as_admin=False):
info = {
'id': asset.id,
@ -33,7 +31,7 @@ class JMSBaseInventory(BaseInventory):
if asset.is_windows():
info["vars"].update({
"ansible_connection": "ssh",
"ansible_shell_type": self.windows_ssh_default_shell,
"ansible_shell_type": settings.WINDOWS_SSH_DEFAULT_SHELL,
})
for label in asset.labels.all():
info["vars"].update({

View File

@ -0,0 +1,17 @@
# Generated by Django 3.2.12 on 2022-07-18 05:57
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('orgs', '0012_auto_20220118_1054'),
]
operations = [
migrations.AlterModelOptions(
name='organization',
options={'permissions': (('view_rootorg', 'Can view root org'), ('view_alljoinedorg', 'Can view all joined org')), 'verbose_name': 'Organization'},
),
]

View File

@ -45,7 +45,7 @@ class OrgManager(models.Manager):
org = get_current_org()
for obj in objs:
if org.is_root():
if not self.org_id:
if not obj.org_id:
raise ValidationError('Please save in a organization')
else:
obj.org_id = org.id

View File

@ -16,9 +16,14 @@ class OrgRoleMixin:
def add_member(self, user, role=None):
from rbac.builtin import BuiltinRole
from .utils import tmp_to_org
role_id = BuiltinRole.org_user.id
if role:
role_id = role.id
elif user.is_service_account:
role_id = BuiltinRole.system_component.id
else:
role_id = BuiltinRole.org_user.id
with tmp_to_org(self):
defaults = {
'user': user, 'role_id': role_id,
@ -80,6 +85,7 @@ class Organization(OrgRoleMixin, models.Model):
verbose_name = _("Organization")
permissions = (
('view_rootorg', _('Can view root org')),
('view_alljoinedorg', _('Can view all joined org')),
)
def __str__(self):

View File

@ -145,6 +145,9 @@ def _clear_users_from_org(org, users):
@receiver(post_save, sender=User)
@on_transaction_commit
def on_user_created_set_default_org(sender, instance, created, **kwargs):
if not instance.id:
# 用户已被手动删除instance.orgs 时会使用 id 进行查找报错所以判断不存在id时不做处理
return
if not created:
return
if instance.orgs.count() > 0:

View File

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
#
from rest_framework.response import Response
from rest_framework.generics import RetrieveAPIView
from perms import serializers
from perms.models import ApplicationPermission
from applications.models import Application
from common.permissions import IsValidUser
from ..base import BasePermissionViewSet
class ApplicationPermissionViewSet(BasePermissionViewSet):
"""
应用授权列表的增删改查API
"""
model = ApplicationPermission
serializer_class = serializers.ApplicationPermissionSerializer
filterset_fields = {
'name': ['exact'],
'category': ['exact'],
'type': ['exact', 'in'],
'from_ticket': ['exact']
}
search_fields = ['name', 'category', 'type']
custom_filter_fields = BasePermissionViewSet.custom_filter_fields + [
'application_id', 'application', 'app', 'app_name'
]
ordering_fields = ('name',)
ordering = ('name',)
def get_queryset(self):
queryset = super().get_queryset().prefetch_related(
"applications", "users", "user_groups", "system_users"
)
return queryset
def filter_application(self, queryset):
app_id = self.request.query_params.get('application_id') or \
self.request.query_params.get('app')
app_name = self.request.query_params.get('application') or \
self.request.query_params.get('app_name')
if app_id:
applications = Application.objects.filter(pk=app_id)
elif app_name:
applications = Application.objects.filter(name=app_name)
else:
return queryset
if not applications:
return queryset.none()
queryset = queryset.filter(applications__in=applications)
return queryset
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
queryset = self.filter_application(queryset)
return queryset
class ApplicationPermissionActionsApi(RetrieveAPIView):
permission_classes = (IsValidUser,)
def retrieve(self, request, *args, **kwargs):
category = request.GET.get('category')
actions = ApplicationPermission.get_include_actions_choices(category=category)
return Response(data=actions)

View File

@ -33,7 +33,9 @@ class UserAllGrantedAssetsQuerysetMixin:
only_fields = serializers.AssetGrantedSerializer.Meta.only_fields
pagination_class = AllGrantedAssetPagination
user: User
ordering_fields = ("hostname", "ip", "port", "cpu_cores")
ordering = ('hostname', )
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()

View File

@ -1,4 +1,7 @@
<<<<<<< HEAD
=======
>>>>>>> origin
from django.utils.translation import ugettext as _
from django.template.loader import render_to_string
@ -10,7 +13,7 @@ class PermedAssetsWillExpireUserMsg(UserMessage):
def __init__(self, user, assets, day_count=0):
super().__init__(user)
self.assets = assets
self.day_count = day_count
self.day_count = _('today') if day_count == 0 else day_count
def get_html_msg(self) -> dict:
subject = _("You permed assets is about to expire")
@ -42,7 +45,7 @@ class AssetPermsWillExpireForOrgAdminMsg(UserMessage):
super().__init__(user)
self.perms = perms
self.org = org
self.day_count = day_count
self.day_count = _('today') if day_count == 0 else day_count
def get_items_with_url(self):
items_with_url = []
@ -50,7 +53,7 @@ class AssetPermsWillExpireForOrgAdminMsg(UserMessage):
url = js_reverse(
'perms:asset-permission-detail',
kwargs={'pk': perm.id}, external=True,
api_to_ui=True
api_to_ui=True, is_console=True
) + f'?oid={perm.org_id}'
items_with_url.append([perm.name, url])
return items_with_url
@ -60,7 +63,7 @@ class AssetPermsWillExpireForOrgAdminMsg(UserMessage):
subject = _("Asset permissions is about to expire")
context = {
'name': self.user.name,
'count': self.day_count,
'count': str(self.day_count),
'items_with_url': items_with_url,
'item_type': _('asset permissions of organization {}').format(self.org)
}
@ -80,3 +83,81 @@ class AssetPermsWillExpireForOrgAdminMsg(UserMessage):
perms = AssetPermission.objects.all()[:10]
org = Organization.objects.first()
return cls(user, perms, org)
<<<<<<< HEAD
=======
class PermedAppsWillExpireUserMsg(UserMessage):
def __init__(self, user, apps, day_count=0):
super().__init__(user)
self.apps = apps
self.day_count = _('today') if day_count == 0 else day_count
def get_html_msg(self) -> dict:
subject = _("Your permed applications is about to expire")
context = {
'name': self.user.name,
'count': str(self.day_count),
'item_type': _('permed applications'),
'items': [str(app) for app in self.apps]
}
message = render_to_string('perms/_msg_permed_items_expire.html', context)
return {
'subject': subject,
'message': message
}
@classmethod
def gen_test_msg(cls):
from users.models import User
from applications.models import Application
user = User.objects.first()
apps = Application.objects.all()[:10]
return cls(user, apps)
class AppPermsWillExpireForOrgAdminMsg(UserMessage):
def __init__(self, user, perms, org, day_count=0):
super().__init__(user)
self.perms = perms
self.org = org
self.day_count = _('today') if day_count == 0 else day_count
def get_items_with_url(self):
items_with_url = []
for perm in self.perms:
url = js_reverse(
'perms:application-permission-detail',
kwargs={'pk': perm.id}, external=True,
api_to_ui=True, is_console=True
) + f'?oid={perm.org_id}'
items_with_url.append([perm.name, url])
return items_with_url
def get_html_msg(self) -> dict:
items = self.get_items_with_url()
subject = _('Application permissions is about to expire')
context = {
'name': self.user.name,
'count': str(self.day_count),
'item_type': _('application permissions of organization {}').format(self.org),
'items_with_url': items
}
message = render_to_string('perms/_msg_item_permissions_expire.html', context)
return {
'subject': subject,
'message': message
}
@classmethod
def gen_test_msg(cls):
from users.models import User
from perms.models import ApplicationPermission
from orgs.models import Organization
user = User.objects.first()
perms = ApplicationPermission.objects.all()[:10]
org = Organization.objects.first()
return cls(user, perms, org)
>>>>>>> origin

View File

@ -72,7 +72,7 @@ def check_asset_permission_will_expired():
for asset_perm in asset_perms:
date_expired = dt_parser(asset_perm.date_expired)
remain_days = (end - date_expired).days
remain_days = (date_expired - start).days
org = asset_perm.org
# 资产授权按照组织分类
@ -100,3 +100,51 @@ def check_asset_permission_will_expired():
org_admins = org.admins.all()
for org_admin in org_admins:
AssetPermsWillExpireForOrgAdminMsg(org_admin, perms, org, day_count).publish_async()
<<<<<<< HEAD
=======
@register_as_period_task(crontab='0 10 * * *')
@shared_task()
@atomic()
@tmp_to_root_org()
def check_app_permission_will_expired():
start = local_now()
end = start + timedelta(days=3)
app_perms = ApplicationPermission.objects.filter(
date_expired__gte=start,
date_expired__lte=end
).distinct()
user_app_remain_day_mapper = defaultdict(dict)
org_perm_remain_day_mapper = defaultdict(dict)
for app_perm in app_perms:
date_expired = dt_parser(app_perm.date_expired)
remain_days = (date_expired - start).days
org = app_perm.org
if org in org_perm_remain_day_mapper[remain_days]:
org_perm_remain_day_mapper[remain_days][org].add(app_perm)
else:
org_perm_remain_day_mapper[remain_days][org] = {app_perm, }
users = app_perm.get_all_users()
apps = app_perm.applications.all()
for u in users:
if u in user_app_remain_day_mapper[remain_days]:
user_app_remain_day_mapper[remain_days][u].update(apps)
else:
user_app_remain_day_mapper[remain_days][u] = set(apps)
for day_count, user_app_mapper in user_app_remain_day_mapper.items():
for user, apps in user_app_mapper.items():
PermedAppsWillExpireUserMsg(user, apps, day_count).publish_async()
for day_count, org_perm_mapper in org_perm_remain_day_mapper.items():
for org, perms in org_perm_mapper.items():
org_admins = org.admins.all()
for org_admin in org_admins:
AppPermsWillExpireForOrgAdminMsg(org_admin, perms, org, day_count).publish_async()
>>>>>>> origin

View File

@ -0,0 +1,50 @@
# coding: utf-8
#
from django.urls import path, include
from rest_framework_bulk.routes import BulkRouter
from .. import api
router = BulkRouter()
router.register('application-permissions', api.ApplicationPermissionViewSet, 'application-permission')
router.register('application-permissions-users-relations', api.ApplicationPermissionUserRelationViewSet, 'application-permissions-users-relation')
router.register('application-permissions-user-groups-relations', api.ApplicationPermissionUserGroupRelationViewSet, 'application-permissions-user-groups-relation')
router.register('application-permissions-applications-relations', api.ApplicationPermissionApplicationRelationViewSet, 'application-permissions-application-relation')
router.register('application-permissions-system-users-relations', api.ApplicationPermissionSystemUserRelationViewSet, 'application-permissions-system-users-relation')
user_permission_urlpatterns = [
path('<uuid:pk>/applications/', api.UserAllGrantedApplicationsApi.as_view(), name='user-applications'),
path('applications/', api.MyAllGrantedApplicationsApi.as_view(), name='my-applications'),
# Application As Tree
path('<uuid:pk>/applications/tree/', api.UserAllGrantedApplicationsAsTreeApi.as_view(), name='user-applications-as-tree'),
path('applications/tree/', api.MyAllGrantedApplicationsAsTreeApi.as_view(), name='my-applications-as-tree'),
# Application System Users
path('<uuid:pk>/applications/<uuid:application_id>/system-users/', api.UserGrantedApplicationSystemUsersApi.as_view(), name='user-application-system-users'),
path('applications/<uuid:application_id>/system-users/', api.MyGrantedApplicationSystemUsersApi.as_view(), name='my-application-system-users'),
]
user_group_permission_urlpatterns = [
path('<uuid:pk>/applications/', api.UserGroupGrantedApplicationsApi.as_view(), name='user-group-applications'),
]
permission_urlpatterns = [
# 授权规则中授权的用户和应用
path('<uuid:pk>/applications/all/', api.ApplicationPermissionAllApplicationListApi.as_view(), name='application-permission-all-applications'),
path('<uuid:pk>/users/all/', api.ApplicationPermissionAllUserListApi.as_view(), name='application-permission-all-users'),
# 验证用户是否有某个应用的权限
path('user/validate/', api.ValidateUserApplicationPermissionApi.as_view(), name='validate-user-application-permission'),
path('applications/actions/', api.ApplicationPermissionActionsApi.as_view(), name='application-actions'),
]
application_permission_urlpatterns = [
path('users/', include(user_permission_urlpatterns)),
path('user-groups/', include(user_group_permission_urlpatterns)),
path('application-permissions/', include(permission_urlpatterns))
]
application_permission_urlpatterns += router.urls

View File

@ -5,8 +5,10 @@ from .const import Scope, system_exclude_permissions, org_exclude_permissions
_view_root_perms = (
('orgs', 'organization', 'view', 'rootorg'),
)
_view_all_joined_org_perms = (
('orgs', 'organization', 'view', 'alljoinedorg'),
)
# 工作台也区分组织后再考虑
user_perms = (
('rbac', 'menupermission', 'view', 'workbench'),
('rbac', 'menupermission', 'view', 'webterminal'),
@ -21,11 +23,11 @@ user_perms = (
)
system_user_perms = (
('authentication', 'connectiontoken', 'add', 'connectiontoken'),
('authentication', 'connectiontoken', 'add,view', 'connectiontoken'),
('authentication', 'temptoken', 'add,change,view', 'temptoken'),
('authentication', 'accesskey', '*', '*'),
('tickets', 'ticket', 'view', 'ticket'),
) + user_perms
) + user_perms + _view_all_joined_org_perms
_auditor_perms = (
('rbac', 'menupermission', 'view', 'audit'),

View File

@ -40,6 +40,10 @@ exclude_permissions = (
('assets', 'gathereduser', 'add,delete,change', 'gathereduser'),
('assets', 'accountbackupplanexecution', 'delete,change', 'accountbackupplanexecution'),
('assets', 'authbook', 'change', 'authbook'),
# TODO 暂时去掉历史账号的权限
('assets', 'authbook', '*', 'assethistoryaccount'),
('assets', 'authbook', '*', 'assethistoryaccountsecret'),
('perms', 'userassetgrantedtreenoderelation', '*', '*'),
('perms', 'usergrantedmappingnode', '*', '*'),
('perms', 'permnode', '*', '*'),

View File

@ -60,11 +60,11 @@ class Permission(DjangoPermission):
if actions == '*' and resource == '*':
pass
elif actions == '*' and resource != '*':
kwargs['codename__iregex'] = r'[a-z]+_{}'.format(resource)
kwargs['codename__iregex'] = r'[a-z]+_{}$'.format(resource)
elif actions != '*' and resource == '*':
kwargs['codename__iregex'] = r'({})_[a-z]+'.format(actions_regex)
else:
kwargs['codename__iregex'] = r'({})_{}'.format(actions_regex, resource)
kwargs['codename__iregex'] = r'({})_{}$'.format(actions_regex, resource)
q |= Q(**kwargs)
return q

View File

@ -126,9 +126,16 @@ class RoleBinding(JMSBaseModel):
org_ids = [b.org.id for b in bindings if b.org]
orgs = all_orgs.filter(id__in=org_ids)
workbench_perm = 'rbac.view_workbench'
# 全局组织
if orgs and perm != 'rbac.view_workbench' and user.has_perm('orgs.view_rootorg'):
orgs = [Organization.root(), *list(orgs)]
if orgs and perm != workbench_perm and user.has_perm('orgs.view_rootorg'):
root_org = Organization.root()
orgs = [root_org, *list(orgs)]
elif orgs and perm == workbench_perm and user.has_perm('orgs.view_alljoinedorg'):
# Todo: 先复用组织
root_org = Organization.root()
root_org.name = _("All organizations")
orgs = [root_org, *list(orgs)]
return orgs

View File

@ -5,6 +5,4 @@ from .dingtalk import *
from .feishu import *
from .public import *
from .email import *
from .alibaba_sms import *
from .tencent_sms import *
from .sms import *

View File

@ -1,58 +0,0 @@
from rest_framework.views import Response
from rest_framework.generics import GenericAPIView
from rest_framework.exceptions import APIException
from rest_framework import status
from django.utils.translation import gettext_lazy as _
from common.sdk.sms.alibaba import AlibabaSMS
from settings.models import Setting
from common.exceptions import JMSException
from .. import serializers
class AlibabaSMSTestingAPI(GenericAPIView):
serializer_class = serializers.AlibabaSMSSettingSerializer
rbac_perms = {
'POST': 'settings.change_sms'
}
def post(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
alibaba_access_key_id = serializer.validated_data['ALIBABA_ACCESS_KEY_ID']
alibaba_access_key_secret = serializer.validated_data.get('ALIBABA_ACCESS_KEY_SECRET')
alibaba_verify_sign_name = serializer.validated_data['ALIBABA_VERIFY_SIGN_NAME']
alibaba_verify_template_code = serializer.validated_data['ALIBABA_VERIFY_TEMPLATE_CODE']
test_phone = serializer.validated_data.get('SMS_TEST_PHONE')
if not test_phone:
raise JMSException(code='test_phone_required', detail=_('test_phone is required'))
if not alibaba_access_key_secret:
secret = Setting.objects.filter(name='ALIBABA_ACCESS_KEY_SECRET').first()
if secret:
alibaba_access_key_secret = secret.cleaned_value
alibaba_access_key_secret = alibaba_access_key_secret or ''
try:
client = AlibabaSMS(
access_key_id=alibaba_access_key_id,
access_key_secret=alibaba_access_key_secret
)
client.send_sms(
phone_numbers=[test_phone],
sign_name=alibaba_verify_sign_name,
template_code=alibaba_verify_template_code,
template_param={'code': 'test'}
)
return Response(status=status.HTTP_200_OK, data={'msg': _('Test success')})
except APIException as e:
try:
error = e.detail['errmsg']
except:
error = e.detail
return Response(status=status.HTTP_400_BAD_REQUEST, data={'error': error})

View File

@ -3,7 +3,11 @@ from rest_framework.permissions import AllowAny, IsAuthenticated
from django.conf import settings
from jumpserver.utils import has_valid_xpack_license, get_xpack_license_info
from common.utils import get_logger, lazyproperty
from common.utils import get_logger, lazyproperty, get_object_or_none
from authentication.models import ConnectionToken
from orgs.utils import tmp_to_root_org
from common.permissions import IsValidUserOrConnectionToken
from .. import serializers
from ..utils import get_interface_setting_or_default
@ -28,7 +32,7 @@ class OpenPublicSettingApi(generics.RetrieveAPIView):
class PublicSettingApi(OpenPublicSettingApi):
permission_classes = (IsAuthenticated,)
permission_classes = (IsValidUserOrConnectionToken,)
serializer_class = serializers.PrivateSettingSerializer
def get_object(self):

View File

@ -34,11 +34,13 @@ class SettingsApi(generics.RetrieveUpdateAPIView):
'cas': serializers.CASSettingSerializer,
'sso': serializers.SSOSettingSerializer,
'saml2': serializers.SAML2SettingSerializer,
'oauth2': serializers.OAuth2SettingSerializer,
'clean': serializers.CleaningSerializer,
'other': serializers.OtherSettingSerializer,
'sms': serializers.SMSSettingSerializer,
'alibaba': serializers.AlibabaSMSSettingSerializer,
'tencent': serializers.TencentSMSSettingSerializer,
'cmpp2': serializers.CMPP2SMSSettingSerializer,
}
rbac_category_permissions = {
@ -113,9 +115,12 @@ class SettingsApi(generics.RetrieveUpdateAPIView):
return data
def perform_update(self, serializer):
post_data_names = list(self.request.data.keys())
settings_items = self.parse_serializer_data(serializer)
serializer_data = getattr(serializer, 'data', {})
for item in settings_items:
if item['name'] not in post_data_names:
continue
changed, setting = Setting.update_or_create(**item)
if not changed:
continue

View File

@ -1,8 +1,19 @@
from rest_framework.generics import ListAPIView
import importlib
from collections import OrderedDict
from rest_framework.generics import ListAPIView, GenericAPIView
from rest_framework.response import Response
from rest_framework.exceptions import APIException
from rest_framework import status
from django.utils.translation import gettext_lazy as _
from common.sdk.sms import BACKENDS
from common.exceptions import JMSException
from settings.serializers.sms import SMSBackendSerializer
from settings.models import Setting
from .. import serializers
class SMSBackendAPI(ListAPIView):
@ -21,3 +32,111 @@ class SMSBackendAPI(ListAPIView):
]
return Response(data)
class SMSTestingAPI(GenericAPIView):
backends_serializer = {
'alibaba': serializers.AlibabaSMSSettingSerializer,
'tencent': serializers.TencentSMSSettingSerializer,
'cmpp2': serializers.CMPP2SMSSettingSerializer
}
rbac_perms = {
'POST': 'settings.change_sms'
}
@staticmethod
def get_or_from_setting(key, value=''):
if not value:
secret = Setting.objects.filter(name=key).first()
if secret:
value = secret.cleaned_value
return value or ''
def get_alibaba_params(self, data):
init_params = {
'access_key_id': data['ALIBABA_ACCESS_KEY_ID'],
'access_key_secret': self.get_or_from_setting(
'ALIBABA_ACCESS_KEY_SECRET', data.get('ALIBABA_ACCESS_KEY_SECRET')
)
}
send_sms_params = {
'sign_name': data['ALIBABA_VERIFY_SIGN_NAME'],
'template_code': data['ALIBABA_VERIFY_TEMPLATE_CODE'],
'template_param': {'code': '666666'}
}
return init_params, send_sms_params
def get_tencent_params(self, data):
init_params = {
'secret_id': data['TENCENT_SECRET_ID'],
'secret_key': self.get_or_from_setting(
'TENCENT_SECRET_KEY', data.get('TENCENT_SECRET_KEY')
),
'sdkappid': data['TENCENT_SDKAPPID']
}
send_sms_params = {
'sign_name': data['TENCENT_VERIFY_SIGN_NAME'],
'template_code': data['TENCENT_VERIFY_TEMPLATE_CODE'],
'template_param': OrderedDict(code='666666')
}
return init_params, send_sms_params
def get_cmpp2_params(self, data):
init_params = {
'host': data['CMPP2_HOST'], 'port': data['CMPP2_PORT'],
'sp_id': data['CMPP2_SP_ID'], 'src_id': data['CMPP2_SRC_ID'],
'sp_secret': self.get_or_from_setting(
'CMPP2_SP_SECRET', data.get('CMPP2_SP_SECRET')
),
'service_id': data['CMPP2_SERVICE_ID'],
}
send_sms_params = {
'sign_name': data['CMPP2_VERIFY_SIGN_NAME'],
'template_code': data['CMPP2_VERIFY_TEMPLATE_CODE'],
'template_param': OrderedDict(code='666666')
}
return init_params, send_sms_params
def get_params_by_backend(self, backend, data):
"""
返回两部分参数
1实例化参数
2发送测试短信参数
"""
get_params_func = getattr(self, 'get_%s_params' % backend)
return get_params_func(data)
def post(self, request, backend):
serializer_class = self.backends_serializer.get(backend)
if serializer_class is None:
raise JMSException(_('Invalid SMS platform'))
serializer = serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
test_phone = serializer.validated_data.get('SMS_TEST_PHONE')
if not test_phone:
raise JMSException(code='test_phone_required', detail=_('test_phone is required'))
init_params, send_sms_params = self.get_params_by_backend(backend, serializer.validated_data)
m = importlib.import_module(f'common.sdk.sms.{backend}', __package__)
try:
client = m.client(**init_params)
client.send_sms(
phone_numbers=[test_phone],
**send_sms_params
)
status_code = status.HTTP_200_OK
data = {'msg': _('Test success')}
except APIException as e:
try:
error = e.detail['errmsg']
except:
error = e.detail
status_code = status.HTTP_400_BAD_REQUEST
data = {'error': error}
except Exception as e:
status_code = status.HTTP_400_BAD_REQUEST
data = {'error': str(e)}
return Response(status=status_code, data=data)

View File

@ -1,63 +0,0 @@
from collections import OrderedDict
from rest_framework.views import Response
from rest_framework.generics import GenericAPIView
from rest_framework.exceptions import APIException
from rest_framework import status
from django.utils.translation import gettext_lazy as _
from common.sdk.sms.tencent import TencentSMS
from settings.models import Setting
from common.exceptions import JMSException
from .. import serializers
class TencentSMSTestingAPI(GenericAPIView):
serializer_class = serializers.TencentSMSSettingSerializer
rbac_perms = {
'POST': 'settings.change_sms'
}
def post(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
tencent_secret_id = serializer.validated_data['TENCENT_SECRET_ID']
tencent_secret_key = serializer.validated_data.get('TENCENT_SECRET_KEY')
tencent_verify_sign_name = serializer.validated_data['TENCENT_VERIFY_SIGN_NAME']
tencent_verify_template_code = serializer.validated_data['TENCENT_VERIFY_TEMPLATE_CODE']
tencent_sdkappid = serializer.validated_data.get('TENCENT_SDKAPPID')
test_phone = serializer.validated_data.get('SMS_TEST_PHONE')
if not test_phone:
raise JMSException(code='test_phone_required', detail=_('test_phone is required'))
if not tencent_secret_key:
secret = Setting.objects.filter(name='TENCENT_SECRET_KEY').first()
if secret:
tencent_secret_key = secret.cleaned_value
tencent_secret_key = tencent_secret_key or ''
try:
client = TencentSMS(
secret_id=tencent_secret_id,
secret_key=tencent_secret_key,
sdkappid=tencent_sdkappid
)
client.send_sms(
phone_numbers=[test_phone],
sign_name=tencent_verify_sign_name,
template_code=tencent_verify_template_code,
template_param=OrderedDict(code='666666')
)
return Response(status=status.HTTP_200_OK, data={'msg': _('Test success')})
except APIException as e:
try:
error = e.detail['errmsg']
except:
error = e.detail
return Response(status=status.HTTP_400_BAD_REQUEST, data={'error': error})

View File

@ -1,9 +1,13 @@
import os
import json
from django.db import models
from django.db.utils import ProgrammingError, OperationalError
from django.utils.translation import ugettext_lazy as _
from django.conf import settings
from django.core.files.storage import default_storage
from django.core.files.base import ContentFile
from django.core.files.uploadedfile import InMemoryUploadedFile
from common.utils import signer, get_logger
@ -118,6 +122,14 @@ class Setting(models.Model):
setattr(settings, key, value)
self.__class__.update_or_create(key, value, encrypted=False, category=self.category)
@classmethod
def save_to_file(cls, value: InMemoryUploadedFile):
filename = value.name
filepath = f'settings/{filename}'
path = default_storage.save(filepath, ContentFile(value.read()))
url = default_storage.url(path)
return url
@classmethod
def update_or_create(cls, name='', value='', encrypted=False, category=''):
"""
@ -128,6 +140,10 @@ class Setting(models.Model):
changed = False
if not setting:
setting = Setting(name=name, encrypted=encrypted, category=category)
if isinstance(value, InMemoryUploadedFile):
value = cls.save_to_file(value)
if setting.cleaned_value != value:
setting.encrypted = encrypted
setting.cleaned_value = value

View File

@ -9,3 +9,4 @@ from .sso import *
from .base import *
from .sms import *
from .saml2 import *
from .oauth2 import *

View File

@ -0,0 +1,55 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from common.drf.fields import EncryptedField
from common.utils import static_or_direct
__all__ = [
'OAuth2SettingSerializer',
]
class SettingImageField(serializers.ImageField):
def to_representation(self, value):
return static_or_direct(value)
class OAuth2SettingSerializer(serializers.Serializer):
AUTH_OAUTH2 = serializers.BooleanField(
default=False, label=_('Enable OAuth2 Auth')
)
AUTH_OAUTH2_LOGO_PATH = SettingImageField(
allow_null=True, required=False, label=_('Logo')
)
AUTH_OAUTH2_PROVIDER = serializers.CharField(
required=True, max_length=16, label=_('Service provider')
)
AUTH_OAUTH2_CLIENT_ID = serializers.CharField(
required=True, max_length=1024, label=_('Client Id')
)
AUTH_OAUTH2_CLIENT_SECRET = EncryptedField(
required=False, max_length=1024, label=_('Client Secret')
)
AUTH_OAUTH2_SCOPE = serializers.CharField(
required=True, max_length=1024, label=_('Scope'), allow_blank=True
)
AUTH_OAUTH2_PROVIDER_AUTHORIZATION_ENDPOINT = serializers.CharField(
required=True, max_length=1024, label=_('Provider auth endpoint')
)
AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT = serializers.CharField(
required=True, max_length=1024, label=_('Provider token endpoint')
)
AUTH_OAUTH2_ACCESS_TOKEN_METHOD = serializers.ChoiceField(
default='GET', label=_('Client authentication method'),
choices=(('GET', 'GET'), ('POST', 'POST'))
)
AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT = serializers.CharField(
required=True, max_length=1024, label=_('Provider userinfo endpoint')
)
AUTH_OAUTH2_USER_ATTR_MAP = serializers.DictField(
required=True, label=_('User attr map')
)
AUTH_OAUTH2_ALWAYS_UPDATE_USER = serializers.BooleanField(
default=True, label=_('Always update user')
)

View File

@ -2,15 +2,19 @@ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from common.drf.fields import EncryptedField
from common.validators import PhoneValidator
from common.sdk.sms import BACKENDS
__all__ = ['SMSSettingSerializer', 'AlibabaSMSSettingSerializer', 'TencentSMSSettingSerializer']
__all__ = [
'SMSSettingSerializer', 'AlibabaSMSSettingSerializer', 'TencentSMSSettingSerializer',
'CMPP2SMSSettingSerializer'
]
class SMSSettingSerializer(serializers.Serializer):
SMS_ENABLED = serializers.BooleanField(default=False, label=_('Enable SMS'))
SMS_BACKEND = serializers.ChoiceField(
choices=BACKENDS.choices, default=BACKENDS.ALIBABA, label=_('SMS provider')
choices=BACKENDS.choices, default=BACKENDS.ALIBABA, label=_('SMS provider / Protocol')
)
@ -20,7 +24,10 @@ class SignTmplPairSerializer(serializers.Serializer):
class BaseSMSSettingSerializer(serializers.Serializer):
SMS_TEST_PHONE = serializers.CharField(max_length=256, required=False, allow_blank=True, label=_('Test phone'))
SMS_TEST_PHONE = serializers.CharField(
max_length=256, required=False, validators=[PhoneValidator(), ],
allow_blank=True, label=_('Test phone')
)
def to_representation(self, instance):
data = super().to_representation(instance)
@ -43,3 +50,29 @@ class TencentSMSSettingSerializer(BaseSMSSettingSerializer):
TENCENT_SDKAPPID = serializers.CharField(max_length=256, required=True, label='SDK app id')
TENCENT_VERIFY_SIGN_NAME = serializers.CharField(max_length=256, required=True, label=_('Signature'))
TENCENT_VERIFY_TEMPLATE_CODE = serializers.CharField(max_length=256, required=True, label=_('Template code'))
class CMPP2SMSSettingSerializer(BaseSMSSettingSerializer):
CMPP2_HOST = serializers.CharField(max_length=256, required=True, label=_('Host'))
CMPP2_PORT = serializers.IntegerField(default=7890, label=_('Port'))
CMPP2_SP_ID = serializers.CharField(max_length=128, required=True, label=_('Enterprise code(SP id)'))
CMPP2_SP_SECRET = EncryptedField(max_length=256, required=False, label=_('Shared secret(Shared secret)'))
CMPP2_SRC_ID = serializers.CharField(max_length=256, required=False, label=_('Original number(Src id)'))
CMPP2_SERVICE_ID = serializers.CharField(max_length=256, required=True, label=_('Business type(Service id)'))
CMPP2_VERIFY_SIGN_NAME = serializers.CharField(max_length=256, required=True, label=_('Signature'))
CMPP2_VERIFY_TEMPLATE_CODE = serializers.CharField(
max_length=69, required=True, label=_('Template'),
help_text=_('Template need contain {code} and Signature + template length does not exceed 67 words. '
'For example, your verification code is {code}, which is valid for 5 minutes. '
'Please do not disclose it to others.')
)
def validate(self, attrs):
sign_name = attrs.get('CMPP2_VERIFY_SIGN_NAME', '')
template_code = attrs.get('CMPP2_VERIFY_TEMPLATE_CODE', '')
if template_code.find('{code}') == -1:
raise serializers.ValidationError(_('The template needs to contain {code}'))
if len(sign_name + template_code) > 65:
# 保证验证码内容在一条短信中(长度小于70字), 签名两边的括号和空格占3个字再减去2个即可(验证码占用4个但占位符6个
raise serializers.ValidationError(_('Signature + Template must not exceed 65 words'))
return attrs

View File

@ -1,4 +1,3 @@
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
@ -14,5 +13,5 @@ class SSOSettingSerializer(serializers.Serializer):
)
AUTH_SSO_AUTHKEY_TTL = serializers.IntegerField(
required=False, label=_('SSO auth key TTL'), help_text=_("Unit: second"),
min_value=1, max_value=60*30
min_value=60, max_value=60 * 30
)

View File

@ -30,6 +30,11 @@ class OtherSettingSerializer(serializers.Serializer):
help_text=_("Perm single to ungroup node")
)
TICKET_AUTHORIZE_DEFAULT_TIME = serializers.IntegerField(
min_value=7, max_value=9999, required=False,
label=_("Ticket authorize default time"), help_text=_("Unit: day")
)
HELP_DOCUMENT_URL = serializers.URLField(
required=False, allow_blank=True, allow_null=True, label=_("Help Docs URL"),
help_text=_('default: http://docs.jumpserver.org')

View File

@ -14,6 +14,7 @@ class PublicSettingSerializer(serializers.Serializer):
class PrivateSettingSerializer(PublicSettingSerializer):
WINDOWS_SKIP_ALL_MANUAL_PASSWORD = serializers.BooleanField()
OLD_PASSWORD_HISTORY_LIMIT_COUNT = serializers.IntegerField()
TICKET_AUTHORIZE_DEFAULT_TIME = serializers.IntegerField()
SECURITY_MAX_IDLE_TIME = serializers.IntegerField()
SECURITY_VIEW_AUTH_NEED_MFA = serializers.BooleanField()
SECURITY_MFA_VERIFY_TTL = serializers.IntegerField()

View File

@ -7,7 +7,7 @@ from .auth import (
LDAPSettingSerializer, OIDCSettingSerializer, KeycloakSettingSerializer,
CASSettingSerializer, RadiusSettingSerializer, FeiShuSettingSerializer,
WeComSettingSerializer, DingTalkSettingSerializer, AlibabaSMSSettingSerializer,
TencentSMSSettingSerializer,
TencentSMSSettingSerializer, CMPP2SMSSettingSerializer
)
from .terminal import TerminalSettingSerializer
from .security import SecuritySettingSerializer
@ -37,6 +37,7 @@ class SettingsSerializer(
CleaningSerializer,
AlibabaSMSSettingSerializer,
TencentSMSSettingSerializer,
CMPP2SMSSettingSerializer,
):
# encrypt_fields 现在使用 write_only 来判断了
pass

View File

@ -16,8 +16,7 @@ urlpatterns = [
path('wecom/testing/', api.WeComTestingAPI.as_view(), name='wecom-testing'),
path('dingtalk/testing/', api.DingTalkTestingAPI.as_view(), name='dingtalk-testing'),
path('feishu/testing/', api.FeiShuTestingAPI.as_view(), name='feishu-testing'),
path('alibaba/testing/', api.AlibabaSMSTestingAPI.as_view(), name='alibaba-sms-testing'),
path('tencent/testing/', api.TencentSMSTestingAPI.as_view(), name='tencent-sms-testing'),
path('sms/<str:backend>/testing/', api.SMSTestingAPI.as_view(), name='sms-testing'),
path('sms/backend/', api.SMSBackendAPI.as_view(), name='sms-backend'),
path('setting/', api.SettingsApi.as_view(), name='settings-setting'),

View File

@ -0,0 +1,9 @@
from django.urls import path
from .. import ws
app_name = 'common'
urlpatterns = [
path('ws/setting/tools/', ws.ToolsWebsocket.as_asgi(), name='setting-tools-ws'),
]

View File

@ -3,3 +3,5 @@
from .ldap import *
from .common import *
from .ping import *
from .telnet import *

154
apps/settings/utils/ping.py Normal file
View File

@ -0,0 +1,154 @@
# -*- coding: utf-8 -*-
#
import os
import select
import socket
import struct
import time
# From /usr/include/linux/icmp.h; your milage may vary.
ICMP_ECHO_REQUEST = 8 # Seems to be the same on Solaris.
def checksum(source_string):
"""
I'm not too confident that this is right but testing seems
to suggest that it gives the same answers as in_cksum in ping.c
"""
sum = 0
count_to = int((len(source_string) / 2) * 2)
for count in range(0, count_to, 2):
this = source_string[count + 1] * 256 + source_string[count]
sum = sum + this
sum = sum & 0xffffffff # Necessary?
if count_to < len(source_string):
sum = sum + ord(source_string[len(source_string) - 1])
sum = sum & 0xffffffff # Necessary?
sum = (sum >> 16) + (sum & 0xffff)
sum = sum + (sum >> 16)
answer = ~sum
answer = answer & 0xffff
# Swap bytes. Bugger me if I know why.
answer = answer >> 8 | (answer << 8 & 0xff00)
return answer
def receive_one_ping(my_socket, id, timeout):
"""
Receive the ping from the socket.
"""
time_left = timeout
while True:
started_select = time.time()
what_ready = select.select([my_socket], [], [], time_left)
how_long_in_select = time.time() - started_select
if not what_ready[0]: # Timeout
return
time_received = time.time()
received_packet, addr = my_socket.recvfrom(1024)
icmpHeader = received_packet[20:28]
type, code, checksum, packet_id, sequence = struct.unpack("bbHHh", icmpHeader)
if packet_id == id:
bytes = struct.calcsize("d")
time_sent = struct.unpack("d", received_packet[28: 28 + bytes])[0]
return time_received - time_sent
time_left = time_left - how_long_in_select
if time_left <= 0:
return
def send_one_ping(my_socket, dest_addr, id, psize):
"""
Send one ping to the given >dest_addr<.
"""
dest_addr = socket.gethostbyname(dest_addr)
# Remove header size from packet size
# psize = psize - 8
# laixintao edit:
# Do not need to remove header here. From BSD ping man:
# The default is 56, which translates into 64 ICMP data
# bytes when combined with the 8 bytes of ICMP header data.
# Header is type (8), code (8), checksum (16), id (16), sequence (16)
my_checksum = 0
# Make a dummy heder with a 0 checksum.
header = struct.pack("bbHHh", ICMP_ECHO_REQUEST, 0, my_checksum, id, 1)
bytes = struct.calcsize("d")
data = (psize - bytes) * b"Q"
data = struct.pack("d", time.time()) + data
# Calculate the checksum on the data and the dummy header.
my_checksum = checksum(header + data)
# Now that we have the right checksum, we put that in. It's just easier
# to make up a new header than to stuff it into the dummy.
header = struct.pack(
"bbHHh", ICMP_ECHO_REQUEST, 0, socket.htons(my_checksum), id, 1
)
packet = header + data
my_socket.sendto(packet, (dest_addr, 1)) # Don't know about the 1
def ping(dest_addr, timeout, psize, flag=0):
"""
Returns either the delay (in seconds) or none on timeout.
"""
icmp = socket.getprotobyname("icmp")
try:
if os.getuid() != 0:
my_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, icmp)
else:
my_socket = socket.socket(socket.AF_INET, socket.SOCK_RAW, icmp)
except socket.error as e:
if e.errno == 1:
# Operation not permitted
msg = str(e)
raise socket.error(msg)
raise # raise the original error
process_pre = os.getpid() & 0xFF00
flag = flag & 0x00FF
my_id = process_pre | flag
send_one_ping(my_socket, dest_addr, my_id, psize)
delay = receive_one_ping(my_socket, my_id, timeout)
my_socket.close()
return delay
def verbose_ping(dest_addr, timeout=2, count=5, psize=64):
"""
Send `count' ping with `psize' size to `dest_addr' with
the given `timeout' and display the result.
"""
for i in range(count):
print("ping %s with ..." % dest_addr, end="")
try:
delay = ping(dest_addr, timeout, psize)
except socket.gaierror as e:
print("failed. (socket error: '%s')" % str(e))
break
if delay is None:
print("failed. (timeout within %ssec.)" % timeout)
else:
delay = delay * 1000
print("get ping in %0.4fms" % delay)
print()
if __name__ == "__main__":
verbose_ping("google.com")
verbose_ping("192.168.4.1")
verbose_ping("www.baidu.com")
verbose_ping("sssssss")

View File

@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
#
import socket
import telnetlib
PROMPT_REGEX = r'[\<|\[](.*)[\>|\]]'
def telnet(dest_addr, port_number=23, timeout=10):
try:
connection = telnetlib.Telnet(dest_addr, port_number, timeout)
except (ConnectionRefusedError, socket.timeout, socket.gaierror) as e:
return False, str(e)
expected_regexes = [bytes(PROMPT_REGEX, encoding='ascii')]
index, prompt_regex, output = connection.expect(expected_regexes, timeout=3)
return True, output.decode('ascii')
if __name__ == "__main__":
print(telnet(dest_addr='1.1.1.1', port_number=2222))
print(telnet(dest_addr='baidu.com', port_number=80))
print(telnet(dest_addr='baidu.com', port_number=8080))
print(telnet(dest_addr='192.168.4.1', port_number=2222))
print(telnet(dest_addr='192.168.4.1', port_number=2223))
print(telnet(dest_addr='ssssss', port_number=-1))

Some files were not shown because too many files have changed in this diff Show More