mirror of https://github.com/jumpserver/jumpserver
merge: with dev
commit
b284bb60f5
|
@ -1,11 +1,12 @@
|
|||
from django.db.models import Q
|
||||
from rest_framework.generics import CreateAPIView
|
||||
|
||||
from accounts import serializers
|
||||
from accounts.models import Account
|
||||
from accounts.permissions import AccountTaskActionPermission
|
||||
from accounts.tasks import (
|
||||
remove_accounts_task, verify_accounts_connectivity_task, push_accounts_to_assets_task
|
||||
)
|
||||
from assets.exceptions import NotSupportedTemporarilyError
|
||||
from authentication.permissions import UserConfirmation, ConfirmType
|
||||
|
||||
__all__ = [
|
||||
|
@ -26,25 +27,35 @@ class AccountsTaskCreateAPI(CreateAPIView):
|
|||
]
|
||||
return super().get_permissions()
|
||||
|
||||
def perform_create(self, serializer):
|
||||
data = serializer.validated_data
|
||||
accounts = data.get('accounts', [])
|
||||
params = data.get('params')
|
||||
@staticmethod
|
||||
def get_account_ids(data, action):
|
||||
account_type = 'gather_accounts' if action == 'remove' else 'accounts'
|
||||
accounts = data.get(account_type, [])
|
||||
account_ids = [str(a.id) for a in accounts]
|
||||
|
||||
if data['action'] == 'push':
|
||||
task = push_accounts_to_assets_task.delay(account_ids, params)
|
||||
elif data['action'] == 'remove':
|
||||
gather_accounts = data.get('gather_accounts', [])
|
||||
gather_account_ids = [str(a.id) for a in gather_accounts]
|
||||
task = remove_accounts_task.delay(gather_account_ids)
|
||||
if action == 'remove':
|
||||
return account_ids
|
||||
|
||||
assets = data.get('assets', [])
|
||||
asset_ids = [str(a.id) for a in assets]
|
||||
ids = Account.objects.filter(
|
||||
Q(id__in=account_ids) | Q(asset_id__in=asset_ids)
|
||||
).distinct().values_list('id', flat=True)
|
||||
return [str(_id) for _id in ids]
|
||||
|
||||
def perform_create(self, serializer):
|
||||
data = serializer.validated_data
|
||||
action = data['action']
|
||||
ids = self.get_account_ids(data, action)
|
||||
|
||||
if action == 'push':
|
||||
task = push_accounts_to_assets_task.delay(ids, data.get('params'))
|
||||
elif action == 'remove':
|
||||
task = remove_accounts_task.delay(ids)
|
||||
elif action == 'verify':
|
||||
task = verify_accounts_connectivity_task.delay(ids)
|
||||
else:
|
||||
account = accounts[0]
|
||||
asset = account.asset
|
||||
if not asset.auto_config['ansible_enabled'] or \
|
||||
not asset.auto_config['ping_enabled']:
|
||||
raise NotSupportedTemporarilyError()
|
||||
task = verify_accounts_connectivity_task.delay(account_ids)
|
||||
raise ValueError(f"Invalid action: {action}")
|
||||
|
||||
data = getattr(serializer, '_data', {})
|
||||
data["task"] = task.id
|
||||
|
|
|
@ -145,9 +145,9 @@ class AccountBackupHandler:
|
|||
wb = Workbook(filename)
|
||||
for sheet, data in data_map.items():
|
||||
ws = wb.add_worksheet(str(sheet))
|
||||
for row in data:
|
||||
for col, _data in enumerate(row):
|
||||
ws.write_string(0, col, _data)
|
||||
for row_index, row_data in enumerate(data):
|
||||
for col_index, col_data in enumerate(row_data):
|
||||
ws.write_string(row_index, col_index, col_data)
|
||||
wb.close()
|
||||
files.append(filename)
|
||||
timedelta = round((time.time() - time_start), 2)
|
||||
|
|
|
@ -39,3 +39,4 @@
|
|||
login_host: "{{ jms_asset.address }}"
|
||||
login_port: "{{ jms_asset.port }}"
|
||||
login_database: "{{ jms_asset.spec_info.db_name }}"
|
||||
mode: "{{ account.mode }}"
|
||||
|
|
|
@ -4,6 +4,7 @@ from copy import deepcopy
|
|||
|
||||
from django.conf import settings
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from xlsxwriter import Workbook
|
||||
|
||||
from accounts.const import AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy
|
||||
|
@ -161,7 +162,8 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
print("Account not found, deleted ?")
|
||||
return
|
||||
account.secret = recorder.new_secret
|
||||
account.save(update_fields=['secret'])
|
||||
account.date_updated = timezone.now()
|
||||
account.save(update_fields=['secret', 'date_updated'])
|
||||
|
||||
def on_host_error(self, host, error, result):
|
||||
recorder = self.name_recorder_mapper.get(host)
|
||||
|
@ -182,17 +184,33 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_summary(recorders):
|
||||
total, succeed, failed = 0, 0, 0
|
||||
for recorder in recorders:
|
||||
if recorder.status == 'success':
|
||||
succeed += 1
|
||||
else:
|
||||
failed += 1
|
||||
total += 1
|
||||
|
||||
summary = _('Success: %s, Failed: %s, Total: %s') % (succeed, failed, total)
|
||||
return summary
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
if self.secret_type and not self.check_secret():
|
||||
return
|
||||
super().run(*args, **kwargs)
|
||||
recorders = list(self.name_recorder_mapper.values())
|
||||
summary = self.get_summary(recorders)
|
||||
print(summary, end='')
|
||||
|
||||
if self.record_id:
|
||||
return
|
||||
recorders = self.name_recorder_mapper.values()
|
||||
recorders = list(recorders)
|
||||
self.send_recorder_mail(recorders)
|
||||
|
||||
def send_recorder_mail(self, recorders):
|
||||
self.send_recorder_mail(recorders, summary)
|
||||
|
||||
def send_recorder_mail(self, recorders, summary):
|
||||
recipients = self.execution.recipients
|
||||
if not recorders or not recipients:
|
||||
return
|
||||
|
@ -212,7 +230,7 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
attachment = os.path.join(path, f'{name}-{local_now_filename()}-{time.time()}.zip')
|
||||
encrypt_and_compress_zip_file(attachment, password, [filename])
|
||||
attachments = [attachment]
|
||||
ChangeSecretExecutionTaskMsg(name, user).publish(attachments)
|
||||
ChangeSecretExecutionTaskMsg(name, user, summary).publish(attachments)
|
||||
os.remove(filename)
|
||||
|
||||
@staticmethod
|
||||
|
@ -228,8 +246,8 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
rows.insert(0, header)
|
||||
wb = Workbook(filename)
|
||||
ws = wb.add_worksheet('Sheet1')
|
||||
for row in rows:
|
||||
for col, data in enumerate(row):
|
||||
ws.write_string(0, col, data)
|
||||
for row_index, row_data in enumerate(rows):
|
||||
for col_index, col_data in enumerate(row_data):
|
||||
ws.write_string(row_index, col_index, col_data)
|
||||
wb.close()
|
||||
return True
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
- hosts: demo
|
||||
gather_facts: no
|
||||
tasks:
|
||||
- name: Gather posix account
|
||||
- name: Gather windows account
|
||||
ansible.builtin.win_shell: net user
|
||||
register: result
|
||||
ignore_errors: true
|
||||
|
||||
- name: Define info by set_fact
|
||||
ansible.builtin.set_fact:
|
||||
|
|
|
@ -39,3 +39,4 @@
|
|||
login_host: "{{ jms_asset.address }}"
|
||||
login_port: "{{ jms_asset.port }}"
|
||||
login_database: "{{ jms_asset.spec_info.db_name }}"
|
||||
mode: "{{ account.mode }}"
|
||||
|
|
|
@ -54,20 +54,23 @@ class AccountBackupByObjStorageExecutionTaskMsg(object):
|
|||
class ChangeSecretExecutionTaskMsg(object):
|
||||
subject = _('Notification of implementation result of encryption change plan')
|
||||
|
||||
def __init__(self, name: str, user: User):
|
||||
def __init__(self, name: str, user: User, summary):
|
||||
self.name = name
|
||||
self.user = user
|
||||
self.summary = summary
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
name = self.name
|
||||
if self.user.secret_key:
|
||||
return _('{} - The encryption change task has been completed. '
|
||||
'See the attachment for details').format(name)
|
||||
default_message = _('{} - The encryption change task has been completed. '
|
||||
'See the attachment for details').format(name)
|
||||
|
||||
else:
|
||||
return _("{} - The encryption change task has been completed: the encryption "
|
||||
"password has not been set - please go to personal information -> "
|
||||
"file encryption password to set the encryption password").format(name)
|
||||
default_message = _("{} - The encryption change task has been completed: the encryption "
|
||||
"password has not been set - please go to personal information -> "
|
||||
"set encryption password in preferences").format(name)
|
||||
return self.summary + '\n' + default_message
|
||||
|
||||
def publish(self, attachments=None):
|
||||
send_mail_attachment_async(
|
||||
|
|
|
@ -60,7 +60,7 @@ class AccountCreateUpdateSerializerMixin(serializers.Serializer):
|
|||
for data in initial_data:
|
||||
if not data.get('asset') and not self.instance:
|
||||
raise serializers.ValidationError({'asset': UniqueTogetherValidator.missing_message})
|
||||
asset = data.get('asset') or self.instance.asset
|
||||
asset = data.get('asset') or getattr(self.instance, 'asset', None)
|
||||
self.from_template_if_need(data)
|
||||
self.set_uniq_name_if_need(data, asset)
|
||||
|
||||
|
@ -457,12 +457,14 @@ class AccountHistorySerializer(serializers.ModelSerializer):
|
|||
|
||||
class AccountTaskSerializer(serializers.Serializer):
|
||||
ACTION_CHOICES = (
|
||||
('test', 'test'),
|
||||
('verify', 'verify'),
|
||||
('push', 'push'),
|
||||
('remove', 'remove'),
|
||||
)
|
||||
action = serializers.ChoiceField(choices=ACTION_CHOICES, write_only=True)
|
||||
assets = serializers.PrimaryKeyRelatedField(
|
||||
queryset=Asset.objects, required=False, allow_empty=True, many=True
|
||||
)
|
||||
accounts = serializers.PrimaryKeyRelatedField(
|
||||
queryset=Account.objects, required=False, allow_empty=True, many=True
|
||||
)
|
||||
|
|
|
@ -21,7 +21,8 @@ def on_account_pre_save(sender, instance, **kwargs):
|
|||
if instance.version == 0:
|
||||
instance.version = 1
|
||||
else:
|
||||
instance.version = instance.history.count()
|
||||
history_account = instance.history.first()
|
||||
instance.version = history_account.version + 1 if history_account else 0
|
||||
|
||||
|
||||
@merge_delay_run(ttl=5)
|
||||
|
@ -62,7 +63,7 @@ def create_accounts_activities(account, action='create'):
|
|||
def on_account_create_by_template(sender, instance, created=False, **kwargs):
|
||||
if not created or instance.source != 'template':
|
||||
return
|
||||
push_accounts_if_need(accounts=(instance,))
|
||||
push_accounts_if_need.delay(accounts=(instance,))
|
||||
create_accounts_activities(instance, action='create')
|
||||
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ def clean_historical_accounts():
|
|||
history_model = Account.history.model
|
||||
history_id_mapper = defaultdict(list)
|
||||
|
||||
ids = history_model.objects.values('id').annotate(count=Count('id')) \
|
||||
ids = history_model.objects.values('id').annotate(count=Count('id', distinct=True)) \
|
||||
.filter(count__gte=limit).values_list('id', flat=True)
|
||||
|
||||
if not ids:
|
||||
|
|
|
@ -92,6 +92,7 @@ class AssetViewSet(SuggestionMixin, OrgBulkModelViewSet):
|
|||
model = Asset
|
||||
filterset_class = AssetFilterSet
|
||||
search_fields = ("name", "address", "comment")
|
||||
ordering = ('name',)
|
||||
ordering_fields = ('name', 'address', 'connectivity', 'platform', 'date_updated', 'date_created')
|
||||
serializer_classes = (
|
||||
("default", serializers.AssetSerializer),
|
||||
|
|
|
@ -12,6 +12,6 @@ class Migration(migrations.Migration):
|
|||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name='asset',
|
||||
options={'ordering': ['name'], 'permissions': [('refresh_assethardwareinfo', 'Can refresh asset hardware info'), ('test_assetconnectivity', 'Can test asset connectivity'), ('match_asset', 'Can match asset'), ('change_assetnodes', 'Can change asset nodes')], 'verbose_name': 'Asset'},
|
||||
options={'ordering': [], 'permissions': [('refresh_assethardwareinfo', 'Can refresh asset hardware info'), ('test_assetconnectivity', 'Can test asset connectivity'), ('match_asset', 'Can match asset'), ('change_assetnodes', 'Can change asset nodes')], 'verbose_name': 'Asset'},
|
||||
),
|
||||
]
|
||||
|
|
|
@ -348,7 +348,7 @@ class Asset(NodesRelationMixin, LabeledMixin, AbsConnectivity, JSONFilterMixin,
|
|||
class Meta:
|
||||
unique_together = [('org_id', 'name')]
|
||||
verbose_name = _("Asset")
|
||||
ordering = ["name", ]
|
||||
ordering = []
|
||||
permissions = [
|
||||
('refresh_assethardwareinfo', _('Can refresh asset hardware info')),
|
||||
('test_assetconnectivity', _('Can test asset connectivity')),
|
||||
|
|
|
@ -429,7 +429,7 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
|
|||
|
||||
@classmethod
|
||||
@timeit
|
||||
def get_nodes_all_assets(cls, *nodes):
|
||||
def get_nodes_all_assets(cls, *nodes, distinct=True):
|
||||
from .asset import Asset
|
||||
node_ids = set()
|
||||
descendant_node_query = Q()
|
||||
|
@ -439,7 +439,10 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
|
|||
if descendant_node_query:
|
||||
_ids = Node.objects.order_by().filter(descendant_node_query).values_list('id', flat=True)
|
||||
node_ids.update(_ids)
|
||||
return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct()
|
||||
assets = Asset.objects.order_by().filter(nodes__id__in=node_ids)
|
||||
if distinct:
|
||||
assets = assets.distinct()
|
||||
return assets
|
||||
|
||||
def get_all_asset_ids(self):
|
||||
asset_ids = self.get_all_asset_ids_by_node_key(org_id=self.org_id, node_key=self.key)
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from rest_framework.pagination import LimitOffsetPagination
|
||||
from rest_framework.request import Request
|
||||
|
||||
from common.utils import get_logger
|
||||
from assets.models import Node
|
||||
from common.utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -28,6 +28,7 @@ class AssetPaginationBase(LimitOffsetPagination):
|
|||
'key', 'all', 'show_current_asset',
|
||||
'cache_policy', 'display', 'draw',
|
||||
'order', 'node', 'node_id', 'fields_size',
|
||||
'asset'
|
||||
}
|
||||
for k, v in self._request.query_params.items():
|
||||
if k not in exclude_query_params and v is not None:
|
||||
|
|
|
@ -206,9 +206,12 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
|||
""" Perform necessary eager loading of data. """
|
||||
queryset = queryset.prefetch_related('domain', 'nodes', 'protocols', ) \
|
||||
.prefetch_related('platform', 'platform__automation') \
|
||||
.prefetch_related('labels', 'labels__label') \
|
||||
.annotate(category=F("platform__category")) \
|
||||
.annotate(type=F("platform__type"))
|
||||
if queryset.model is Asset:
|
||||
queryset = queryset.prefetch_related('labels__label', 'labels')
|
||||
else:
|
||||
queryset = queryset.prefetch_related('asset_ptr__labels__label', 'asset_ptr__labels')
|
||||
return queryset
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -56,7 +56,14 @@ class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
|
|||
|
||||
class DomainListSerializer(DomainSerializer):
|
||||
class Meta(DomainSerializer.Meta):
|
||||
fields = list(set(DomainSerializer.Meta.fields) - {'assets'})
|
||||
fields = list(set(DomainSerializer.Meta.fields + ['assets_amount']) - {'assets'})
|
||||
|
||||
@classmethod
|
||||
def setup_eager_loading(cls, queryset):
|
||||
queryset = queryset.annotate(
|
||||
assets_amount=Count('assets', distinct=True),
|
||||
)
|
||||
return queryset
|
||||
|
||||
|
||||
class DomainWithGatewaySerializer(serializers.ModelSerializer):
|
||||
|
|
|
@ -191,7 +191,6 @@ class PlatformSerializer(ResourceLabelsMixin, WritableNestedModelSerializer):
|
|||
def add_type_choices(self, name, label):
|
||||
tp = self.fields['type']
|
||||
tp.choices[name] = label
|
||||
tp.choice_mapper[name] = label
|
||||
tp.choice_strings_to_values[name] = label
|
||||
|
||||
@lazyproperty
|
||||
|
|
|
@ -63,13 +63,13 @@ def on_asset_create(sender, instance=None, created=False, **kwargs):
|
|||
return
|
||||
logger.info("Asset create signal recv: {}".format(instance))
|
||||
|
||||
ensure_asset_has_node(assets=(instance,))
|
||||
ensure_asset_has_node.delay(assets=(instance,))
|
||||
|
||||
# 获取资产硬件信息
|
||||
auto_config = instance.auto_config
|
||||
if auto_config.get('ping_enabled'):
|
||||
logger.debug('Asset {} ping enabled, test connectivity'.format(instance.name))
|
||||
test_assets_connectivity_handler(assets=(instance,))
|
||||
test_assets_connectivity_handler.delay(assets=(instance,))
|
||||
if auto_config.get('gather_facts_enabled'):
|
||||
logger.debug('Asset {} gather facts enabled, gather facts'.format(instance.name))
|
||||
gather_assets_facts_handler(assets=(instance,))
|
||||
|
|
|
@ -2,14 +2,16 @@
|
|||
#
|
||||
from operator import add, sub
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.models.signals import m2m_changed
|
||||
from django.dispatch import receiver
|
||||
|
||||
from assets.models import Asset, Node
|
||||
from common.const.signals import PRE_CLEAR, POST_ADD, PRE_REMOVE
|
||||
from common.decorators import on_transaction_commit, merge_delay_run
|
||||
from common.signals import django_ready
|
||||
from common.utils import get_logger
|
||||
from orgs.utils import tmp_to_org
|
||||
from orgs.utils import tmp_to_org, tmp_to_root_org
|
||||
from ..tasks import check_node_assets_amount_task
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
@ -34,7 +36,7 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
|
|||
node_ids = [instance.id]
|
||||
else:
|
||||
node_ids = list(pk_set)
|
||||
update_nodes_assets_amount(node_ids=node_ids)
|
||||
update_nodes_assets_amount.delay(node_ids=node_ids)
|
||||
|
||||
|
||||
@merge_delay_run(ttl=30)
|
||||
|
@ -52,3 +54,18 @@ def update_nodes_assets_amount(node_ids=()):
|
|||
node.assets_amount = node.get_assets_amount()
|
||||
|
||||
Node.objects.bulk_update(nodes, ['assets_amount'])
|
||||
|
||||
|
||||
@receiver(django_ready)
|
||||
def set_assets_size_to_setting(sender, **kwargs):
|
||||
from assets.models import Asset
|
||||
try:
|
||||
with tmp_to_root_org():
|
||||
amount = Asset.objects.order_by().count()
|
||||
except:
|
||||
amount = 0
|
||||
|
||||
if amount > 20000:
|
||||
settings.ASSET_SIZE = 'large'
|
||||
elif amount > 2000:
|
||||
settings.ASSET_SIZE = 'medium'
|
||||
|
|
|
@ -44,18 +44,18 @@ def on_node_post_create(sender, instance, created, update_fields, **kwargs):
|
|||
need_expire = False
|
||||
|
||||
if need_expire:
|
||||
expire_node_assets_mapping(org_ids=(instance.org_id,))
|
||||
expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
|
||||
|
||||
|
||||
@receiver(post_delete, sender=Node)
|
||||
def on_node_post_delete(sender, instance, **kwargs):
|
||||
expire_node_assets_mapping(org_ids=(instance.org_id,))
|
||||
expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
|
||||
|
||||
|
||||
@receiver(m2m_changed, sender=Asset.nodes.through)
|
||||
def on_node_asset_change(sender, instance, action='pre_remove', **kwargs):
|
||||
if action.startswith('post'):
|
||||
expire_node_assets_mapping(org_ids=(instance.org_id,))
|
||||
expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
|
||||
|
||||
|
||||
@receiver(django_ready)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from django.urls import path
|
||||
from rest_framework_bulk.routes import BulkRouter
|
||||
|
||||
from labels.api import LabelViewSet
|
||||
from .. import api
|
||||
|
||||
app_name = 'assets'
|
||||
|
@ -22,6 +23,7 @@ router.register(r'domains', api.DomainViewSet, 'domain')
|
|||
router.register(r'gateways', api.GatewayViewSet, 'gateway')
|
||||
router.register(r'favorite-assets', api.FavoriteAssetViewSet, 'favorite-asset')
|
||||
router.register(r'protocol-settings', api.PlatformProtocolViewSet, 'protocol-setting')
|
||||
router.register(r'labels', LabelViewSet, 'label')
|
||||
|
||||
urlpatterns = [
|
||||
# path('assets/<uuid:pk>/gateways/', api.AssetGatewayListApi.as_view(), name='asset-gateway-list'),
|
||||
|
|
|
@ -4,7 +4,6 @@ from urllib.parse import urlencode, urlparse
|
|||
from kubernetes import client
|
||||
from kubernetes.client import api_client
|
||||
from kubernetes.client.api import core_v1_api
|
||||
from kubernetes.client.exceptions import ApiException
|
||||
from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError
|
||||
|
||||
from common.utils import get_logger
|
||||
|
@ -88,8 +87,9 @@ class KubernetesClient:
|
|||
if hasattr(self, func_name):
|
||||
try:
|
||||
data = getattr(self, func_name)(*args)
|
||||
except ApiException as e:
|
||||
logger.error(e.reason)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
||||
if self.server:
|
||||
self.server.stop()
|
||||
|
|
|
@ -20,6 +20,7 @@ from common.const.http import GET, POST
|
|||
from common.drf.filters import DatetimeRangeFilterBackend
|
||||
from common.permissions import IsServiceAccount
|
||||
from common.plugins.es import QuerySet as ESQuerySet
|
||||
from common.sessions.cache import user_session_manager
|
||||
from common.storage.ftp_file import FTPFileStorageHandler
|
||||
from common.utils import is_uuid, get_logger, lazyproperty
|
||||
from orgs.mixins.api import OrgReadonlyModelViewSet, OrgModelViewSet
|
||||
|
@ -289,8 +290,7 @@ class UserSessionViewSet(CommonApiMixin, viewsets.ModelViewSet):
|
|||
return Response(status=status.HTTP_200_OK)
|
||||
|
||||
keys = queryset.values_list('key', flat=True)
|
||||
session_store_cls = import_module(settings.SESSION_ENGINE).SessionStore
|
||||
for key in keys:
|
||||
session_store_cls(key).delete()
|
||||
user_session_manager.decrement_or_remove(key)
|
||||
queryset.delete()
|
||||
return Response(status=status.HTTP_200_OK)
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from django.core.cache import cache
|
||||
from django_filters import rest_framework as drf_filters
|
||||
from rest_framework import filters
|
||||
from rest_framework.compat import coreapi, coreschema
|
||||
|
||||
from common.drf.filters import BaseFilterSet
|
||||
from notifications.ws import WS_SESSION_KEY
|
||||
from common.sessions.cache import user_session_manager
|
||||
from orgs.utils import current_org
|
||||
from .models import UserSession
|
||||
|
||||
|
@ -41,13 +40,11 @@ class UserSessionFilterSet(BaseFilterSet):
|
|||
|
||||
@staticmethod
|
||||
def filter_is_active(queryset, name, is_active):
|
||||
redis_client = cache.client.get_client()
|
||||
members = redis_client.smembers(WS_SESSION_KEY)
|
||||
members = [member.decode('utf-8') for member in members]
|
||||
keys = user_session_manager.get_active_keys()
|
||||
if is_active:
|
||||
queryset = queryset.filter(key__in=members)
|
||||
queryset = queryset.filter(key__in=keys)
|
||||
else:
|
||||
queryset = queryset.exclude(key__in=members)
|
||||
queryset = queryset.exclude(key__in=keys)
|
||||
return queryset
|
||||
|
||||
class Meta:
|
||||
|
|
|
@ -4,15 +4,15 @@ from datetime import timedelta
|
|||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache import caches, cache
|
||||
from django.core.cache import caches
|
||||
from django.db import models
|
||||
from django.db.models import Q
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext, gettext_lazy as _
|
||||
|
||||
from common.db.encoder import ModelJSONFieldEncoder
|
||||
from common.sessions.cache import user_session_manager
|
||||
from common.utils import lazyproperty, i18n_trans
|
||||
from notifications.ws import WS_SESSION_KEY
|
||||
from ops.models import JobExecution
|
||||
from orgs.mixins.models import OrgModelMixin, Organization
|
||||
from orgs.utils import current_org
|
||||
|
@ -278,8 +278,7 @@ class UserSession(models.Model):
|
|||
|
||||
@property
|
||||
def is_active(self):
|
||||
redis_client = cache.client.get_client()
|
||||
return redis_client.sismember(WS_SESSION_KEY, self.key)
|
||||
return user_session_manager.check_active(self.key)
|
||||
|
||||
@property
|
||||
def date_expired(self):
|
||||
|
|
|
@ -205,7 +205,7 @@ class RDPFileClientProtocolURLMixin:
|
|||
return data
|
||||
|
||||
def get_smart_endpoint(self, protocol, asset=None):
|
||||
endpoint = Endpoint.match_by_instance_label(asset, protocol)
|
||||
endpoint = Endpoint.match_by_instance_label(asset, protocol, self.request)
|
||||
if not endpoint:
|
||||
target_ip = asset.get_target_ip() if asset else ''
|
||||
endpoint = EndpointRule.match_endpoint(
|
||||
|
|
|
@ -90,6 +90,6 @@ class MFAChallengeVerifyApi(AuthMixin, CreateAPIView):
|
|||
return Response({'msg': 'ok'})
|
||||
except errors.AuthFailedError as e:
|
||||
data = {"error": e.error, "msg": e.msg}
|
||||
raise ValidationError(data)
|
||||
return Response(data, status=401)
|
||||
except errors.NeedMoreInfoError as e:
|
||||
return Response(e.as_data(), status=200)
|
||||
|
|
|
@ -10,6 +10,7 @@ from rest_framework import authentication, exceptions
|
|||
from common.auth import signature
|
||||
from common.decorators import merge_delay_run
|
||||
from common.utils import get_object_or_none, get_request_ip_or_data, contains_ip
|
||||
from users.models import User
|
||||
from ..models import AccessKey, PrivateToken
|
||||
|
||||
|
||||
|
@ -19,22 +20,23 @@ def date_more_than(d, seconds):
|
|||
|
||||
@merge_delay_run(ttl=60)
|
||||
def update_token_last_used(tokens=()):
|
||||
for token in tokens:
|
||||
token.date_last_used = timezone.now()
|
||||
token.save(update_fields=['date_last_used'])
|
||||
access_keys_ids = [token.id for token in tokens if isinstance(token, AccessKey)]
|
||||
private_token_keys = [token.key for token in tokens if isinstance(token, PrivateToken)]
|
||||
if len(access_keys_ids) > 0:
|
||||
AccessKey.objects.filter(id__in=access_keys_ids).update(date_last_used=timezone.now())
|
||||
if len(private_token_keys) > 0:
|
||||
PrivateToken.objects.filter(key__in=private_token_keys).update(date_last_used=timezone.now())
|
||||
|
||||
|
||||
@merge_delay_run(ttl=60)
|
||||
def update_user_last_used(users=()):
|
||||
for user in users:
|
||||
user.date_api_key_last_used = timezone.now()
|
||||
user.save(update_fields=['date_api_key_last_used'])
|
||||
User.objects.filter(id__in=users).update(date_api_key_last_used=timezone.now())
|
||||
|
||||
|
||||
def after_authenticate_update_date(user, token=None):
|
||||
update_user_last_used(users=(user,))
|
||||
update_user_last_used.delay(users=(user.id,))
|
||||
if token:
|
||||
update_token_last_used(tokens=(token,))
|
||||
update_token_last_used.delay(tokens=(token,))
|
||||
|
||||
|
||||
class AccessTokenAuthentication(authentication.BaseAuthentication):
|
||||
|
|
|
@ -98,16 +98,19 @@ class OAuth2Backend(JMSModelBackend):
|
|||
access_token_url = '{url}{separator}{query}'.format(
|
||||
url=settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT, separator=separator, query=urlencode(query_dict)
|
||||
)
|
||||
# token_method -> get, post(post_data), post_json
|
||||
token_method = settings.AUTH_OAUTH2_ACCESS_TOKEN_METHOD.lower()
|
||||
requests_func = getattr(requests, token_method, requests.get)
|
||||
logger.debug(log_prompt.format('Call the access token endpoint[method: %s]' % token_method))
|
||||
headers = {
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
if token_method == 'post':
|
||||
access_token_response = requests_func(access_token_url, headers=headers, data=query_dict)
|
||||
if token_method.startswith('post'):
|
||||
body_key = 'json' if token_method.endswith('json') else 'data'
|
||||
access_token_response = requests.post(
|
||||
access_token_url, headers=headers, **{body_key: query_dict}
|
||||
)
|
||||
else:
|
||||
access_token_response = requests_func(access_token_url, headers=headers)
|
||||
access_token_response = requests.get(access_token_url, headers=headers)
|
||||
try:
|
||||
access_token_response.raise_for_status()
|
||||
access_token_response_data = access_token_response.json()
|
||||
|
|
|
@ -18,7 +18,7 @@ class EncryptedField(forms.CharField):
|
|||
|
||||
class UserLoginForm(forms.Form):
|
||||
days_auto_login = int(settings.SESSION_COOKIE_AGE / 3600 / 24)
|
||||
disable_days_auto_login = settings.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE \
|
||||
disable_days_auto_login = settings.SESSION_EXPIRE_AT_BROWSER_CLOSE \
|
||||
or days_auto_login < 1
|
||||
|
||||
username = forms.CharField(
|
||||
|
|
|
@ -142,23 +142,7 @@ class SessionCookieMiddleware(MiddlewareMixin):
|
|||
return response
|
||||
response.set_cookie(key, value)
|
||||
|
||||
@staticmethod
|
||||
def set_cookie_session_expire(request, response):
|
||||
if not request.session.get('auth_session_expiration_required'):
|
||||
return
|
||||
value = 'age'
|
||||
if settings.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE or \
|
||||
not request.session.get('auto_login', False):
|
||||
value = 'close'
|
||||
|
||||
age = request.session.get_expiry_age()
|
||||
expire_timestamp = request.session.get_expiry_date().timestamp()
|
||||
response.set_cookie('jms_session_expire_timestamp', expire_timestamp)
|
||||
response.set_cookie('jms_session_expire', value, max_age=age)
|
||||
request.session.pop('auth_session_expiration_required', None)
|
||||
|
||||
def process_response(self, request, response: HttpResponse):
|
||||
self.set_cookie_session_prefix(request, response)
|
||||
self.set_cookie_public_key(request, response)
|
||||
self.set_cookie_session_expire(request, response)
|
||||
return response
|
||||
|
|
|
@ -37,9 +37,6 @@ def on_user_auth_login_success(sender, user, request, **kwargs):
|
|||
UserSession.objects.filter(key=session_key).delete()
|
||||
cache.set(lock_key, request.session.session_key, None)
|
||||
|
||||
# 标记登录,设置 cookie,前端可以控制刷新, Middleware 会拦截这个生成 cookie
|
||||
request.session['auth_session_expiration_required'] = 1
|
||||
|
||||
|
||||
@receiver(cas_user_authenticated)
|
||||
def on_cas_user_login_success(sender, request, user, **kwargs):
|
||||
|
|
|
@ -70,11 +70,12 @@ class DingTalkQRMixin(DingTalkBaseMixin, View):
|
|||
self.request.session[DINGTALK_STATE_SESSION_KEY] = state
|
||||
|
||||
params = {
|
||||
'appid': settings.DINGTALK_APPKEY,
|
||||
'client_id': settings.DINGTALK_APPKEY,
|
||||
'response_type': 'code',
|
||||
'scope': 'snsapi_login',
|
||||
'scope': 'openid',
|
||||
'state': state,
|
||||
'redirect_uri': redirect_uri,
|
||||
'prompt': 'consent'
|
||||
}
|
||||
url = URL.QR_CONNECT + '?' + urlencode(params)
|
||||
return url
|
||||
|
|
|
@ -104,9 +104,11 @@ class QuerySetMixin:
|
|||
page = super().paginate_queryset(queryset)
|
||||
serializer_class = self.get_serializer_class()
|
||||
if page and serializer_class and hasattr(serializer_class, 'setup_eager_loading'):
|
||||
ids = [i.id for i in page]
|
||||
ids = [str(obj.id) for obj in page]
|
||||
page = self.get_queryset().filter(id__in=ids)
|
||||
page = serializer_class.setup_eager_loading(page)
|
||||
page_mapper = {str(obj.id): obj for obj in page}
|
||||
page = [page_mapper.get(_id) for _id in ids if _id in page_mapper]
|
||||
return page
|
||||
|
||||
|
||||
|
|
|
@ -19,3 +19,17 @@ class Status(models.TextChoices):
|
|||
failed = 'failed', _("Failed")
|
||||
error = 'error', _("Error")
|
||||
canceled = 'canceled', _("Canceled")
|
||||
|
||||
|
||||
COUNTRY_CALLING_CODES = [
|
||||
{'name': 'China(中国)', 'value': '+86'},
|
||||
{'name': 'HongKong(中国香港)', 'value': '+852'},
|
||||
{'name': 'Macao(中国澳门)', 'value': '+853'},
|
||||
{'name': 'Taiwan(中国台湾)', 'value': '+886'},
|
||||
{'name': 'America(America)', 'value': '+1'}, {'name': 'Russia(Россия)', 'value': '+7'},
|
||||
{'name': 'France(français)', 'value': '+33'},
|
||||
{'name': 'Britain(Britain)', 'value': '+44'},
|
||||
{'name': 'Germany(Deutschland)', 'value': '+49'},
|
||||
{'name': 'Japan(日本)', 'value': '+81'}, {'name': 'Korea(한국)', 'value': '+82'},
|
||||
{'name': 'India(भारत)', 'value': '+91'}
|
||||
]
|
||||
|
|
|
@ -362,11 +362,15 @@ class RelatedManager:
|
|||
if name is None or val is None:
|
||||
continue
|
||||
|
||||
if custom_attr_filter:
|
||||
custom_filter_q = None
|
||||
spec_attr_filter = getattr(to_model, "get_{}_filter_attr_q".format(name), None)
|
||||
if spec_attr_filter:
|
||||
custom_filter_q = spec_attr_filter(val, match)
|
||||
elif custom_attr_filter:
|
||||
custom_filter_q = custom_attr_filter(name, val, match)
|
||||
if custom_filter_q:
|
||||
filters.append(custom_filter_q)
|
||||
continue
|
||||
if custom_filter_q:
|
||||
filters.append(custom_filter_q)
|
||||
continue
|
||||
|
||||
if match == 'ip_in':
|
||||
q = cls.get_ip_in_q(name, val)
|
||||
|
@ -464,11 +468,15 @@ class JSONManyToManyDescriptor:
|
|||
rule_value = rule.get('value', '')
|
||||
rule_match = rule.get('match', 'exact')
|
||||
|
||||
if custom_attr_filter:
|
||||
q = custom_attr_filter(rule['name'], rule_value, rule_match)
|
||||
if q:
|
||||
custom_q &= q
|
||||
continue
|
||||
custom_filter_q = None
|
||||
spec_attr_filter = getattr(to_model, "get_filter_{}_attr_q".format(rule['name']), None)
|
||||
if spec_attr_filter:
|
||||
custom_filter_q = spec_attr_filter(rule_value, rule_match)
|
||||
elif custom_attr_filter:
|
||||
custom_filter_q = custom_attr_filter(rule['name'], rule_value, rule_match)
|
||||
if custom_filter_q:
|
||||
custom_q &= custom_filter_q
|
||||
continue
|
||||
|
||||
if rule_match == 'in':
|
||||
res &= value in rule_value or '*' in rule_value
|
||||
|
@ -517,7 +525,6 @@ class JSONManyToManyDescriptor:
|
|||
res &= rule_value.issubset(value)
|
||||
else:
|
||||
res &= bool(value & rule_value)
|
||||
|
||||
else:
|
||||
logging.error("unknown match: {}".format(rule['match']))
|
||||
res &= False
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
@ -101,7 +102,11 @@ def run_debouncer_func(cache_key, org, ttl, func, *args, **kwargs):
|
|||
first_run_time = current
|
||||
|
||||
if current - first_run_time > ttl:
|
||||
_loop_debouncer_func_args_cache.pop(cache_key, None)
|
||||
_loop_debouncer_func_task_time_cache.pop(cache_key, None)
|
||||
executor.submit(run_func_partial, *args, **kwargs)
|
||||
logger.debug('pid {} executor submit run {}'.format(
|
||||
os.getpid(), func.__name__, ))
|
||||
return
|
||||
|
||||
loop = _loop_thread.get_loop()
|
||||
|
@ -133,13 +138,26 @@ class Debouncer(object):
|
|||
return await self.loop.run_in_executor(self.executor, func)
|
||||
|
||||
|
||||
ignore_err_exceptions = (
|
||||
"(3101, 'Plugin instructed the server to rollback the current transaction.')",
|
||||
)
|
||||
|
||||
|
||||
def _run_func_with_org(key, org, func, *args, **kwargs):
|
||||
from orgs.utils import set_current_org
|
||||
try:
|
||||
set_current_org(org)
|
||||
func(*args, **kwargs)
|
||||
with transaction.atomic():
|
||||
set_current_org(org)
|
||||
func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error('delay run error: %s' % e)
|
||||
msg = str(e)
|
||||
log_func = logger.error
|
||||
if msg in ignore_err_exceptions:
|
||||
log_func = logger.info
|
||||
pid = os.getpid()
|
||||
thread_name = threading.current_thread()
|
||||
log_func('pid {} thread {} delay run {} error: {}'.format(
|
||||
pid, thread_name, func.__name__, msg))
|
||||
_loop_debouncer_func_task_cache.pop(key, None)
|
||||
_loop_debouncer_func_args_cache.pop(key, None)
|
||||
_loop_debouncer_func_task_time_cache.pop(key, None)
|
||||
|
@ -181,6 +199,32 @@ def merge_delay_run(ttl=5, key=None):
|
|||
:return:
|
||||
"""
|
||||
|
||||
def delay(func, *args, **kwargs):
|
||||
from orgs.utils import get_current_org
|
||||
suffix_key_func = key if key else default_suffix_key
|
||||
org = get_current_org()
|
||||
func_name = f'{func.__module__}_{func.__name__}'
|
||||
key_suffix = suffix_key_func(*args, **kwargs)
|
||||
cache_key = f'MERGE_DELAY_RUN_{func_name}_{key_suffix}'
|
||||
cache_kwargs = _loop_debouncer_func_args_cache.get(cache_key, {})
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if not isinstance(v, (tuple, list, set)):
|
||||
raise ValueError('func kwargs value must be list or tuple: %s %s' % (func.__name__, v))
|
||||
v = set(v)
|
||||
if k not in cache_kwargs:
|
||||
cache_kwargs[k] = v
|
||||
else:
|
||||
cache_kwargs[k] = cache_kwargs[k].union(v)
|
||||
_loop_debouncer_func_args_cache[cache_key] = cache_kwargs
|
||||
run_debouncer_func(cache_key, org, ttl, func, *args, **cache_kwargs)
|
||||
|
||||
def apply(func, sync=False, *args, **kwargs):
|
||||
if sync:
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
return delay(func, *args, **kwargs)
|
||||
|
||||
def inner(func):
|
||||
sigs = inspect.signature(func)
|
||||
if len(sigs.parameters) != 1:
|
||||
|
@ -188,27 +232,12 @@ def merge_delay_run(ttl=5, key=None):
|
|||
param = list(sigs.parameters.values())[0]
|
||||
if not isinstance(param.default, tuple):
|
||||
raise ValueError('func default must be tuple: %s' % param.default)
|
||||
suffix_key_func = key if key else default_suffix_key
|
||||
func.delay = functools.partial(delay, func)
|
||||
func.apply = functools.partial(apply, func)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
from orgs.utils import get_current_org
|
||||
org = get_current_org()
|
||||
func_name = f'{func.__module__}_{func.__name__}'
|
||||
key_suffix = suffix_key_func(*args, **kwargs)
|
||||
cache_key = f'MERGE_DELAY_RUN_{func_name}_{key_suffix}'
|
||||
cache_kwargs = _loop_debouncer_func_args_cache.get(cache_key, {})
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if not isinstance(v, (tuple, list, set)):
|
||||
raise ValueError('func kwargs value must be list or tuple: %s %s' % (func.__name__, v))
|
||||
v = set(v)
|
||||
if k not in cache_kwargs:
|
||||
cache_kwargs[k] = v
|
||||
else:
|
||||
cache_kwargs[k] = cache_kwargs[k].union(v)
|
||||
_loop_debouncer_func_args_cache[cache_key] = cache_kwargs
|
||||
run_debouncer_func(cache_key, org, ttl, func, *args, **cache_kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import logging
|
|||
|
||||
from django.core.cache import cache
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db.models import Q, Count
|
||||
from django.db.models import Q
|
||||
from django_filters import rest_framework as drf_filters
|
||||
from rest_framework import filters
|
||||
from rest_framework.compat import coreapi, coreschema
|
||||
|
@ -180,36 +180,30 @@ class LabelFilterBackend(filters.BaseFilterBackend):
|
|||
]
|
||||
|
||||
@staticmethod
|
||||
def filter_resources(resources, labels_id):
|
||||
def parse_label_ids(labels_id):
|
||||
from labels.models import Label
|
||||
label_ids = [i.strip() for i in labels_id.split(',')]
|
||||
cleaned = []
|
||||
|
||||
args = []
|
||||
for label_id in label_ids:
|
||||
kwargs = {}
|
||||
if ':' in label_id:
|
||||
k, v = label_id.split(':', 1)
|
||||
kwargs['label__name'] = k.strip()
|
||||
kwargs['name'] = k.strip()
|
||||
if v != '*':
|
||||
kwargs['label__value'] = v.strip()
|
||||
kwargs['value'] = v.strip()
|
||||
args.append(kwargs)
|
||||
else:
|
||||
kwargs['label_id'] = label_id
|
||||
args.append(kwargs)
|
||||
cleaned.append(label_id)
|
||||
|
||||
if len(args) == 1:
|
||||
resources = resources.filter(**args[0])
|
||||
return resources
|
||||
|
||||
q = Q()
|
||||
for kwarg in args:
|
||||
q |= Q(**kwarg)
|
||||
|
||||
resources = resources.filter(q) \
|
||||
.values('res_id') \
|
||||
.order_by('res_id') \
|
||||
.annotate(count=Count('res_id')) \
|
||||
.values('res_id', 'count') \
|
||||
.filter(count=len(args))
|
||||
return resources
|
||||
if len(args) != 0:
|
||||
q = Q()
|
||||
for kwarg in args:
|
||||
q |= Q(**kwarg)
|
||||
ids = Label.objects.filter(q).values_list('id', flat=True)
|
||||
cleaned.extend(list(ids))
|
||||
return cleaned
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
labels_id = request.query_params.get('labels')
|
||||
|
@ -223,14 +217,15 @@ class LabelFilterBackend(filters.BaseFilterBackend):
|
|||
return queryset
|
||||
|
||||
model = queryset.model.label_model()
|
||||
labeled_resource_cls = model._labels.field.related_model
|
||||
labeled_resource_cls = model.labels.field.related_model
|
||||
app_label = model._meta.app_label
|
||||
model_name = model._meta.model_name
|
||||
|
||||
resources = labeled_resource_cls.objects.filter(
|
||||
res_type__app_label=app_label, res_type__model=model_name,
|
||||
)
|
||||
resources = self.filter_resources(resources, labels_id)
|
||||
label_ids = self.parse_label_ids(labels_id)
|
||||
resources = model.filter_resources_by_labels(resources, label_ids)
|
||||
res_ids = resources.values_list('res_id', flat=True)
|
||||
queryset = queryset.filter(id__in=set(res_ids))
|
||||
return queryset
|
||||
|
|
|
@ -14,6 +14,7 @@ class CeleryBaseService(BaseService):
|
|||
print('\n- Start Celery as Distributed Task Queue: {}'.format(self.queue.capitalize()))
|
||||
ansible_config_path = os.path.join(settings.APPS_DIR, 'ops', 'ansible', 'ansible.cfg')
|
||||
ansible_modules_path = os.path.join(settings.APPS_DIR, 'ops', 'ansible', 'modules')
|
||||
os.environ.setdefault('LC_ALL', 'C.UTF-8')
|
||||
os.environ.setdefault('PYTHONOPTIMIZE', '1')
|
||||
os.environ.setdefault('ANSIBLE_FORCE_COLOR', 'True')
|
||||
os.environ.setdefault('ANSIBLE_CONFIG', ansible_config_path)
|
||||
|
|
|
@ -28,9 +28,10 @@ class ErrorCode:
|
|||
|
||||
|
||||
class URL:
|
||||
QR_CONNECT = 'https://oapi.dingtalk.com/connect/qrconnect'
|
||||
QR_CONNECT = 'https://login.dingtalk.com/oauth2/auth'
|
||||
OAUTH_CONNECT = 'https://oapi.dingtalk.com/connect/oauth2/sns_authorize'
|
||||
GET_USER_INFO_BY_CODE = 'https://oapi.dingtalk.com/sns/getuserinfo_bycode'
|
||||
GET_USER_ACCESSTOKEN = 'https://api.dingtalk.com/v1.0/oauth2/userAccessToken'
|
||||
GET_USER_INFO = 'https://api.dingtalk.com/v1.0/contact/users/me'
|
||||
GET_TOKEN = 'https://oapi.dingtalk.com/gettoken'
|
||||
SEND_MESSAGE_BY_TEMPLATE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/sendbytemplate'
|
||||
SEND_MESSAGE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/asyncsend_v2'
|
||||
|
@ -72,8 +73,9 @@ class DingTalkRequests(BaseRequest):
|
|||
def get(self, url, params=None,
|
||||
with_token=False, with_sign=False,
|
||||
check_errcode_is_0=True,
|
||||
**kwargs):
|
||||
**kwargs) -> dict:
|
||||
pass
|
||||
|
||||
get = as_request(get)
|
||||
|
||||
def post(self, url, json=None, params=None,
|
||||
|
@ -81,6 +83,7 @@ class DingTalkRequests(BaseRequest):
|
|||
check_errcode_is_0=True,
|
||||
**kwargs) -> dict:
|
||||
pass
|
||||
|
||||
post = as_request(post)
|
||||
|
||||
def _add_sign(self, kwargs: dict):
|
||||
|
@ -123,17 +126,22 @@ class DingTalk:
|
|||
)
|
||||
|
||||
def get_userinfo_bycode(self, code):
|
||||
# https://developers.dingtalk.com/document/app/obtain-the-user-information-based-on-the-sns-temporary-authorization?spm=ding_open_doc.document.0.0.3a256573y8Y7yg#topic-1995619
|
||||
body = {
|
||||
"tmp_auth_code": code
|
||||
'clientId': self._appid,
|
||||
'clientSecret': self._appsecret,
|
||||
'code': code,
|
||||
'grantType': 'authorization_code'
|
||||
}
|
||||
data = self._request.post(URL.GET_USER_ACCESSTOKEN, json=body, check_errcode_is_0=False)
|
||||
token = data['accessToken']
|
||||
|
||||
data = self._request.post(URL.GET_USER_INFO_BY_CODE, json=body, with_sign=True)
|
||||
return data['user_info']
|
||||
user = self._request.get(URL.GET_USER_INFO,
|
||||
headers={'x-acs-dingtalk-access-token': token}, check_errcode_is_0=False)
|
||||
return user
|
||||
|
||||
def get_user_id_by_code(self, code):
|
||||
user_info = self.get_userinfo_bycode(code)
|
||||
unionid = user_info['unionid']
|
||||
unionid = user_info['unionId']
|
||||
userid = self.get_userid_by_unionid(unionid)
|
||||
return userid, None
|
||||
|
||||
|
|
|
@ -394,20 +394,20 @@ class CommonBulkModelSerializer(CommonBulkSerializerMixin, serializers.ModelSeri
|
|||
|
||||
|
||||
class ResourceLabelsMixin(serializers.Serializer):
|
||||
labels = LabelRelatedField(many=True, label=_('Labels'), required=False, allow_null=True)
|
||||
labels = LabelRelatedField(many=True, label=_('Labels'), required=False, allow_null=True, source='res_labels')
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
labels = validated_data.pop('labels', None)
|
||||
labels = validated_data.pop('res_labels', None)
|
||||
res = super().update(instance, validated_data)
|
||||
if labels is not None:
|
||||
instance.labels.set(labels, bulk=False)
|
||||
instance.res_labels.set(labels, bulk=False)
|
||||
return res
|
||||
|
||||
def create(self, validated_data):
|
||||
labels = validated_data.pop('labels', None)
|
||||
labels = validated_data.pop('res_labels', None)
|
||||
instance = super().create(validated_data)
|
||||
if labels is not None:
|
||||
instance.labels.set(labels, bulk=False)
|
||||
instance.res_labels.set(labels, bulk=False)
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
import re
|
||||
|
||||
from django.contrib.sessions.backends.cache import (
|
||||
SessionStore as DjangoSessionStore
|
||||
)
|
||||
from django.core.cache import cache
|
||||
|
||||
from jumpserver.utils import get_current_request
|
||||
|
||||
|
||||
class SessionStore(DjangoSessionStore):
|
||||
ignore_urls = [
|
||||
r'^/api/v1/users/profile/'
|
||||
]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_pattern = re.compile('|'.join(self.ignore_urls))
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
request = get_current_request()
|
||||
if request is None or not self.ignore_pattern.match(request.path):
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class RedisUserSessionManager:
|
||||
JMS_SESSION_KEY = 'jms_session_key'
|
||||
|
||||
def __init__(self):
|
||||
self.client = cache.client.get_client()
|
||||
|
||||
def add_or_increment(self, session_key):
|
||||
self.client.hincrby(self.JMS_SESSION_KEY, session_key, 1)
|
||||
|
||||
def decrement_or_remove(self, session_key):
|
||||
new_count = self.client.hincrby(self.JMS_SESSION_KEY, session_key, -1)
|
||||
if new_count <= 0:
|
||||
self.client.hdel(self.JMS_SESSION_KEY, session_key)
|
||||
|
||||
def check_active(self, session_key):
|
||||
count = self.client.hget(self.JMS_SESSION_KEY, session_key)
|
||||
count = 0 if count is None else int(count.decode('utf-8'))
|
||||
return count > 0
|
||||
|
||||
def get_active_keys(self):
|
||||
session_keys = []
|
||||
for k, v in self.client.hgetall(self.JMS_SESSION_KEY).items():
|
||||
count = int(v.decode('utf-8'))
|
||||
if count <= 0:
|
||||
continue
|
||||
key = k.decode('utf-8')
|
||||
session_keys.append(key)
|
||||
return session_keys
|
||||
|
||||
|
||||
user_session_manager = RedisUserSessionManager()
|
|
@ -69,7 +69,7 @@ def digest_sql_query():
|
|||
|
||||
for query in queries:
|
||||
sql = query['sql']
|
||||
print(" # {}: {}".format(query['time'], sql[:1000]))
|
||||
print(" # {}: {}".format(query['time'], sql[:1000]))
|
||||
if len(queries) < 3:
|
||||
continue
|
||||
print("- Table: {}".format(table_name))
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -282,6 +282,7 @@ class Config(dict):
|
|||
'AUTH_LDAP_SYNC_INTERVAL': None,
|
||||
'AUTH_LDAP_SYNC_CRONTAB': None,
|
||||
'AUTH_LDAP_SYNC_ORG_IDS': ['00000000-0000-0000-0000-000000000002'],
|
||||
'AUTH_LDAP_SYNC_RECEIVERS': [],
|
||||
'AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS': False,
|
||||
'AUTH_LDAP_OPTIONS_OPT_REFERRALS': -1,
|
||||
|
||||
|
@ -546,7 +547,6 @@ class Config(dict):
|
|||
'REFERER_CHECK_ENABLED': False,
|
||||
'SESSION_ENGINE': 'cache',
|
||||
'SESSION_SAVE_EVERY_REQUEST': True,
|
||||
'SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE': False,
|
||||
'SERVER_REPLAY_STORAGE': {},
|
||||
'SECURITY_DATA_CRYPTO_ALGO': None,
|
||||
'GMSSL_ENABLED': False,
|
||||
|
@ -605,7 +605,9 @@ class Config(dict):
|
|||
'GPT_MODEL': 'gpt-3.5-turbo',
|
||||
'VIRTUAL_APP_ENABLED': False,
|
||||
|
||||
'FILE_UPLOAD_SIZE_LIMIT_MB': 200
|
||||
'FILE_UPLOAD_SIZE_LIMIT_MB': 200,
|
||||
|
||||
'TICKET_APPLY_ASSET_SCOPE': 'all'
|
||||
}
|
||||
|
||||
old_config_map = {
|
||||
|
|
|
@ -66,11 +66,6 @@ class RequestMiddleware:
|
|||
def __call__(self, request):
|
||||
set_current_request(request)
|
||||
response = self.get_response(request)
|
||||
is_request_api = request.path.startswith('/api')
|
||||
if not settings.SESSION_EXPIRE_AT_BROWSER_CLOSE and \
|
||||
not is_request_api:
|
||||
age = request.session.get_expiry_age()
|
||||
request.session.set_expiry(age)
|
||||
return response
|
||||
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
path_perms_map = {
|
||||
'xpack': '*',
|
||||
'settings': '*',
|
||||
'img': '*',
|
||||
'replay': 'default',
|
||||
'applets': 'terminal.view_applet',
|
||||
'virtual_apps': 'terminal.view_virtualapp',
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
from private_storage.servers import NginxXAccelRedirectServer, DjangoServer
|
||||
|
||||
|
||||
class StaticFileServer(object):
|
||||
|
||||
@staticmethod
|
||||
def serve(private_file):
|
||||
full_path = private_file.full_path
|
||||
# todo: gzip 文件录像 nginx 处理后,浏览器无法正常解析内容
|
||||
# 造成在线播放失败,暂时仅使用 nginx 处理 mp4 录像文件
|
||||
if full_path.endswith('.mp4'):
|
||||
return NginxXAccelRedirectServer.serve(private_file)
|
||||
else:
|
||||
return DjangoServer.serve(private_file)
|
|
@ -50,6 +50,7 @@ AUTH_LDAP_SYNC_IS_PERIODIC = CONFIG.AUTH_LDAP_SYNC_IS_PERIODIC
|
|||
AUTH_LDAP_SYNC_INTERVAL = CONFIG.AUTH_LDAP_SYNC_INTERVAL
|
||||
AUTH_LDAP_SYNC_CRONTAB = CONFIG.AUTH_LDAP_SYNC_CRONTAB
|
||||
AUTH_LDAP_SYNC_ORG_IDS = CONFIG.AUTH_LDAP_SYNC_ORG_IDS
|
||||
AUTH_LDAP_SYNC_RECEIVERS = CONFIG.AUTH_LDAP_SYNC_RECEIVERS
|
||||
AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS = CONFIG.AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS
|
||||
|
||||
# ==============================================================================
|
||||
|
|
|
@ -234,11 +234,9 @@ CSRF_COOKIE_NAME = '{}csrftoken'.format(SESSION_COOKIE_NAME_PREFIX)
|
|||
SESSION_COOKIE_NAME = '{}sessionid'.format(SESSION_COOKIE_NAME_PREFIX)
|
||||
|
||||
SESSION_COOKIE_AGE = CONFIG.SESSION_COOKIE_AGE
|
||||
SESSION_EXPIRE_AT_BROWSER_CLOSE = True
|
||||
# 自定义的配置,SESSION_EXPIRE_AT_BROWSER_CLOSE 始终为 True, 下面这个来控制是否强制关闭后过期 cookie
|
||||
SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE = CONFIG.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE
|
||||
SESSION_SAVE_EVERY_REQUEST = CONFIG.SESSION_SAVE_EVERY_REQUEST
|
||||
SESSION_ENGINE = "django.contrib.sessions.backends.{}".format(CONFIG.SESSION_ENGINE)
|
||||
SESSION_EXPIRE_AT_BROWSER_CLOSE = CONFIG.SESSION_EXPIRE_AT_BROWSER_CLOSE
|
||||
SESSION_ENGINE = "common.sessions.{}".format(CONFIG.SESSION_ENGINE)
|
||||
|
||||
MESSAGE_STORAGE = 'django.contrib.messages.storage.cookie.CookieStorage'
|
||||
# Database
|
||||
|
@ -319,9 +317,7 @@ MEDIA_ROOT = os.path.join(PROJECT_DIR, 'data', 'media').replace('\\', '/') + '/'
|
|||
PRIVATE_STORAGE_ROOT = MEDIA_ROOT
|
||||
PRIVATE_STORAGE_AUTH_FUNCTION = 'jumpserver.rewriting.storage.permissions.allow_access'
|
||||
PRIVATE_STORAGE_INTERNAL_URL = '/private-media/'
|
||||
PRIVATE_STORAGE_SERVER = 'nginx'
|
||||
if DEBUG_DEV:
|
||||
PRIVATE_STORAGE_SERVER = 'django'
|
||||
PRIVATE_STORAGE_SERVER = 'jumpserver.rewriting.storage.servers.StaticFileServer'
|
||||
|
||||
# Use django-bootstrap-form to format template, input max width arg
|
||||
# BOOTSTRAP_COLUMN_COUNT = 11
|
||||
|
|
|
@ -214,6 +214,9 @@ PERM_TREE_REGEN_INTERVAL = CONFIG.PERM_TREE_REGEN_INTERVAL
|
|||
MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS
|
||||
LIMIT_SUPER_PRIV = CONFIG.LIMIT_SUPER_PRIV
|
||||
|
||||
# Asset account may be too many
|
||||
ASSET_SIZE = 'small'
|
||||
|
||||
# Chat AI
|
||||
CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED
|
||||
GPT_API_KEY = CONFIG.GPT_API_KEY
|
||||
|
@ -224,3 +227,5 @@ GPT_MODEL = CONFIG.GPT_MODEL
|
|||
VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED
|
||||
|
||||
FILE_UPLOAD_SIZE_LIMIT_MB = CONFIG.FILE_UPLOAD_SIZE_LIMIT_MB
|
||||
|
||||
TICKET_APPLY_ASSET_SCOPE = CONFIG.TICKET_APPLY_ASSET_SCOPE
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
from django.contrib.contenttypes.fields import GenericRelation
|
||||
from django.db import models
|
||||
from django.db.models import OneToOneField
|
||||
from django.db.models import OneToOneField, Count
|
||||
|
||||
from common.utils import lazyproperty
|
||||
from .models import LabeledResource
|
||||
|
||||
__all__ = ['LabeledMixin']
|
||||
|
||||
|
||||
class LabeledMixin(models.Model):
|
||||
_labels = GenericRelation(LabeledResource, object_id_field='res_id', content_type_field='res_type')
|
||||
labels = GenericRelation(LabeledResource, object_id_field='res_id', content_type_field='res_type')
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
@ -21,7 +22,7 @@ class LabeledMixin(models.Model):
|
|||
model = pk_field.related_model
|
||||
return model
|
||||
|
||||
@property
|
||||
@lazyproperty
|
||||
def real(self):
|
||||
pk_field = self._meta.pk
|
||||
if isinstance(pk_field, OneToOneField):
|
||||
|
@ -29,9 +30,43 @@ class LabeledMixin(models.Model):
|
|||
return self
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
return self.real._labels
|
||||
def res_labels(self):
|
||||
return self.real.labels
|
||||
|
||||
@labels.setter
|
||||
def labels(self, value):
|
||||
self.real._labels.set(value, bulk=False)
|
||||
@res_labels.setter
|
||||
def res_labels(self, value):
|
||||
self.real.labels.set(value, bulk=False)
|
||||
|
||||
@classmethod
|
||||
def filter_resources_by_labels(cls, resources, label_ids):
|
||||
return cls._get_filter_res_by_labels_m2m_all(resources, label_ids)
|
||||
|
||||
@classmethod
|
||||
def _get_filter_res_by_labels_m2m_in(cls, resources, label_ids):
|
||||
return resources.filter(label_id__in=label_ids)
|
||||
|
||||
@classmethod
|
||||
def _get_filter_res_by_labels_m2m_all(cls, resources, label_ids):
|
||||
if len(label_ids) == 1:
|
||||
return cls._get_filter_res_by_labels_m2m_in(resources, label_ids)
|
||||
|
||||
resources = resources.filter(label_id__in=label_ids) \
|
||||
.values('res_id') \
|
||||
.order_by('res_id') \
|
||||
.annotate(count=Count('res_id', distinct=True)) \
|
||||
.values('res_id', 'count') \
|
||||
.filter(count=len(label_ids))
|
||||
return resources
|
||||
|
||||
@classmethod
|
||||
def get_labels_filter_attr_q(cls, value, match):
|
||||
resources = LabeledResource.objects.all()
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if match != 'm2m_all':
|
||||
resources = cls._get_filter_res_by_labels_m2m_in(resources, value)
|
||||
else:
|
||||
resources = cls._get_filter_res_by_labels_m2m_all(resources, value)
|
||||
res_ids = set(resources.values_list('res_id', flat=True))
|
||||
return models.Q(id__in=res_ids)
|
||||
|
|
|
@ -34,7 +34,7 @@ class LabelSerializer(BulkOrgResourceModelSerializer):
|
|||
@classmethod
|
||||
def setup_eager_loading(cls, queryset):
|
||||
""" Perform necessary eager loading of data. """
|
||||
queryset = queryset.annotate(res_count=Count('labeled_resources'))
|
||||
queryset = queryset.annotate(res_count=Count('labeled_resources', distinct=True))
|
||||
return queryset
|
||||
|
||||
|
||||
|
|
|
@ -1,28 +1,32 @@
|
|||
import json
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
from channels.generic.websocket import JsonWebsocketConsumer
|
||||
from django.core.cache import cache
|
||||
from django.conf import settings
|
||||
|
||||
from common.db.utils import safe_db_connection
|
||||
from common.sessions.cache import user_session_manager
|
||||
from common.utils import get_logger
|
||||
from .signal_handlers import new_site_msg_chan
|
||||
from .site_msg import SiteMessageUtil
|
||||
|
||||
logger = get_logger(__name__)
|
||||
WS_SESSION_KEY = 'ws_session_key'
|
||||
|
||||
|
||||
class SiteMsgWebsocket(JsonWebsocketConsumer):
|
||||
sub = None
|
||||
refresh_every_seconds = 10
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self.scope['session']
|
||||
|
||||
def connect(self):
|
||||
user = self.scope["user"]
|
||||
if user.is_authenticated:
|
||||
self.accept()
|
||||
session = self.scope['session']
|
||||
redis_client = cache.client.get_client()
|
||||
redis_client.sadd(WS_SESSION_KEY, session.session_key)
|
||||
user_session_manager.add_or_increment(self.session.session_key)
|
||||
self.sub = self.watch_recv_new_site_msg()
|
||||
else:
|
||||
self.close()
|
||||
|
@ -66,6 +70,32 @@ class SiteMsgWebsocket(JsonWebsocketConsumer):
|
|||
if not self.sub:
|
||||
return
|
||||
self.sub.unsubscribe()
|
||||
session = self.scope['session']
|
||||
redis_client = cache.client.get_client()
|
||||
redis_client.srem(WS_SESSION_KEY, session.session_key)
|
||||
|
||||
user_session_manager.decrement_or_remove(self.session.session_key)
|
||||
if self.should_delete_session():
|
||||
thread = Thread(target=self.delay_delete_session)
|
||||
thread.start()
|
||||
|
||||
def should_delete_session(self):
|
||||
return (self.session.modified or settings.SESSION_SAVE_EVERY_REQUEST) and \
|
||||
not self.session.is_empty() and \
|
||||
self.session.get_expire_at_browser_close() and \
|
||||
not user_session_manager.check_active(self.session.session_key)
|
||||
|
||||
def delay_delete_session(self):
|
||||
timeout = 3
|
||||
check_interval = 0.5
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
time.sleep(check_interval)
|
||||
if user_session_manager.check_active(self.session.session_key):
|
||||
return
|
||||
|
||||
self.delete_session()
|
||||
|
||||
def delete_session(self):
|
||||
try:
|
||||
self.session.delete()
|
||||
except Exception as e:
|
||||
logger.info(f'delete session error: {e}')
|
||||
|
|
|
@ -4,6 +4,21 @@ import time
|
|||
import paramiko
|
||||
from sshtunnel import SSHTunnelForwarder
|
||||
|
||||
from packaging import version
|
||||
|
||||
if version.parse(paramiko.__version__) > version.parse("2.8.1"):
|
||||
_preferred_pubkeys = (
|
||||
"ssh-ed25519",
|
||||
"ecdsa-sha2-nistp256",
|
||||
"ecdsa-sha2-nistp384",
|
||||
"ecdsa-sha2-nistp521",
|
||||
"ssh-rsa",
|
||||
"rsa-sha2-256",
|
||||
"rsa-sha2-512",
|
||||
"ssh-dss",
|
||||
)
|
||||
paramiko.transport.Transport._preferred_pubkeys = _preferred_pubkeys
|
||||
|
||||
|
||||
def common_argument_spec():
|
||||
options = dict(
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from django.shortcuts import get_object_or_404
|
||||
|
@ -166,16 +167,58 @@ class CeleryTaskViewSet(
|
|||
i.next_exec_time = now + next_run_at
|
||||
return queryset
|
||||
|
||||
def generate_summary_state(self, execution_qs):
|
||||
model = self.get_queryset().model
|
||||
executions = execution_qs.order_by('-date_published').values('name', 'state')
|
||||
summary_state_dict = defaultdict(
|
||||
lambda: {
|
||||
'states': [], 'state': 'green',
|
||||
'summary': {'total': 0, 'success': 0}
|
||||
}
|
||||
)
|
||||
for execution in executions:
|
||||
name = execution['name']
|
||||
state = execution['state']
|
||||
|
||||
summary = summary_state_dict[name]['summary']
|
||||
|
||||
summary['total'] += 1
|
||||
summary['success'] += 1 if state == 'SUCCESS' else 0
|
||||
|
||||
states = summary_state_dict[name].get('states')
|
||||
if states is not None and len(states) >= 5:
|
||||
color = model.compute_state_color(states)
|
||||
summary_state_dict[name]['state'] = color
|
||||
summary_state_dict[name].pop('states', None)
|
||||
elif isinstance(states, list):
|
||||
states.append(state)
|
||||
|
||||
return summary_state_dict
|
||||
|
||||
def loading_summary_state(self, queryset):
|
||||
if isinstance(queryset, list):
|
||||
names = [i.name for i in queryset]
|
||||
execution_qs = CeleryTaskExecution.objects.filter(name__in=names)
|
||||
else:
|
||||
execution_qs = CeleryTaskExecution.objects.all()
|
||||
summary_state_dict = self.generate_summary_state(execution_qs)
|
||||
for i in queryset:
|
||||
i.summary = summary_state_dict.get(i.name, {}).get('summary', {})
|
||||
i.state = summary_state_dict.get(i.name, {}).get('state', 'green')
|
||||
return queryset
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
page = self.paginate_queryset(queryset)
|
||||
if page is not None:
|
||||
page = self.generate_execute_time(page)
|
||||
page = self.loading_summary_state(page)
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
queryset = self.generate_execute_time(queryset)
|
||||
queryset = self.loading_summary_state(queryset)
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
|
|
@ -246,6 +246,6 @@ class UsernameHintsAPI(APIView):
|
|||
.filter(username__icontains=query) \
|
||||
.filter(asset__in=assets) \
|
||||
.values('username') \
|
||||
.annotate(total=Count('username')) \
|
||||
.annotate(total=Count('username', distinct=True)) \
|
||||
.order_by('total', '-username')[:10]
|
||||
return Response(data=top_accounts)
|
||||
|
|
|
@ -15,6 +15,9 @@ class CeleryTask(models.Model):
|
|||
name = models.CharField(max_length=1024, verbose_name=_('Name'))
|
||||
date_last_publish = models.DateTimeField(null=True, verbose_name=_("Date last publish"))
|
||||
|
||||
__summary = None
|
||||
__state = None
|
||||
|
||||
@property
|
||||
def meta(self):
|
||||
task = app.tasks.get(self.name, None)
|
||||
|
@ -25,25 +28,43 @@ class CeleryTask(models.Model):
|
|||
|
||||
@property
|
||||
def summary(self):
|
||||
if self.__summary is not None:
|
||||
return self.__summary
|
||||
executions = CeleryTaskExecution.objects.filter(name=self.name)
|
||||
total = executions.count()
|
||||
success = executions.filter(state='SUCCESS').count()
|
||||
return {'total': total, 'success': success}
|
||||
|
||||
@summary.setter
|
||||
def summary(self, value):
|
||||
self.__summary = value
|
||||
|
||||
@staticmethod
|
||||
def compute_state_color(states: list, default_count=5):
|
||||
color = 'green'
|
||||
states = states[:default_count]
|
||||
if not states:
|
||||
return color
|
||||
if states[0] == 'FAILURE':
|
||||
color = 'red'
|
||||
elif 'FAILURE' in states:
|
||||
color = 'yellow'
|
||||
return color
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
last_five_executions = CeleryTaskExecution.objects \
|
||||
.filter(name=self.name) \
|
||||
.order_by('-date_published')[:5]
|
||||
if self.__state is not None:
|
||||
return self.__state
|
||||
last_five_executions = CeleryTaskExecution.objects.filter(
|
||||
name=self.name
|
||||
).order_by('-date_published').values('state')[:5]
|
||||
states = [i['state'] for i in last_five_executions]
|
||||
color = self.compute_state_color(states)
|
||||
return color
|
||||
|
||||
if len(last_five_executions) > 0:
|
||||
if last_five_executions[0].state == 'FAILURE':
|
||||
return "red"
|
||||
|
||||
for execution in last_five_executions:
|
||||
if execution.state == 'FAILURE':
|
||||
return "yellow"
|
||||
return "green"
|
||||
@state.setter
|
||||
def state(self, value):
|
||||
self.__state = value
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Celery Task")
|
||||
|
|
|
@ -67,6 +67,7 @@ class JMSPermedInventory(JMSInventory):
|
|||
'postgresql': ['postgresql'],
|
||||
'sqlserver': ['sqlserver'],
|
||||
'ssh': ['shell', 'python', 'win_shell', 'raw'],
|
||||
'winrm': ['win_shell', 'shell'],
|
||||
}
|
||||
|
||||
if self.module not in protocol_supported_modules_mapping.get(protocol.name, []):
|
||||
|
|
|
@ -87,7 +87,8 @@ class OrgResourceStatisticsRefreshUtil:
|
|||
if not cache_field_name:
|
||||
return
|
||||
org = getattr(instance, 'org', None)
|
||||
cls.refresh_org_fields(((org, cache_field_name),))
|
||||
cache_field_name = tuple(cache_field_name)
|
||||
cls.refresh_org_fields.delay(org_fields=((org, cache_field_name),))
|
||||
|
||||
|
||||
@receiver(post_save)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import abc
|
||||
|
||||
from django.conf import settings
|
||||
from rest_framework.generics import ListAPIView, RetrieveAPIView
|
||||
|
||||
from assets.api.asset.asset import AssetFilterSet
|
||||
|
@ -7,8 +8,7 @@ from assets.models import Asset, Node
|
|||
from common.utils import get_logger, lazyproperty, is_uuid
|
||||
from orgs.utils import tmp_to_root_org
|
||||
from perms import serializers
|
||||
from perms.pagination import AllPermedAssetPagination
|
||||
from perms.pagination import NodePermedAssetPagination
|
||||
from perms.pagination import NodePermedAssetPagination, AllPermedAssetPagination
|
||||
from perms.utils import UserPermAssetUtil, PermAssetDetailUtil
|
||||
from .mixin import (
|
||||
SelfOrPKUserMixin
|
||||
|
@ -39,7 +39,7 @@ class UserPermedAssetRetrieveApi(SelfOrPKUserMixin, RetrieveAPIView):
|
|||
|
||||
|
||||
class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView):
|
||||
ordering = ('name',)
|
||||
ordering = []
|
||||
search_fields = ('name', 'address', 'comment')
|
||||
ordering_fields = ("name", "address")
|
||||
filterset_class = AssetFilterSet
|
||||
|
@ -48,6 +48,8 @@ class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView):
|
|||
def get_queryset(self):
|
||||
if getattr(self, 'swagger_fake_view', False):
|
||||
return Asset.objects.none()
|
||||
if settings.ASSET_SIZE == 'small':
|
||||
self.ordering = ['name']
|
||||
assets = self.get_assets()
|
||||
assets = self.serializer_class.setup_eager_loading(assets)
|
||||
return assets
|
||||
|
|
|
@ -14,6 +14,7 @@ from assets.api import SerializeToTreeNodeMixin
|
|||
from assets.models import Asset
|
||||
from assets.utils import KubernetesTree
|
||||
from authentication.models import ConnectionToken
|
||||
from common.exceptions import JMSException
|
||||
from common.utils import get_object_or_none, lazyproperty
|
||||
from common.utils.common import timeit
|
||||
from perms.hands import Node
|
||||
|
@ -181,6 +182,8 @@ class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi(BaseUserNodeWithAssetAsT
|
|||
return self.query_asset_util.get_all_assets()
|
||||
|
||||
def _get_tree_nodes_async(self):
|
||||
if self.request.query_params.get('lv') == '0':
|
||||
return [], []
|
||||
if not self.tp or not all(self.tp):
|
||||
nodes = UserPermAssetUtil.get_type_nodes_tree_or_cached(self.user)
|
||||
return nodes, []
|
||||
|
@ -262,5 +265,8 @@ class UserGrantedK8sAsTreeApi(SelfOrPKUserMixin, ListAPIView):
|
|||
if not any([namespace, pod]) and not key:
|
||||
asset_node = k8s_tree_instance.as_asset_tree_node()
|
||||
tree.append(asset_node)
|
||||
tree.extend(k8s_tree_instance.async_tree_node(namespace, pod))
|
||||
return Response(data=tree)
|
||||
try:
|
||||
tree.extend(k8s_tree_instance.async_tree_node(namespace, pod))
|
||||
return Response(data=tree)
|
||||
except Exception as e:
|
||||
raise JMSException(e)
|
||||
|
|
|
@ -130,7 +130,7 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel):
|
|||
qs1_ids = User.objects.filter(id__in=user_ids).distinct().values_list('id', flat=True)
|
||||
qs2_ids = User.objects.filter(groups__id__in=group_ids).distinct().values_list('id', flat=True)
|
||||
qs_ids = list(qs1_ids) + list(qs2_ids)
|
||||
qs = User.objects.filter(id__in=qs_ids)
|
||||
qs = User.objects.filter(id__in=qs_ids, is_service_account=False)
|
||||
return qs
|
||||
|
||||
def get_all_assets(self, flat=False):
|
||||
|
|
|
@ -9,7 +9,7 @@ class PermedAssetsWillExpireUserMsg(UserMessage):
|
|||
def __init__(self, user, assets, day_count=0):
|
||||
super().__init__(user)
|
||||
self.assets = assets
|
||||
self.day_count = _('today') if day_count == 0 else day_count + _('day')
|
||||
self.day_count = _('today') if day_count == 0 else str(day_count) + _('day')
|
||||
|
||||
def get_html_msg(self) -> dict:
|
||||
subject = _("You permed assets is about to expire")
|
||||
|
@ -41,7 +41,7 @@ class AssetPermsWillExpireForOrgAdminMsg(UserMessage):
|
|||
super().__init__(user)
|
||||
self.perms = perms
|
||||
self.org = org
|
||||
self.day_count = _('today') if day_count == 0 else day_count + _('day')
|
||||
self.day_count = _('today') if day_count == 0 else str(day_count) + _('day')
|
||||
|
||||
def get_items_with_url(self):
|
||||
items_with_url = []
|
||||
|
|
|
@ -197,9 +197,9 @@ class AssetPermissionListSerializer(AssetPermissionSerializer):
|
|||
"""Perform necessary eager loading of data."""
|
||||
queryset = queryset \
|
||||
.prefetch_related('labels', 'labels__label') \
|
||||
.annotate(users_amount=Count("users"),
|
||||
user_groups_amount=Count("user_groups"),
|
||||
assets_amount=Count("assets"),
|
||||
nodes_amount=Count("nodes"),
|
||||
.annotate(users_amount=Count("users", distinct=True),
|
||||
user_groups_amount=Count("user_groups", distinct=True),
|
||||
assets_amount=Count("assets", distinct=True),
|
||||
nodes_amount=Count("nodes", distinct=True),
|
||||
)
|
||||
return queryset
|
||||
|
|
|
@ -8,9 +8,9 @@ from rest_framework import serializers
|
|||
from accounts.models import Account
|
||||
from assets.const import Category, AllTypes
|
||||
from assets.models import Node, Asset, Platform
|
||||
from assets.serializers.asset.common import AssetLabelSerializer, AssetProtocolsPermsSerializer
|
||||
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
|
||||
from assets.serializers.asset.common import AssetProtocolsPermsSerializer
|
||||
from common.serializers import ResourceLabelsMixin
|
||||
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
|
||||
from orgs.mixins.serializers import OrgResourceModelSerializerMixin
|
||||
from perms.serializers.permission import ActionChoicesField
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ class AssetPermissionUtil(object):
|
|||
""" 资产授权相关的方法工具 """
|
||||
|
||||
@timeit
|
||||
def get_permissions_for_user(self, user, with_group=True, flat=False):
|
||||
def get_permissions_for_user(self, user, with_group=True, flat=False, with_expired=False):
|
||||
""" 获取用户的授权规则 """
|
||||
perm_ids = set()
|
||||
# user
|
||||
|
@ -25,7 +25,7 @@ class AssetPermissionUtil(object):
|
|||
groups = user.groups.all()
|
||||
group_perm_ids = self.get_permissions_for_user_groups(groups, flat=True)
|
||||
perm_ids.update(group_perm_ids)
|
||||
perms = self.get_permissions(ids=perm_ids)
|
||||
perms = self.get_permissions(ids=perm_ids, with_expired=with_expired)
|
||||
if flat:
|
||||
return perms.values_list('id', flat=True)
|
||||
return perms
|
||||
|
@ -102,6 +102,8 @@ class AssetPermissionUtil(object):
|
|||
return model.objects.filter(id__in=ids)
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(ids):
|
||||
perms = AssetPermission.objects.filter(id__in=ids).valid().order_by('-date_expired')
|
||||
return perms
|
||||
def get_permissions(ids, with_expired=False):
|
||||
perms = AssetPermission.objects.filter(id__in=ids)
|
||||
if not with_expired:
|
||||
perms = perms.valid()
|
||||
return perms.order_by('-date_expired')
|
||||
|
|
|
@ -7,10 +7,10 @@ from django.db.models import Q
|
|||
from rest_framework.utils.encoders import JSONEncoder
|
||||
|
||||
from assets.const import AllTypes
|
||||
from assets.models import FavoriteAsset, Asset
|
||||
from assets.models import FavoriteAsset, Asset, Node
|
||||
from common.utils.common import timeit, get_logger
|
||||
from orgs.utils import current_org, tmp_to_root_org
|
||||
from perms.models import PermNode, UserAssetGrantedTreeNodeRelation
|
||||
from perms.models import PermNode, UserAssetGrantedTreeNodeRelation, AssetPermission
|
||||
from .permission import AssetPermissionUtil
|
||||
|
||||
__all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil']
|
||||
|
@ -21,36 +21,37 @@ logger = get_logger(__name__)
|
|||
class AssetPermissionPermAssetUtil:
|
||||
|
||||
def __init__(self, perm_ids):
|
||||
self.perm_ids = perm_ids
|
||||
self.perm_ids = set(perm_ids)
|
||||
|
||||
def get_all_assets(self):
|
||||
""" 获取所有授权的资产 """
|
||||
node_assets = self.get_perm_nodes_assets()
|
||||
direct_assets = self.get_direct_assets()
|
||||
# 比原来的查到所有 asset id 再搜索块很多,因为当资产量大的时候,搜索会很慢
|
||||
return (node_assets | direct_assets).distinct()
|
||||
return (node_assets | direct_assets).order_by().distinct()
|
||||
|
||||
def get_perm_nodes(self):
|
||||
""" 获取所有授权节点 """
|
||||
nodes_ids = AssetPermission.objects \
|
||||
.filter(id__in=self.perm_ids) \
|
||||
.values_list('nodes', flat=True)
|
||||
nodes_ids = set(nodes_ids)
|
||||
nodes = Node.objects.filter(id__in=nodes_ids).only('id', 'key')
|
||||
return nodes
|
||||
|
||||
@timeit
|
||||
def get_perm_nodes_assets(self, flat=False):
|
||||
def get_perm_nodes_assets(self):
|
||||
""" 获取所有授权节点下的资产 """
|
||||
from assets.models import Node
|
||||
nodes = Node.objects \
|
||||
.prefetch_related('granted_by_permissions') \
|
||||
.filter(granted_by_permissions__in=self.perm_ids) \
|
||||
.only('id', 'key')
|
||||
assets = PermNode.get_nodes_all_assets(*nodes)
|
||||
if flat:
|
||||
return set(assets.values_list('id', flat=True))
|
||||
nodes = self.get_perm_nodes()
|
||||
assets = PermNode.get_nodes_all_assets(*nodes, distinct=False)
|
||||
return assets
|
||||
|
||||
@timeit
|
||||
def get_direct_assets(self, flat=False):
|
||||
def get_direct_assets(self):
|
||||
""" 获取直接授权的资产 """
|
||||
assets = Asset.objects.order_by() \
|
||||
.filter(granted_by_permissions__id__in=self.perm_ids) \
|
||||
.distinct()
|
||||
if flat:
|
||||
return set(assets.values_list('id', flat=True))
|
||||
asset_ids = AssetPermission.assets.through.objects \
|
||||
.filter(assetpermission_id__in=self.perm_ids) \
|
||||
.values_list('asset_id', flat=True)
|
||||
assets = Asset.objects.filter(id__in=asset_ids)
|
||||
return assets
|
||||
|
||||
|
||||
|
@ -152,6 +153,7 @@ class UserPermAssetUtil(AssetPermissionPermAssetUtil):
|
|||
assets = assets.filter(nodes__id=node.id).order_by().distinct()
|
||||
return assets
|
||||
|
||||
@timeit
|
||||
def _get_indirect_perm_node_all_assets(self, node):
|
||||
""" 获取间接授权节点下的所有资产
|
||||
此算法依据 `UserAssetGrantedTreeNodeRelation` 的数据查询
|
||||
|
|
|
@ -72,7 +72,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
|
|||
|
||||
@timeit
|
||||
def refresh_if_need(self, force=False):
|
||||
built_just_now = cache.get(self.cache_key_time)
|
||||
built_just_now = False if settings.ASSET_SIZE == 'small' else cache.get(self.cache_key_time)
|
||||
if built_just_now:
|
||||
logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now))
|
||||
return
|
||||
|
@ -80,12 +80,18 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
|
|||
if not to_refresh_orgs:
|
||||
logger.info('Not have to refresh orgs')
|
||||
return
|
||||
|
||||
logger.info("Delay refresh user orgs: {} {}".format(self.user, [o.name for o in to_refresh_orgs]))
|
||||
refresh_user_orgs_perm_tree(user_orgs=((self.user, tuple(to_refresh_orgs)),))
|
||||
refresh_user_favorite_assets(users=(self.user,))
|
||||
sync = True if settings.ASSET_SIZE == 'small' else False
|
||||
refresh_user_orgs_perm_tree.apply(sync=sync, user_orgs=((self.user, tuple(to_refresh_orgs)),))
|
||||
refresh_user_favorite_assets.apply(sync=sync, users=(self.user,))
|
||||
|
||||
@timeit
|
||||
def refresh_tree_manual(self):
|
||||
"""
|
||||
用来手动 debug
|
||||
:return:
|
||||
"""
|
||||
built_just_now = cache.get(self.cache_key_time)
|
||||
if built_just_now:
|
||||
logger.info('Refresh just now, pass: {}'.format(built_just_now))
|
||||
|
@ -105,8 +111,9 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
|
|||
return
|
||||
|
||||
self._clean_user_perm_tree_for_legacy_org()
|
||||
ttl = settings.PERM_TREE_REGEN_INTERVAL
|
||||
cache.set(self.cache_key_time, int(time.time()), ttl)
|
||||
if settings.ASSET_SIZE != 'small':
|
||||
ttl = settings.PERM_TREE_REGEN_INTERVAL
|
||||
cache.set(self.cache_key_time, int(time.time()), ttl)
|
||||
|
||||
lock = UserGrantedTreeRebuildLock(self.user.id)
|
||||
got = lock.acquire(blocking=False)
|
||||
|
@ -193,7 +200,13 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin):
|
|||
cache_key = self.get_cache_key(uid)
|
||||
p.srem(cache_key, *org_ids)
|
||||
p.execute()
|
||||
logger.info('Expire perm tree for users: [{}], orgs: [{}]'.format(user_ids, org_ids))
|
||||
users_display = ','.join([str(i) for i in user_ids[:3]])
|
||||
if len(user_ids) > 3:
|
||||
users_display += '...'
|
||||
orgs_display = ','.join([str(i) for i in org_ids[:3]])
|
||||
if len(org_ids) > 3:
|
||||
orgs_display += '...'
|
||||
logger.info('Expire perm tree for users: [{}], orgs: [{}]'.format(users_display, orgs_display))
|
||||
|
||||
def expire_perm_tree_for_all_user(self):
|
||||
keys = self.client.keys(self.cache_key_all_user)
|
||||
|
|
|
@ -80,9 +80,11 @@ class RoleViewSet(JMSModelViewSet):
|
|||
queryset = Role.objects.filter(id__in=ids).order_by(*self.ordering)
|
||||
org_id = current_org.id
|
||||
q = Q(role__scope=Role.Scope.system) | Q(role__scope=Role.Scope.org, org_id=org_id)
|
||||
role_bindings = RoleBinding.objects.filter(q).values_list('role_id').annotate(user_count=Count('user_id'))
|
||||
role_bindings = RoleBinding.objects.filter(q).values_list('role_id').annotate(
|
||||
user_count=Count('user_id', distinct=True)
|
||||
)
|
||||
role_user_amount_mapper = {role_id: user_count for role_id, user_count in role_bindings}
|
||||
queryset = queryset.annotate(permissions_amount=Count('permissions'))
|
||||
queryset = queryset.annotate(permissions_amount=Count('permissions', distinct=True))
|
||||
queryset = list(queryset)
|
||||
for role in queryset:
|
||||
role.users_amount = role_user_amount_mapper.get(role.id, 0)
|
||||
|
|
|
@ -137,7 +137,7 @@ class LDAPUserImportAPI(APIView):
|
|||
return Response({'msg': _('Get ldap users is None')}, status=400)
|
||||
|
||||
orgs = self.get_orgs()
|
||||
errors = LDAPImportUtil().perform_import(users, orgs)
|
||||
new_users, errors = LDAPImportUtil().perform_import(users, orgs)
|
||||
if errors:
|
||||
return Response({'errors': errors}, status=400)
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from rest_framework import generics
|
|||
from rest_framework.permissions import AllowAny
|
||||
|
||||
from authentication.permissions import IsValidUserOrConnectionToken
|
||||
from common.const.choices import COUNTRY_CALLING_CODES
|
||||
from common.utils import get_logger, lazyproperty
|
||||
from common.utils.timezone import local_now
|
||||
from .. import serializers
|
||||
|
@ -24,7 +25,8 @@ class OpenPublicSettingApi(generics.RetrieveAPIView):
|
|||
def get_object(self):
|
||||
return {
|
||||
"XPACK_ENABLED": settings.XPACK_ENABLED,
|
||||
"INTERFACE": self.interface_setting
|
||||
"INTERFACE": self.interface_setting,
|
||||
"COUNTRY_CALLING_CODES": COUNTRY_CALLING_CODES
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
from django.template.loader import render_to_string
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from common.utils import get_logger
|
||||
from common.utils.timezone import local_now_display
|
||||
from notifications.notifications import UserMessage
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class LDAPImportMessage(UserMessage):
|
||||
def __init__(self, user, extra_kwargs):
|
||||
super().__init__(user)
|
||||
self.orgs = extra_kwargs.pop('orgs', [])
|
||||
self.end_time = extra_kwargs.pop('end_time', '')
|
||||
self.start_time = extra_kwargs.pop('start_time', '')
|
||||
self.time_start_display = extra_kwargs.pop('time_start_display', '')
|
||||
self.new_users = extra_kwargs.pop('new_users', [])
|
||||
self.errors = extra_kwargs.pop('errors', [])
|
||||
self.cost_time = extra_kwargs.pop('cost_time', '')
|
||||
|
||||
def get_html_msg(self) -> dict:
|
||||
subject = _('Notification of Synchronized LDAP User Task Results')
|
||||
context = {
|
||||
'orgs': self.orgs,
|
||||
'start_time': self.time_start_display,
|
||||
'end_time': local_now_display(),
|
||||
'cost_time': self.cost_time,
|
||||
'users': self.new_users,
|
||||
'errors': self.errors
|
||||
}
|
||||
message = render_to_string('ldap/_msg_import_ldap_user.html', context)
|
||||
return {
|
||||
'subject': subject,
|
||||
'message': message
|
||||
}
|
|
@ -77,6 +77,9 @@ class LDAPSettingSerializer(serializers.Serializer):
|
|||
required=False, label=_('Connect timeout (s)'),
|
||||
)
|
||||
AUTH_LDAP_SEARCH_PAGED_SIZE = serializers.IntegerField(required=False, label=_('Search paged size (piece)'))
|
||||
AUTH_LDAP_SYNC_RECEIVERS = serializers.ListField(
|
||||
required=False, label=_('Recipient'), max_length=36
|
||||
)
|
||||
|
||||
AUTH_LDAP = serializers.BooleanField(required=False, label=_('Enable LDAP auth'))
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ class OAuth2SettingSerializer(serializers.Serializer):
|
|||
)
|
||||
AUTH_OAUTH2_ACCESS_TOKEN_METHOD = serializers.ChoiceField(
|
||||
default='GET', label=_('Client authentication method'),
|
||||
choices=(('GET', 'GET'), ('POST', 'POST'))
|
||||
choices=(('GET', 'GET'), ('POST', 'POST-DATA'), ('POST_JSON', 'POST-JSON'))
|
||||
)
|
||||
AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT = serializers.CharField(
|
||||
required=True, max_length=1024, label=_('Provider userinfo endpoint')
|
||||
|
|
|
@ -11,6 +11,7 @@ __all__ = [
|
|||
class PublicSettingSerializer(serializers.Serializer):
|
||||
XPACK_ENABLED = serializers.BooleanField()
|
||||
INTERFACE = serializers.DictField()
|
||||
COUNTRY_CALLING_CODES = serializers.ListField()
|
||||
|
||||
|
||||
class PrivateSettingSerializer(PublicSettingSerializer):
|
||||
|
|
|
@ -1,15 +1,19 @@
|
|||
# coding: utf-8
|
||||
#
|
||||
import time
|
||||
from celery import shared_task
|
||||
from django.conf import settings
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common.utils import get_logger
|
||||
from common.utils.timezone import local_now_display
|
||||
from ops.celery.decorator import after_app_ready_start
|
||||
from ops.celery.utils import (
|
||||
create_or_update_celery_periodic_tasks, disable_celery_periodic_task
|
||||
)
|
||||
from orgs.models import Organization
|
||||
from settings.notifications import LDAPImportMessage
|
||||
from users.models import User
|
||||
from ..utils import LDAPSyncUtil, LDAPServerUtil, LDAPImportUtil
|
||||
|
||||
__all__ = ['sync_ldap_user', 'import_ldap_user_periodic', 'import_ldap_user']
|
||||
|
@ -23,6 +27,8 @@ def sync_ldap_user():
|
|||
|
||||
@shared_task(verbose_name=_('Periodic import ldap user'))
|
||||
def import_ldap_user():
|
||||
start_time = time.time()
|
||||
time_start_display = local_now_display()
|
||||
logger.info("Start import ldap user task")
|
||||
util_server = LDAPServerUtil()
|
||||
util_import = LDAPImportUtil()
|
||||
|
@ -35,11 +41,26 @@ def import_ldap_user():
|
|||
org_ids = [Organization.DEFAULT_ID]
|
||||
default_org = Organization.default()
|
||||
orgs = list(set([Organization.get_instance(org_id, default=default_org) for org_id in org_ids]))
|
||||
errors = util_import.perform_import(users, orgs)
|
||||
new_users, errors = util_import.perform_import(users, orgs)
|
||||
if errors:
|
||||
logger.error("Imported LDAP users errors: {}".format(errors))
|
||||
else:
|
||||
logger.info('Imported {} users successfully'.format(len(users)))
|
||||
if settings.AUTH_LDAP_SYNC_RECEIVERS:
|
||||
user_ids = settings.AUTH_LDAP_SYNC_RECEIVERS
|
||||
recipient_list = User.objects.filter(id__in=list(user_ids))
|
||||
end_time = time.time()
|
||||
extra_kwargs = {
|
||||
'orgs': orgs,
|
||||
'end_time': end_time,
|
||||
'start_time': start_time,
|
||||
'time_start_display': time_start_display,
|
||||
'new_users': new_users,
|
||||
'errors': errors,
|
||||
'cost_time': end_time - start_time,
|
||||
}
|
||||
for user in recipient_list:
|
||||
LDAPImportMessage(user, extra_kwargs).publish()
|
||||
|
||||
|
||||
@shared_task(verbose_name=_('Registration periodic import ldap user task'))
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
{% load i18n %}
|
||||
<p>{% trans "Sync task Finish" %}</p>
|
||||
<b>{% trans 'Time' %}:</b>
|
||||
<ul>
|
||||
<li>{% trans 'Date start' %}: {{ start_time }}</li>
|
||||
<li>{% trans 'Date end' %}: {{ end_time }}</li>
|
||||
<li>{% trans 'Time cost' %}: {{ cost_time| floatformat:0 }}s</li>
|
||||
</ul>
|
||||
<b>{% trans "Synced Organization" %}:</b>
|
||||
<ul>
|
||||
{% for org in orgs %}
|
||||
<li>{{ org }}</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
<b>{% trans "Synced User" %}:</b>
|
||||
<ul>
|
||||
{% if users %}
|
||||
{% for user in users %}
|
||||
<li>{{ user }}</li>
|
||||
{% endfor %}
|
||||
{% else %}
|
||||
<li>{% trans 'No user synchronization required' %}</li>
|
||||
{% endif %}
|
||||
</ul>
|
||||
{% if errors %}
|
||||
<b>{% trans 'Error' %}:</b>
|
||||
<ul>
|
||||
{% for error in errors %}
|
||||
<li>{{ error }}</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
{% endif %}
|
||||
|
||||
|
|
@ -400,11 +400,14 @@ class LDAPImportUtil(object):
|
|||
logger.info('Start perform import ldap users, count: {}'.format(len(users)))
|
||||
errors = []
|
||||
objs = []
|
||||
new_users = []
|
||||
group_users_mapper = defaultdict(set)
|
||||
for user in users:
|
||||
groups = user.pop('groups', [])
|
||||
try:
|
||||
obj, created = self.update_or_create(user)
|
||||
if created:
|
||||
new_users.append(obj)
|
||||
objs.append(obj)
|
||||
except Exception as e:
|
||||
errors.append({user['username']: str(e)})
|
||||
|
@ -421,7 +424,7 @@ class LDAPImportUtil(object):
|
|||
for org in orgs:
|
||||
self.bind_org(org, objs, group_users_mapper)
|
||||
logger.info('End perform import ldap users')
|
||||
return errors
|
||||
return new_users, errors
|
||||
|
||||
def exit_user_group(self, user_groups_mapper):
|
||||
# 通过对比查询本次导入用户需要移除的用户组
|
||||
|
|
|
@ -42,7 +42,7 @@ class SmartEndpointViewMixin:
|
|||
return endpoint
|
||||
|
||||
def match_endpoint_by_label(self):
|
||||
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol)
|
||||
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol, self.request)
|
||||
|
||||
def match_endpoint_by_target_ip(self):
|
||||
target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数,用来方便测试
|
||||
|
|
|
@ -75,7 +75,20 @@ class Endpoint(JMSBaseModel):
|
|||
return endpoint
|
||||
|
||||
@classmethod
|
||||
def match_by_instance_label(cls, instance, protocol):
|
||||
def handle_endpoint_host(cls, endpoint, request=None):
|
||||
if not endpoint.host and request:
|
||||
# 动态添加 current request host
|
||||
host_port = request.get_host()
|
||||
# IPv6
|
||||
if host_port.startswith('['):
|
||||
host = host_port.split(']:')[0].rstrip(']') + ']'
|
||||
else:
|
||||
host = host_port.split(':')[0]
|
||||
endpoint.host = host
|
||||
return endpoint
|
||||
|
||||
@classmethod
|
||||
def match_by_instance_label(cls, instance, protocol, request=None):
|
||||
from assets.models import Asset
|
||||
from terminal.models import Session
|
||||
if isinstance(instance, Session):
|
||||
|
@ -88,6 +101,7 @@ class Endpoint(JMSBaseModel):
|
|||
endpoints = cls.objects.filter(name__in=list(values)).order_by('-date_updated')
|
||||
for endpoint in endpoints:
|
||||
if endpoint.is_valid_for(instance, protocol):
|
||||
endpoint = cls.handle_endpoint_host(endpoint, request)
|
||||
return endpoint
|
||||
|
||||
|
||||
|
@ -130,13 +144,5 @@ class EndpointRule(JMSBaseModel):
|
|||
endpoint = endpoint_rule.endpoint
|
||||
else:
|
||||
endpoint = Endpoint.get_or_create_default(request)
|
||||
if not endpoint.host and request:
|
||||
# 动态添加 current request host
|
||||
host_port = request.get_host()
|
||||
# IPv6
|
||||
if host_port.startswith('['):
|
||||
host = host_port.split(']:')[0].rstrip(']') + ']'
|
||||
else:
|
||||
host = host_port.split(':')[0]
|
||||
endpoint.host = host
|
||||
endpoint = Endpoint.handle_endpoint_host(endpoint, request)
|
||||
return endpoint
|
||||
|
|
|
@ -5,3 +5,4 @@ from .ticket import *
|
|||
from .comment import *
|
||||
from .relation import *
|
||||
from .super_ticket import *
|
||||
from .perms import *
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
from django.conf import settings
|
||||
|
||||
from assets.models import Asset, Node
|
||||
from assets.serializers.asset.common import MiniAssetSerializer
|
||||
from assets.serializers.node import NodeSerializer
|
||||
from common.api import SuggestionMixin
|
||||
from orgs.mixins.api import OrgReadonlyModelViewSet
|
||||
from perms.utils import AssetPermissionPermAssetUtil
|
||||
from perms.utils.permission import AssetPermissionUtil
|
||||
from tickets.const import TicketApplyAssetScope
|
||||
|
||||
__all__ = ['ApplyAssetsViewSet', 'ApplyNodesViewSet']
|
||||
|
||||
|
||||
class ApplyAssetsViewSet(OrgReadonlyModelViewSet, SuggestionMixin):
|
||||
model = Asset
|
||||
serializer_class = MiniAssetSerializer
|
||||
rbac_perms = (
|
||||
("match", "assets.match_asset"),
|
||||
)
|
||||
|
||||
search_fields = ("name", "address", "comment")
|
||||
|
||||
def get_queryset(self):
|
||||
if TicketApplyAssetScope.is_permed():
|
||||
queryset = self.get_assets(with_expired=True)
|
||||
elif TicketApplyAssetScope.is_permed_valid():
|
||||
queryset = self.get_assets()
|
||||
else:
|
||||
queryset = super().get_queryset()
|
||||
return queryset
|
||||
|
||||
def get_assets(self, with_expired=False):
|
||||
perms = AssetPermissionUtil().get_permissions_for_user(
|
||||
self.request.user, flat=True, with_expired=with_expired
|
||||
)
|
||||
util = AssetPermissionPermAssetUtil(perms)
|
||||
assets = util.get_all_assets()
|
||||
return assets
|
||||
|
||||
|
||||
class ApplyNodesViewSet(OrgReadonlyModelViewSet, SuggestionMixin):
|
||||
model = Node
|
||||
serializer_class = NodeSerializer
|
||||
rbac_perms = (
|
||||
("match", "assets.match_node"),
|
||||
)
|
||||
|
||||
search_fields = ('full_value',)
|
||||
|
||||
def get_queryset(self):
|
||||
if TicketApplyAssetScope.is_permed():
|
||||
queryset = self.get_nodes(with_expired=True)
|
||||
elif TicketApplyAssetScope.is_permed_valid():
|
||||
queryset = self.get_nodes()
|
||||
else:
|
||||
queryset = super().get_queryset()
|
||||
return queryset
|
||||
|
||||
def get_nodes(self, with_expired=False):
|
||||
perms = AssetPermissionUtil().get_permissions_for_user(
|
||||
self.request.user, flat=True, with_expired=with_expired
|
||||
)
|
||||
util = AssetPermissionPermAssetUtil(perms)
|
||||
nodes = util.get_perm_nodes()
|
||||
return nodes
|
|
@ -4,6 +4,7 @@ from django.utils.translation import gettext_lazy as _
|
|||
from rest_framework import viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import MethodNotAllowed
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
from audits.handler import create_or_update_operate_log
|
||||
|
@ -41,7 +42,6 @@ class TicketViewSet(CommonApiMixin, viewsets.ModelViewSet):
|
|||
ordering = ('-date_created',)
|
||||
rbac_perms = {
|
||||
'open': 'tickets.view_ticket',
|
||||
'bulk': 'tickets.change_ticket',
|
||||
}
|
||||
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
|
@ -122,7 +122,7 @@ class TicketViewSet(CommonApiMixin, viewsets.ModelViewSet):
|
|||
self._record_operate_log(instance, TicketAction.close)
|
||||
return Response('ok')
|
||||
|
||||
@action(detail=False, methods=[PUT], permission_classes=[RBACPermission, ])
|
||||
@action(detail=False, methods=[PUT], permission_classes=[IsAuthenticated, ])
|
||||
def bulk(self, request, *args, **kwargs):
|
||||
self.ticket_not_allowed()
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from django.conf import settings
|
||||
from django.db.models import TextChoices, IntegerChoices
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
@ -56,3 +57,21 @@ class TicketApprovalStrategy(TextChoices):
|
|||
custom_user = 'custom_user', _("Custom user")
|
||||
super_admin = 'super_admin', _("Super admin")
|
||||
super_org_admin = 'super_org_admin', _("Super admin and org admin")
|
||||
|
||||
|
||||
class TicketApplyAssetScope(TextChoices):
|
||||
all = 'all', _("All assets")
|
||||
permed = 'permed', _("Permed assets")
|
||||
permed_valid = 'permed_valid', _('Permed valid assets')
|
||||
|
||||
@classmethod
|
||||
def get_scope(cls):
|
||||
return settings.TICKET_APPLY_ASSET_SCOPE.lower()
|
||||
|
||||
@classmethod
|
||||
def is_permed(cls):
|
||||
return cls.get_scope() == cls.permed
|
||||
|
||||
@classmethod
|
||||
def is_permed_valid(cls):
|
||||
return cls.get_scope() == cls.permed_valid
|
||||
|
|
|
@ -57,7 +57,7 @@ class TicketStep(JMSBaseModel):
|
|||
assignees.update(state=state)
|
||||
self.status = StepStatus.closed
|
||||
self.state = state
|
||||
self.save(update_fields=['state', 'status'])
|
||||
self.save(update_fields=['state', 'status', 'date_updated'])
|
||||
|
||||
def set_active(self):
|
||||
self.status = StepStatus.active
|
||||
|
|
|
@ -16,6 +16,8 @@ router.register('apply-login-tickets', api.ApplyLoginTicketViewSet, 'apply-login
|
|||
router.register('apply-command-tickets', api.ApplyCommandTicketViewSet, 'apply-command-ticket')
|
||||
router.register('apply-login-asset-tickets', api.ApplyLoginAssetTicketViewSet, 'apply-login-asset-ticket')
|
||||
router.register('ticket-session-relation', api.TicketSessionRelationViewSet, 'ticket-session-relation')
|
||||
router.register('apply-assets', api.ApplyAssetsViewSet, 'ticket-session-relation')
|
||||
router.register('apply-nodes', api.ApplyNodesViewSet, 'ticket-session-relation')
|
||||
|
||||
urlpatterns = [
|
||||
path('tickets/<uuid:ticket_id>/session/', api.TicketSessionApi.as_view(), name='ticket-session'),
|
||||
|
|
|
@ -729,7 +729,7 @@ class JSONFilterMixin:
|
|||
|
||||
bindings = RoleBinding.objects.filter(**kwargs, role__in=value)
|
||||
if match == 'm2m_all':
|
||||
user_id = bindings.values('user_id').annotate(count=Count('user_id')) \
|
||||
user_id = bindings.values('user_id').annotate(count=Count('user_id', distinct=True)) \
|
||||
.filter(count=len(value)).values_list('user_id', flat=True)
|
||||
else:
|
||||
user_id = bindings.values_list('user_id', flat=True)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from django.db.models import Count
|
||||
from django.db.models import Count, Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
|
@ -46,7 +46,7 @@ class UserGroupSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
|
|||
def setup_eager_loading(cls, queryset):
|
||||
""" Perform necessary eager loading of data. """
|
||||
queryset = queryset.prefetch_related('labels', 'labels__label') \
|
||||
.annotate(users_amount=Count('users'))
|
||||
.annotate(users_amount=Count('users', distinct=True, filter=Q(users__is_service_account=False)))
|
||||
return queryset
|
||||
|
||||
|
||||
|
|
|
@ -163,9 +163,9 @@ def on_openid_create_or_update_user(sender, request, user, created, name, userna
|
|||
user.save()
|
||||
|
||||
|
||||
@shared_task(verbose_name=_('Clean audits session task log'))
|
||||
@shared_task(verbose_name=_('Clean up expired user sessions'))
|
||||
@register_as_period_task(crontab=CRONTAB_AT_PM_TWO)
|
||||
def clean_audits_log_period():
|
||||
def clean_expired_user_session_period():
|
||||
UserSession.clear_expired_sessions()
|
||||
|
||||
|
||||
|
|
|
@ -86,7 +86,7 @@ def check_user_expired_periodic():
|
|||
@tmp_to_root_org()
|
||||
def check_unused_users():
|
||||
uncommon_users_ttl = settings.SECURITY_UNCOMMON_USERS_TTL
|
||||
if not uncommon_users_ttl or not uncommon_users_ttl.isdigit():
|
||||
if not uncommon_users_ttl:
|
||||
return
|
||||
|
||||
uncommon_users_ttl = int(uncommon_users_ttl)
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
.margin-bottom {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
|
||||
.input-style {
|
||||
width: 100%;
|
||||
display: inline-block;
|
||||
|
@ -22,6 +23,19 @@
|
|||
height: 100%;
|
||||
vertical-align: top;
|
||||
}
|
||||
|
||||
.scrollable-menu {
|
||||
height: auto;
|
||||
max-height: 18rem;
|
||||
overflow-x: hidden;
|
||||
}
|
||||
|
||||
.input-group {
|
||||
.input-group-btn .btn-secondary {
|
||||
color: #464a4c;
|
||||
background-color: #eceeef;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
{% endblock %}
|
||||
{% block html_title %}{% trans 'Forgot password' %}{% endblock %}
|
||||
|
@ -57,9 +71,26 @@
|
|||
placeholder="{% trans 'Email account' %}" value="{{ email }}">
|
||||
</div>
|
||||
<div id="validate-sms" class="validate-field margin-bottom">
|
||||
<input type="tel" id="sms" name="sms" class="form-control input-style"
|
||||
placeholder="{% trans 'Mobile number' %}" value="{{ sms }}">
|
||||
<small style="color: #999; margin-left: 5px">{{ form.sms.help_text }}</small>
|
||||
<div class="input-group">
|
||||
<div class="input-group-btn">
|
||||
<button type="button" class="btn btn-secondary dropdown-toggle" data-toggle="dropdown"
|
||||
aria-haspopup="true" aria-expanded="false">
|
||||
<span class="country-code-value">+86</span>
|
||||
</button>
|
||||
<ul class="dropdown-menu scrollable-menu">
|
||||
{% for country in countries %}
|
||||
<li>
|
||||
<a href="#" class="dropdown-item d-flex justify-content-between">
|
||||
<span class="country-name text-left">{{ country.name }}</span>
|
||||
<span class="country-code">{{ country.value }}</span>
|
||||
</a>
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</div>
|
||||
<input type="tel" id="sms" name="sms" class="form-control input-style"
|
||||
placeholder="{% trans 'Mobile number' %}" value="{{ sms }}">
|
||||
</div>
|
||||
</div>
|
||||
<div class="margin-bottom challenge-required">
|
||||
<input type="text" id="code" name="code" class="form-control input-style"
|
||||
|
@ -76,7 +107,7 @@
|
|||
</div>
|
||||
</div>
|
||||
<script>
|
||||
$(function (){
|
||||
$(function () {
|
||||
const validateSelectRef = $('#validate-backend-select')
|
||||
const formType = $('input[name="form_type"]').val()
|
||||
validateSelectRef.val(formType)
|
||||
|
@ -84,19 +115,31 @@
|
|||
selectChange(formType);
|
||||
}
|
||||
})
|
||||
|
||||
$(".dropdown-menu li a").click(function (evt) {
|
||||
const inputGroup = $('.input-group');
|
||||
const inputGroupAddon = inputGroup.find('.country-code-value');
|
||||
const selectedCountry = $(evt.target).closest('li');
|
||||
const selectedCountryCode = selectedCountry.find('.country-code').html();
|
||||
inputGroupAddon.html(selectedCountryCode)
|
||||
});
|
||||
|
||||
|
||||
function getQueryString(name) {
|
||||
const reg = new RegExp("(^|&)"+ name +"=([^&]*)(&|$)");
|
||||
const reg = new RegExp("(^|&)" + name + "=([^&]*)(&|$)");
|
||||
const r = window.location.search.substr(1).match(reg);
|
||||
if(r !== null)
|
||||
if (r !== null)
|
||||
return unescape(r[2])
|
||||
return null
|
||||
}
|
||||
|
||||
function selectChange(name) {
|
||||
$('.validate-field').hide()
|
||||
$('#validate-' + name).show()
|
||||
$('#validate-' + name + '-tip').show()
|
||||
$('input[name="form_type"]').attr('value', name)
|
||||
}
|
||||
|
||||
function sendChallengeCode(currentBtn) {
|
||||
let time = 60;
|
||||
const token = getQueryString('token')
|
||||
|
@ -104,7 +147,7 @@
|
|||
|
||||
const formType = $('input[name="form_type"]').val()
|
||||
const email = $('#email').val()
|
||||
const sms = $('#sms').val()
|
||||
let sms = $('#sms').val();
|
||||
const errMsg = "{% trans 'The {} cannot be empty' %}"
|
||||
|
||||
if (formType === 'sms') {
|
||||
|
@ -118,10 +161,11 @@
|
|||
return
|
||||
}
|
||||
}
|
||||
|
||||
sms = $(".input-group .country-code-value").html() + sms
|
||||
const data = {
|
||||
form_type: formType, email: email, sms: sms,
|
||||
}
|
||||
|
||||
function onSuccess() {
|
||||
const originBtnText = currentBtn.innerHTML;
|
||||
currentBtn.disabled = true
|
||||
|
|
|
@ -14,22 +14,24 @@
|
|||
</strong>
|
||||
</p>
|
||||
<div>
|
||||
<img src="{% static 'img/authenticator_android.png' %}" width="128" height="128" alt="">
|
||||
<img src="{{ authenticator_android_url }}" width="128" height="128" alt="">
|
||||
<p>{% trans 'Android downloads' %}</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<img src="{% static 'img/authenticator_iphone.png' %}" width="128" height="128" alt="">
|
||||
<img src="{{ authenticator_iphone_url }}" width="128" height="128" alt="">
|
||||
<p>{% trans 'iPhone downloads' %}</p>
|
||||
</div>
|
||||
|
||||
<p style="margin: 20px auto;"><strong style="color: #000000">{% trans 'After installation, click the next step to enter the binding page (if installed, go to the next step directly).' %}</strong></p>
|
||||
<p style="margin: 20px auto;"><strong
|
||||
style="color: #000000">{% trans 'After installation, click the next step to enter the binding page (if installed, go to the next step directly).' %}</strong>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<a href="{% url 'authentication:user-otp-enable-bind' %}" class="next">{% trans 'Next' %}</a>
|
||||
|
||||
<script>
|
||||
$(function(){
|
||||
$(function () {
|
||||
$('.change-color li:eq(1) i').css('color', '{{ INTERFACE.primary_color }}')
|
||||
})
|
||||
</script>
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
# ~*~ coding: utf-8 ~*~
|
||||
import os
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import logout as auth_logout
|
||||
from django.http.response import HttpResponseRedirect
|
||||
from django.shortcuts import redirect
|
||||
from django.templatetags.static import static
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext as _
|
||||
from django.utils._os import safe_join
|
||||
from django.views.generic.base import TemplateView
|
||||
from django.views.generic.edit import FormView
|
||||
|
||||
|
@ -45,9 +49,26 @@ class UserOtpEnableStartView(AuthMixin, TemplateView):
|
|||
class UserOtpEnableInstallAppView(TemplateView):
|
||||
template_name = 'users/user_otp_enable_install_app.html'
|
||||
|
||||
@staticmethod
|
||||
def replace_authenticator_png(platform):
|
||||
media_url = settings.MEDIA_URL
|
||||
base_path = f'img/authenticator_{platform}.png'
|
||||
authenticator_media_path = safe_join(settings.MEDIA_ROOT, base_path)
|
||||
if os.path.exists(authenticator_media_path):
|
||||
authenticator_url = f'{media_url}{base_path}'
|
||||
else:
|
||||
authenticator_url = static(base_path)
|
||||
return authenticator_url
|
||||
|
||||
def get_context_data(self, **kwargs):
|
||||
user = get_user_or_pre_auth_user(self.request)
|
||||
context = {'user': user}
|
||||
authenticator_android_url = self.replace_authenticator_png('android')
|
||||
authenticator_iphone_url = self.replace_authenticator_png('iphone')
|
||||
context = {
|
||||
'user': user,
|
||||
'authenticator_android_url': authenticator_android_url,
|
||||
'authenticator_iphone_url': authenticator_iphone_url
|
||||
}
|
||||
kwargs.update(context)
|
||||
return super().get_context_data(**kwargs)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from django.views.generic import FormView, RedirectView
|
|||
|
||||
from authentication.errors import IntervalTooShort
|
||||
from authentication.utils import check_user_property_is_correct
|
||||
from common.const.choices import COUNTRY_CALLING_CODES
|
||||
from common.utils import FlashMessageUtil, get_object_or_none, random_string
|
||||
from common.utils.verify_code import SendAndVerifyCodeUtil
|
||||
from users.notifications import ResetPasswordSuccessMsg
|
||||
|
@ -108,7 +109,7 @@ class UserForgotPasswordView(FormView):
|
|||
for k, v in cleaned_data.items():
|
||||
if v:
|
||||
context[k] = v
|
||||
|
||||
context['countries'] = COUNTRY_CALLING_CODES
|
||||
context['form_type'] = 'email'
|
||||
context['XPACK_ENABLED'] = settings.XPACK_ENABLED
|
||||
validate_backends = self.get_validate_backends_context(has_phone)
|
||||
|
|
|
@ -85,7 +85,7 @@ REDIS_PORT: 6379
|
|||
# SECURITY_WATERMARK_ENABLED: False
|
||||
|
||||
# 浏览器关闭页面后,会话过期
|
||||
# SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE: False
|
||||
# SESSION_EXPIRE_AT_BROWSER_CLOSE: False
|
||||
|
||||
# 每次api请求,session续期
|
||||
# SESSION_SAVE_EVERY_REQUEST: True
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue