merge: with dev

pull/12643/head
ibuler 2024-02-05 09:49:43 +08:00
commit b284bb60f5
103 changed files with 1967 additions and 1189 deletions

View File

@ -1,11 +1,12 @@
from django.db.models import Q
from rest_framework.generics import CreateAPIView from rest_framework.generics import CreateAPIView
from accounts import serializers from accounts import serializers
from accounts.models import Account
from accounts.permissions import AccountTaskActionPermission from accounts.permissions import AccountTaskActionPermission
from accounts.tasks import ( from accounts.tasks import (
remove_accounts_task, verify_accounts_connectivity_task, push_accounts_to_assets_task remove_accounts_task, verify_accounts_connectivity_task, push_accounts_to_assets_task
) )
from assets.exceptions import NotSupportedTemporarilyError
from authentication.permissions import UserConfirmation, ConfirmType from authentication.permissions import UserConfirmation, ConfirmType
__all__ = [ __all__ = [
@ -26,25 +27,35 @@ class AccountsTaskCreateAPI(CreateAPIView):
] ]
return super().get_permissions() return super().get_permissions()
def perform_create(self, serializer): @staticmethod
data = serializer.validated_data def get_account_ids(data, action):
accounts = data.get('accounts', []) account_type = 'gather_accounts' if action == 'remove' else 'accounts'
params = data.get('params') accounts = data.get(account_type, [])
account_ids = [str(a.id) for a in accounts] account_ids = [str(a.id) for a in accounts]
if data['action'] == 'push': if action == 'remove':
task = push_accounts_to_assets_task.delay(account_ids, params) return account_ids
elif data['action'] == 'remove':
gather_accounts = data.get('gather_accounts', []) assets = data.get('assets', [])
gather_account_ids = [str(a.id) for a in gather_accounts] asset_ids = [str(a.id) for a in assets]
task = remove_accounts_task.delay(gather_account_ids) ids = Account.objects.filter(
Q(id__in=account_ids) | Q(asset_id__in=asset_ids)
).distinct().values_list('id', flat=True)
return [str(_id) for _id in ids]
def perform_create(self, serializer):
data = serializer.validated_data
action = data['action']
ids = self.get_account_ids(data, action)
if action == 'push':
task = push_accounts_to_assets_task.delay(ids, data.get('params'))
elif action == 'remove':
task = remove_accounts_task.delay(ids)
elif action == 'verify':
task = verify_accounts_connectivity_task.delay(ids)
else: else:
account = accounts[0] raise ValueError(f"Invalid action: {action}")
asset = account.asset
if not asset.auto_config['ansible_enabled'] or \
not asset.auto_config['ping_enabled']:
raise NotSupportedTemporarilyError()
task = verify_accounts_connectivity_task.delay(account_ids)
data = getattr(serializer, '_data', {}) data = getattr(serializer, '_data', {})
data["task"] = task.id data["task"] = task.id

View File

@ -145,9 +145,9 @@ class AccountBackupHandler:
wb = Workbook(filename) wb = Workbook(filename)
for sheet, data in data_map.items(): for sheet, data in data_map.items():
ws = wb.add_worksheet(str(sheet)) ws = wb.add_worksheet(str(sheet))
for row in data: for row_index, row_data in enumerate(data):
for col, _data in enumerate(row): for col_index, col_data in enumerate(row_data):
ws.write_string(0, col, _data) ws.write_string(row_index, col_index, col_data)
wb.close() wb.close()
files.append(filename) files.append(filename)
timedelta = round((time.time() - time_start), 2) timedelta = round((time.time() - time_start), 2)

View File

@ -39,3 +39,4 @@
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
login_database: "{{ jms_asset.spec_info.db_name }}" login_database: "{{ jms_asset.spec_info.db_name }}"
mode: "{{ account.mode }}"

View File

@ -4,6 +4,7 @@ from copy import deepcopy
from django.conf import settings from django.conf import settings
from django.utils import timezone from django.utils import timezone
from django.utils.translation import gettext_lazy as _
from xlsxwriter import Workbook from xlsxwriter import Workbook
from accounts.const import AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy from accounts.const import AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy
@ -161,7 +162,8 @@ class ChangeSecretManager(AccountBasePlaybookManager):
print("Account not found, deleted ?") print("Account not found, deleted ?")
return return
account.secret = recorder.new_secret account.secret = recorder.new_secret
account.save(update_fields=['secret']) account.date_updated = timezone.now()
account.save(update_fields=['secret', 'date_updated'])
def on_host_error(self, host, error, result): def on_host_error(self, host, error, result):
recorder = self.name_recorder_mapper.get(host) recorder = self.name_recorder_mapper.get(host)
@ -182,17 +184,33 @@ class ChangeSecretManager(AccountBasePlaybookManager):
return False return False
return True return True
@staticmethod
def get_summary(recorders):
total, succeed, failed = 0, 0, 0
for recorder in recorders:
if recorder.status == 'success':
succeed += 1
else:
failed += 1
total += 1
summary = _('Success: %s, Failed: %s, Total: %s') % (succeed, failed, total)
return summary
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
if self.secret_type and not self.check_secret(): if self.secret_type and not self.check_secret():
return return
super().run(*args, **kwargs) super().run(*args, **kwargs)
recorders = list(self.name_recorder_mapper.values())
summary = self.get_summary(recorders)
print(summary, end='')
if self.record_id: if self.record_id:
return return
recorders = self.name_recorder_mapper.values()
recorders = list(recorders)
self.send_recorder_mail(recorders)
def send_recorder_mail(self, recorders): self.send_recorder_mail(recorders, summary)
def send_recorder_mail(self, recorders, summary):
recipients = self.execution.recipients recipients = self.execution.recipients
if not recorders or not recipients: if not recorders or not recipients:
return return
@ -212,7 +230,7 @@ class ChangeSecretManager(AccountBasePlaybookManager):
attachment = os.path.join(path, f'{name}-{local_now_filename()}-{time.time()}.zip') attachment = os.path.join(path, f'{name}-{local_now_filename()}-{time.time()}.zip')
encrypt_and_compress_zip_file(attachment, password, [filename]) encrypt_and_compress_zip_file(attachment, password, [filename])
attachments = [attachment] attachments = [attachment]
ChangeSecretExecutionTaskMsg(name, user).publish(attachments) ChangeSecretExecutionTaskMsg(name, user, summary).publish(attachments)
os.remove(filename) os.remove(filename)
@staticmethod @staticmethod
@ -228,8 +246,8 @@ class ChangeSecretManager(AccountBasePlaybookManager):
rows.insert(0, header) rows.insert(0, header)
wb = Workbook(filename) wb = Workbook(filename)
ws = wb.add_worksheet('Sheet1') ws = wb.add_worksheet('Sheet1')
for row in rows: for row_index, row_data in enumerate(rows):
for col, data in enumerate(row): for col_index, col_data in enumerate(row_data):
ws.write_string(0, col, data) ws.write_string(row_index, col_index, col_data)
wb.close() wb.close()
return True return True

View File

@ -1,9 +1,10 @@
- hosts: demo - hosts: demo
gather_facts: no gather_facts: no
tasks: tasks:
- name: Gather posix account - name: Gather windows account
ansible.builtin.win_shell: net user ansible.builtin.win_shell: net user
register: result register: result
ignore_errors: true
- name: Define info by set_fact - name: Define info by set_fact
ansible.builtin.set_fact: ansible.builtin.set_fact:

View File

@ -39,3 +39,4 @@
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
login_database: "{{ jms_asset.spec_info.db_name }}" login_database: "{{ jms_asset.spec_info.db_name }}"
mode: "{{ account.mode }}"

View File

@ -54,20 +54,23 @@ class AccountBackupByObjStorageExecutionTaskMsg(object):
class ChangeSecretExecutionTaskMsg(object): class ChangeSecretExecutionTaskMsg(object):
subject = _('Notification of implementation result of encryption change plan') subject = _('Notification of implementation result of encryption change plan')
def __init__(self, name: str, user: User): def __init__(self, name: str, user: User, summary):
self.name = name self.name = name
self.user = user self.user = user
self.summary = summary
@property @property
def message(self): def message(self):
name = self.name name = self.name
if self.user.secret_key: if self.user.secret_key:
return _('{} - The encryption change task has been completed. ' default_message = _('{} - The encryption change task has been completed. '
'See the attachment for details').format(name) 'See the attachment for details').format(name)
else: else:
return _("{} - The encryption change task has been completed: the encryption " default_message = _("{} - The encryption change task has been completed: the encryption "
"password has not been set - please go to personal information -> " "password has not been set - please go to personal information -> "
"file encryption password to set the encryption password").format(name) "set encryption password in preferences").format(name)
return self.summary + '\n' + default_message
def publish(self, attachments=None): def publish(self, attachments=None):
send_mail_attachment_async( send_mail_attachment_async(

View File

@ -60,7 +60,7 @@ class AccountCreateUpdateSerializerMixin(serializers.Serializer):
for data in initial_data: for data in initial_data:
if not data.get('asset') and not self.instance: if not data.get('asset') and not self.instance:
raise serializers.ValidationError({'asset': UniqueTogetherValidator.missing_message}) raise serializers.ValidationError({'asset': UniqueTogetherValidator.missing_message})
asset = data.get('asset') or self.instance.asset asset = data.get('asset') or getattr(self.instance, 'asset', None)
self.from_template_if_need(data) self.from_template_if_need(data)
self.set_uniq_name_if_need(data, asset) self.set_uniq_name_if_need(data, asset)
@ -457,12 +457,14 @@ class AccountHistorySerializer(serializers.ModelSerializer):
class AccountTaskSerializer(serializers.Serializer): class AccountTaskSerializer(serializers.Serializer):
ACTION_CHOICES = ( ACTION_CHOICES = (
('test', 'test'),
('verify', 'verify'), ('verify', 'verify'),
('push', 'push'), ('push', 'push'),
('remove', 'remove'), ('remove', 'remove'),
) )
action = serializers.ChoiceField(choices=ACTION_CHOICES, write_only=True) action = serializers.ChoiceField(choices=ACTION_CHOICES, write_only=True)
assets = serializers.PrimaryKeyRelatedField(
queryset=Asset.objects, required=False, allow_empty=True, many=True
)
accounts = serializers.PrimaryKeyRelatedField( accounts = serializers.PrimaryKeyRelatedField(
queryset=Account.objects, required=False, allow_empty=True, many=True queryset=Account.objects, required=False, allow_empty=True, many=True
) )

View File

@ -21,7 +21,8 @@ def on_account_pre_save(sender, instance, **kwargs):
if instance.version == 0: if instance.version == 0:
instance.version = 1 instance.version = 1
else: else:
instance.version = instance.history.count() history_account = instance.history.first()
instance.version = history_account.version + 1 if history_account else 0
@merge_delay_run(ttl=5) @merge_delay_run(ttl=5)
@ -62,7 +63,7 @@ def create_accounts_activities(account, action='create'):
def on_account_create_by_template(sender, instance, created=False, **kwargs): def on_account_create_by_template(sender, instance, created=False, **kwargs):
if not created or instance.source != 'template': if not created or instance.source != 'template':
return return
push_accounts_if_need(accounts=(instance,)) push_accounts_if_need.delay(accounts=(instance,))
create_accounts_activities(instance, action='create') create_accounts_activities(instance, action='create')

View File

@ -55,7 +55,7 @@ def clean_historical_accounts():
history_model = Account.history.model history_model = Account.history.model
history_id_mapper = defaultdict(list) history_id_mapper = defaultdict(list)
ids = history_model.objects.values('id').annotate(count=Count('id')) \ ids = history_model.objects.values('id').annotate(count=Count('id', distinct=True)) \
.filter(count__gte=limit).values_list('id', flat=True) .filter(count__gte=limit).values_list('id', flat=True)
if not ids: if not ids:

View File

@ -92,6 +92,7 @@ class AssetViewSet(SuggestionMixin, OrgBulkModelViewSet):
model = Asset model = Asset
filterset_class = AssetFilterSet filterset_class = AssetFilterSet
search_fields = ("name", "address", "comment") search_fields = ("name", "address", "comment")
ordering = ('name',)
ordering_fields = ('name', 'address', 'connectivity', 'platform', 'date_updated', 'date_created') ordering_fields = ('name', 'address', 'connectivity', 'platform', 'date_updated', 'date_created')
serializer_classes = ( serializer_classes = (
("default", serializers.AssetSerializer), ("default", serializers.AssetSerializer),

View File

@ -12,6 +12,6 @@ class Migration(migrations.Migration):
operations = [ operations = [
migrations.AlterModelOptions( migrations.AlterModelOptions(
name='asset', name='asset',
options={'ordering': ['name'], 'permissions': [('refresh_assethardwareinfo', 'Can refresh asset hardware info'), ('test_assetconnectivity', 'Can test asset connectivity'), ('match_asset', 'Can match asset'), ('change_assetnodes', 'Can change asset nodes')], 'verbose_name': 'Asset'}, options={'ordering': [], 'permissions': [('refresh_assethardwareinfo', 'Can refresh asset hardware info'), ('test_assetconnectivity', 'Can test asset connectivity'), ('match_asset', 'Can match asset'), ('change_assetnodes', 'Can change asset nodes')], 'verbose_name': 'Asset'},
), ),
] ]

View File

@ -348,7 +348,7 @@ class Asset(NodesRelationMixin, LabeledMixin, AbsConnectivity, JSONFilterMixin,
class Meta: class Meta:
unique_together = [('org_id', 'name')] unique_together = [('org_id', 'name')]
verbose_name = _("Asset") verbose_name = _("Asset")
ordering = ["name", ] ordering = []
permissions = [ permissions = [
('refresh_assethardwareinfo', _('Can refresh asset hardware info')), ('refresh_assethardwareinfo', _('Can refresh asset hardware info')),
('test_assetconnectivity', _('Can test asset connectivity')), ('test_assetconnectivity', _('Can test asset connectivity')),

View File

@ -429,7 +429,7 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
@classmethod @classmethod
@timeit @timeit
def get_nodes_all_assets(cls, *nodes): def get_nodes_all_assets(cls, *nodes, distinct=True):
from .asset import Asset from .asset import Asset
node_ids = set() node_ids = set()
descendant_node_query = Q() descendant_node_query = Q()
@ -439,7 +439,10 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
if descendant_node_query: if descendant_node_query:
_ids = Node.objects.order_by().filter(descendant_node_query).values_list('id', flat=True) _ids = Node.objects.order_by().filter(descendant_node_query).values_list('id', flat=True)
node_ids.update(_ids) node_ids.update(_ids)
return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct() assets = Asset.objects.order_by().filter(nodes__id__in=node_ids)
if distinct:
assets = assets.distinct()
return assets
def get_all_asset_ids(self): def get_all_asset_ids(self):
asset_ids = self.get_all_asset_ids_by_node_key(org_id=self.org_id, node_key=self.key) asset_ids = self.get_all_asset_ids_by_node_key(org_id=self.org_id, node_key=self.key)

View File

@ -1,8 +1,8 @@
from rest_framework.pagination import LimitOffsetPagination from rest_framework.pagination import LimitOffsetPagination
from rest_framework.request import Request from rest_framework.request import Request
from common.utils import get_logger
from assets.models import Node from assets.models import Node
from common.utils import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@ -28,6 +28,7 @@ class AssetPaginationBase(LimitOffsetPagination):
'key', 'all', 'show_current_asset', 'key', 'all', 'show_current_asset',
'cache_policy', 'display', 'draw', 'cache_policy', 'display', 'draw',
'order', 'node', 'node_id', 'fields_size', 'order', 'node', 'node_id', 'fields_size',
'asset'
} }
for k, v in self._request.query_params.items(): for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None: if k not in exclude_query_params and v is not None:

View File

@ -206,9 +206,12 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
""" Perform necessary eager loading of data. """ """ Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('domain', 'nodes', 'protocols', ) \ queryset = queryset.prefetch_related('domain', 'nodes', 'protocols', ) \
.prefetch_related('platform', 'platform__automation') \ .prefetch_related('platform', 'platform__automation') \
.prefetch_related('labels', 'labels__label') \
.annotate(category=F("platform__category")) \ .annotate(category=F("platform__category")) \
.annotate(type=F("platform__type")) .annotate(type=F("platform__type"))
if queryset.model is Asset:
queryset = queryset.prefetch_related('labels__label', 'labels')
else:
queryset = queryset.prefetch_related('asset_ptr__labels__label', 'asset_ptr__labels')
return queryset return queryset
@staticmethod @staticmethod

View File

@ -56,7 +56,14 @@ class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
class DomainListSerializer(DomainSerializer): class DomainListSerializer(DomainSerializer):
class Meta(DomainSerializer.Meta): class Meta(DomainSerializer.Meta):
fields = list(set(DomainSerializer.Meta.fields) - {'assets'}) fields = list(set(DomainSerializer.Meta.fields + ['assets_amount']) - {'assets'})
@classmethod
def setup_eager_loading(cls, queryset):
queryset = queryset.annotate(
assets_amount=Count('assets', distinct=True),
)
return queryset
class DomainWithGatewaySerializer(serializers.ModelSerializer): class DomainWithGatewaySerializer(serializers.ModelSerializer):

View File

@ -191,7 +191,6 @@ class PlatformSerializer(ResourceLabelsMixin, WritableNestedModelSerializer):
def add_type_choices(self, name, label): def add_type_choices(self, name, label):
tp = self.fields['type'] tp = self.fields['type']
tp.choices[name] = label tp.choices[name] = label
tp.choice_mapper[name] = label
tp.choice_strings_to_values[name] = label tp.choice_strings_to_values[name] = label
@lazyproperty @lazyproperty

View File

@ -63,13 +63,13 @@ def on_asset_create(sender, instance=None, created=False, **kwargs):
return return
logger.info("Asset create signal recv: {}".format(instance)) logger.info("Asset create signal recv: {}".format(instance))
ensure_asset_has_node(assets=(instance,)) ensure_asset_has_node.delay(assets=(instance,))
# 获取资产硬件信息 # 获取资产硬件信息
auto_config = instance.auto_config auto_config = instance.auto_config
if auto_config.get('ping_enabled'): if auto_config.get('ping_enabled'):
logger.debug('Asset {} ping enabled, test connectivity'.format(instance.name)) logger.debug('Asset {} ping enabled, test connectivity'.format(instance.name))
test_assets_connectivity_handler(assets=(instance,)) test_assets_connectivity_handler.delay(assets=(instance,))
if auto_config.get('gather_facts_enabled'): if auto_config.get('gather_facts_enabled'):
logger.debug('Asset {} gather facts enabled, gather facts'.format(instance.name)) logger.debug('Asset {} gather facts enabled, gather facts'.format(instance.name))
gather_assets_facts_handler(assets=(instance,)) gather_assets_facts_handler(assets=(instance,))

View File

@ -2,14 +2,16 @@
# #
from operator import add, sub from operator import add, sub
from django.conf import settings
from django.db.models.signals import m2m_changed from django.db.models.signals import m2m_changed
from django.dispatch import receiver from django.dispatch import receiver
from assets.models import Asset, Node from assets.models import Asset, Node
from common.const.signals import PRE_CLEAR, POST_ADD, PRE_REMOVE from common.const.signals import PRE_CLEAR, POST_ADD, PRE_REMOVE
from common.decorators import on_transaction_commit, merge_delay_run from common.decorators import on_transaction_commit, merge_delay_run
from common.signals import django_ready
from common.utils import get_logger from common.utils import get_logger
from orgs.utils import tmp_to_org from orgs.utils import tmp_to_org, tmp_to_root_org
from ..tasks import check_node_assets_amount_task from ..tasks import check_node_assets_amount_task
logger = get_logger(__file__) logger = get_logger(__file__)
@ -34,7 +36,7 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
node_ids = [instance.id] node_ids = [instance.id]
else: else:
node_ids = list(pk_set) node_ids = list(pk_set)
update_nodes_assets_amount(node_ids=node_ids) update_nodes_assets_amount.delay(node_ids=node_ids)
@merge_delay_run(ttl=30) @merge_delay_run(ttl=30)
@ -52,3 +54,18 @@ def update_nodes_assets_amount(node_ids=()):
node.assets_amount = node.get_assets_amount() node.assets_amount = node.get_assets_amount()
Node.objects.bulk_update(nodes, ['assets_amount']) Node.objects.bulk_update(nodes, ['assets_amount'])
@receiver(django_ready)
def set_assets_size_to_setting(sender, **kwargs):
from assets.models import Asset
try:
with tmp_to_root_org():
amount = Asset.objects.order_by().count()
except:
amount = 0
if amount > 20000:
settings.ASSET_SIZE = 'large'
elif amount > 2000:
settings.ASSET_SIZE = 'medium'

View File

@ -44,18 +44,18 @@ def on_node_post_create(sender, instance, created, update_fields, **kwargs):
need_expire = False need_expire = False
if need_expire: if need_expire:
expire_node_assets_mapping(org_ids=(instance.org_id,)) expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
@receiver(post_delete, sender=Node) @receiver(post_delete, sender=Node)
def on_node_post_delete(sender, instance, **kwargs): def on_node_post_delete(sender, instance, **kwargs):
expire_node_assets_mapping(org_ids=(instance.org_id,)) expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
@receiver(m2m_changed, sender=Asset.nodes.through) @receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, instance, action='pre_remove', **kwargs): def on_node_asset_change(sender, instance, action='pre_remove', **kwargs):
if action.startswith('post'): if action.startswith('post'):
expire_node_assets_mapping(org_ids=(instance.org_id,)) expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
@receiver(django_ready) @receiver(django_ready)

View File

@ -2,6 +2,7 @@
from django.urls import path from django.urls import path
from rest_framework_bulk.routes import BulkRouter from rest_framework_bulk.routes import BulkRouter
from labels.api import LabelViewSet
from .. import api from .. import api
app_name = 'assets' app_name = 'assets'
@ -22,6 +23,7 @@ router.register(r'domains', api.DomainViewSet, 'domain')
router.register(r'gateways', api.GatewayViewSet, 'gateway') router.register(r'gateways', api.GatewayViewSet, 'gateway')
router.register(r'favorite-assets', api.FavoriteAssetViewSet, 'favorite-asset') router.register(r'favorite-assets', api.FavoriteAssetViewSet, 'favorite-asset')
router.register(r'protocol-settings', api.PlatformProtocolViewSet, 'protocol-setting') router.register(r'protocol-settings', api.PlatformProtocolViewSet, 'protocol-setting')
router.register(r'labels', LabelViewSet, 'label')
urlpatterns = [ urlpatterns = [
# path('assets/<uuid:pk>/gateways/', api.AssetGatewayListApi.as_view(), name='asset-gateway-list'), # path('assets/<uuid:pk>/gateways/', api.AssetGatewayListApi.as_view(), name='asset-gateway-list'),

View File

@ -4,7 +4,6 @@ from urllib.parse import urlencode, urlparse
from kubernetes import client from kubernetes import client
from kubernetes.client import api_client from kubernetes.client import api_client
from kubernetes.client.api import core_v1_api from kubernetes.client.api import core_v1_api
from kubernetes.client.exceptions import ApiException
from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError
from common.utils import get_logger from common.utils import get_logger
@ -88,8 +87,9 @@ class KubernetesClient:
if hasattr(self, func_name): if hasattr(self, func_name):
try: try:
data = getattr(self, func_name)(*args) data = getattr(self, func_name)(*args)
except ApiException as e: except Exception as e:
logger.error(e.reason) logger.error(e)
raise e
if self.server: if self.server:
self.server.stop() self.server.stop()

View File

@ -20,6 +20,7 @@ from common.const.http import GET, POST
from common.drf.filters import DatetimeRangeFilterBackend from common.drf.filters import DatetimeRangeFilterBackend
from common.permissions import IsServiceAccount from common.permissions import IsServiceAccount
from common.plugins.es import QuerySet as ESQuerySet from common.plugins.es import QuerySet as ESQuerySet
from common.sessions.cache import user_session_manager
from common.storage.ftp_file import FTPFileStorageHandler from common.storage.ftp_file import FTPFileStorageHandler
from common.utils import is_uuid, get_logger, lazyproperty from common.utils import is_uuid, get_logger, lazyproperty
from orgs.mixins.api import OrgReadonlyModelViewSet, OrgModelViewSet from orgs.mixins.api import OrgReadonlyModelViewSet, OrgModelViewSet
@ -289,8 +290,7 @@ class UserSessionViewSet(CommonApiMixin, viewsets.ModelViewSet):
return Response(status=status.HTTP_200_OK) return Response(status=status.HTTP_200_OK)
keys = queryset.values_list('key', flat=True) keys = queryset.values_list('key', flat=True)
session_store_cls = import_module(settings.SESSION_ENGINE).SessionStore
for key in keys: for key in keys:
session_store_cls(key).delete() user_session_manager.decrement_or_remove(key)
queryset.delete() queryset.delete()
return Response(status=status.HTTP_200_OK) return Response(status=status.HTTP_200_OK)

View File

@ -1,10 +1,9 @@
from django.core.cache import cache
from django_filters import rest_framework as drf_filters from django_filters import rest_framework as drf_filters
from rest_framework import filters from rest_framework import filters
from rest_framework.compat import coreapi, coreschema from rest_framework.compat import coreapi, coreschema
from common.drf.filters import BaseFilterSet from common.drf.filters import BaseFilterSet
from notifications.ws import WS_SESSION_KEY from common.sessions.cache import user_session_manager
from orgs.utils import current_org from orgs.utils import current_org
from .models import UserSession from .models import UserSession
@ -41,13 +40,11 @@ class UserSessionFilterSet(BaseFilterSet):
@staticmethod @staticmethod
def filter_is_active(queryset, name, is_active): def filter_is_active(queryset, name, is_active):
redis_client = cache.client.get_client() keys = user_session_manager.get_active_keys()
members = redis_client.smembers(WS_SESSION_KEY)
members = [member.decode('utf-8') for member in members]
if is_active: if is_active:
queryset = queryset.filter(key__in=members) queryset = queryset.filter(key__in=keys)
else: else:
queryset = queryset.exclude(key__in=members) queryset = queryset.exclude(key__in=keys)
return queryset return queryset
class Meta: class Meta:

View File

@ -4,15 +4,15 @@ from datetime import timedelta
from importlib import import_module from importlib import import_module
from django.conf import settings from django.conf import settings
from django.core.cache import caches, cache from django.core.cache import caches
from django.db import models from django.db import models
from django.db.models import Q from django.db.models import Q
from django.utils import timezone from django.utils import timezone
from django.utils.translation import gettext, gettext_lazy as _ from django.utils.translation import gettext, gettext_lazy as _
from common.db.encoder import ModelJSONFieldEncoder from common.db.encoder import ModelJSONFieldEncoder
from common.sessions.cache import user_session_manager
from common.utils import lazyproperty, i18n_trans from common.utils import lazyproperty, i18n_trans
from notifications.ws import WS_SESSION_KEY
from ops.models import JobExecution from ops.models import JobExecution
from orgs.mixins.models import OrgModelMixin, Organization from orgs.mixins.models import OrgModelMixin, Organization
from orgs.utils import current_org from orgs.utils import current_org
@ -278,8 +278,7 @@ class UserSession(models.Model):
@property @property
def is_active(self): def is_active(self):
redis_client = cache.client.get_client() return user_session_manager.check_active(self.key)
return redis_client.sismember(WS_SESSION_KEY, self.key)
@property @property
def date_expired(self): def date_expired(self):

View File

@ -205,7 +205,7 @@ class RDPFileClientProtocolURLMixin:
return data return data
def get_smart_endpoint(self, protocol, asset=None): def get_smart_endpoint(self, protocol, asset=None):
endpoint = Endpoint.match_by_instance_label(asset, protocol) endpoint = Endpoint.match_by_instance_label(asset, protocol, self.request)
if not endpoint: if not endpoint:
target_ip = asset.get_target_ip() if asset else '' target_ip = asset.get_target_ip() if asset else ''
endpoint = EndpointRule.match_endpoint( endpoint = EndpointRule.match_endpoint(

View File

@ -90,6 +90,6 @@ class MFAChallengeVerifyApi(AuthMixin, CreateAPIView):
return Response({'msg': 'ok'}) return Response({'msg': 'ok'})
except errors.AuthFailedError as e: except errors.AuthFailedError as e:
data = {"error": e.error, "msg": e.msg} data = {"error": e.error, "msg": e.msg}
raise ValidationError(data) return Response(data, status=401)
except errors.NeedMoreInfoError as e: except errors.NeedMoreInfoError as e:
return Response(e.as_data(), status=200) return Response(e.as_data(), status=200)

View File

@ -10,6 +10,7 @@ from rest_framework import authentication, exceptions
from common.auth import signature from common.auth import signature
from common.decorators import merge_delay_run from common.decorators import merge_delay_run
from common.utils import get_object_or_none, get_request_ip_or_data, contains_ip from common.utils import get_object_or_none, get_request_ip_or_data, contains_ip
from users.models import User
from ..models import AccessKey, PrivateToken from ..models import AccessKey, PrivateToken
@ -19,22 +20,23 @@ def date_more_than(d, seconds):
@merge_delay_run(ttl=60) @merge_delay_run(ttl=60)
def update_token_last_used(tokens=()): def update_token_last_used(tokens=()):
for token in tokens: access_keys_ids = [token.id for token in tokens if isinstance(token, AccessKey)]
token.date_last_used = timezone.now() private_token_keys = [token.key for token in tokens if isinstance(token, PrivateToken)]
token.save(update_fields=['date_last_used']) if len(access_keys_ids) > 0:
AccessKey.objects.filter(id__in=access_keys_ids).update(date_last_used=timezone.now())
if len(private_token_keys) > 0:
PrivateToken.objects.filter(key__in=private_token_keys).update(date_last_used=timezone.now())
@merge_delay_run(ttl=60) @merge_delay_run(ttl=60)
def update_user_last_used(users=()): def update_user_last_used(users=()):
for user in users: User.objects.filter(id__in=users).update(date_api_key_last_used=timezone.now())
user.date_api_key_last_used = timezone.now()
user.save(update_fields=['date_api_key_last_used'])
def after_authenticate_update_date(user, token=None): def after_authenticate_update_date(user, token=None):
update_user_last_used(users=(user,)) update_user_last_used.delay(users=(user.id,))
if token: if token:
update_token_last_used(tokens=(token,)) update_token_last_used.delay(tokens=(token,))
class AccessTokenAuthentication(authentication.BaseAuthentication): class AccessTokenAuthentication(authentication.BaseAuthentication):

View File

@ -98,16 +98,19 @@ class OAuth2Backend(JMSModelBackend):
access_token_url = '{url}{separator}{query}'.format( access_token_url = '{url}{separator}{query}'.format(
url=settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT, separator=separator, query=urlencode(query_dict) url=settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT, separator=separator, query=urlencode(query_dict)
) )
# token_method -> get, post(post_data), post_json
token_method = settings.AUTH_OAUTH2_ACCESS_TOKEN_METHOD.lower() token_method = settings.AUTH_OAUTH2_ACCESS_TOKEN_METHOD.lower()
requests_func = getattr(requests, token_method, requests.get)
logger.debug(log_prompt.format('Call the access token endpoint[method: %s]' % token_method)) logger.debug(log_prompt.format('Call the access token endpoint[method: %s]' % token_method))
headers = { headers = {
'Accept': 'application/json' 'Accept': 'application/json'
} }
if token_method == 'post': if token_method.startswith('post'):
access_token_response = requests_func(access_token_url, headers=headers, data=query_dict) body_key = 'json' if token_method.endswith('json') else 'data'
access_token_response = requests.post(
access_token_url, headers=headers, **{body_key: query_dict}
)
else: else:
access_token_response = requests_func(access_token_url, headers=headers) access_token_response = requests.get(access_token_url, headers=headers)
try: try:
access_token_response.raise_for_status() access_token_response.raise_for_status()
access_token_response_data = access_token_response.json() access_token_response_data = access_token_response.json()

View File

@ -18,7 +18,7 @@ class EncryptedField(forms.CharField):
class UserLoginForm(forms.Form): class UserLoginForm(forms.Form):
days_auto_login = int(settings.SESSION_COOKIE_AGE / 3600 / 24) days_auto_login = int(settings.SESSION_COOKIE_AGE / 3600 / 24)
disable_days_auto_login = settings.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE \ disable_days_auto_login = settings.SESSION_EXPIRE_AT_BROWSER_CLOSE \
or days_auto_login < 1 or days_auto_login < 1
username = forms.CharField( username = forms.CharField(

View File

@ -142,23 +142,7 @@ class SessionCookieMiddleware(MiddlewareMixin):
return response return response
response.set_cookie(key, value) response.set_cookie(key, value)
@staticmethod
def set_cookie_session_expire(request, response):
if not request.session.get('auth_session_expiration_required'):
return
value = 'age'
if settings.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE or \
not request.session.get('auto_login', False):
value = 'close'
age = request.session.get_expiry_age()
expire_timestamp = request.session.get_expiry_date().timestamp()
response.set_cookie('jms_session_expire_timestamp', expire_timestamp)
response.set_cookie('jms_session_expire', value, max_age=age)
request.session.pop('auth_session_expiration_required', None)
def process_response(self, request, response: HttpResponse): def process_response(self, request, response: HttpResponse):
self.set_cookie_session_prefix(request, response) self.set_cookie_session_prefix(request, response)
self.set_cookie_public_key(request, response) self.set_cookie_public_key(request, response)
self.set_cookie_session_expire(request, response)
return response return response

View File

@ -37,9 +37,6 @@ def on_user_auth_login_success(sender, user, request, **kwargs):
UserSession.objects.filter(key=session_key).delete() UserSession.objects.filter(key=session_key).delete()
cache.set(lock_key, request.session.session_key, None) cache.set(lock_key, request.session.session_key, None)
# 标记登录,设置 cookie前端可以控制刷新, Middleware 会拦截这个生成 cookie
request.session['auth_session_expiration_required'] = 1
@receiver(cas_user_authenticated) @receiver(cas_user_authenticated)
def on_cas_user_login_success(sender, request, user, **kwargs): def on_cas_user_login_success(sender, request, user, **kwargs):

View File

@ -70,11 +70,12 @@ class DingTalkQRMixin(DingTalkBaseMixin, View):
self.request.session[DINGTALK_STATE_SESSION_KEY] = state self.request.session[DINGTALK_STATE_SESSION_KEY] = state
params = { params = {
'appid': settings.DINGTALK_APPKEY, 'client_id': settings.DINGTALK_APPKEY,
'response_type': 'code', 'response_type': 'code',
'scope': 'snsapi_login', 'scope': 'openid',
'state': state, 'state': state,
'redirect_uri': redirect_uri, 'redirect_uri': redirect_uri,
'prompt': 'consent'
} }
url = URL.QR_CONNECT + '?' + urlencode(params) url = URL.QR_CONNECT + '?' + urlencode(params)
return url return url

View File

@ -104,9 +104,11 @@ class QuerySetMixin:
page = super().paginate_queryset(queryset) page = super().paginate_queryset(queryset)
serializer_class = self.get_serializer_class() serializer_class = self.get_serializer_class()
if page and serializer_class and hasattr(serializer_class, 'setup_eager_loading'): if page and serializer_class and hasattr(serializer_class, 'setup_eager_loading'):
ids = [i.id for i in page] ids = [str(obj.id) for obj in page]
page = self.get_queryset().filter(id__in=ids) page = self.get_queryset().filter(id__in=ids)
page = serializer_class.setup_eager_loading(page) page = serializer_class.setup_eager_loading(page)
page_mapper = {str(obj.id): obj for obj in page}
page = [page_mapper.get(_id) for _id in ids if _id in page_mapper]
return page return page

View File

@ -19,3 +19,17 @@ class Status(models.TextChoices):
failed = 'failed', _("Failed") failed = 'failed', _("Failed")
error = 'error', _("Error") error = 'error', _("Error")
canceled = 'canceled', _("Canceled") canceled = 'canceled', _("Canceled")
COUNTRY_CALLING_CODES = [
{'name': 'China(中国)', 'value': '+86'},
{'name': 'HongKong(中国香港)', 'value': '+852'},
{'name': 'Macao(中国澳门)', 'value': '+853'},
{'name': 'Taiwan(中国台湾)', 'value': '+886'},
{'name': 'America(America)', 'value': '+1'}, {'name': 'Russia(Россия)', 'value': '+7'},
{'name': 'France(français)', 'value': '+33'},
{'name': 'Britain(Britain)', 'value': '+44'},
{'name': 'Germany(Deutschland)', 'value': '+49'},
{'name': 'Japan(日本)', 'value': '+81'}, {'name': 'Korea(한국)', 'value': '+82'},
{'name': 'India(भारत)', 'value': '+91'}
]

View File

@ -362,11 +362,15 @@ class RelatedManager:
if name is None or val is None: if name is None or val is None:
continue continue
if custom_attr_filter: custom_filter_q = None
spec_attr_filter = getattr(to_model, "get_{}_filter_attr_q".format(name), None)
if spec_attr_filter:
custom_filter_q = spec_attr_filter(val, match)
elif custom_attr_filter:
custom_filter_q = custom_attr_filter(name, val, match) custom_filter_q = custom_attr_filter(name, val, match)
if custom_filter_q: if custom_filter_q:
filters.append(custom_filter_q) filters.append(custom_filter_q)
continue continue
if match == 'ip_in': if match == 'ip_in':
q = cls.get_ip_in_q(name, val) q = cls.get_ip_in_q(name, val)
@ -464,11 +468,15 @@ class JSONManyToManyDescriptor:
rule_value = rule.get('value', '') rule_value = rule.get('value', '')
rule_match = rule.get('match', 'exact') rule_match = rule.get('match', 'exact')
if custom_attr_filter: custom_filter_q = None
q = custom_attr_filter(rule['name'], rule_value, rule_match) spec_attr_filter = getattr(to_model, "get_filter_{}_attr_q".format(rule['name']), None)
if q: if spec_attr_filter:
custom_q &= q custom_filter_q = spec_attr_filter(rule_value, rule_match)
continue elif custom_attr_filter:
custom_filter_q = custom_attr_filter(rule['name'], rule_value, rule_match)
if custom_filter_q:
custom_q &= custom_filter_q
continue
if rule_match == 'in': if rule_match == 'in':
res &= value in rule_value or '*' in rule_value res &= value in rule_value or '*' in rule_value
@ -517,7 +525,6 @@ class JSONManyToManyDescriptor:
res &= rule_value.issubset(value) res &= rule_value.issubset(value)
else: else:
res &= bool(value & rule_value) res &= bool(value & rule_value)
else: else:
logging.error("unknown match: {}".format(rule['match'])) logging.error("unknown match: {}".format(rule['match']))
res &= False res &= False

View File

@ -3,6 +3,7 @@
import asyncio import asyncio
import functools import functools
import inspect import inspect
import os
import threading import threading
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -101,7 +102,11 @@ def run_debouncer_func(cache_key, org, ttl, func, *args, **kwargs):
first_run_time = current first_run_time = current
if current - first_run_time > ttl: if current - first_run_time > ttl:
_loop_debouncer_func_args_cache.pop(cache_key, None)
_loop_debouncer_func_task_time_cache.pop(cache_key, None)
executor.submit(run_func_partial, *args, **kwargs) executor.submit(run_func_partial, *args, **kwargs)
logger.debug('pid {} executor submit run {}'.format(
os.getpid(), func.__name__, ))
return return
loop = _loop_thread.get_loop() loop = _loop_thread.get_loop()
@ -133,13 +138,26 @@ class Debouncer(object):
return await self.loop.run_in_executor(self.executor, func) return await self.loop.run_in_executor(self.executor, func)
ignore_err_exceptions = (
"(3101, 'Plugin instructed the server to rollback the current transaction.')",
)
def _run_func_with_org(key, org, func, *args, **kwargs): def _run_func_with_org(key, org, func, *args, **kwargs):
from orgs.utils import set_current_org from orgs.utils import set_current_org
try: try:
set_current_org(org) with transaction.atomic():
func(*args, **kwargs) set_current_org(org)
func(*args, **kwargs)
except Exception as e: except Exception as e:
logger.error('delay run error: %s' % e) msg = str(e)
log_func = logger.error
if msg in ignore_err_exceptions:
log_func = logger.info
pid = os.getpid()
thread_name = threading.current_thread()
log_func('pid {} thread {} delay run {} error: {}'.format(
pid, thread_name, func.__name__, msg))
_loop_debouncer_func_task_cache.pop(key, None) _loop_debouncer_func_task_cache.pop(key, None)
_loop_debouncer_func_args_cache.pop(key, None) _loop_debouncer_func_args_cache.pop(key, None)
_loop_debouncer_func_task_time_cache.pop(key, None) _loop_debouncer_func_task_time_cache.pop(key, None)
@ -181,6 +199,32 @@ def merge_delay_run(ttl=5, key=None):
:return: :return:
""" """
def delay(func, *args, **kwargs):
from orgs.utils import get_current_org
suffix_key_func = key if key else default_suffix_key
org = get_current_org()
func_name = f'{func.__module__}_{func.__name__}'
key_suffix = suffix_key_func(*args, **kwargs)
cache_key = f'MERGE_DELAY_RUN_{func_name}_{key_suffix}'
cache_kwargs = _loop_debouncer_func_args_cache.get(cache_key, {})
for k, v in kwargs.items():
if not isinstance(v, (tuple, list, set)):
raise ValueError('func kwargs value must be list or tuple: %s %s' % (func.__name__, v))
v = set(v)
if k not in cache_kwargs:
cache_kwargs[k] = v
else:
cache_kwargs[k] = cache_kwargs[k].union(v)
_loop_debouncer_func_args_cache[cache_key] = cache_kwargs
run_debouncer_func(cache_key, org, ttl, func, *args, **cache_kwargs)
def apply(func, sync=False, *args, **kwargs):
if sync:
return func(*args, **kwargs)
else:
return delay(func, *args, **kwargs)
def inner(func): def inner(func):
sigs = inspect.signature(func) sigs = inspect.signature(func)
if len(sigs.parameters) != 1: if len(sigs.parameters) != 1:
@ -188,27 +232,12 @@ def merge_delay_run(ttl=5, key=None):
param = list(sigs.parameters.values())[0] param = list(sigs.parameters.values())[0]
if not isinstance(param.default, tuple): if not isinstance(param.default, tuple):
raise ValueError('func default must be tuple: %s' % param.default) raise ValueError('func default must be tuple: %s' % param.default)
suffix_key_func = key if key else default_suffix_key func.delay = functools.partial(delay, func)
func.apply = functools.partial(apply, func)
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
from orgs.utils import get_current_org return func(*args, **kwargs)
org = get_current_org()
func_name = f'{func.__module__}_{func.__name__}'
key_suffix = suffix_key_func(*args, **kwargs)
cache_key = f'MERGE_DELAY_RUN_{func_name}_{key_suffix}'
cache_kwargs = _loop_debouncer_func_args_cache.get(cache_key, {})
for k, v in kwargs.items():
if not isinstance(v, (tuple, list, set)):
raise ValueError('func kwargs value must be list or tuple: %s %s' % (func.__name__, v))
v = set(v)
if k not in cache_kwargs:
cache_kwargs[k] = v
else:
cache_kwargs[k] = cache_kwargs[k].union(v)
_loop_debouncer_func_args_cache[cache_key] = cache_kwargs
run_debouncer_func(cache_key, org, ttl, func, *args, **cache_kwargs)
return wrapper return wrapper

View File

@ -6,7 +6,7 @@ import logging
from django.core.cache import cache from django.core.cache import cache
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db.models import Q, Count from django.db.models import Q
from django_filters import rest_framework as drf_filters from django_filters import rest_framework as drf_filters
from rest_framework import filters from rest_framework import filters
from rest_framework.compat import coreapi, coreschema from rest_framework.compat import coreapi, coreschema
@ -180,36 +180,30 @@ class LabelFilterBackend(filters.BaseFilterBackend):
] ]
@staticmethod @staticmethod
def filter_resources(resources, labels_id): def parse_label_ids(labels_id):
from labels.models import Label
label_ids = [i.strip() for i in labels_id.split(',')] label_ids = [i.strip() for i in labels_id.split(',')]
cleaned = []
args = [] args = []
for label_id in label_ids: for label_id in label_ids:
kwargs = {} kwargs = {}
if ':' in label_id: if ':' in label_id:
k, v = label_id.split(':', 1) k, v = label_id.split(':', 1)
kwargs['label__name'] = k.strip() kwargs['name'] = k.strip()
if v != '*': if v != '*':
kwargs['label__value'] = v.strip() kwargs['value'] = v.strip()
args.append(kwargs)
else: else:
kwargs['label_id'] = label_id cleaned.append(label_id)
args.append(kwargs)
if len(args) == 1: if len(args) != 0:
resources = resources.filter(**args[0]) q = Q()
return resources for kwarg in args:
q |= Q(**kwarg)
q = Q() ids = Label.objects.filter(q).values_list('id', flat=True)
for kwarg in args: cleaned.extend(list(ids))
q |= Q(**kwarg) return cleaned
resources = resources.filter(q) \
.values('res_id') \
.order_by('res_id') \
.annotate(count=Count('res_id')) \
.values('res_id', 'count') \
.filter(count=len(args))
return resources
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
labels_id = request.query_params.get('labels') labels_id = request.query_params.get('labels')
@ -223,14 +217,15 @@ class LabelFilterBackend(filters.BaseFilterBackend):
return queryset return queryset
model = queryset.model.label_model() model = queryset.model.label_model()
labeled_resource_cls = model._labels.field.related_model labeled_resource_cls = model.labels.field.related_model
app_label = model._meta.app_label app_label = model._meta.app_label
model_name = model._meta.model_name model_name = model._meta.model_name
resources = labeled_resource_cls.objects.filter( resources = labeled_resource_cls.objects.filter(
res_type__app_label=app_label, res_type__model=model_name, res_type__app_label=app_label, res_type__model=model_name,
) )
resources = self.filter_resources(resources, labels_id) label_ids = self.parse_label_ids(labels_id)
resources = model.filter_resources_by_labels(resources, label_ids)
res_ids = resources.values_list('res_id', flat=True) res_ids = resources.values_list('res_id', flat=True)
queryset = queryset.filter(id__in=set(res_ids)) queryset = queryset.filter(id__in=set(res_ids))
return queryset return queryset

View File

@ -14,6 +14,7 @@ class CeleryBaseService(BaseService):
print('\n- Start Celery as Distributed Task Queue: {}'.format(self.queue.capitalize())) print('\n- Start Celery as Distributed Task Queue: {}'.format(self.queue.capitalize()))
ansible_config_path = os.path.join(settings.APPS_DIR, 'ops', 'ansible', 'ansible.cfg') ansible_config_path = os.path.join(settings.APPS_DIR, 'ops', 'ansible', 'ansible.cfg')
ansible_modules_path = os.path.join(settings.APPS_DIR, 'ops', 'ansible', 'modules') ansible_modules_path = os.path.join(settings.APPS_DIR, 'ops', 'ansible', 'modules')
os.environ.setdefault('LC_ALL', 'C.UTF-8')
os.environ.setdefault('PYTHONOPTIMIZE', '1') os.environ.setdefault('PYTHONOPTIMIZE', '1')
os.environ.setdefault('ANSIBLE_FORCE_COLOR', 'True') os.environ.setdefault('ANSIBLE_FORCE_COLOR', 'True')
os.environ.setdefault('ANSIBLE_CONFIG', ansible_config_path) os.environ.setdefault('ANSIBLE_CONFIG', ansible_config_path)

View File

@ -28,9 +28,10 @@ class ErrorCode:
class URL: class URL:
QR_CONNECT = 'https://oapi.dingtalk.com/connect/qrconnect' QR_CONNECT = 'https://login.dingtalk.com/oauth2/auth'
OAUTH_CONNECT = 'https://oapi.dingtalk.com/connect/oauth2/sns_authorize' OAUTH_CONNECT = 'https://oapi.dingtalk.com/connect/oauth2/sns_authorize'
GET_USER_INFO_BY_CODE = 'https://oapi.dingtalk.com/sns/getuserinfo_bycode' GET_USER_ACCESSTOKEN = 'https://api.dingtalk.com/v1.0/oauth2/userAccessToken'
GET_USER_INFO = 'https://api.dingtalk.com/v1.0/contact/users/me'
GET_TOKEN = 'https://oapi.dingtalk.com/gettoken' GET_TOKEN = 'https://oapi.dingtalk.com/gettoken'
SEND_MESSAGE_BY_TEMPLATE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/sendbytemplate' SEND_MESSAGE_BY_TEMPLATE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/sendbytemplate'
SEND_MESSAGE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/asyncsend_v2' SEND_MESSAGE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/asyncsend_v2'
@ -72,8 +73,9 @@ class DingTalkRequests(BaseRequest):
def get(self, url, params=None, def get(self, url, params=None,
with_token=False, with_sign=False, with_token=False, with_sign=False,
check_errcode_is_0=True, check_errcode_is_0=True,
**kwargs): **kwargs) -> dict:
pass pass
get = as_request(get) get = as_request(get)
def post(self, url, json=None, params=None, def post(self, url, json=None, params=None,
@ -81,6 +83,7 @@ class DingTalkRequests(BaseRequest):
check_errcode_is_0=True, check_errcode_is_0=True,
**kwargs) -> dict: **kwargs) -> dict:
pass pass
post = as_request(post) post = as_request(post)
def _add_sign(self, kwargs: dict): def _add_sign(self, kwargs: dict):
@ -123,17 +126,22 @@ class DingTalk:
) )
def get_userinfo_bycode(self, code): def get_userinfo_bycode(self, code):
# https://developers.dingtalk.com/document/app/obtain-the-user-information-based-on-the-sns-temporary-authorization?spm=ding_open_doc.document.0.0.3a256573y8Y7yg#topic-1995619
body = { body = {
"tmp_auth_code": code 'clientId': self._appid,
'clientSecret': self._appsecret,
'code': code,
'grantType': 'authorization_code'
} }
data = self._request.post(URL.GET_USER_ACCESSTOKEN, json=body, check_errcode_is_0=False)
token = data['accessToken']
data = self._request.post(URL.GET_USER_INFO_BY_CODE, json=body, with_sign=True) user = self._request.get(URL.GET_USER_INFO,
return data['user_info'] headers={'x-acs-dingtalk-access-token': token}, check_errcode_is_0=False)
return user
def get_user_id_by_code(self, code): def get_user_id_by_code(self, code):
user_info = self.get_userinfo_bycode(code) user_info = self.get_userinfo_bycode(code)
unionid = user_info['unionid'] unionid = user_info['unionId']
userid = self.get_userid_by_unionid(unionid) userid = self.get_userid_by_unionid(unionid)
return userid, None return userid, None

View File

@ -394,20 +394,20 @@ class CommonBulkModelSerializer(CommonBulkSerializerMixin, serializers.ModelSeri
class ResourceLabelsMixin(serializers.Serializer): class ResourceLabelsMixin(serializers.Serializer):
labels = LabelRelatedField(many=True, label=_('Labels'), required=False, allow_null=True) labels = LabelRelatedField(many=True, label=_('Labels'), required=False, allow_null=True, source='res_labels')
def update(self, instance, validated_data): def update(self, instance, validated_data):
labels = validated_data.pop('labels', None) labels = validated_data.pop('res_labels', None)
res = super().update(instance, validated_data) res = super().update(instance, validated_data)
if labels is not None: if labels is not None:
instance.labels.set(labels, bulk=False) instance.res_labels.set(labels, bulk=False)
return res return res
def create(self, validated_data): def create(self, validated_data):
labels = validated_data.pop('labels', None) labels = validated_data.pop('res_labels', None)
instance = super().create(validated_data) instance = super().create(validated_data)
if labels is not None: if labels is not None:
instance.labels.set(labels, bulk=False) instance.res_labels.set(labels, bulk=False)
return instance return instance
@classmethod @classmethod

View File

View File

@ -0,0 +1,56 @@
import re
from django.contrib.sessions.backends.cache import (
SessionStore as DjangoSessionStore
)
from django.core.cache import cache
from jumpserver.utils import get_current_request
class SessionStore(DjangoSessionStore):
ignore_urls = [
r'^/api/v1/users/profile/'
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ignore_pattern = re.compile('|'.join(self.ignore_urls))
def save(self, *args, **kwargs):
request = get_current_request()
if request is None or not self.ignore_pattern.match(request.path):
super().save(*args, **kwargs)
class RedisUserSessionManager:
JMS_SESSION_KEY = 'jms_session_key'
def __init__(self):
self.client = cache.client.get_client()
def add_or_increment(self, session_key):
self.client.hincrby(self.JMS_SESSION_KEY, session_key, 1)
def decrement_or_remove(self, session_key):
new_count = self.client.hincrby(self.JMS_SESSION_KEY, session_key, -1)
if new_count <= 0:
self.client.hdel(self.JMS_SESSION_KEY, session_key)
def check_active(self, session_key):
count = self.client.hget(self.JMS_SESSION_KEY, session_key)
count = 0 if count is None else int(count.decode('utf-8'))
return count > 0
def get_active_keys(self):
session_keys = []
for k, v in self.client.hgetall(self.JMS_SESSION_KEY).items():
count = int(v.decode('utf-8'))
if count <= 0:
continue
key = k.decode('utf-8')
session_keys.append(key)
return session_keys
user_session_manager = RedisUserSessionManager()

View File

@ -69,7 +69,7 @@ def digest_sql_query():
for query in queries: for query in queries:
sql = query['sql'] sql = query['sql']
print(" # {}: {}".format(query['time'], sql[:1000])) print(" # {}: {}".format(query['time'], sql[:1000]))
if len(queries) < 3: if len(queries) < 3:
continue continue
print("- Table: {}".format(table_name)) print("- Table: {}".format(table_name))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -282,6 +282,7 @@ class Config(dict):
'AUTH_LDAP_SYNC_INTERVAL': None, 'AUTH_LDAP_SYNC_INTERVAL': None,
'AUTH_LDAP_SYNC_CRONTAB': None, 'AUTH_LDAP_SYNC_CRONTAB': None,
'AUTH_LDAP_SYNC_ORG_IDS': ['00000000-0000-0000-0000-000000000002'], 'AUTH_LDAP_SYNC_ORG_IDS': ['00000000-0000-0000-0000-000000000002'],
'AUTH_LDAP_SYNC_RECEIVERS': [],
'AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS': False, 'AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS': False,
'AUTH_LDAP_OPTIONS_OPT_REFERRALS': -1, 'AUTH_LDAP_OPTIONS_OPT_REFERRALS': -1,
@ -546,7 +547,6 @@ class Config(dict):
'REFERER_CHECK_ENABLED': False, 'REFERER_CHECK_ENABLED': False,
'SESSION_ENGINE': 'cache', 'SESSION_ENGINE': 'cache',
'SESSION_SAVE_EVERY_REQUEST': True, 'SESSION_SAVE_EVERY_REQUEST': True,
'SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE': False,
'SERVER_REPLAY_STORAGE': {}, 'SERVER_REPLAY_STORAGE': {},
'SECURITY_DATA_CRYPTO_ALGO': None, 'SECURITY_DATA_CRYPTO_ALGO': None,
'GMSSL_ENABLED': False, 'GMSSL_ENABLED': False,
@ -605,7 +605,9 @@ class Config(dict):
'GPT_MODEL': 'gpt-3.5-turbo', 'GPT_MODEL': 'gpt-3.5-turbo',
'VIRTUAL_APP_ENABLED': False, 'VIRTUAL_APP_ENABLED': False,
'FILE_UPLOAD_SIZE_LIMIT_MB': 200 'FILE_UPLOAD_SIZE_LIMIT_MB': 200,
'TICKET_APPLY_ASSET_SCOPE': 'all'
} }
old_config_map = { old_config_map = {

View File

@ -66,11 +66,6 @@ class RequestMiddleware:
def __call__(self, request): def __call__(self, request):
set_current_request(request) set_current_request(request)
response = self.get_response(request) response = self.get_response(request)
is_request_api = request.path.startswith('/api')
if not settings.SESSION_EXPIRE_AT_BROWSER_CLOSE and \
not is_request_api:
age = request.session.get_expiry_age()
request.session.set_expiry(age)
return response return response

View File

@ -3,6 +3,7 @@
path_perms_map = { path_perms_map = {
'xpack': '*', 'xpack': '*',
'settings': '*', 'settings': '*',
'img': '*',
'replay': 'default', 'replay': 'default',
'applets': 'terminal.view_applet', 'applets': 'terminal.view_applet',
'virtual_apps': 'terminal.view_virtualapp', 'virtual_apps': 'terminal.view_virtualapp',

View File

@ -0,0 +1,14 @@
from private_storage.servers import NginxXAccelRedirectServer, DjangoServer
class StaticFileServer(object):
@staticmethod
def serve(private_file):
full_path = private_file.full_path
# todo: gzip 文件录像 nginx 处理后,浏览器无法正常解析内容
# 造成在线播放失败,暂时仅使用 nginx 处理 mp4 录像文件
if full_path.endswith('.mp4'):
return NginxXAccelRedirectServer.serve(private_file)
else:
return DjangoServer.serve(private_file)

View File

@ -50,6 +50,7 @@ AUTH_LDAP_SYNC_IS_PERIODIC = CONFIG.AUTH_LDAP_SYNC_IS_PERIODIC
AUTH_LDAP_SYNC_INTERVAL = CONFIG.AUTH_LDAP_SYNC_INTERVAL AUTH_LDAP_SYNC_INTERVAL = CONFIG.AUTH_LDAP_SYNC_INTERVAL
AUTH_LDAP_SYNC_CRONTAB = CONFIG.AUTH_LDAP_SYNC_CRONTAB AUTH_LDAP_SYNC_CRONTAB = CONFIG.AUTH_LDAP_SYNC_CRONTAB
AUTH_LDAP_SYNC_ORG_IDS = CONFIG.AUTH_LDAP_SYNC_ORG_IDS AUTH_LDAP_SYNC_ORG_IDS = CONFIG.AUTH_LDAP_SYNC_ORG_IDS
AUTH_LDAP_SYNC_RECEIVERS = CONFIG.AUTH_LDAP_SYNC_RECEIVERS
AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS = CONFIG.AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS = CONFIG.AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS
# ============================================================================== # ==============================================================================

View File

@ -234,11 +234,9 @@ CSRF_COOKIE_NAME = '{}csrftoken'.format(SESSION_COOKIE_NAME_PREFIX)
SESSION_COOKIE_NAME = '{}sessionid'.format(SESSION_COOKIE_NAME_PREFIX) SESSION_COOKIE_NAME = '{}sessionid'.format(SESSION_COOKIE_NAME_PREFIX)
SESSION_COOKIE_AGE = CONFIG.SESSION_COOKIE_AGE SESSION_COOKIE_AGE = CONFIG.SESSION_COOKIE_AGE
SESSION_EXPIRE_AT_BROWSER_CLOSE = True
# 自定义的配置SESSION_EXPIRE_AT_BROWSER_CLOSE 始终为 True, 下面这个来控制是否强制关闭后过期 cookie
SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE = CONFIG.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE
SESSION_SAVE_EVERY_REQUEST = CONFIG.SESSION_SAVE_EVERY_REQUEST SESSION_SAVE_EVERY_REQUEST = CONFIG.SESSION_SAVE_EVERY_REQUEST
SESSION_ENGINE = "django.contrib.sessions.backends.{}".format(CONFIG.SESSION_ENGINE) SESSION_EXPIRE_AT_BROWSER_CLOSE = CONFIG.SESSION_EXPIRE_AT_BROWSER_CLOSE
SESSION_ENGINE = "common.sessions.{}".format(CONFIG.SESSION_ENGINE)
MESSAGE_STORAGE = 'django.contrib.messages.storage.cookie.CookieStorage' MESSAGE_STORAGE = 'django.contrib.messages.storage.cookie.CookieStorage'
# Database # Database
@ -319,9 +317,7 @@ MEDIA_ROOT = os.path.join(PROJECT_DIR, 'data', 'media').replace('\\', '/') + '/'
PRIVATE_STORAGE_ROOT = MEDIA_ROOT PRIVATE_STORAGE_ROOT = MEDIA_ROOT
PRIVATE_STORAGE_AUTH_FUNCTION = 'jumpserver.rewriting.storage.permissions.allow_access' PRIVATE_STORAGE_AUTH_FUNCTION = 'jumpserver.rewriting.storage.permissions.allow_access'
PRIVATE_STORAGE_INTERNAL_URL = '/private-media/' PRIVATE_STORAGE_INTERNAL_URL = '/private-media/'
PRIVATE_STORAGE_SERVER = 'nginx' PRIVATE_STORAGE_SERVER = 'jumpserver.rewriting.storage.servers.StaticFileServer'
if DEBUG_DEV:
PRIVATE_STORAGE_SERVER = 'django'
# Use django-bootstrap-form to format template, input max width arg # Use django-bootstrap-form to format template, input max width arg
# BOOTSTRAP_COLUMN_COUNT = 11 # BOOTSTRAP_COLUMN_COUNT = 11

View File

@ -214,6 +214,9 @@ PERM_TREE_REGEN_INTERVAL = CONFIG.PERM_TREE_REGEN_INTERVAL
MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS
LIMIT_SUPER_PRIV = CONFIG.LIMIT_SUPER_PRIV LIMIT_SUPER_PRIV = CONFIG.LIMIT_SUPER_PRIV
# Asset account may be too many
ASSET_SIZE = 'small'
# Chat AI # Chat AI
CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED
GPT_API_KEY = CONFIG.GPT_API_KEY GPT_API_KEY = CONFIG.GPT_API_KEY
@ -224,3 +227,5 @@ GPT_MODEL = CONFIG.GPT_MODEL
VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED
FILE_UPLOAD_SIZE_LIMIT_MB = CONFIG.FILE_UPLOAD_SIZE_LIMIT_MB FILE_UPLOAD_SIZE_LIMIT_MB = CONFIG.FILE_UPLOAD_SIZE_LIMIT_MB
TICKET_APPLY_ASSET_SCOPE = CONFIG.TICKET_APPLY_ASSET_SCOPE

View File

@ -1,14 +1,15 @@
from django.contrib.contenttypes.fields import GenericRelation from django.contrib.contenttypes.fields import GenericRelation
from django.db import models from django.db import models
from django.db.models import OneToOneField from django.db.models import OneToOneField, Count
from common.utils import lazyproperty
from .models import LabeledResource from .models import LabeledResource
__all__ = ['LabeledMixin'] __all__ = ['LabeledMixin']
class LabeledMixin(models.Model): class LabeledMixin(models.Model):
_labels = GenericRelation(LabeledResource, object_id_field='res_id', content_type_field='res_type') labels = GenericRelation(LabeledResource, object_id_field='res_id', content_type_field='res_type')
class Meta: class Meta:
abstract = True abstract = True
@ -21,7 +22,7 @@ class LabeledMixin(models.Model):
model = pk_field.related_model model = pk_field.related_model
return model return model
@property @lazyproperty
def real(self): def real(self):
pk_field = self._meta.pk pk_field = self._meta.pk
if isinstance(pk_field, OneToOneField): if isinstance(pk_field, OneToOneField):
@ -29,9 +30,43 @@ class LabeledMixin(models.Model):
return self return self
@property @property
def labels(self): def res_labels(self):
return self.real._labels return self.real.labels
@labels.setter @res_labels.setter
def labels(self, value): def res_labels(self, value):
self.real._labels.set(value, bulk=False) self.real.labels.set(value, bulk=False)
@classmethod
def filter_resources_by_labels(cls, resources, label_ids):
return cls._get_filter_res_by_labels_m2m_all(resources, label_ids)
@classmethod
def _get_filter_res_by_labels_m2m_in(cls, resources, label_ids):
return resources.filter(label_id__in=label_ids)
@classmethod
def _get_filter_res_by_labels_m2m_all(cls, resources, label_ids):
if len(label_ids) == 1:
return cls._get_filter_res_by_labels_m2m_in(resources, label_ids)
resources = resources.filter(label_id__in=label_ids) \
.values('res_id') \
.order_by('res_id') \
.annotate(count=Count('res_id', distinct=True)) \
.values('res_id', 'count') \
.filter(count=len(label_ids))
return resources
@classmethod
def get_labels_filter_attr_q(cls, value, match):
resources = LabeledResource.objects.all()
if not value:
return None
if match != 'm2m_all':
resources = cls._get_filter_res_by_labels_m2m_in(resources, value)
else:
resources = cls._get_filter_res_by_labels_m2m_all(resources, value)
res_ids = set(resources.values_list('res_id', flat=True))
return models.Q(id__in=res_ids)

View File

@ -34,7 +34,7 @@ class LabelSerializer(BulkOrgResourceModelSerializer):
@classmethod @classmethod
def setup_eager_loading(cls, queryset): def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """ """ Perform necessary eager loading of data. """
queryset = queryset.annotate(res_count=Count('labeled_resources')) queryset = queryset.annotate(res_count=Count('labeled_resources', distinct=True))
return queryset return queryset

View File

@ -1,28 +1,32 @@
import json import json
import time
from threading import Thread
from channels.generic.websocket import JsonWebsocketConsumer from channels.generic.websocket import JsonWebsocketConsumer
from django.core.cache import cache from django.conf import settings
from common.db.utils import safe_db_connection from common.db.utils import safe_db_connection
from common.sessions.cache import user_session_manager
from common.utils import get_logger from common.utils import get_logger
from .signal_handlers import new_site_msg_chan from .signal_handlers import new_site_msg_chan
from .site_msg import SiteMessageUtil from .site_msg import SiteMessageUtil
logger = get_logger(__name__) logger = get_logger(__name__)
WS_SESSION_KEY = 'ws_session_key'
class SiteMsgWebsocket(JsonWebsocketConsumer): class SiteMsgWebsocket(JsonWebsocketConsumer):
sub = None sub = None
refresh_every_seconds = 10 refresh_every_seconds = 10
@property
def session(self):
return self.scope['session']
def connect(self): def connect(self):
user = self.scope["user"] user = self.scope["user"]
if user.is_authenticated: if user.is_authenticated:
self.accept() self.accept()
session = self.scope['session'] user_session_manager.add_or_increment(self.session.session_key)
redis_client = cache.client.get_client()
redis_client.sadd(WS_SESSION_KEY, session.session_key)
self.sub = self.watch_recv_new_site_msg() self.sub = self.watch_recv_new_site_msg()
else: else:
self.close() self.close()
@ -66,6 +70,32 @@ class SiteMsgWebsocket(JsonWebsocketConsumer):
if not self.sub: if not self.sub:
return return
self.sub.unsubscribe() self.sub.unsubscribe()
session = self.scope['session']
redis_client = cache.client.get_client() user_session_manager.decrement_or_remove(self.session.session_key)
redis_client.srem(WS_SESSION_KEY, session.session_key) if self.should_delete_session():
thread = Thread(target=self.delay_delete_session)
thread.start()
def should_delete_session(self):
return (self.session.modified or settings.SESSION_SAVE_EVERY_REQUEST) and \
not self.session.is_empty() and \
self.session.get_expire_at_browser_close() and \
not user_session_manager.check_active(self.session.session_key)
def delay_delete_session(self):
timeout = 3
check_interval = 0.5
start_time = time.time()
while time.time() - start_time < timeout:
time.sleep(check_interval)
if user_session_manager.check_active(self.session.session_key):
return
self.delete_session()
def delete_session(self):
try:
self.session.delete()
except Exception as e:
logger.info(f'delete session error: {e}')

View File

@ -4,6 +4,21 @@ import time
import paramiko import paramiko
from sshtunnel import SSHTunnelForwarder from sshtunnel import SSHTunnelForwarder
from packaging import version
if version.parse(paramiko.__version__) > version.parse("2.8.1"):
_preferred_pubkeys = (
"ssh-ed25519",
"ecdsa-sha2-nistp256",
"ecdsa-sha2-nistp384",
"ecdsa-sha2-nistp521",
"ssh-rsa",
"rsa-sha2-256",
"rsa-sha2-512",
"ssh-dss",
)
paramiko.transport.Transport._preferred_pubkeys = _preferred_pubkeys
def common_argument_spec(): def common_argument_spec():
options = dict( options = dict(

View File

@ -2,6 +2,7 @@
# #
import os import os
import re import re
from collections import defaultdict
from celery.result import AsyncResult from celery.result import AsyncResult
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
@ -166,16 +167,58 @@ class CeleryTaskViewSet(
i.next_exec_time = now + next_run_at i.next_exec_time = now + next_run_at
return queryset return queryset
def generate_summary_state(self, execution_qs):
model = self.get_queryset().model
executions = execution_qs.order_by('-date_published').values('name', 'state')
summary_state_dict = defaultdict(
lambda: {
'states': [], 'state': 'green',
'summary': {'total': 0, 'success': 0}
}
)
for execution in executions:
name = execution['name']
state = execution['state']
summary = summary_state_dict[name]['summary']
summary['total'] += 1
summary['success'] += 1 if state == 'SUCCESS' else 0
states = summary_state_dict[name].get('states')
if states is not None and len(states) >= 5:
color = model.compute_state_color(states)
summary_state_dict[name]['state'] = color
summary_state_dict[name].pop('states', None)
elif isinstance(states, list):
states.append(state)
return summary_state_dict
def loading_summary_state(self, queryset):
if isinstance(queryset, list):
names = [i.name for i in queryset]
execution_qs = CeleryTaskExecution.objects.filter(name__in=names)
else:
execution_qs = CeleryTaskExecution.objects.all()
summary_state_dict = self.generate_summary_state(execution_qs)
for i in queryset:
i.summary = summary_state_dict.get(i.name, {}).get('summary', {})
i.state = summary_state_dict.get(i.name, {}).get('state', 'green')
return queryset
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset()) queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset) page = self.paginate_queryset(queryset)
if page is not None: if page is not None:
page = self.generate_execute_time(page) page = self.generate_execute_time(page)
page = self.loading_summary_state(page)
serializer = self.get_serializer(page, many=True) serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data) return self.get_paginated_response(serializer.data)
queryset = self.generate_execute_time(queryset) queryset = self.generate_execute_time(queryset)
queryset = self.loading_summary_state(queryset)
serializer = self.get_serializer(queryset, many=True) serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data) return Response(serializer.data)

View File

@ -246,6 +246,6 @@ class UsernameHintsAPI(APIView):
.filter(username__icontains=query) \ .filter(username__icontains=query) \
.filter(asset__in=assets) \ .filter(asset__in=assets) \
.values('username') \ .values('username') \
.annotate(total=Count('username')) \ .annotate(total=Count('username', distinct=True)) \
.order_by('total', '-username')[:10] .order_by('total', '-username')[:10]
return Response(data=top_accounts) return Response(data=top_accounts)

View File

@ -15,6 +15,9 @@ class CeleryTask(models.Model):
name = models.CharField(max_length=1024, verbose_name=_('Name')) name = models.CharField(max_length=1024, verbose_name=_('Name'))
date_last_publish = models.DateTimeField(null=True, verbose_name=_("Date last publish")) date_last_publish = models.DateTimeField(null=True, verbose_name=_("Date last publish"))
__summary = None
__state = None
@property @property
def meta(self): def meta(self):
task = app.tasks.get(self.name, None) task = app.tasks.get(self.name, None)
@ -25,25 +28,43 @@ class CeleryTask(models.Model):
@property @property
def summary(self): def summary(self):
if self.__summary is not None:
return self.__summary
executions = CeleryTaskExecution.objects.filter(name=self.name) executions = CeleryTaskExecution.objects.filter(name=self.name)
total = executions.count() total = executions.count()
success = executions.filter(state='SUCCESS').count() success = executions.filter(state='SUCCESS').count()
return {'total': total, 'success': success} return {'total': total, 'success': success}
@summary.setter
def summary(self, value):
self.__summary = value
@staticmethod
def compute_state_color(states: list, default_count=5):
color = 'green'
states = states[:default_count]
if not states:
return color
if states[0] == 'FAILURE':
color = 'red'
elif 'FAILURE' in states:
color = 'yellow'
return color
@property @property
def state(self): def state(self):
last_five_executions = CeleryTaskExecution.objects \ if self.__state is not None:
.filter(name=self.name) \ return self.__state
.order_by('-date_published')[:5] last_five_executions = CeleryTaskExecution.objects.filter(
name=self.name
).order_by('-date_published').values('state')[:5]
states = [i['state'] for i in last_five_executions]
color = self.compute_state_color(states)
return color
if len(last_five_executions) > 0: @state.setter
if last_five_executions[0].state == 'FAILURE': def state(self, value):
return "red" self.__state = value
for execution in last_five_executions:
if execution.state == 'FAILURE':
return "yellow"
return "green"
class Meta: class Meta:
verbose_name = _("Celery Task") verbose_name = _("Celery Task")

View File

@ -67,6 +67,7 @@ class JMSPermedInventory(JMSInventory):
'postgresql': ['postgresql'], 'postgresql': ['postgresql'],
'sqlserver': ['sqlserver'], 'sqlserver': ['sqlserver'],
'ssh': ['shell', 'python', 'win_shell', 'raw'], 'ssh': ['shell', 'python', 'win_shell', 'raw'],
'winrm': ['win_shell', 'shell'],
} }
if self.module not in protocol_supported_modules_mapping.get(protocol.name, []): if self.module not in protocol_supported_modules_mapping.get(protocol.name, []):

View File

@ -87,7 +87,8 @@ class OrgResourceStatisticsRefreshUtil:
if not cache_field_name: if not cache_field_name:
return return
org = getattr(instance, 'org', None) org = getattr(instance, 'org', None)
cls.refresh_org_fields(((org, cache_field_name),)) cache_field_name = tuple(cache_field_name)
cls.refresh_org_fields.delay(org_fields=((org, cache_field_name),))
@receiver(post_save) @receiver(post_save)

View File

@ -1,5 +1,6 @@
import abc import abc
from django.conf import settings
from rest_framework.generics import ListAPIView, RetrieveAPIView from rest_framework.generics import ListAPIView, RetrieveAPIView
from assets.api.asset.asset import AssetFilterSet from assets.api.asset.asset import AssetFilterSet
@ -7,8 +8,7 @@ from assets.models import Asset, Node
from common.utils import get_logger, lazyproperty, is_uuid from common.utils import get_logger, lazyproperty, is_uuid
from orgs.utils import tmp_to_root_org from orgs.utils import tmp_to_root_org
from perms import serializers from perms import serializers
from perms.pagination import AllPermedAssetPagination from perms.pagination import NodePermedAssetPagination, AllPermedAssetPagination
from perms.pagination import NodePermedAssetPagination
from perms.utils import UserPermAssetUtil, PermAssetDetailUtil from perms.utils import UserPermAssetUtil, PermAssetDetailUtil
from .mixin import ( from .mixin import (
SelfOrPKUserMixin SelfOrPKUserMixin
@ -39,7 +39,7 @@ class UserPermedAssetRetrieveApi(SelfOrPKUserMixin, RetrieveAPIView):
class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView): class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView):
ordering = ('name',) ordering = []
search_fields = ('name', 'address', 'comment') search_fields = ('name', 'address', 'comment')
ordering_fields = ("name", "address") ordering_fields = ("name", "address")
filterset_class = AssetFilterSet filterset_class = AssetFilterSet
@ -48,6 +48,8 @@ class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView):
def get_queryset(self): def get_queryset(self):
if getattr(self, 'swagger_fake_view', False): if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none() return Asset.objects.none()
if settings.ASSET_SIZE == 'small':
self.ordering = ['name']
assets = self.get_assets() assets = self.get_assets()
assets = self.serializer_class.setup_eager_loading(assets) assets = self.serializer_class.setup_eager_loading(assets)
return assets return assets

View File

@ -14,6 +14,7 @@ from assets.api import SerializeToTreeNodeMixin
from assets.models import Asset from assets.models import Asset
from assets.utils import KubernetesTree from assets.utils import KubernetesTree
from authentication.models import ConnectionToken from authentication.models import ConnectionToken
from common.exceptions import JMSException
from common.utils import get_object_or_none, lazyproperty from common.utils import get_object_or_none, lazyproperty
from common.utils.common import timeit from common.utils.common import timeit
from perms.hands import Node from perms.hands import Node
@ -181,6 +182,8 @@ class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi(BaseUserNodeWithAssetAsT
return self.query_asset_util.get_all_assets() return self.query_asset_util.get_all_assets()
def _get_tree_nodes_async(self): def _get_tree_nodes_async(self):
if self.request.query_params.get('lv') == '0':
return [], []
if not self.tp or not all(self.tp): if not self.tp or not all(self.tp):
nodes = UserPermAssetUtil.get_type_nodes_tree_or_cached(self.user) nodes = UserPermAssetUtil.get_type_nodes_tree_or_cached(self.user)
return nodes, [] return nodes, []
@ -262,5 +265,8 @@ class UserGrantedK8sAsTreeApi(SelfOrPKUserMixin, ListAPIView):
if not any([namespace, pod]) and not key: if not any([namespace, pod]) and not key:
asset_node = k8s_tree_instance.as_asset_tree_node() asset_node = k8s_tree_instance.as_asset_tree_node()
tree.append(asset_node) tree.append(asset_node)
tree.extend(k8s_tree_instance.async_tree_node(namespace, pod)) try:
return Response(data=tree) tree.extend(k8s_tree_instance.async_tree_node(namespace, pod))
return Response(data=tree)
except Exception as e:
raise JMSException(e)

View File

@ -130,7 +130,7 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel):
qs1_ids = User.objects.filter(id__in=user_ids).distinct().values_list('id', flat=True) qs1_ids = User.objects.filter(id__in=user_ids).distinct().values_list('id', flat=True)
qs2_ids = User.objects.filter(groups__id__in=group_ids).distinct().values_list('id', flat=True) qs2_ids = User.objects.filter(groups__id__in=group_ids).distinct().values_list('id', flat=True)
qs_ids = list(qs1_ids) + list(qs2_ids) qs_ids = list(qs1_ids) + list(qs2_ids)
qs = User.objects.filter(id__in=qs_ids) qs = User.objects.filter(id__in=qs_ids, is_service_account=False)
return qs return qs
def get_all_assets(self, flat=False): def get_all_assets(self, flat=False):

View File

@ -9,7 +9,7 @@ class PermedAssetsWillExpireUserMsg(UserMessage):
def __init__(self, user, assets, day_count=0): def __init__(self, user, assets, day_count=0):
super().__init__(user) super().__init__(user)
self.assets = assets self.assets = assets
self.day_count = _('today') if day_count == 0 else day_count + _('day') self.day_count = _('today') if day_count == 0 else str(day_count) + _('day')
def get_html_msg(self) -> dict: def get_html_msg(self) -> dict:
subject = _("You permed assets is about to expire") subject = _("You permed assets is about to expire")
@ -41,7 +41,7 @@ class AssetPermsWillExpireForOrgAdminMsg(UserMessage):
super().__init__(user) super().__init__(user)
self.perms = perms self.perms = perms
self.org = org self.org = org
self.day_count = _('today') if day_count == 0 else day_count + _('day') self.day_count = _('today') if day_count == 0 else str(day_count) + _('day')
def get_items_with_url(self): def get_items_with_url(self):
items_with_url = [] items_with_url = []

View File

@ -197,9 +197,9 @@ class AssetPermissionListSerializer(AssetPermissionSerializer):
"""Perform necessary eager loading of data.""" """Perform necessary eager loading of data."""
queryset = queryset \ queryset = queryset \
.prefetch_related('labels', 'labels__label') \ .prefetch_related('labels', 'labels__label') \
.annotate(users_amount=Count("users"), .annotate(users_amount=Count("users", distinct=True),
user_groups_amount=Count("user_groups"), user_groups_amount=Count("user_groups", distinct=True),
assets_amount=Count("assets"), assets_amount=Count("assets", distinct=True),
nodes_amount=Count("nodes"), nodes_amount=Count("nodes", distinct=True),
) )
return queryset return queryset

View File

@ -8,9 +8,9 @@ from rest_framework import serializers
from accounts.models import Account from accounts.models import Account
from assets.const import Category, AllTypes from assets.const import Category, AllTypes
from assets.models import Node, Asset, Platform from assets.models import Node, Asset, Platform
from assets.serializers.asset.common import AssetLabelSerializer, AssetProtocolsPermsSerializer from assets.serializers.asset.common import AssetProtocolsPermsSerializer
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from common.serializers import ResourceLabelsMixin from common.serializers import ResourceLabelsMixin
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from orgs.mixins.serializers import OrgResourceModelSerializerMixin from orgs.mixins.serializers import OrgResourceModelSerializerMixin
from perms.serializers.permission import ActionChoicesField from perms.serializers.permission import ActionChoicesField

View File

@ -13,7 +13,7 @@ class AssetPermissionUtil(object):
""" 资产授权相关的方法工具 """ """ 资产授权相关的方法工具 """
@timeit @timeit
def get_permissions_for_user(self, user, with_group=True, flat=False): def get_permissions_for_user(self, user, with_group=True, flat=False, with_expired=False):
""" 获取用户的授权规则 """ """ 获取用户的授权规则 """
perm_ids = set() perm_ids = set()
# user # user
@ -25,7 +25,7 @@ class AssetPermissionUtil(object):
groups = user.groups.all() groups = user.groups.all()
group_perm_ids = self.get_permissions_for_user_groups(groups, flat=True) group_perm_ids = self.get_permissions_for_user_groups(groups, flat=True)
perm_ids.update(group_perm_ids) perm_ids.update(group_perm_ids)
perms = self.get_permissions(ids=perm_ids) perms = self.get_permissions(ids=perm_ids, with_expired=with_expired)
if flat: if flat:
return perms.values_list('id', flat=True) return perms.values_list('id', flat=True)
return perms return perms
@ -102,6 +102,8 @@ class AssetPermissionUtil(object):
return model.objects.filter(id__in=ids) return model.objects.filter(id__in=ids)
@staticmethod @staticmethod
def get_permissions(ids): def get_permissions(ids, with_expired=False):
perms = AssetPermission.objects.filter(id__in=ids).valid().order_by('-date_expired') perms = AssetPermission.objects.filter(id__in=ids)
return perms if not with_expired:
perms = perms.valid()
return perms.order_by('-date_expired')

View File

@ -7,10 +7,10 @@ from django.db.models import Q
from rest_framework.utils.encoders import JSONEncoder from rest_framework.utils.encoders import JSONEncoder
from assets.const import AllTypes from assets.const import AllTypes
from assets.models import FavoriteAsset, Asset from assets.models import FavoriteAsset, Asset, Node
from common.utils.common import timeit, get_logger from common.utils.common import timeit, get_logger
from orgs.utils import current_org, tmp_to_root_org from orgs.utils import current_org, tmp_to_root_org
from perms.models import PermNode, UserAssetGrantedTreeNodeRelation from perms.models import PermNode, UserAssetGrantedTreeNodeRelation, AssetPermission
from .permission import AssetPermissionUtil from .permission import AssetPermissionUtil
__all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil'] __all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil']
@ -21,36 +21,37 @@ logger = get_logger(__name__)
class AssetPermissionPermAssetUtil: class AssetPermissionPermAssetUtil:
def __init__(self, perm_ids): def __init__(self, perm_ids):
self.perm_ids = perm_ids self.perm_ids = set(perm_ids)
def get_all_assets(self): def get_all_assets(self):
""" 获取所有授权的资产 """
node_assets = self.get_perm_nodes_assets() node_assets = self.get_perm_nodes_assets()
direct_assets = self.get_direct_assets() direct_assets = self.get_direct_assets()
# 比原来的查到所有 asset id 再搜索块很多,因为当资产量大的时候,搜索会很慢 # 比原来的查到所有 asset id 再搜索块很多,因为当资产量大的时候,搜索会很慢
return (node_assets | direct_assets).distinct() return (node_assets | direct_assets).order_by().distinct()
def get_perm_nodes(self):
""" 获取所有授权节点 """
nodes_ids = AssetPermission.objects \
.filter(id__in=self.perm_ids) \
.values_list('nodes', flat=True)
nodes_ids = set(nodes_ids)
nodes = Node.objects.filter(id__in=nodes_ids).only('id', 'key')
return nodes
@timeit @timeit
def get_perm_nodes_assets(self, flat=False): def get_perm_nodes_assets(self):
""" 获取所有授权节点下的资产 """ """ 获取所有授权节点下的资产 """
from assets.models import Node nodes = self.get_perm_nodes()
nodes = Node.objects \ assets = PermNode.get_nodes_all_assets(*nodes, distinct=False)
.prefetch_related('granted_by_permissions') \
.filter(granted_by_permissions__in=self.perm_ids) \
.only('id', 'key')
assets = PermNode.get_nodes_all_assets(*nodes)
if flat:
return set(assets.values_list('id', flat=True))
return assets return assets
@timeit @timeit
def get_direct_assets(self, flat=False): def get_direct_assets(self):
""" 获取直接授权的资产 """ """ 获取直接授权的资产 """
assets = Asset.objects.order_by() \ asset_ids = AssetPermission.assets.through.objects \
.filter(granted_by_permissions__id__in=self.perm_ids) \ .filter(assetpermission_id__in=self.perm_ids) \
.distinct() .values_list('asset_id', flat=True)
if flat: assets = Asset.objects.filter(id__in=asset_ids)
return set(assets.values_list('id', flat=True))
return assets return assets
@ -152,6 +153,7 @@ class UserPermAssetUtil(AssetPermissionPermAssetUtil):
assets = assets.filter(nodes__id=node.id).order_by().distinct() assets = assets.filter(nodes__id=node.id).order_by().distinct()
return assets return assets
@timeit
def _get_indirect_perm_node_all_assets(self, node): def _get_indirect_perm_node_all_assets(self, node):
""" 获取间接授权节点下的所有资产 """ 获取间接授权节点下的所有资产
此算法依据 `UserAssetGrantedTreeNodeRelation` 的数据查询 此算法依据 `UserAssetGrantedTreeNodeRelation` 的数据查询

View File

@ -72,7 +72,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
@timeit @timeit
def refresh_if_need(self, force=False): def refresh_if_need(self, force=False):
built_just_now = cache.get(self.cache_key_time) built_just_now = False if settings.ASSET_SIZE == 'small' else cache.get(self.cache_key_time)
if built_just_now: if built_just_now:
logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now)) logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now))
return return
@ -80,12 +80,18 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
if not to_refresh_orgs: if not to_refresh_orgs:
logger.info('Not have to refresh orgs') logger.info('Not have to refresh orgs')
return return
logger.info("Delay refresh user orgs: {} {}".format(self.user, [o.name for o in to_refresh_orgs])) logger.info("Delay refresh user orgs: {} {}".format(self.user, [o.name for o in to_refresh_orgs]))
refresh_user_orgs_perm_tree(user_orgs=((self.user, tuple(to_refresh_orgs)),)) sync = True if settings.ASSET_SIZE == 'small' else False
refresh_user_favorite_assets(users=(self.user,)) refresh_user_orgs_perm_tree.apply(sync=sync, user_orgs=((self.user, tuple(to_refresh_orgs)),))
refresh_user_favorite_assets.apply(sync=sync, users=(self.user,))
@timeit @timeit
def refresh_tree_manual(self): def refresh_tree_manual(self):
"""
用来手动 debug
:return:
"""
built_just_now = cache.get(self.cache_key_time) built_just_now = cache.get(self.cache_key_time)
if built_just_now: if built_just_now:
logger.info('Refresh just now, pass: {}'.format(built_just_now)) logger.info('Refresh just now, pass: {}'.format(built_just_now))
@ -105,8 +111,9 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
return return
self._clean_user_perm_tree_for_legacy_org() self._clean_user_perm_tree_for_legacy_org()
ttl = settings.PERM_TREE_REGEN_INTERVAL if settings.ASSET_SIZE != 'small':
cache.set(self.cache_key_time, int(time.time()), ttl) ttl = settings.PERM_TREE_REGEN_INTERVAL
cache.set(self.cache_key_time, int(time.time()), ttl)
lock = UserGrantedTreeRebuildLock(self.user.id) lock = UserGrantedTreeRebuildLock(self.user.id)
got = lock.acquire(blocking=False) got = lock.acquire(blocking=False)
@ -193,7 +200,13 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin):
cache_key = self.get_cache_key(uid) cache_key = self.get_cache_key(uid)
p.srem(cache_key, *org_ids) p.srem(cache_key, *org_ids)
p.execute() p.execute()
logger.info('Expire perm tree for users: [{}], orgs: [{}]'.format(user_ids, org_ids)) users_display = ','.join([str(i) for i in user_ids[:3]])
if len(user_ids) > 3:
users_display += '...'
orgs_display = ','.join([str(i) for i in org_ids[:3]])
if len(org_ids) > 3:
orgs_display += '...'
logger.info('Expire perm tree for users: [{}], orgs: [{}]'.format(users_display, orgs_display))
def expire_perm_tree_for_all_user(self): def expire_perm_tree_for_all_user(self):
keys = self.client.keys(self.cache_key_all_user) keys = self.client.keys(self.cache_key_all_user)

View File

@ -80,9 +80,11 @@ class RoleViewSet(JMSModelViewSet):
queryset = Role.objects.filter(id__in=ids).order_by(*self.ordering) queryset = Role.objects.filter(id__in=ids).order_by(*self.ordering)
org_id = current_org.id org_id = current_org.id
q = Q(role__scope=Role.Scope.system) | Q(role__scope=Role.Scope.org, org_id=org_id) q = Q(role__scope=Role.Scope.system) | Q(role__scope=Role.Scope.org, org_id=org_id)
role_bindings = RoleBinding.objects.filter(q).values_list('role_id').annotate(user_count=Count('user_id')) role_bindings = RoleBinding.objects.filter(q).values_list('role_id').annotate(
user_count=Count('user_id', distinct=True)
)
role_user_amount_mapper = {role_id: user_count for role_id, user_count in role_bindings} role_user_amount_mapper = {role_id: user_count for role_id, user_count in role_bindings}
queryset = queryset.annotate(permissions_amount=Count('permissions')) queryset = queryset.annotate(permissions_amount=Count('permissions', distinct=True))
queryset = list(queryset) queryset = list(queryset)
for role in queryset: for role in queryset:
role.users_amount = role_user_amount_mapper.get(role.id, 0) role.users_amount = role_user_amount_mapper.get(role.id, 0)

View File

@ -137,7 +137,7 @@ class LDAPUserImportAPI(APIView):
return Response({'msg': _('Get ldap users is None')}, status=400) return Response({'msg': _('Get ldap users is None')}, status=400)
orgs = self.get_orgs() orgs = self.get_orgs()
errors = LDAPImportUtil().perform_import(users, orgs) new_users, errors = LDAPImportUtil().perform_import(users, orgs)
if errors: if errors:
return Response({'errors': errors}, status=400) return Response({'errors': errors}, status=400)

View File

@ -3,6 +3,7 @@ from rest_framework import generics
from rest_framework.permissions import AllowAny from rest_framework.permissions import AllowAny
from authentication.permissions import IsValidUserOrConnectionToken from authentication.permissions import IsValidUserOrConnectionToken
from common.const.choices import COUNTRY_CALLING_CODES
from common.utils import get_logger, lazyproperty from common.utils import get_logger, lazyproperty
from common.utils.timezone import local_now from common.utils.timezone import local_now
from .. import serializers from .. import serializers
@ -24,7 +25,8 @@ class OpenPublicSettingApi(generics.RetrieveAPIView):
def get_object(self): def get_object(self):
return { return {
"XPACK_ENABLED": settings.XPACK_ENABLED, "XPACK_ENABLED": settings.XPACK_ENABLED,
"INTERFACE": self.interface_setting "INTERFACE": self.interface_setting,
"COUNTRY_CALLING_CODES": COUNTRY_CALLING_CODES
} }

View File

@ -0,0 +1,36 @@
from django.template.loader import render_to_string
from django.utils.translation import gettext as _
from common.utils import get_logger
from common.utils.timezone import local_now_display
from notifications.notifications import UserMessage
logger = get_logger(__file__)
class LDAPImportMessage(UserMessage):
def __init__(self, user, extra_kwargs):
super().__init__(user)
self.orgs = extra_kwargs.pop('orgs', [])
self.end_time = extra_kwargs.pop('end_time', '')
self.start_time = extra_kwargs.pop('start_time', '')
self.time_start_display = extra_kwargs.pop('time_start_display', '')
self.new_users = extra_kwargs.pop('new_users', [])
self.errors = extra_kwargs.pop('errors', [])
self.cost_time = extra_kwargs.pop('cost_time', '')
def get_html_msg(self) -> dict:
subject = _('Notification of Synchronized LDAP User Task Results')
context = {
'orgs': self.orgs,
'start_time': self.time_start_display,
'end_time': local_now_display(),
'cost_time': self.cost_time,
'users': self.new_users,
'errors': self.errors
}
message = render_to_string('ldap/_msg_import_ldap_user.html', context)
return {
'subject': subject,
'message': message
}

View File

@ -77,6 +77,9 @@ class LDAPSettingSerializer(serializers.Serializer):
required=False, label=_('Connect timeout (s)'), required=False, label=_('Connect timeout (s)'),
) )
AUTH_LDAP_SEARCH_PAGED_SIZE = serializers.IntegerField(required=False, label=_('Search paged size (piece)')) AUTH_LDAP_SEARCH_PAGED_SIZE = serializers.IntegerField(required=False, label=_('Search paged size (piece)'))
AUTH_LDAP_SYNC_RECEIVERS = serializers.ListField(
required=False, label=_('Recipient'), max_length=36
)
AUTH_LDAP = serializers.BooleanField(required=False, label=_('Enable LDAP auth')) AUTH_LDAP = serializers.BooleanField(required=False, label=_('Enable LDAP auth'))

View File

@ -43,7 +43,7 @@ class OAuth2SettingSerializer(serializers.Serializer):
) )
AUTH_OAUTH2_ACCESS_TOKEN_METHOD = serializers.ChoiceField( AUTH_OAUTH2_ACCESS_TOKEN_METHOD = serializers.ChoiceField(
default='GET', label=_('Client authentication method'), default='GET', label=_('Client authentication method'),
choices=(('GET', 'GET'), ('POST', 'POST')) choices=(('GET', 'GET'), ('POST', 'POST-DATA'), ('POST_JSON', 'POST-JSON'))
) )
AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT = serializers.CharField( AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT = serializers.CharField(
required=True, max_length=1024, label=_('Provider userinfo endpoint') required=True, max_length=1024, label=_('Provider userinfo endpoint')

View File

@ -11,6 +11,7 @@ __all__ = [
class PublicSettingSerializer(serializers.Serializer): class PublicSettingSerializer(serializers.Serializer):
XPACK_ENABLED = serializers.BooleanField() XPACK_ENABLED = serializers.BooleanField()
INTERFACE = serializers.DictField() INTERFACE = serializers.DictField()
COUNTRY_CALLING_CODES = serializers.ListField()
class PrivateSettingSerializer(PublicSettingSerializer): class PrivateSettingSerializer(PublicSettingSerializer):

View File

@ -1,15 +1,19 @@
# coding: utf-8 # coding: utf-8
# #
import time
from celery import shared_task from celery import shared_task
from django.conf import settings from django.conf import settings
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from common.utils import get_logger from common.utils import get_logger
from common.utils.timezone import local_now_display
from ops.celery.decorator import after_app_ready_start from ops.celery.decorator import after_app_ready_start
from ops.celery.utils import ( from ops.celery.utils import (
create_or_update_celery_periodic_tasks, disable_celery_periodic_task create_or_update_celery_periodic_tasks, disable_celery_periodic_task
) )
from orgs.models import Organization from orgs.models import Organization
from settings.notifications import LDAPImportMessage
from users.models import User
from ..utils import LDAPSyncUtil, LDAPServerUtil, LDAPImportUtil from ..utils import LDAPSyncUtil, LDAPServerUtil, LDAPImportUtil
__all__ = ['sync_ldap_user', 'import_ldap_user_periodic', 'import_ldap_user'] __all__ = ['sync_ldap_user', 'import_ldap_user_periodic', 'import_ldap_user']
@ -23,6 +27,8 @@ def sync_ldap_user():
@shared_task(verbose_name=_('Periodic import ldap user')) @shared_task(verbose_name=_('Periodic import ldap user'))
def import_ldap_user(): def import_ldap_user():
start_time = time.time()
time_start_display = local_now_display()
logger.info("Start import ldap user task") logger.info("Start import ldap user task")
util_server = LDAPServerUtil() util_server = LDAPServerUtil()
util_import = LDAPImportUtil() util_import = LDAPImportUtil()
@ -35,11 +41,26 @@ def import_ldap_user():
org_ids = [Organization.DEFAULT_ID] org_ids = [Organization.DEFAULT_ID]
default_org = Organization.default() default_org = Organization.default()
orgs = list(set([Organization.get_instance(org_id, default=default_org) for org_id in org_ids])) orgs = list(set([Organization.get_instance(org_id, default=default_org) for org_id in org_ids]))
errors = util_import.perform_import(users, orgs) new_users, errors = util_import.perform_import(users, orgs)
if errors: if errors:
logger.error("Imported LDAP users errors: {}".format(errors)) logger.error("Imported LDAP users errors: {}".format(errors))
else: else:
logger.info('Imported {} users successfully'.format(len(users))) logger.info('Imported {} users successfully'.format(len(users)))
if settings.AUTH_LDAP_SYNC_RECEIVERS:
user_ids = settings.AUTH_LDAP_SYNC_RECEIVERS
recipient_list = User.objects.filter(id__in=list(user_ids))
end_time = time.time()
extra_kwargs = {
'orgs': orgs,
'end_time': end_time,
'start_time': start_time,
'time_start_display': time_start_display,
'new_users': new_users,
'errors': errors,
'cost_time': end_time - start_time,
}
for user in recipient_list:
LDAPImportMessage(user, extra_kwargs).publish()
@shared_task(verbose_name=_('Registration periodic import ldap user task')) @shared_task(verbose_name=_('Registration periodic import ldap user task'))

View File

@ -0,0 +1,34 @@
{% load i18n %}
<p>{% trans "Sync task Finish" %}</p>
<b>{% trans 'Time' %}:</b>
<ul>
<li>{% trans 'Date start' %}: {{ start_time }}</li>
<li>{% trans 'Date end' %}: {{ end_time }}</li>
<li>{% trans 'Time cost' %}: {{ cost_time| floatformat:0 }}s</li>
</ul>
<b>{% trans "Synced Organization" %}:</b>
<ul>
{% for org in orgs %}
<li>{{ org }}</li>
{% endfor %}
</ul>
<b>{% trans "Synced User" %}:</b>
<ul>
{% if users %}
{% for user in users %}
<li>{{ user }}</li>
{% endfor %}
{% else %}
<li>{% trans 'No user synchronization required' %}</li>
{% endif %}
</ul>
{% if errors %}
<b>{% trans 'Error' %}:</b>
<ul>
{% for error in errors %}
<li>{{ error }}</li>
{% endfor %}
</ul>
{% endif %}

View File

@ -400,11 +400,14 @@ class LDAPImportUtil(object):
logger.info('Start perform import ldap users, count: {}'.format(len(users))) logger.info('Start perform import ldap users, count: {}'.format(len(users)))
errors = [] errors = []
objs = [] objs = []
new_users = []
group_users_mapper = defaultdict(set) group_users_mapper = defaultdict(set)
for user in users: for user in users:
groups = user.pop('groups', []) groups = user.pop('groups', [])
try: try:
obj, created = self.update_or_create(user) obj, created = self.update_or_create(user)
if created:
new_users.append(obj)
objs.append(obj) objs.append(obj)
except Exception as e: except Exception as e:
errors.append({user['username']: str(e)}) errors.append({user['username']: str(e)})
@ -421,7 +424,7 @@ class LDAPImportUtil(object):
for org in orgs: for org in orgs:
self.bind_org(org, objs, group_users_mapper) self.bind_org(org, objs, group_users_mapper)
logger.info('End perform import ldap users') logger.info('End perform import ldap users')
return errors return new_users, errors
def exit_user_group(self, user_groups_mapper): def exit_user_group(self, user_groups_mapper):
# 通过对比查询本次导入用户需要移除的用户组 # 通过对比查询本次导入用户需要移除的用户组

View File

@ -42,7 +42,7 @@ class SmartEndpointViewMixin:
return endpoint return endpoint
def match_endpoint_by_label(self): def match_endpoint_by_label(self):
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol) return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol, self.request)
def match_endpoint_by_target_ip(self): def match_endpoint_by_target_ip(self):
target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数用来方便测试 target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数用来方便测试

View File

@ -75,7 +75,20 @@ class Endpoint(JMSBaseModel):
return endpoint return endpoint
@classmethod @classmethod
def match_by_instance_label(cls, instance, protocol): def handle_endpoint_host(cls, endpoint, request=None):
if not endpoint.host and request:
# 动态添加 current request host
host_port = request.get_host()
# IPv6
if host_port.startswith('['):
host = host_port.split(']:')[0].rstrip(']') + ']'
else:
host = host_port.split(':')[0]
endpoint.host = host
return endpoint
@classmethod
def match_by_instance_label(cls, instance, protocol, request=None):
from assets.models import Asset from assets.models import Asset
from terminal.models import Session from terminal.models import Session
if isinstance(instance, Session): if isinstance(instance, Session):
@ -88,6 +101,7 @@ class Endpoint(JMSBaseModel):
endpoints = cls.objects.filter(name__in=list(values)).order_by('-date_updated') endpoints = cls.objects.filter(name__in=list(values)).order_by('-date_updated')
for endpoint in endpoints: for endpoint in endpoints:
if endpoint.is_valid_for(instance, protocol): if endpoint.is_valid_for(instance, protocol):
endpoint = cls.handle_endpoint_host(endpoint, request)
return endpoint return endpoint
@ -130,13 +144,5 @@ class EndpointRule(JMSBaseModel):
endpoint = endpoint_rule.endpoint endpoint = endpoint_rule.endpoint
else: else:
endpoint = Endpoint.get_or_create_default(request) endpoint = Endpoint.get_or_create_default(request)
if not endpoint.host and request: endpoint = Endpoint.handle_endpoint_host(endpoint, request)
# 动态添加 current request host
host_port = request.get_host()
# IPv6
if host_port.startswith('['):
host = host_port.split(']:')[0].rstrip(']') + ']'
else:
host = host_port.split(':')[0]
endpoint.host = host
return endpoint return endpoint

View File

@ -5,3 +5,4 @@ from .ticket import *
from .comment import * from .comment import *
from .relation import * from .relation import *
from .super_ticket import * from .super_ticket import *
from .perms import *

66
apps/tickets/api/perms.py Normal file
View File

@ -0,0 +1,66 @@
from django.conf import settings
from assets.models import Asset, Node
from assets.serializers.asset.common import MiniAssetSerializer
from assets.serializers.node import NodeSerializer
from common.api import SuggestionMixin
from orgs.mixins.api import OrgReadonlyModelViewSet
from perms.utils import AssetPermissionPermAssetUtil
from perms.utils.permission import AssetPermissionUtil
from tickets.const import TicketApplyAssetScope
__all__ = ['ApplyAssetsViewSet', 'ApplyNodesViewSet']
class ApplyAssetsViewSet(OrgReadonlyModelViewSet, SuggestionMixin):
model = Asset
serializer_class = MiniAssetSerializer
rbac_perms = (
("match", "assets.match_asset"),
)
search_fields = ("name", "address", "comment")
def get_queryset(self):
if TicketApplyAssetScope.is_permed():
queryset = self.get_assets(with_expired=True)
elif TicketApplyAssetScope.is_permed_valid():
queryset = self.get_assets()
else:
queryset = super().get_queryset()
return queryset
def get_assets(self, with_expired=False):
perms = AssetPermissionUtil().get_permissions_for_user(
self.request.user, flat=True, with_expired=with_expired
)
util = AssetPermissionPermAssetUtil(perms)
assets = util.get_all_assets()
return assets
class ApplyNodesViewSet(OrgReadonlyModelViewSet, SuggestionMixin):
model = Node
serializer_class = NodeSerializer
rbac_perms = (
("match", "assets.match_node"),
)
search_fields = ('full_value',)
def get_queryset(self):
if TicketApplyAssetScope.is_permed():
queryset = self.get_nodes(with_expired=True)
elif TicketApplyAssetScope.is_permed_valid():
queryset = self.get_nodes()
else:
queryset = super().get_queryset()
return queryset
def get_nodes(self, with_expired=False):
perms = AssetPermissionUtil().get_permissions_for_user(
self.request.user, flat=True, with_expired=with_expired
)
util = AssetPermissionPermAssetUtil(perms)
nodes = util.get_perm_nodes()
return nodes

View File

@ -4,6 +4,7 @@ from django.utils.translation import gettext_lazy as _
from rest_framework import viewsets from rest_framework import viewsets
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import MethodNotAllowed from rest_framework.exceptions import MethodNotAllowed
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
from audits.handler import create_or_update_operate_log from audits.handler import create_or_update_operate_log
@ -41,7 +42,6 @@ class TicketViewSet(CommonApiMixin, viewsets.ModelViewSet):
ordering = ('-date_created',) ordering = ('-date_created',)
rbac_perms = { rbac_perms = {
'open': 'tickets.view_ticket', 'open': 'tickets.view_ticket',
'bulk': 'tickets.change_ticket',
} }
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
@ -122,7 +122,7 @@ class TicketViewSet(CommonApiMixin, viewsets.ModelViewSet):
self._record_operate_log(instance, TicketAction.close) self._record_operate_log(instance, TicketAction.close)
return Response('ok') return Response('ok')
@action(detail=False, methods=[PUT], permission_classes=[RBACPermission, ]) @action(detail=False, methods=[PUT], permission_classes=[IsAuthenticated, ])
def bulk(self, request, *args, **kwargs): def bulk(self, request, *args, **kwargs):
self.ticket_not_allowed() self.ticket_not_allowed()

View File

@ -1,3 +1,4 @@
from django.conf import settings
from django.db.models import TextChoices, IntegerChoices from django.db.models import TextChoices, IntegerChoices
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -56,3 +57,21 @@ class TicketApprovalStrategy(TextChoices):
custom_user = 'custom_user', _("Custom user") custom_user = 'custom_user', _("Custom user")
super_admin = 'super_admin', _("Super admin") super_admin = 'super_admin', _("Super admin")
super_org_admin = 'super_org_admin', _("Super admin and org admin") super_org_admin = 'super_org_admin', _("Super admin and org admin")
class TicketApplyAssetScope(TextChoices):
all = 'all', _("All assets")
permed = 'permed', _("Permed assets")
permed_valid = 'permed_valid', _('Permed valid assets')
@classmethod
def get_scope(cls):
return settings.TICKET_APPLY_ASSET_SCOPE.lower()
@classmethod
def is_permed(cls):
return cls.get_scope() == cls.permed
@classmethod
def is_permed_valid(cls):
return cls.get_scope() == cls.permed_valid

View File

@ -57,7 +57,7 @@ class TicketStep(JMSBaseModel):
assignees.update(state=state) assignees.update(state=state)
self.status = StepStatus.closed self.status = StepStatus.closed
self.state = state self.state = state
self.save(update_fields=['state', 'status']) self.save(update_fields=['state', 'status', 'date_updated'])
def set_active(self): def set_active(self):
self.status = StepStatus.active self.status = StepStatus.active

View File

@ -16,6 +16,8 @@ router.register('apply-login-tickets', api.ApplyLoginTicketViewSet, 'apply-login
router.register('apply-command-tickets', api.ApplyCommandTicketViewSet, 'apply-command-ticket') router.register('apply-command-tickets', api.ApplyCommandTicketViewSet, 'apply-command-ticket')
router.register('apply-login-asset-tickets', api.ApplyLoginAssetTicketViewSet, 'apply-login-asset-ticket') router.register('apply-login-asset-tickets', api.ApplyLoginAssetTicketViewSet, 'apply-login-asset-ticket')
router.register('ticket-session-relation', api.TicketSessionRelationViewSet, 'ticket-session-relation') router.register('ticket-session-relation', api.TicketSessionRelationViewSet, 'ticket-session-relation')
router.register('apply-assets', api.ApplyAssetsViewSet, 'ticket-session-relation')
router.register('apply-nodes', api.ApplyNodesViewSet, 'ticket-session-relation')
urlpatterns = [ urlpatterns = [
path('tickets/<uuid:ticket_id>/session/', api.TicketSessionApi.as_view(), name='ticket-session'), path('tickets/<uuid:ticket_id>/session/', api.TicketSessionApi.as_view(), name='ticket-session'),

View File

@ -729,7 +729,7 @@ class JSONFilterMixin:
bindings = RoleBinding.objects.filter(**kwargs, role__in=value) bindings = RoleBinding.objects.filter(**kwargs, role__in=value)
if match == 'm2m_all': if match == 'm2m_all':
user_id = bindings.values('user_id').annotate(count=Count('user_id')) \ user_id = bindings.values('user_id').annotate(count=Count('user_id', distinct=True)) \
.filter(count=len(value)).values_list('user_id', flat=True) .filter(count=len(value)).values_list('user_id', flat=True)
else: else:
user_id = bindings.values_list('user_id', flat=True) user_id = bindings.values_list('user_id', flat=True)

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.db.models import Count from django.db.models import Count, Q
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
@ -46,7 +46,7 @@ class UserGroupSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
def setup_eager_loading(cls, queryset): def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """ """ Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('labels', 'labels__label') \ queryset = queryset.prefetch_related('labels', 'labels__label') \
.annotate(users_amount=Count('users')) .annotate(users_amount=Count('users', distinct=True, filter=Q(users__is_service_account=False)))
return queryset return queryset

View File

@ -163,9 +163,9 @@ def on_openid_create_or_update_user(sender, request, user, created, name, userna
user.save() user.save()
@shared_task(verbose_name=_('Clean audits session task log')) @shared_task(verbose_name=_('Clean up expired user sessions'))
@register_as_period_task(crontab=CRONTAB_AT_PM_TWO) @register_as_period_task(crontab=CRONTAB_AT_PM_TWO)
def clean_audits_log_period(): def clean_expired_user_session_period():
UserSession.clear_expired_sessions() UserSession.clear_expired_sessions()

View File

@ -86,7 +86,7 @@ def check_user_expired_periodic():
@tmp_to_root_org() @tmp_to_root_org()
def check_unused_users(): def check_unused_users():
uncommon_users_ttl = settings.SECURITY_UNCOMMON_USERS_TTL uncommon_users_ttl = settings.SECURITY_UNCOMMON_USERS_TTL
if not uncommon_users_ttl or not uncommon_users_ttl.isdigit(): if not uncommon_users_ttl:
return return
uncommon_users_ttl = int(uncommon_users_ttl) uncommon_users_ttl = int(uncommon_users_ttl)

View File

@ -7,6 +7,7 @@
.margin-bottom { .margin-bottom {
margin-bottom: 15px; margin-bottom: 15px;
} }
.input-style { .input-style {
width: 100%; width: 100%;
display: inline-block; display: inline-block;
@ -22,6 +23,19 @@
height: 100%; height: 100%;
vertical-align: top; vertical-align: top;
} }
.scrollable-menu {
height: auto;
max-height: 18rem;
overflow-x: hidden;
}
.input-group {
.input-group-btn .btn-secondary {
color: #464a4c;
background-color: #eceeef;
}
}
</style> </style>
{% endblock %} {% endblock %}
{% block html_title %}{% trans 'Forgot password' %}{% endblock %} {% block html_title %}{% trans 'Forgot password' %}{% endblock %}
@ -57,9 +71,26 @@
placeholder="{% trans 'Email account' %}" value="{{ email }}"> placeholder="{% trans 'Email account' %}" value="{{ email }}">
</div> </div>
<div id="validate-sms" class="validate-field margin-bottom"> <div id="validate-sms" class="validate-field margin-bottom">
<input type="tel" id="sms" name="sms" class="form-control input-style" <div class="input-group">
placeholder="{% trans 'Mobile number' %}" value="{{ sms }}"> <div class="input-group-btn">
<small style="color: #999; margin-left: 5px">{{ form.sms.help_text }}</small> <button type="button" class="btn btn-secondary dropdown-toggle" data-toggle="dropdown"
aria-haspopup="true" aria-expanded="false">
<span class="country-code-value">+86</span>
</button>
<ul class="dropdown-menu scrollable-menu">
{% for country in countries %}
<li>
<a href="#" class="dropdown-item d-flex justify-content-between">
<span class="country-name text-left">{{ country.name }}</span>
<span class="country-code">{{ country.value }}</span>
</a>
</li>
{% endfor %}
</ul>
</div>
<input type="tel" id="sms" name="sms" class="form-control input-style"
placeholder="{% trans 'Mobile number' %}" value="{{ sms }}">
</div>
</div> </div>
<div class="margin-bottom challenge-required"> <div class="margin-bottom challenge-required">
<input type="text" id="code" name="code" class="form-control input-style" <input type="text" id="code" name="code" class="form-control input-style"
@ -76,7 +107,7 @@
</div> </div>
</div> </div>
<script> <script>
$(function (){ $(function () {
const validateSelectRef = $('#validate-backend-select') const validateSelectRef = $('#validate-backend-select')
const formType = $('input[name="form_type"]').val() const formType = $('input[name="form_type"]').val()
validateSelectRef.val(formType) validateSelectRef.val(formType)
@ -84,19 +115,31 @@
selectChange(formType); selectChange(formType);
} }
}) })
$(".dropdown-menu li a").click(function (evt) {
const inputGroup = $('.input-group');
const inputGroupAddon = inputGroup.find('.country-code-value');
const selectedCountry = $(evt.target).closest('li');
const selectedCountryCode = selectedCountry.find('.country-code').html();
inputGroupAddon.html(selectedCountryCode)
});
function getQueryString(name) { function getQueryString(name) {
const reg = new RegExp("(^|&)"+ name +"=([^&]*)(&|$)"); const reg = new RegExp("(^|&)" + name + "=([^&]*)(&|$)");
const r = window.location.search.substr(1).match(reg); const r = window.location.search.substr(1).match(reg);
if(r !== null) if (r !== null)
return unescape(r[2]) return unescape(r[2])
return null return null
} }
function selectChange(name) { function selectChange(name) {
$('.validate-field').hide() $('.validate-field').hide()
$('#validate-' + name).show() $('#validate-' + name).show()
$('#validate-' + name + '-tip').show() $('#validate-' + name + '-tip').show()
$('input[name="form_type"]').attr('value', name) $('input[name="form_type"]').attr('value', name)
} }
function sendChallengeCode(currentBtn) { function sendChallengeCode(currentBtn) {
let time = 60; let time = 60;
const token = getQueryString('token') const token = getQueryString('token')
@ -104,7 +147,7 @@
const formType = $('input[name="form_type"]').val() const formType = $('input[name="form_type"]').val()
const email = $('#email').val() const email = $('#email').val()
const sms = $('#sms').val() let sms = $('#sms').val();
const errMsg = "{% trans 'The {} cannot be empty' %}" const errMsg = "{% trans 'The {} cannot be empty' %}"
if (formType === 'sms') { if (formType === 'sms') {
@ -118,10 +161,11 @@
return return
} }
} }
sms = $(".input-group .country-code-value").html() + sms
const data = { const data = {
form_type: formType, email: email, sms: sms, form_type: formType, email: email, sms: sms,
} }
function onSuccess() { function onSuccess() {
const originBtnText = currentBtn.innerHTML; const originBtnText = currentBtn.innerHTML;
currentBtn.disabled = true currentBtn.disabled = true

View File

@ -14,22 +14,24 @@
</strong> </strong>
</p> </p>
<div> <div>
<img src="{% static 'img/authenticator_android.png' %}" width="128" height="128" alt=""> <img src="{{ authenticator_android_url }}" width="128" height="128" alt="">
<p>{% trans 'Android downloads' %}</p> <p>{% trans 'Android downloads' %}</p>
</div> </div>
<div> <div>
<img src="{% static 'img/authenticator_iphone.png' %}" width="128" height="128" alt=""> <img src="{{ authenticator_iphone_url }}" width="128" height="128" alt="">
<p>{% trans 'iPhone downloads' %}</p> <p>{% trans 'iPhone downloads' %}</p>
</div> </div>
<p style="margin: 20px auto;"><strong style="color: #000000">{% trans 'After installation, click the next step to enter the binding page (if installed, go to the next step directly).' %}</strong></p> <p style="margin: 20px auto;"><strong
style="color: #000000">{% trans 'After installation, click the next step to enter the binding page (if installed, go to the next step directly).' %}</strong>
</p>
</div> </div>
<a href="{% url 'authentication:user-otp-enable-bind' %}" class="next">{% trans 'Next' %}</a> <a href="{% url 'authentication:user-otp-enable-bind' %}" class="next">{% trans 'Next' %}</a>
<script> <script>
$(function(){ $(function () {
$('.change-color li:eq(1) i').css('color', '{{ INTERFACE.primary_color }}') $('.change-color li:eq(1) i').css('color', '{{ INTERFACE.primary_color }}')
}) })
</script> </script>

View File

@ -1,10 +1,14 @@
# ~*~ coding: utf-8 ~*~ # ~*~ coding: utf-8 ~*~
import os
from django.conf import settings
from django.contrib.auth import logout as auth_logout from django.contrib.auth import logout as auth_logout
from django.http.response import HttpResponseRedirect from django.http.response import HttpResponseRedirect
from django.shortcuts import redirect from django.shortcuts import redirect
from django.templatetags.static import static
from django.urls import reverse from django.urls import reverse
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django.utils._os import safe_join
from django.views.generic.base import TemplateView from django.views.generic.base import TemplateView
from django.views.generic.edit import FormView from django.views.generic.edit import FormView
@ -45,9 +49,26 @@ class UserOtpEnableStartView(AuthMixin, TemplateView):
class UserOtpEnableInstallAppView(TemplateView): class UserOtpEnableInstallAppView(TemplateView):
template_name = 'users/user_otp_enable_install_app.html' template_name = 'users/user_otp_enable_install_app.html'
@staticmethod
def replace_authenticator_png(platform):
media_url = settings.MEDIA_URL
base_path = f'img/authenticator_{platform}.png'
authenticator_media_path = safe_join(settings.MEDIA_ROOT, base_path)
if os.path.exists(authenticator_media_path):
authenticator_url = f'{media_url}{base_path}'
else:
authenticator_url = static(base_path)
return authenticator_url
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
user = get_user_or_pre_auth_user(self.request) user = get_user_or_pre_auth_user(self.request)
context = {'user': user} authenticator_android_url = self.replace_authenticator_png('android')
authenticator_iphone_url = self.replace_authenticator_png('iphone')
context = {
'user': user,
'authenticator_android_url': authenticator_android_url,
'authenticator_iphone_url': authenticator_iphone_url
}
kwargs.update(context) kwargs.update(context)
return super().get_context_data(**kwargs) return super().get_context_data(**kwargs)

View File

@ -13,6 +13,7 @@ from django.views.generic import FormView, RedirectView
from authentication.errors import IntervalTooShort from authentication.errors import IntervalTooShort
from authentication.utils import check_user_property_is_correct from authentication.utils import check_user_property_is_correct
from common.const.choices import COUNTRY_CALLING_CODES
from common.utils import FlashMessageUtil, get_object_or_none, random_string from common.utils import FlashMessageUtil, get_object_or_none, random_string
from common.utils.verify_code import SendAndVerifyCodeUtil from common.utils.verify_code import SendAndVerifyCodeUtil
from users.notifications import ResetPasswordSuccessMsg from users.notifications import ResetPasswordSuccessMsg
@ -108,7 +109,7 @@ class UserForgotPasswordView(FormView):
for k, v in cleaned_data.items(): for k, v in cleaned_data.items():
if v: if v:
context[k] = v context[k] = v
context['countries'] = COUNTRY_CALLING_CODES
context['form_type'] = 'email' context['form_type'] = 'email'
context['XPACK_ENABLED'] = settings.XPACK_ENABLED context['XPACK_ENABLED'] = settings.XPACK_ENABLED
validate_backends = self.get_validate_backends_context(has_phone) validate_backends = self.get_validate_backends_context(has_phone)

View File

@ -85,7 +85,7 @@ REDIS_PORT: 6379
# SECURITY_WATERMARK_ENABLED: False # SECURITY_WATERMARK_ENABLED: False
# 浏览器关闭页面后,会话过期 # 浏览器关闭页面后,会话过期
# SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE: False # SESSION_EXPIRE_AT_BROWSER_CLOSE: False
# 每次api请求session续期 # 每次api请求session续期
# SESSION_SAVE_EVERY_REQUEST: True # SESSION_SAVE_EVERY_REQUEST: True

Some files were not shown because too many files have changed in this diff Show More