mirror of https://github.com/jumpserver/jumpserver
commit
cf31cbfb07
|
@ -19,11 +19,11 @@ ARG BUILD_DEPENDENCIES=" \
|
|||
|
||||
ARG DEPENDENCIES=" \
|
||||
freetds-dev \
|
||||
libpq-dev \
|
||||
libffi-dev \
|
||||
libjpeg-dev \
|
||||
libkrb5-dev \
|
||||
libldap2-dev \
|
||||
libpq-dev \
|
||||
libsasl2-dev \
|
||||
libssl-dev \
|
||||
libxml2-dev \
|
||||
|
@ -75,6 +75,7 @@ ENV LANG=zh_CN.UTF-8 \
|
|||
|
||||
ARG DEPENDENCIES=" \
|
||||
libjpeg-dev \
|
||||
libpq-dev \
|
||||
libx11-dev \
|
||||
freerdp2-dev \
|
||||
libxmlsec1-openssl"
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from django.db.models import Q
|
||||
from rest_framework.generics import CreateAPIView
|
||||
|
||||
from accounts import serializers
|
||||
from accounts.models import Account
|
||||
from accounts.permissions import AccountTaskActionPermission
|
||||
from accounts.tasks import (
|
||||
remove_accounts_task, verify_accounts_connectivity_task, push_accounts_to_assets_task
|
||||
)
|
||||
from assets.exceptions import NotSupportedTemporarilyError
|
||||
from authentication.permissions import UserConfirmation, ConfirmType
|
||||
|
||||
__all__ = [
|
||||
|
@ -26,25 +27,35 @@ class AccountsTaskCreateAPI(CreateAPIView):
|
|||
]
|
||||
return super().get_permissions()
|
||||
|
||||
def perform_create(self, serializer):
|
||||
data = serializer.validated_data
|
||||
accounts = data.get('accounts', [])
|
||||
params = data.get('params')
|
||||
@staticmethod
|
||||
def get_account_ids(data, action):
|
||||
account_type = 'gather_accounts' if action == 'remove' else 'accounts'
|
||||
accounts = data.get(account_type, [])
|
||||
account_ids = [str(a.id) for a in accounts]
|
||||
|
||||
if data['action'] == 'push':
|
||||
task = push_accounts_to_assets_task.delay(account_ids, params)
|
||||
elif data['action'] == 'remove':
|
||||
gather_accounts = data.get('gather_accounts', [])
|
||||
gather_account_ids = [str(a.id) for a in gather_accounts]
|
||||
task = remove_accounts_task.delay(gather_account_ids)
|
||||
if action == 'remove':
|
||||
return account_ids
|
||||
|
||||
assets = data.get('assets', [])
|
||||
asset_ids = [str(a.id) for a in assets]
|
||||
ids = Account.objects.filter(
|
||||
Q(id__in=account_ids) | Q(asset_id__in=asset_ids)
|
||||
).distinct().values_list('id', flat=True)
|
||||
return [str(_id) for _id in ids]
|
||||
|
||||
def perform_create(self, serializer):
|
||||
data = serializer.validated_data
|
||||
action = data['action']
|
||||
ids = self.get_account_ids(data, action)
|
||||
|
||||
if action == 'push':
|
||||
task = push_accounts_to_assets_task.delay(ids, data.get('params'))
|
||||
elif action == 'remove':
|
||||
task = remove_accounts_task.delay(ids)
|
||||
elif action == 'verify':
|
||||
task = verify_accounts_connectivity_task.delay(ids)
|
||||
else:
|
||||
account = accounts[0]
|
||||
asset = account.asset
|
||||
if not asset.auto_config['ansible_enabled'] or \
|
||||
not asset.auto_config['ping_enabled']:
|
||||
raise NotSupportedTemporarilyError()
|
||||
task = verify_accounts_connectivity_task.delay(account_ids)
|
||||
raise ValueError(f"Invalid action: {action}")
|
||||
|
||||
data = getattr(serializer, '_data', {})
|
||||
data["task"] = task.id
|
||||
|
|
|
@ -168,9 +168,8 @@ class AccountBackupHandler:
|
|||
if not user.secret_key:
|
||||
attachment_list = []
|
||||
else:
|
||||
password = user.secret_key.encode('utf8')
|
||||
attachment = os.path.join(PATH, f'{plan_name}-{local_now_filename()}-{time.time()}.zip')
|
||||
encrypt_and_compress_zip_file(attachment, password, files)
|
||||
encrypt_and_compress_zip_file(attachment, user.secret_key, files)
|
||||
attachment_list = [attachment, ]
|
||||
AccountBackupExecutionTaskMsg(plan_name, user).publish(attachment_list)
|
||||
print('邮件已发送至{}({})'.format(user, user.email))
|
||||
|
@ -191,7 +190,6 @@ class AccountBackupHandler:
|
|||
attachment = os.path.join(PATH, f'{plan_name}-{local_now_filename()}-{time.time()}.zip')
|
||||
if password:
|
||||
print('\033[32m>>> 使用加密密码对文件进行加密中\033[0m')
|
||||
password = password.encode('utf8')
|
||||
encrypt_and_compress_zip_file(attachment, password, files)
|
||||
else:
|
||||
zip_files(attachment, files)
|
||||
|
|
|
@ -7,6 +7,7 @@ type:
|
|||
- all
|
||||
method: change_secret
|
||||
protocol: ssh
|
||||
priority: 50
|
||||
params:
|
||||
- name: commands
|
||||
type: list
|
||||
|
|
|
@ -39,3 +39,4 @@
|
|||
login_host: "{{ jms_asset.address }}"
|
||||
login_port: "{{ jms_asset.port }}"
|
||||
login_database: "{{ jms_asset.spec_info.db_name }}"
|
||||
mode: "{{ account.mode }}"
|
||||
|
|
|
@ -5,6 +5,7 @@ method: change_secret
|
|||
category: host
|
||||
type:
|
||||
- windows
|
||||
priority: 49
|
||||
params:
|
||||
- name: groups
|
||||
type: str
|
||||
|
|
|
@ -4,6 +4,7 @@ from copy import deepcopy
|
|||
|
||||
from django.conf import settings
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from xlsxwriter import Workbook
|
||||
|
||||
from accounts.const import AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy
|
||||
|
@ -118,6 +119,10 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
else:
|
||||
new_secret = self.get_secret(secret_type)
|
||||
|
||||
if new_secret is None:
|
||||
print(f'new_secret is None, account: {account}')
|
||||
continue
|
||||
|
||||
if self.record_id is None:
|
||||
recorder = ChangeSecretRecord(
|
||||
asset=asset, account=account, execution=self.execution,
|
||||
|
@ -183,17 +188,33 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_summary(recorders):
|
||||
total, succeed, failed = 0, 0, 0
|
||||
for recorder in recorders:
|
||||
if recorder.status == 'success':
|
||||
succeed += 1
|
||||
else:
|
||||
failed += 1
|
||||
total += 1
|
||||
|
||||
summary = _('Success: %s, Failed: %s, Total: %s') % (succeed, failed, total)
|
||||
return summary
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
if self.secret_type and not self.check_secret():
|
||||
return
|
||||
super().run(*args, **kwargs)
|
||||
recorders = list(self.name_recorder_mapper.values())
|
||||
summary = self.get_summary(recorders)
|
||||
print(summary, end='')
|
||||
|
||||
if self.record_id:
|
||||
return
|
||||
recorders = self.name_recorder_mapper.values()
|
||||
recorders = list(recorders)
|
||||
self.send_recorder_mail(recorders)
|
||||
|
||||
def send_recorder_mail(self, recorders):
|
||||
self.send_recorder_mail(recorders, summary)
|
||||
|
||||
def send_recorder_mail(self, recorders, summary):
|
||||
recipients = self.execution.recipients
|
||||
if not recorders or not recipients:
|
||||
return
|
||||
|
@ -209,11 +230,10 @@ class ChangeSecretManager(AccountBasePlaybookManager):
|
|||
for user in recipients:
|
||||
attachments = []
|
||||
if user.secret_key:
|
||||
password = user.secret_key.encode('utf8')
|
||||
attachment = os.path.join(path, f'{name}-{local_now_filename()}-{time.time()}.zip')
|
||||
encrypt_and_compress_zip_file(attachment, password, [filename])
|
||||
encrypt_and_compress_zip_file(attachment, user.secret_key, [filename])
|
||||
attachments = [attachment]
|
||||
ChangeSecretExecutionTaskMsg(name, user).publish(attachments)
|
||||
ChangeSecretExecutionTaskMsg(name, user, summary).publish(attachments)
|
||||
os.remove(filename)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
- hosts: demo
|
||||
gather_facts: no
|
||||
tasks:
|
||||
- name: Gather posix account
|
||||
- name: Gather windows account
|
||||
ansible.builtin.win_shell: net user
|
||||
register: result
|
||||
ignore_errors: true
|
||||
|
||||
- name: Define info by set_fact
|
||||
ansible.builtin.set_fact:
|
||||
|
|
|
@ -39,3 +39,4 @@
|
|||
login_host: "{{ jms_asset.address }}"
|
||||
login_port: "{{ jms_asset.port }}"
|
||||
login_database: "{{ jms_asset.spec_info.db_name }}"
|
||||
mode: "{{ account.mode }}"
|
||||
|
|
|
@ -5,6 +5,7 @@ method: push_account
|
|||
category: host
|
||||
type:
|
||||
- windows
|
||||
priority: 49
|
||||
params:
|
||||
- name: groups
|
||||
type: str
|
||||
|
|
|
@ -6,6 +6,7 @@ type:
|
|||
- windows
|
||||
method: verify_account
|
||||
protocol: rdp
|
||||
priority: 1
|
||||
|
||||
i18n:
|
||||
Windows rdp account verify:
|
||||
|
|
|
@ -7,6 +7,7 @@ type:
|
|||
- all
|
||||
method: verify_account
|
||||
protocol: ssh
|
||||
priority: 50
|
||||
|
||||
i18n:
|
||||
SSH account verify:
|
||||
|
|
|
@ -51,6 +51,9 @@ class VerifyAccountManager(AccountBasePlaybookManager):
|
|||
h['name'] += '(' + account.username + ')'
|
||||
self.host_account_mapper[h['name']] = account
|
||||
secret = account.secret
|
||||
if secret is None:
|
||||
print(f'account {account.name} secret is None')
|
||||
continue
|
||||
|
||||
private_key_path = None
|
||||
if account.secret_type == SecretType.SSH_KEY:
|
||||
|
@ -62,7 +65,7 @@ class VerifyAccountManager(AccountBasePlaybookManager):
|
|||
'name': account.name,
|
||||
'username': account.username,
|
||||
'secret_type': account.secret_type,
|
||||
'secret': account.escape_jinja2_syntax(secret),
|
||||
'secret': account.escape_jinja2_syntax(secret),
|
||||
'private_key_path': private_key_path,
|
||||
'become': account.get_ansible_become_auth(),
|
||||
}
|
||||
|
|
|
@ -52,6 +52,7 @@ class AccountFilterSet(BaseFilterSet):
|
|||
class GatheredAccountFilterSet(BaseFilterSet):
|
||||
node_id = drf_filters.CharFilter(method='filter_nodes')
|
||||
asset_id = drf_filters.CharFilter(field_name='asset_id', lookup_expr='exact')
|
||||
asset_name = drf_filters.CharFilter(field_name='asset__name', lookup_expr='icontains')
|
||||
|
||||
@staticmethod
|
||||
def filter_nodes(queryset, name, value):
|
||||
|
|
|
@ -54,20 +54,23 @@ class AccountBackupByObjStorageExecutionTaskMsg(object):
|
|||
class ChangeSecretExecutionTaskMsg(object):
|
||||
subject = _('Notification of implementation result of encryption change plan')
|
||||
|
||||
def __init__(self, name: str, user: User):
|
||||
def __init__(self, name: str, user: User, summary):
|
||||
self.name = name
|
||||
self.user = user
|
||||
self.summary = summary
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
name = self.name
|
||||
if self.user.secret_key:
|
||||
return _('{} - The encryption change task has been completed. '
|
||||
'See the attachment for details').format(name)
|
||||
default_message = _('{} - The encryption change task has been completed. '
|
||||
'See the attachment for details').format(name)
|
||||
|
||||
else:
|
||||
return _("{} - The encryption change task has been completed: the encryption "
|
||||
"password has not been set - please go to personal information -> "
|
||||
"file encryption password to set the encryption password").format(name)
|
||||
default_message = _("{} - The encryption change task has been completed: the encryption "
|
||||
"password has not been set - please go to personal information -> "
|
||||
"set encryption password in preferences").format(name)
|
||||
return self.summary + '\n' + default_message
|
||||
|
||||
def publish(self, attachments=None):
|
||||
send_mail_attachment_async(
|
||||
|
|
|
@ -58,7 +58,7 @@ class AccountCreateUpdateSerializerMixin(serializers.Serializer):
|
|||
for data in initial_data:
|
||||
if not data.get('asset') and not self.instance:
|
||||
raise serializers.ValidationError({'asset': UniqueTogetherValidator.missing_message})
|
||||
asset = data.get('asset') or self.instance.asset
|
||||
asset = data.get('asset') or getattr(self.instance, 'asset', None)
|
||||
self.from_template_if_need(data)
|
||||
self.set_uniq_name_if_need(data, asset)
|
||||
|
||||
|
@ -455,12 +455,14 @@ class AccountHistorySerializer(serializers.ModelSerializer):
|
|||
|
||||
class AccountTaskSerializer(serializers.Serializer):
|
||||
ACTION_CHOICES = (
|
||||
('test', 'test'),
|
||||
('verify', 'verify'),
|
||||
('push', 'push'),
|
||||
('remove', 'remove'),
|
||||
)
|
||||
action = serializers.ChoiceField(choices=ACTION_CHOICES, write_only=True)
|
||||
assets = serializers.PrimaryKeyRelatedField(
|
||||
queryset=Asset.objects, required=False, allow_empty=True, many=True
|
||||
)
|
||||
accounts = serializers.PrimaryKeyRelatedField(
|
||||
queryset=Account.objects, required=False, allow_empty=True, many=True
|
||||
)
|
||||
|
|
|
@ -63,7 +63,7 @@ def create_accounts_activities(account, action='create'):
|
|||
def on_account_create_by_template(sender, instance, created=False, **kwargs):
|
||||
if not created or instance.source != 'template':
|
||||
return
|
||||
push_accounts_if_need(accounts=(instance,))
|
||||
push_accounts_if_need.delay(accounts=(instance,))
|
||||
create_accounts_activities(instance, action='create')
|
||||
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ def clean_historical_accounts():
|
|||
history_model = Account.history.model
|
||||
history_id_mapper = defaultdict(list)
|
||||
|
||||
ids = history_model.objects.values('id').annotate(count=Count('id')) \
|
||||
ids = history_model.objects.values('id').annotate(count=Count('id', distinct=True)) \
|
||||
.filter(count__gte=limit).values_list('id', flat=True)
|
||||
|
||||
if not ids:
|
||||
|
|
|
@ -41,21 +41,21 @@ class UserLoginReminderMsg(UserMessage):
|
|||
class AssetLoginReminderMsg(UserMessage):
|
||||
subject = _('Asset login reminder')
|
||||
|
||||
def __init__(self, user, asset: Asset, login_user: User, account_username):
|
||||
def __init__(self, user, asset: Asset, login_user: User, account: Account, input_username):
|
||||
self.asset = asset
|
||||
self.login_user = login_user
|
||||
self.account_username = account_username
|
||||
self.account = account
|
||||
self.input_username = input_username
|
||||
super().__init__(user)
|
||||
|
||||
def get_html_msg(self) -> dict:
|
||||
account = Account.objects.get(asset=self.asset, username=self.account_username)
|
||||
context = {
|
||||
'recipient': self.user,
|
||||
'username': self.login_user.username,
|
||||
'name': self.login_user.name,
|
||||
'asset': str(self.asset),
|
||||
'account': self.account_username,
|
||||
'account_name': account.name,
|
||||
'account': self.input_username,
|
||||
'account_name': self.account.name,
|
||||
}
|
||||
message = render_to_string('acls/asset_login_reminder.html', context)
|
||||
|
||||
|
|
|
@ -92,6 +92,7 @@ class AssetViewSet(SuggestionMixin, OrgBulkModelViewSet):
|
|||
model = Asset
|
||||
filterset_class = AssetFilterSet
|
||||
search_fields = ("name", "address", "comment")
|
||||
ordering = ('name',)
|
||||
ordering_fields = ('name', 'address', 'connectivity', 'platform', 'date_updated', 'date_created')
|
||||
serializer_classes = (
|
||||
("default", serializers.AssetSerializer),
|
||||
|
|
|
@ -48,7 +48,7 @@ class AssetPermUserListApi(BaseAssetPermUserOrUserGroupListApi):
|
|||
|
||||
def get_queryset(self):
|
||||
perms = self.get_asset_related_perms()
|
||||
users = User.objects.filter(
|
||||
users = User.get_queryset().filter(
|
||||
Q(assetpermissions__in=perms) | Q(groups__assetpermissions__in=perms)
|
||||
).distinct()
|
||||
return users
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
from .endpoint import ExecutionManager
|
||||
from .methods import platform_automation_methods, filter_platform_methods
|
||||
from .methods import platform_automation_methods, filter_platform_methods, sorted_methods
|
||||
|
|
|
@ -68,6 +68,10 @@ def filter_platform_methods(category, tp_name, method=None, methods=None):
|
|||
return methods
|
||||
|
||||
|
||||
def sorted_methods(methods):
|
||||
return sorted(methods, key=lambda x: x.get('priority', 10))
|
||||
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
platform_automation_methods = get_platform_automation_methods(BASE_DIR)
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ type:
|
|||
- windows
|
||||
method: ping
|
||||
protocol: rdp
|
||||
priority: 1
|
||||
|
||||
i18n:
|
||||
Ping by pyfreerdp:
|
||||
|
|
|
@ -7,6 +7,7 @@ type:
|
|||
- all
|
||||
method: ping
|
||||
protocol: ssh
|
||||
priority: 50
|
||||
|
||||
i18n:
|
||||
Ping by paramiko:
|
||||
|
|
|
@ -90,7 +90,7 @@ class AllTypes(ChoicesMixin):
|
|||
|
||||
@classmethod
|
||||
def set_automation_methods(cls, category, tp_name, constraints):
|
||||
from assets.automations import filter_platform_methods
|
||||
from assets.automations import filter_platform_methods, sorted_methods
|
||||
automation = constraints.get('automation', {})
|
||||
automation_methods = {}
|
||||
platform_automation_methods = cls.get_automation_methods()
|
||||
|
@ -101,6 +101,7 @@ class AllTypes(ChoicesMixin):
|
|||
methods = filter_platform_methods(
|
||||
category, tp_name, item_name, methods=platform_automation_methods
|
||||
)
|
||||
methods = sorted_methods(methods)
|
||||
methods = [{'name': m['name'], 'id': m['id']} for m in methods]
|
||||
automation_methods[item_name + '_methods'] = methods
|
||||
automation.update(automation_methods)
|
||||
|
|
|
@ -12,6 +12,6 @@ class Migration(migrations.Migration):
|
|||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name='asset',
|
||||
options={'ordering': ['name'], 'permissions': [('refresh_assethardwareinfo', 'Can refresh asset hardware info'), ('test_assetconnectivity', 'Can test asset connectivity'), ('match_asset', 'Can match asset'), ('change_assetnodes', 'Can change asset nodes')], 'verbose_name': 'Asset'},
|
||||
options={'ordering': [], 'permissions': [('refresh_assethardwareinfo', 'Can refresh asset hardware info'), ('test_assetconnectivity', 'Can test asset connectivity'), ('match_asset', 'Can match asset'), ('change_assetnodes', 'Can change asset nodes')], 'verbose_name': 'Asset'},
|
||||
),
|
||||
]
|
||||
|
|
|
@ -348,7 +348,7 @@ class Asset(NodesRelationMixin, LabeledMixin, AbsConnectivity, JSONFilterMixin,
|
|||
class Meta:
|
||||
unique_together = [('org_id', 'name')]
|
||||
verbose_name = _("Asset")
|
||||
ordering = ["name", ]
|
||||
ordering = []
|
||||
permissions = [
|
||||
('refresh_assethardwareinfo', _('Can refresh asset hardware info')),
|
||||
('test_assetconnectivity', _('Can test asset connectivity')),
|
||||
|
|
|
@ -429,7 +429,7 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
|
|||
|
||||
@classmethod
|
||||
@timeit
|
||||
def get_nodes_all_assets(cls, *nodes):
|
||||
def get_nodes_all_assets(cls, *nodes, distinct=True):
|
||||
from .asset import Asset
|
||||
node_ids = set()
|
||||
descendant_node_query = Q()
|
||||
|
@ -439,7 +439,10 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
|
|||
if descendant_node_query:
|
||||
_ids = Node.objects.order_by().filter(descendant_node_query).values_list('id', flat=True)
|
||||
node_ids.update(_ids)
|
||||
return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct()
|
||||
assets = Asset.objects.order_by().filter(nodes__id__in=node_ids)
|
||||
if distinct:
|
||||
assets = assets.distinct()
|
||||
return assets
|
||||
|
||||
def get_all_asset_ids(self):
|
||||
asset_ids = self.get_all_asset_ids_by_node_key(org_id=self.org_id, node_key=self.key)
|
||||
|
|
|
@ -58,7 +58,7 @@ class DomainListSerializer(DomainSerializer):
|
|||
@classmethod
|
||||
def setup_eager_loading(cls, queryset):
|
||||
queryset = queryset.annotate(
|
||||
assets_amount=Count('assets'),
|
||||
assets_amount=Count('assets', distinct=True),
|
||||
)
|
||||
return queryset
|
||||
|
||||
|
|
|
@ -63,13 +63,13 @@ def on_asset_create(sender, instance=None, created=False, **kwargs):
|
|||
return
|
||||
logger.info("Asset create signal recv: {}".format(instance))
|
||||
|
||||
ensure_asset_has_node(assets=(instance,))
|
||||
ensure_asset_has_node.delay(assets=(instance,))
|
||||
|
||||
# 获取资产硬件信息
|
||||
auto_config = instance.auto_config
|
||||
if auto_config.get('ping_enabled'):
|
||||
logger.debug('Asset {} ping enabled, test connectivity'.format(instance.name))
|
||||
test_assets_connectivity_handler(assets=(instance,))
|
||||
test_assets_connectivity_handler.delay(assets=(instance,))
|
||||
if auto_config.get('gather_facts_enabled'):
|
||||
logger.debug('Asset {} gather facts enabled, gather facts'.format(instance.name))
|
||||
gather_assets_facts_handler(assets=(instance,))
|
||||
|
|
|
@ -2,14 +2,16 @@
|
|||
#
|
||||
from operator import add, sub
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.models.signals import m2m_changed
|
||||
from django.dispatch import receiver
|
||||
|
||||
from assets.models import Asset, Node
|
||||
from common.const.signals import PRE_CLEAR, POST_ADD, PRE_REMOVE
|
||||
from common.decorators import on_transaction_commit, merge_delay_run
|
||||
from common.signals import django_ready
|
||||
from common.utils import get_logger
|
||||
from orgs.utils import tmp_to_org
|
||||
from orgs.utils import tmp_to_org, tmp_to_root_org
|
||||
from ..tasks import check_node_assets_amount_task
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
@ -34,7 +36,7 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
|
|||
node_ids = [instance.id]
|
||||
else:
|
||||
node_ids = list(pk_set)
|
||||
update_nodes_assets_amount(node_ids=node_ids)
|
||||
update_nodes_assets_amount.delay(node_ids=node_ids)
|
||||
|
||||
|
||||
@merge_delay_run(ttl=30)
|
||||
|
@ -52,3 +54,18 @@ def update_nodes_assets_amount(node_ids=()):
|
|||
node.assets_amount = node.get_assets_amount()
|
||||
|
||||
Node.objects.bulk_update(nodes, ['assets_amount'])
|
||||
|
||||
|
||||
@receiver(django_ready)
|
||||
def set_assets_size_to_setting(sender, **kwargs):
|
||||
from assets.models import Asset
|
||||
try:
|
||||
with tmp_to_root_org():
|
||||
amount = Asset.objects.order_by().count()
|
||||
except:
|
||||
amount = 0
|
||||
|
||||
if amount > 20000:
|
||||
settings.ASSET_SIZE = 'large'
|
||||
elif amount > 2000:
|
||||
settings.ASSET_SIZE = 'medium'
|
||||
|
|
|
@ -44,18 +44,18 @@ def on_node_post_create(sender, instance, created, update_fields, **kwargs):
|
|||
need_expire = False
|
||||
|
||||
if need_expire:
|
||||
expire_node_assets_mapping(org_ids=(instance.org_id,))
|
||||
expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
|
||||
|
||||
|
||||
@receiver(post_delete, sender=Node)
|
||||
def on_node_post_delete(sender, instance, **kwargs):
|
||||
expire_node_assets_mapping(org_ids=(instance.org_id,))
|
||||
expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
|
||||
|
||||
|
||||
@receiver(m2m_changed, sender=Asset.nodes.through)
|
||||
def on_node_asset_change(sender, instance, action='pre_remove', **kwargs):
|
||||
if action.startswith('post'):
|
||||
expire_node_assets_mapping(org_ids=(instance.org_id,))
|
||||
expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
|
||||
|
||||
|
||||
@receiver(django_ready)
|
||||
|
|
|
@ -20,6 +20,7 @@ from common.const.http import GET, POST
|
|||
from common.drf.filters import DatetimeRangeFilterBackend
|
||||
from common.permissions import IsServiceAccount
|
||||
from common.plugins.es import QuerySet as ESQuerySet
|
||||
from common.sessions.cache import user_session_manager
|
||||
from common.storage.ftp_file import FTPFileStorageHandler
|
||||
from common.utils import is_uuid, get_logger, lazyproperty
|
||||
from orgs.mixins.api import OrgReadonlyModelViewSet, OrgModelViewSet
|
||||
|
@ -30,7 +31,7 @@ from terminal.models import default_storage
|
|||
from users.models import User
|
||||
from .backends import TYPE_ENGINE_MAPPING
|
||||
from .const import ActivityChoices
|
||||
from .filters import UserSessionFilterSet
|
||||
from .filters import UserSessionFilterSet, OperateLogFilterSet
|
||||
from .models import (
|
||||
FTPLog, UserLoginLog, OperateLog, PasswordChangeLog,
|
||||
ActivityLog, JobLog, UserSession
|
||||
|
@ -204,10 +205,7 @@ class OperateLogViewSet(OrgReadonlyModelViewSet):
|
|||
date_range_filter_fields = [
|
||||
('datetime', ('date_from', 'date_to'))
|
||||
]
|
||||
filterset_fields = [
|
||||
'user', 'action', 'resource_type', 'resource',
|
||||
'remote_addr'
|
||||
]
|
||||
filterset_class = OperateLogFilterSet
|
||||
search_fields = ['resource', 'user']
|
||||
ordering = ['-datetime']
|
||||
|
||||
|
@ -289,8 +287,7 @@ class UserSessionViewSet(CommonApiMixin, viewsets.ModelViewSet):
|
|||
return Response(status=status.HTTP_200_OK)
|
||||
|
||||
keys = queryset.values_list('key', flat=True)
|
||||
session_store_cls = import_module(settings.SESSION_ENGINE).SessionStore
|
||||
for key in keys:
|
||||
session_store_cls(key).delete()
|
||||
user_session_manager.decrement_or_remove(key)
|
||||
queryset.delete()
|
||||
return Response(status=status.HTTP_200_OK)
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from django.core.cache import cache
|
||||
from django.apps import apps
|
||||
from django.utils import translation
|
||||
|
||||
from django_filters import rest_framework as drf_filters
|
||||
from rest_framework import filters
|
||||
from rest_framework.compat import coreapi, coreschema
|
||||
|
||||
from common.drf.filters import BaseFilterSet
|
||||
from notifications.ws import WS_SESSION_KEY
|
||||
from common.sessions.cache import user_session_manager
|
||||
from orgs.utils import current_org
|
||||
from .models import UserSession
|
||||
from .models import UserSession, OperateLog
|
||||
|
||||
__all__ = ['CurrentOrgMembersFilter']
|
||||
|
||||
|
@ -41,15 +42,32 @@ class UserSessionFilterSet(BaseFilterSet):
|
|||
|
||||
@staticmethod
|
||||
def filter_is_active(queryset, name, is_active):
|
||||
redis_client = cache.client.get_client()
|
||||
members = redis_client.smembers(WS_SESSION_KEY)
|
||||
members = [member.decode('utf-8') for member in members]
|
||||
keys = user_session_manager.get_active_keys()
|
||||
if is_active:
|
||||
queryset = queryset.filter(key__in=members)
|
||||
queryset = queryset.filter(key__in=keys)
|
||||
else:
|
||||
queryset = queryset.exclude(key__in=members)
|
||||
queryset = queryset.exclude(key__in=keys)
|
||||
return queryset
|
||||
|
||||
class Meta:
|
||||
model = UserSession
|
||||
fields = ['id', 'ip', 'city', 'type']
|
||||
|
||||
|
||||
class OperateLogFilterSet(BaseFilterSet):
|
||||
resource_type = drf_filters.CharFilter(method='filter_resource_type')
|
||||
|
||||
@staticmethod
|
||||
def filter_resource_type(queryset, name, resource_type):
|
||||
current_lang = translation.get_language()
|
||||
with translation.override(current_lang):
|
||||
mapper = {str(m._meta.verbose_name): m._meta.verbose_name_raw for m in apps.get_models()}
|
||||
tp = mapper.get(resource_type)
|
||||
queryset = queryset.filter(resource_type=tp)
|
||||
return queryset
|
||||
|
||||
class Meta:
|
||||
model = OperateLog
|
||||
fields = [
|
||||
'user', 'action', 'resource', 'remote_addr'
|
||||
]
|
||||
|
|
|
@ -4,15 +4,15 @@ from datetime import timedelta
|
|||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache import caches, cache
|
||||
from django.core.cache import caches
|
||||
from django.db import models
|
||||
from django.db.models import Q
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext, gettext_lazy as _
|
||||
|
||||
from common.db.encoder import ModelJSONFieldEncoder
|
||||
from common.sessions.cache import user_session_manager
|
||||
from common.utils import lazyproperty, i18n_trans
|
||||
from notifications.ws import WS_SESSION_KEY
|
||||
from ops.models import JobExecution
|
||||
from orgs.mixins.models import OrgModelMixin, Organization
|
||||
from orgs.utils import current_org
|
||||
|
@ -278,8 +278,7 @@ class UserSession(models.Model):
|
|||
|
||||
@property
|
||||
def is_active(self):
|
||||
redis_client = cache.client.get_client()
|
||||
return redis_client.sismember(WS_SESSION_KEY, self.key)
|
||||
return user_session_manager.check_active(self.key)
|
||||
|
||||
@property
|
||||
def date_expired(self):
|
||||
|
|
|
@ -23,7 +23,7 @@ class JobLogSerializer(JobExecutionSerializer):
|
|||
class Meta:
|
||||
model = models.JobLog
|
||||
read_only_fields = [
|
||||
"id", "material", "time_cost", 'date_start',
|
||||
"id", "material", 'job_type', "time_cost", 'date_start',
|
||||
'date_finished', 'date_created',
|
||||
'is_finished', 'is_success',
|
||||
'task_id', 'creator_name'
|
||||
|
|
|
@ -19,7 +19,7 @@ from ops.celery.decorator import (
|
|||
from ops.models import CeleryTaskExecution
|
||||
from terminal.models import Session, Command
|
||||
from terminal.backends import server_replay_storage
|
||||
from .models import UserLoginLog, OperateLog, FTPLog, ActivityLog
|
||||
from .models import UserLoginLog, OperateLog, FTPLog, ActivityLog, PasswordChangeLog
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -38,6 +38,14 @@ def clean_operation_log_period():
|
|||
OperateLog.objects.filter(datetime__lt=expired_day).delete()
|
||||
|
||||
|
||||
def clean_password_change_log_period():
|
||||
now = timezone.now()
|
||||
days = get_log_keep_day('PASSWORD_CHANGE_LOG_KEEP_DAYS')
|
||||
expired_day = now - datetime.timedelta(days=days)
|
||||
PasswordChangeLog.objects.filter(datetime__lt=expired_day).delete()
|
||||
logger.info("Clean password change log done")
|
||||
|
||||
|
||||
def clean_activity_log_period():
|
||||
now = timezone.now()
|
||||
days = get_log_keep_day('ACTIVITY_LOG_KEEP_DAYS')
|
||||
|
@ -109,6 +117,7 @@ def clean_audits_log_period():
|
|||
clean_activity_log_period()
|
||||
clean_celery_tasks_period()
|
||||
clean_expired_session_period()
|
||||
clean_password_change_log_period()
|
||||
|
||||
|
||||
@shared_task(verbose_name=_('Upload FTP file to external storage'))
|
||||
|
|
|
@ -205,7 +205,7 @@ class RDPFileClientProtocolURLMixin:
|
|||
return data
|
||||
|
||||
def get_smart_endpoint(self, protocol, asset=None):
|
||||
endpoint = Endpoint.match_by_instance_label(asset, protocol)
|
||||
endpoint = Endpoint.match_by_instance_label(asset, protocol, self.request)
|
||||
if not endpoint:
|
||||
target_ip = asset.get_target_ip() if asset else ''
|
||||
endpoint = EndpointRule.match_endpoint(
|
||||
|
@ -443,7 +443,7 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
|
|||
self._record_operate_log(acl, asset)
|
||||
for reviewer in reviewers:
|
||||
AssetLoginReminderMsg(
|
||||
reviewer, asset, user, self.input_username
|
||||
reviewer, asset, user, account, self.input_username
|
||||
).publish_async()
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
|
|
|
@ -10,6 +10,7 @@ from rest_framework import authentication, exceptions
|
|||
from common.auth import signature
|
||||
from common.decorators import merge_delay_run
|
||||
from common.utils import get_object_or_none, get_request_ip_or_data, contains_ip
|
||||
from users.models import User
|
||||
from ..models import AccessKey, PrivateToken
|
||||
|
||||
|
||||
|
@ -19,22 +20,23 @@ def date_more_than(d, seconds):
|
|||
|
||||
@merge_delay_run(ttl=60)
|
||||
def update_token_last_used(tokens=()):
|
||||
for token in tokens:
|
||||
token.date_last_used = timezone.now()
|
||||
token.save(update_fields=['date_last_used'])
|
||||
access_keys_ids = [token.id for token in tokens if isinstance(token, AccessKey)]
|
||||
private_token_keys = [token.key for token in tokens if isinstance(token, PrivateToken)]
|
||||
if len(access_keys_ids) > 0:
|
||||
AccessKey.objects.filter(id__in=access_keys_ids).update(date_last_used=timezone.now())
|
||||
if len(private_token_keys) > 0:
|
||||
PrivateToken.objects.filter(key__in=private_token_keys).update(date_last_used=timezone.now())
|
||||
|
||||
|
||||
@merge_delay_run(ttl=60)
|
||||
def update_user_last_used(users=()):
|
||||
for user in users:
|
||||
user.date_api_key_last_used = timezone.now()
|
||||
user.save(update_fields=['date_api_key_last_used'])
|
||||
User.objects.filter(id__in=users).update(date_api_key_last_used=timezone.now())
|
||||
|
||||
|
||||
def after_authenticate_update_date(user, token=None):
|
||||
update_user_last_used(users=(user,))
|
||||
update_user_last_used.delay(users=(user.id,))
|
||||
if token:
|
||||
update_token_last_used(tokens=(token,))
|
||||
update_token_last_used.delay(tokens=(token,))
|
||||
|
||||
|
||||
class AccessTokenAuthentication(authentication.BaseAuthentication):
|
||||
|
|
|
@ -98,16 +98,19 @@ class OAuth2Backend(JMSModelBackend):
|
|||
access_token_url = '{url}{separator}{query}'.format(
|
||||
url=settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT, separator=separator, query=urlencode(query_dict)
|
||||
)
|
||||
# token_method -> get, post(post_data), post_json
|
||||
token_method = settings.AUTH_OAUTH2_ACCESS_TOKEN_METHOD.lower()
|
||||
requests_func = getattr(requests, token_method, requests.get)
|
||||
logger.debug(log_prompt.format('Call the access token endpoint[method: %s]' % token_method))
|
||||
headers = {
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
if token_method == 'post':
|
||||
access_token_response = requests_func(access_token_url, headers=headers, data=query_dict)
|
||||
if token_method.startswith('post'):
|
||||
body_key = 'json' if token_method.endswith('json') else 'data'
|
||||
access_token_response = requests.post(
|
||||
access_token_url, headers=headers, **{body_key: query_dict}
|
||||
)
|
||||
else:
|
||||
access_token_response = requests_func(access_token_url, headers=headers)
|
||||
access_token_response = requests.get(access_token_url, headers=headers)
|
||||
try:
|
||||
access_token_response.raise_for_status()
|
||||
access_token_response_data = access_token_response.json()
|
||||
|
|
|
@ -18,7 +18,7 @@ class EncryptedField(forms.CharField):
|
|||
|
||||
class UserLoginForm(forms.Form):
|
||||
days_auto_login = int(settings.SESSION_COOKIE_AGE / 3600 / 24)
|
||||
disable_days_auto_login = settings.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE \
|
||||
disable_days_auto_login = settings.SESSION_EXPIRE_AT_BROWSER_CLOSE \
|
||||
or days_auto_login < 1
|
||||
|
||||
username = forms.CharField(
|
||||
|
|
|
@ -142,23 +142,7 @@ class SessionCookieMiddleware(MiddlewareMixin):
|
|||
return response
|
||||
response.set_cookie(key, value)
|
||||
|
||||
@staticmethod
|
||||
def set_cookie_session_expire(request, response):
|
||||
if not request.session.get('auth_session_expiration_required'):
|
||||
return
|
||||
value = 'age'
|
||||
if settings.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE or \
|
||||
not request.session.get('auto_login', False):
|
||||
value = 'close'
|
||||
|
||||
age = request.session.get_expiry_age()
|
||||
expire_timestamp = request.session.get_expiry_date().timestamp()
|
||||
response.set_cookie('jms_session_expire_timestamp', expire_timestamp)
|
||||
response.set_cookie('jms_session_expire', value, max_age=age)
|
||||
request.session.pop('auth_session_expiration_required', None)
|
||||
|
||||
def process_response(self, request, response: HttpResponse):
|
||||
self.set_cookie_session_prefix(request, response)
|
||||
self.set_cookie_public_key(request, response)
|
||||
self.set_cookie_session_expire(request, response)
|
||||
return response
|
||||
|
|
|
@ -37,9 +37,6 @@ def on_user_auth_login_success(sender, user, request, **kwargs):
|
|||
UserSession.objects.filter(key=session_key).delete()
|
||||
cache.set(lock_key, request.session.session_key, None)
|
||||
|
||||
# 标记登录,设置 cookie,前端可以控制刷新, Middleware 会拦截这个生成 cookie
|
||||
request.session['auth_session_expiration_required'] = 1
|
||||
|
||||
|
||||
@receiver(cas_user_authenticated)
|
||||
def on_cas_user_login_success(sender, request, user, **kwargs):
|
||||
|
|
|
@ -407,6 +407,15 @@
|
|||
$('#password-hidden').val(passwordEncrypted); //返回给密码输入input
|
||||
$('#login-form').submit(); //post提交
|
||||
}
|
||||
function checkHealth() {
|
||||
let url = "{% url 'health' %}";
|
||||
requestApi({
|
||||
url: url,
|
||||
method: "GET",
|
||||
flash_message: false,
|
||||
})
|
||||
}
|
||||
setInterval(checkHealth, 30 * 1000);
|
||||
</script>
|
||||
</html>
|
||||
|
||||
|
|
|
@ -70,11 +70,12 @@ class DingTalkQRMixin(DingTalkBaseMixin, View):
|
|||
self.request.session[DINGTALK_STATE_SESSION_KEY] = state
|
||||
|
||||
params = {
|
||||
'appid': settings.DINGTALK_APPKEY,
|
||||
'client_id': settings.DINGTALK_APPKEY,
|
||||
'response_type': 'code',
|
||||
'scope': 'snsapi_login',
|
||||
'scope': 'openid',
|
||||
'state': state,
|
||||
'redirect_uri': redirect_uri,
|
||||
'prompt': 'consent'
|
||||
}
|
||||
url = URL.QR_CONNECT + '?' + urlencode(params)
|
||||
return url
|
||||
|
|
|
@ -19,3 +19,17 @@ class Status(models.TextChoices):
|
|||
failed = 'failed', _("Failed")
|
||||
error = 'error', _("Error")
|
||||
canceled = 'canceled', _("Canceled")
|
||||
|
||||
|
||||
COUNTRY_CALLING_CODES = [
|
||||
{'name': 'China(中国)', 'value': '+86'},
|
||||
{'name': 'HongKong(中国香港)', 'value': '+852'},
|
||||
{'name': 'Macao(中国澳门)', 'value': '+853'},
|
||||
{'name': 'Taiwan(中国台湾)', 'value': '+886'},
|
||||
{'name': 'America(America)', 'value': '+1'}, {'name': 'Russia(Россия)', 'value': '+7'},
|
||||
{'name': 'France(français)', 'value': '+33'},
|
||||
{'name': 'Britain(Britain)', 'value': '+44'},
|
||||
{'name': 'Germany(Deutschland)', 'value': '+49'},
|
||||
{'name': 'Japan(日本)', 'value': '+81'}, {'name': 'Korea(한국)', 'value': '+82'},
|
||||
{'name': 'India(भारत)', 'value': '+91'}
|
||||
]
|
||||
|
|
|
@ -362,11 +362,15 @@ class RelatedManager:
|
|||
if name is None or val is None:
|
||||
continue
|
||||
|
||||
if custom_attr_filter:
|
||||
custom_filter_q = None
|
||||
spec_attr_filter = getattr(to_model, "get_{}_filter_attr_q".format(name), None)
|
||||
if spec_attr_filter:
|
||||
custom_filter_q = spec_attr_filter(val, match)
|
||||
elif custom_attr_filter:
|
||||
custom_filter_q = custom_attr_filter(name, val, match)
|
||||
if custom_filter_q:
|
||||
filters.append(custom_filter_q)
|
||||
continue
|
||||
if custom_filter_q:
|
||||
filters.append(custom_filter_q)
|
||||
continue
|
||||
|
||||
if match == 'ip_in':
|
||||
q = cls.get_ip_in_q(name, val)
|
||||
|
@ -464,11 +468,15 @@ class JSONManyToManyDescriptor:
|
|||
rule_value = rule.get('value', '')
|
||||
rule_match = rule.get('match', 'exact')
|
||||
|
||||
if custom_attr_filter:
|
||||
q = custom_attr_filter(rule['name'], rule_value, rule_match)
|
||||
if q:
|
||||
custom_q &= q
|
||||
continue
|
||||
custom_filter_q = None
|
||||
spec_attr_filter = getattr(to_model, "get_filter_{}_attr_q".format(rule['name']), None)
|
||||
if spec_attr_filter:
|
||||
custom_filter_q = spec_attr_filter(rule_value, rule_match)
|
||||
elif custom_attr_filter:
|
||||
custom_filter_q = custom_attr_filter(rule['name'], rule_value, rule_match)
|
||||
if custom_filter_q:
|
||||
custom_q &= custom_filter_q
|
||||
continue
|
||||
|
||||
if rule_match == 'in':
|
||||
res &= value in rule_value or '*' in rule_value
|
||||
|
@ -517,7 +525,6 @@ class JSONManyToManyDescriptor:
|
|||
res &= rule_value.issubset(value)
|
||||
else:
|
||||
res &= bool(value & rule_value)
|
||||
|
||||
else:
|
||||
logging.error("unknown match: {}".format(rule['match']))
|
||||
res &= False
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
@ -101,7 +102,11 @@ def run_debouncer_func(cache_key, org, ttl, func, *args, **kwargs):
|
|||
first_run_time = current
|
||||
|
||||
if current - first_run_time > ttl:
|
||||
_loop_debouncer_func_args_cache.pop(cache_key, None)
|
||||
_loop_debouncer_func_task_time_cache.pop(cache_key, None)
|
||||
executor.submit(run_func_partial, *args, **kwargs)
|
||||
logger.debug('pid {} executor submit run {}'.format(
|
||||
os.getpid(), func.__name__, ))
|
||||
return
|
||||
|
||||
loop = _loop_thread.get_loop()
|
||||
|
@ -133,13 +138,26 @@ class Debouncer(object):
|
|||
return await self.loop.run_in_executor(self.executor, func)
|
||||
|
||||
|
||||
ignore_err_exceptions = (
|
||||
"(3101, 'Plugin instructed the server to rollback the current transaction.')",
|
||||
)
|
||||
|
||||
|
||||
def _run_func_with_org(key, org, func, *args, **kwargs):
|
||||
from orgs.utils import set_current_org
|
||||
try:
|
||||
set_current_org(org)
|
||||
func(*args, **kwargs)
|
||||
with transaction.atomic():
|
||||
set_current_org(org)
|
||||
func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error('delay run error: %s' % e)
|
||||
msg = str(e)
|
||||
log_func = logger.error
|
||||
if msg in ignore_err_exceptions:
|
||||
log_func = logger.info
|
||||
pid = os.getpid()
|
||||
thread_name = threading.current_thread()
|
||||
log_func('pid {} thread {} delay run {} error: {}'.format(
|
||||
pid, thread_name, func.__name__, msg))
|
||||
_loop_debouncer_func_task_cache.pop(key, None)
|
||||
_loop_debouncer_func_args_cache.pop(key, None)
|
||||
_loop_debouncer_func_task_time_cache.pop(key, None)
|
||||
|
@ -181,6 +199,32 @@ def merge_delay_run(ttl=5, key=None):
|
|||
:return:
|
||||
"""
|
||||
|
||||
def delay(func, *args, **kwargs):
|
||||
from orgs.utils import get_current_org
|
||||
suffix_key_func = key if key else default_suffix_key
|
||||
org = get_current_org()
|
||||
func_name = f'{func.__module__}_{func.__name__}'
|
||||
key_suffix = suffix_key_func(*args, **kwargs)
|
||||
cache_key = f'MERGE_DELAY_RUN_{func_name}_{key_suffix}'
|
||||
cache_kwargs = _loop_debouncer_func_args_cache.get(cache_key, {})
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if not isinstance(v, (tuple, list, set)):
|
||||
raise ValueError('func kwargs value must be list or tuple: %s %s' % (func.__name__, v))
|
||||
v = set(v)
|
||||
if k not in cache_kwargs:
|
||||
cache_kwargs[k] = v
|
||||
else:
|
||||
cache_kwargs[k] = cache_kwargs[k].union(v)
|
||||
_loop_debouncer_func_args_cache[cache_key] = cache_kwargs
|
||||
run_debouncer_func(cache_key, org, ttl, func, *args, **cache_kwargs)
|
||||
|
||||
def apply(func, sync=False, *args, **kwargs):
|
||||
if sync:
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
return delay(func, *args, **kwargs)
|
||||
|
||||
def inner(func):
|
||||
sigs = inspect.signature(func)
|
||||
if len(sigs.parameters) != 1:
|
||||
|
@ -188,27 +232,12 @@ def merge_delay_run(ttl=5, key=None):
|
|||
param = list(sigs.parameters.values())[0]
|
||||
if not isinstance(param.default, tuple):
|
||||
raise ValueError('func default must be tuple: %s' % param.default)
|
||||
suffix_key_func = key if key else default_suffix_key
|
||||
func.delay = functools.partial(delay, func)
|
||||
func.apply = functools.partial(apply, func)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
from orgs.utils import get_current_org
|
||||
org = get_current_org()
|
||||
func_name = f'{func.__module__}_{func.__name__}'
|
||||
key_suffix = suffix_key_func(*args, **kwargs)
|
||||
cache_key = f'MERGE_DELAY_RUN_{func_name}_{key_suffix}'
|
||||
cache_kwargs = _loop_debouncer_func_args_cache.get(cache_key, {})
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if not isinstance(v, (tuple, list, set)):
|
||||
raise ValueError('func kwargs value must be list or tuple: %s %s' % (func.__name__, v))
|
||||
v = set(v)
|
||||
if k not in cache_kwargs:
|
||||
cache_kwargs[k] = v
|
||||
else:
|
||||
cache_kwargs[k] = cache_kwargs[k].union(v)
|
||||
_loop_debouncer_func_args_cache[cache_key] = cache_kwargs
|
||||
run_debouncer_func(cache_key, org, ttl, func, *args, **cache_kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import logging
|
|||
|
||||
from django.core.cache import cache
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db.models import Q, Count
|
||||
from django.db.models import Q
|
||||
from django_filters import rest_framework as drf_filters
|
||||
from rest_framework import filters
|
||||
from rest_framework.compat import coreapi, coreschema
|
||||
|
@ -180,36 +180,30 @@ class LabelFilterBackend(filters.BaseFilterBackend):
|
|||
]
|
||||
|
||||
@staticmethod
|
||||
def filter_resources(resources, labels_id):
|
||||
def parse_label_ids(labels_id):
|
||||
from labels.models import Label
|
||||
label_ids = [i.strip() for i in labels_id.split(',')]
|
||||
cleaned = []
|
||||
|
||||
args = []
|
||||
for label_id in label_ids:
|
||||
kwargs = {}
|
||||
if ':' in label_id:
|
||||
k, v = label_id.split(':', 1)
|
||||
kwargs['label__name'] = k.strip()
|
||||
kwargs['name'] = k.strip()
|
||||
if v != '*':
|
||||
kwargs['label__value'] = v.strip()
|
||||
kwargs['value'] = v.strip()
|
||||
args.append(kwargs)
|
||||
else:
|
||||
kwargs['label_id'] = label_id
|
||||
args.append(kwargs)
|
||||
cleaned.append(label_id)
|
||||
|
||||
if len(args) == 1:
|
||||
resources = resources.filter(**args[0])
|
||||
return resources
|
||||
|
||||
q = Q()
|
||||
for kwarg in args:
|
||||
q |= Q(**kwarg)
|
||||
|
||||
resources = resources.filter(q) \
|
||||
.values('res_id') \
|
||||
.order_by('res_id') \
|
||||
.annotate(count=Count('res_id')) \
|
||||
.values('res_id', 'count') \
|
||||
.filter(count=len(args))
|
||||
return resources
|
||||
if len(args) != 0:
|
||||
q = Q()
|
||||
for kwarg in args:
|
||||
q |= Q(**kwarg)
|
||||
ids = Label.objects.filter(q).values_list('id', flat=True)
|
||||
cleaned.extend(list(ids))
|
||||
return cleaned
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
labels_id = request.query_params.get('labels')
|
||||
|
@ -230,7 +224,8 @@ class LabelFilterBackend(filters.BaseFilterBackend):
|
|||
resources = labeled_resource_cls.objects.filter(
|
||||
res_type__app_label=app_label, res_type__model=model_name,
|
||||
)
|
||||
resources = self.filter_resources(resources, labels_id)
|
||||
label_ids = self.parse_label_ids(labels_id)
|
||||
resources = model.filter_resources_by_labels(resources, label_ids)
|
||||
res_ids = resources.values_list('res_id', flat=True)
|
||||
queryset = queryset.filter(id__in=set(res_ids))
|
||||
return queryset
|
||||
|
|
|
@ -87,7 +87,7 @@ class BaseFileRenderer(BaseRenderer):
|
|||
if value is None:
|
||||
return '-'
|
||||
pk = str(value.get('id', '') or value.get('pk', ''))
|
||||
name = value.get('name') or value.get('display_name', '')
|
||||
name = value.get('display_name', '') or value.get('name', '')
|
||||
return '{}({})'.format(name, pk)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -28,9 +28,10 @@ class ErrorCode:
|
|||
|
||||
|
||||
class URL:
|
||||
QR_CONNECT = 'https://oapi.dingtalk.com/connect/qrconnect'
|
||||
QR_CONNECT = 'https://login.dingtalk.com/oauth2/auth'
|
||||
OAUTH_CONNECT = 'https://oapi.dingtalk.com/connect/oauth2/sns_authorize'
|
||||
GET_USER_INFO_BY_CODE = 'https://oapi.dingtalk.com/sns/getuserinfo_bycode'
|
||||
GET_USER_ACCESSTOKEN = 'https://api.dingtalk.com/v1.0/oauth2/userAccessToken'
|
||||
GET_USER_INFO = 'https://api.dingtalk.com/v1.0/contact/users/me'
|
||||
GET_TOKEN = 'https://oapi.dingtalk.com/gettoken'
|
||||
SEND_MESSAGE_BY_TEMPLATE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/sendbytemplate'
|
||||
SEND_MESSAGE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/asyncsend_v2'
|
||||
|
@ -72,8 +73,9 @@ class DingTalkRequests(BaseRequest):
|
|||
def get(self, url, params=None,
|
||||
with_token=False, with_sign=False,
|
||||
check_errcode_is_0=True,
|
||||
**kwargs):
|
||||
**kwargs) -> dict:
|
||||
pass
|
||||
|
||||
get = as_request(get)
|
||||
|
||||
def post(self, url, json=None, params=None,
|
||||
|
@ -81,6 +83,7 @@ class DingTalkRequests(BaseRequest):
|
|||
check_errcode_is_0=True,
|
||||
**kwargs) -> dict:
|
||||
pass
|
||||
|
||||
post = as_request(post)
|
||||
|
||||
def _add_sign(self, kwargs: dict):
|
||||
|
@ -123,17 +126,22 @@ class DingTalk:
|
|||
)
|
||||
|
||||
def get_userinfo_bycode(self, code):
|
||||
# https://developers.dingtalk.com/document/app/obtain-the-user-information-based-on-the-sns-temporary-authorization?spm=ding_open_doc.document.0.0.3a256573y8Y7yg#topic-1995619
|
||||
body = {
|
||||
"tmp_auth_code": code
|
||||
'clientId': self._appid,
|
||||
'clientSecret': self._appsecret,
|
||||
'code': code,
|
||||
'grantType': 'authorization_code'
|
||||
}
|
||||
data = self._request.post(URL.GET_USER_ACCESSTOKEN, json=body, check_errcode_is_0=False)
|
||||
token = data['accessToken']
|
||||
|
||||
data = self._request.post(URL.GET_USER_INFO_BY_CODE, json=body, with_sign=True)
|
||||
return data['user_info']
|
||||
user = self._request.get(URL.GET_USER_INFO,
|
||||
headers={'x-acs-dingtalk-access-token': token}, check_errcode_is_0=False)
|
||||
return user
|
||||
|
||||
def get_user_id_by_code(self, code):
|
||||
user_info = self.get_userinfo_bycode(code)
|
||||
unionid = user_info['unionid']
|
||||
unionid = user_info['unionId']
|
||||
userid = self.get_userid_by_unionid(unionid)
|
||||
return userid, None
|
||||
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
import re
|
||||
|
||||
from django.contrib.sessions.backends.cache import (
|
||||
SessionStore as DjangoSessionStore
|
||||
)
|
||||
from django.core.cache import cache
|
||||
|
||||
from jumpserver.utils import get_current_request
|
||||
|
||||
|
||||
class SessionStore(DjangoSessionStore):
|
||||
ignore_urls = [
|
||||
r'^/api/v1/users/profile/'
|
||||
]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_pattern = re.compile('|'.join(self.ignore_urls))
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
request = get_current_request()
|
||||
if request is None or not self.ignore_pattern.match(request.path):
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class RedisUserSessionManager:
|
||||
JMS_SESSION_KEY = 'jms_session_key'
|
||||
|
||||
def __init__(self):
|
||||
self.client = cache.client.get_client()
|
||||
|
||||
def add_or_increment(self, session_key):
|
||||
self.client.hincrby(self.JMS_SESSION_KEY, session_key, 1)
|
||||
|
||||
def decrement_or_remove(self, session_key):
|
||||
new_count = self.client.hincrby(self.JMS_SESSION_KEY, session_key, -1)
|
||||
if new_count <= 0:
|
||||
self.client.hdel(self.JMS_SESSION_KEY, session_key)
|
||||
|
||||
def check_active(self, session_key):
|
||||
count = self.client.hget(self.JMS_SESSION_KEY, session_key)
|
||||
count = 0 if count is None else int(count.decode('utf-8'))
|
||||
return count > 0
|
||||
|
||||
def get_active_keys(self):
|
||||
session_keys = []
|
||||
for k, v in self.client.hgetall(self.JMS_SESSION_KEY).items():
|
||||
count = int(v.decode('utf-8'))
|
||||
if count <= 0:
|
||||
continue
|
||||
key = k.decode('utf-8')
|
||||
session_keys.append(key)
|
||||
return session_keys
|
||||
|
||||
|
||||
user_session_manager = RedisUserSessionManager()
|
|
@ -69,7 +69,7 @@ def digest_sql_query():
|
|||
|
||||
for query in queries:
|
||||
sql = query['sql']
|
||||
print(" # {}: {}".format(query['time'], sql[:1000]))
|
||||
print(" # {}: {}".format(query['time'], sql[:1000]))
|
||||
if len(queries) < 3:
|
||||
continue
|
||||
print("- Table: {}".format(table_name))
|
||||
|
|
|
@ -21,6 +21,8 @@ def encrypt_and_compress_zip_file(filename, secret_password, encrypted_filenames
|
|||
with pyzipper.AESZipFile(
|
||||
filename, 'w', compression=pyzipper.ZIP_LZMA, encryption=pyzipper.WZ_AES
|
||||
) as zf:
|
||||
if secret_password and isinstance(secret_password, str):
|
||||
secret_password = secret_password.encode('utf8')
|
||||
zf.setpassword(secret_password)
|
||||
for encrypted_filename in encrypted_filenames:
|
||||
with open(encrypted_filename, 'rb') as f:
|
||||
|
|
|
@ -547,7 +547,6 @@ class Config(dict):
|
|||
'REFERER_CHECK_ENABLED': False,
|
||||
'SESSION_ENGINE': 'cache',
|
||||
'SESSION_SAVE_EVERY_REQUEST': True,
|
||||
'SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE': False,
|
||||
'SERVER_REPLAY_STORAGE': {},
|
||||
'SECURITY_DATA_CRYPTO_ALGO': None,
|
||||
'GMSSL_ENABLED': False,
|
||||
|
@ -564,8 +563,10 @@ class Config(dict):
|
|||
'FTP_LOG_KEEP_DAYS': 180,
|
||||
'CLOUD_SYNC_TASK_EXECUTION_KEEP_DAYS': 180,
|
||||
'JOB_EXECUTION_KEEP_DAYS': 180,
|
||||
'PASSWORD_CHANGE_LOG_KEEP_DAYS': 999,
|
||||
|
||||
'TICKETS_ENABLED': True,
|
||||
'TICKETS_DIRECT_APPROVE': False,
|
||||
|
||||
# 废弃的
|
||||
'DEFAULT_ORG_SHOW_ALL_USERS': True,
|
||||
|
@ -606,7 +607,9 @@ class Config(dict):
|
|||
'GPT_MODEL': 'gpt-3.5-turbo',
|
||||
'VIRTUAL_APP_ENABLED': False,
|
||||
|
||||
'FILE_UPLOAD_SIZE_LIMIT_MB': 200
|
||||
'FILE_UPLOAD_SIZE_LIMIT_MB': 200,
|
||||
|
||||
'TICKET_APPLY_ASSET_SCOPE': 'all'
|
||||
}
|
||||
|
||||
old_config_map = {
|
||||
|
@ -701,7 +704,8 @@ class Config(dict):
|
|||
|
||||
def compatible_redis(self):
|
||||
redis_config = {
|
||||
'REDIS_PASSWORD': quote(str(self.REDIS_PASSWORD)),
|
||||
'REDIS_PASSWORD': str(self.REDIS_PASSWORD),
|
||||
'REDIS_PASSWORD_QUOTE': quote(str(self.REDIS_PASSWORD)),
|
||||
}
|
||||
for key, value in redis_config.items():
|
||||
self[key] = value
|
||||
|
|
|
@ -66,11 +66,6 @@ class RequestMiddleware:
|
|||
def __call__(self, request):
|
||||
set_current_request(request)
|
||||
response = self.get_response(request)
|
||||
is_request_api = request.path.startswith('/api')
|
||||
if not settings.SESSION_EXPIRE_AT_BROWSER_CLOSE and \
|
||||
not is_request_api:
|
||||
age = request.session.get_expiry_age()
|
||||
request.session.set_expiry(age)
|
||||
return response
|
||||
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
path_perms_map = {
|
||||
'xpack': '*',
|
||||
'settings': '*',
|
||||
'img': '*',
|
||||
'replay': 'default',
|
||||
'applets': 'terminal.view_applet',
|
||||
'virtual_apps': 'terminal.view_virtualapp',
|
||||
|
|
|
@ -234,11 +234,9 @@ CSRF_COOKIE_NAME = '{}csrftoken'.format(SESSION_COOKIE_NAME_PREFIX)
|
|||
SESSION_COOKIE_NAME = '{}sessionid'.format(SESSION_COOKIE_NAME_PREFIX)
|
||||
|
||||
SESSION_COOKIE_AGE = CONFIG.SESSION_COOKIE_AGE
|
||||
SESSION_EXPIRE_AT_BROWSER_CLOSE = True
|
||||
# 自定义的配置,SESSION_EXPIRE_AT_BROWSER_CLOSE 始终为 True, 下面这个来控制是否强制关闭后过期 cookie
|
||||
SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE = CONFIG.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE
|
||||
SESSION_SAVE_EVERY_REQUEST = CONFIG.SESSION_SAVE_EVERY_REQUEST
|
||||
SESSION_ENGINE = "django.contrib.sessions.backends.{}".format(CONFIG.SESSION_ENGINE)
|
||||
SESSION_EXPIRE_AT_BROWSER_CLOSE = CONFIG.SESSION_EXPIRE_AT_BROWSER_CLOSE
|
||||
SESSION_ENGINE = "common.sessions.{}".format(CONFIG.SESSION_ENGINE)
|
||||
|
||||
MESSAGE_STORAGE = 'django.contrib.messages.storage.cookie.CookieStorage'
|
||||
# Database
|
||||
|
@ -408,7 +406,7 @@ if REDIS_SENTINEL_SERVICE_NAME and REDIS_SENTINELS:
|
|||
else:
|
||||
REDIS_LOCATION_NO_DB = '%(protocol)s://:%(password)s@%(host)s:%(port)s/{}' % {
|
||||
'protocol': REDIS_PROTOCOL,
|
||||
'password': CONFIG.REDIS_PASSWORD,
|
||||
'password': CONFIG.REDIS_PASSWORD_QUOTE,
|
||||
'host': CONFIG.REDIS_HOST,
|
||||
'port': CONFIG.REDIS_PORT,
|
||||
}
|
||||
|
|
|
@ -122,11 +122,11 @@ WS_LISTEN_PORT = CONFIG.WS_LISTEN_PORT
|
|||
LOGIN_LOG_KEEP_DAYS = CONFIG.LOGIN_LOG_KEEP_DAYS
|
||||
TASK_LOG_KEEP_DAYS = CONFIG.TASK_LOG_KEEP_DAYS
|
||||
OPERATE_LOG_KEEP_DAYS = CONFIG.OPERATE_LOG_KEEP_DAYS
|
||||
PASSWORD_CHANGE_LOG_KEEP_DAYS = CONFIG.PASSWORD_CHANGE_LOG_KEEP_DAYS
|
||||
ACTIVITY_LOG_KEEP_DAYS = CONFIG.ACTIVITY_LOG_KEEP_DAYS
|
||||
FTP_LOG_KEEP_DAYS = CONFIG.FTP_LOG_KEEP_DAYS
|
||||
CLOUD_SYNC_TASK_EXECUTION_KEEP_DAYS = CONFIG.CLOUD_SYNC_TASK_EXECUTION_KEEP_DAYS
|
||||
JOB_EXECUTION_KEEP_DAYS = CONFIG.JOB_EXECUTION_KEEP_DAYS
|
||||
|
||||
ORG_CHANGE_TO_URL = CONFIG.ORG_CHANGE_TO_URL
|
||||
WINDOWS_SKIP_ALL_MANUAL_PASSWORD = CONFIG.WINDOWS_SKIP_ALL_MANUAL_PASSWORD
|
||||
|
||||
|
@ -137,6 +137,7 @@ CHANGE_AUTH_PLAN_SECURE_MODE_ENABLED = CONFIG.CHANGE_AUTH_PLAN_SECURE_MODE_ENABL
|
|||
DATETIME_DISPLAY_FORMAT = '%Y-%m-%d %H:%M:%S'
|
||||
|
||||
TICKETS_ENABLED = CONFIG.TICKETS_ENABLED
|
||||
TICKETS_DIRECT_APPROVE = CONFIG.TICKETS_DIRECT_APPROVE
|
||||
REFERER_CHECK_ENABLED = CONFIG.REFERER_CHECK_ENABLED
|
||||
|
||||
CONNECTION_TOKEN_ENABLED = CONFIG.CONNECTION_TOKEN_ENABLED
|
||||
|
@ -214,6 +215,9 @@ PERM_TREE_REGEN_INTERVAL = CONFIG.PERM_TREE_REGEN_INTERVAL
|
|||
MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS
|
||||
LIMIT_SUPER_PRIV = CONFIG.LIMIT_SUPER_PRIV
|
||||
|
||||
# Asset account may be too many
|
||||
ASSET_SIZE = 'small'
|
||||
|
||||
# Chat AI
|
||||
CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED
|
||||
GPT_API_KEY = CONFIG.GPT_API_KEY
|
||||
|
@ -224,3 +228,5 @@ GPT_MODEL = CONFIG.GPT_MODEL
|
|||
VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED
|
||||
|
||||
FILE_UPLOAD_SIZE_LIMIT_MB = CONFIG.FILE_UPLOAD_SIZE_LIMIT_MB
|
||||
|
||||
TICKET_APPLY_ASSET_SCOPE = CONFIG.TICKET_APPLY_ASSET_SCOPE
|
||||
|
|
|
@ -82,7 +82,6 @@ BOOTSTRAP3 = {
|
|||
# Django channels support websocket
|
||||
REDIS_LAYERS_HOST = {
|
||||
'db': CONFIG.REDIS_DB_WS,
|
||||
'password': CONFIG.REDIS_PASSWORD or None,
|
||||
}
|
||||
|
||||
REDIS_LAYERS_SSL_PARAMS = {}
|
||||
|
@ -97,6 +96,7 @@ if REDIS_USE_SSL:
|
|||
|
||||
if REDIS_SENTINEL_SERVICE_NAME and REDIS_SENTINELS:
|
||||
REDIS_LAYERS_HOST['sentinels'] = REDIS_SENTINELS
|
||||
REDIS_LAYERS_HOST['password'] = CONFIG.REDIS_PASSWORD or None
|
||||
REDIS_LAYERS_HOST['master_name'] = REDIS_SENTINEL_SERVICE_NAME
|
||||
REDIS_LAYERS_HOST['sentinel_kwargs'] = {
|
||||
'password': REDIS_SENTINEL_PASSWORD,
|
||||
|
@ -111,7 +111,7 @@ else:
|
|||
# More info see: https://github.com/django/channels_redis/issues/334
|
||||
# REDIS_LAYERS_HOST['address'] = (CONFIG.REDIS_HOST, CONFIG.REDIS_PORT)
|
||||
REDIS_LAYERS_ADDRESS = '{protocol}://:{password}@{host}:{port}/{db}'.format(
|
||||
protocol=REDIS_PROTOCOL, password=CONFIG.REDIS_PASSWORD,
|
||||
protocol=REDIS_PROTOCOL, password=CONFIG.REDIS_PASSWORD_QUOTE,
|
||||
host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, db=CONFIG.REDIS_DB_WS
|
||||
)
|
||||
REDIS_LAYERS_HOST['address'] = REDIS_LAYERS_ADDRESS
|
||||
|
@ -153,7 +153,7 @@ if REDIS_SENTINEL_SERVICE_NAME and REDIS_SENTINELS:
|
|||
else:
|
||||
CELERY_BROKER_URL = CELERY_BROKER_URL_FORMAT % {
|
||||
'protocol': REDIS_PROTOCOL,
|
||||
'password': CONFIG.REDIS_PASSWORD,
|
||||
'password': CONFIG.REDIS_PASSWORD_QUOTE,
|
||||
'host': CONFIG.REDIS_HOST,
|
||||
'port': CONFIG.REDIS_PORT,
|
||||
'db': CONFIG.REDIS_DB_CELERY,
|
||||
|
@ -187,6 +187,7 @@ ANSIBLE_LOG_DIR = os.path.join(PROJECT_DIR, 'data', 'ansible')
|
|||
REDIS_HOST = CONFIG.REDIS_HOST
|
||||
REDIS_PORT = CONFIG.REDIS_PORT
|
||||
REDIS_PASSWORD = CONFIG.REDIS_PASSWORD
|
||||
REDIS_PASSWORD_QUOTE = CONFIG.REDIS_PASSWORD_QUOTE
|
||||
|
||||
DJANGO_REDIS_SCAN_ITERSIZE = 1000
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from django.contrib.contenttypes.fields import GenericRelation
|
||||
from django.db import models
|
||||
from django.db.models import OneToOneField
|
||||
from django.db.models import OneToOneField, Count
|
||||
|
||||
from common.utils import lazyproperty
|
||||
from .models import LabeledResource
|
||||
|
@ -36,3 +36,37 @@ class LabeledMixin(models.Model):
|
|||
@res_labels.setter
|
||||
def res_labels(self, value):
|
||||
self.real.labels.set(value, bulk=False)
|
||||
|
||||
@classmethod
|
||||
def filter_resources_by_labels(cls, resources, label_ids):
|
||||
return cls._get_filter_res_by_labels_m2m_all(resources, label_ids)
|
||||
|
||||
@classmethod
|
||||
def _get_filter_res_by_labels_m2m_in(cls, resources, label_ids):
|
||||
return resources.filter(label_id__in=label_ids)
|
||||
|
||||
@classmethod
|
||||
def _get_filter_res_by_labels_m2m_all(cls, resources, label_ids):
|
||||
if len(label_ids) == 1:
|
||||
return cls._get_filter_res_by_labels_m2m_in(resources, label_ids)
|
||||
|
||||
resources = resources.filter(label_id__in=label_ids) \
|
||||
.values('res_id') \
|
||||
.order_by('res_id') \
|
||||
.annotate(count=Count('res_id', distinct=True)) \
|
||||
.values('res_id', 'count') \
|
||||
.filter(count=len(label_ids))
|
||||
return resources
|
||||
|
||||
@classmethod
|
||||
def get_labels_filter_attr_q(cls, value, match):
|
||||
resources = LabeledResource.objects.all()
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if match != 'm2m_all':
|
||||
resources = cls._get_filter_res_by_labels_m2m_in(resources, value)
|
||||
else:
|
||||
resources = cls._get_filter_res_by_labels_m2m_all(resources, value)
|
||||
res_ids = set(resources.values_list('res_id', flat=True))
|
||||
return models.Q(id__in=res_ids)
|
||||
|
|
|
@ -34,7 +34,7 @@ class LabelSerializer(BulkOrgResourceModelSerializer):
|
|||
@classmethod
|
||||
def setup_eager_loading(cls, queryset):
|
||||
""" Perform necessary eager loading of data. """
|
||||
queryset = queryset.annotate(res_count=Count('labeled_resources'))
|
||||
queryset = queryset.annotate(res_count=Count('labeled_resources', distinct=True))
|
||||
return queryset
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7879f4eeb499e920ad6c4bfdb0b1f334936ca344c275be056f12fcf7485f2bf6
|
||||
size 170948
|
||||
oid sha256:d04781f4f0b0de3ac5f707febb222e239553d6103bca0cec41ab2fd5ab044571
|
||||
size 173799
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,3 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:19d3a111cc245f9a9d36b860fd95447df916ad66c918bef672bacdad6bc77a8f
|
||||
size 140119
|
||||
oid sha256:e66a6fa05d25f1c502f95001b5ff0d0a310affd32eac939fd7b840845028074f
|
||||
size 142298
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,28 +1,32 @@
|
|||
import json
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
from channels.generic.websocket import JsonWebsocketConsumer
|
||||
from django.core.cache import cache
|
||||
from django.conf import settings
|
||||
|
||||
from common.db.utils import safe_db_connection
|
||||
from common.sessions.cache import user_session_manager
|
||||
from common.utils import get_logger
|
||||
from .signal_handlers import new_site_msg_chan
|
||||
from .site_msg import SiteMessageUtil
|
||||
|
||||
logger = get_logger(__name__)
|
||||
WS_SESSION_KEY = 'ws_session_key'
|
||||
|
||||
|
||||
class SiteMsgWebsocket(JsonWebsocketConsumer):
|
||||
sub = None
|
||||
refresh_every_seconds = 10
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self.scope['session']
|
||||
|
||||
def connect(self):
|
||||
user = self.scope["user"]
|
||||
if user.is_authenticated:
|
||||
self.accept()
|
||||
session = self.scope['session']
|
||||
redis_client = cache.client.get_client()
|
||||
redis_client.sadd(WS_SESSION_KEY, session.session_key)
|
||||
user_session_manager.add_or_increment(self.session.session_key)
|
||||
self.sub = self.watch_recv_new_site_msg()
|
||||
else:
|
||||
self.close()
|
||||
|
@ -66,6 +70,32 @@ class SiteMsgWebsocket(JsonWebsocketConsumer):
|
|||
if not self.sub:
|
||||
return
|
||||
self.sub.unsubscribe()
|
||||
session = self.scope['session']
|
||||
redis_client = cache.client.get_client()
|
||||
redis_client.srem(WS_SESSION_KEY, session.session_key)
|
||||
|
||||
user_session_manager.decrement_or_remove(self.session.session_key)
|
||||
if self.should_delete_session():
|
||||
thread = Thread(target=self.delay_delete_session)
|
||||
thread.start()
|
||||
|
||||
def should_delete_session(self):
|
||||
return (self.session.modified or settings.SESSION_SAVE_EVERY_REQUEST) and \
|
||||
not self.session.is_empty() and \
|
||||
self.session.get_expire_at_browser_close() and \
|
||||
not user_session_manager.check_active(self.session.session_key)
|
||||
|
||||
def delay_delete_session(self):
|
||||
timeout = 6
|
||||
check_interval = 0.5
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
time.sleep(check_interval)
|
||||
if user_session_manager.check_active(self.session.session_key):
|
||||
return
|
||||
|
||||
self.delete_session()
|
||||
|
||||
def delete_session(self):
|
||||
try:
|
||||
self.session.delete()
|
||||
except Exception as e:
|
||||
logger.info(f'delete session error: {e}')
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from functools import reduce
|
||||
|
||||
|
@ -29,6 +30,8 @@ class DefaultCallback:
|
|||
)
|
||||
self.status = 'running'
|
||||
self.finished = False
|
||||
self.local_pid = 0
|
||||
self.private_data_dir = None
|
||||
|
||||
@property
|
||||
def host_results(self):
|
||||
|
@ -45,6 +48,9 @@ class DefaultCallback:
|
|||
event = data.get('event', None)
|
||||
if not event:
|
||||
return
|
||||
pid = data.get('pid', None)
|
||||
if pid:
|
||||
self.write_pid(pid)
|
||||
event_data = data.get('event_data', {})
|
||||
host = event_data.get('remote_addr', '')
|
||||
task = event_data.get('task', '')
|
||||
|
@ -152,3 +158,11 @@ class DefaultCallback:
|
|||
def status_handler(self, data, **kwargs):
|
||||
status = data.get('status', '')
|
||||
self.status = self.STATUS_MAPPER.get(status, 'unknown')
|
||||
|
||||
rc = kwargs.get('runner_config', None)
|
||||
self.private_data_dir = rc.private_data_dir if rc else '/tmp/'
|
||||
|
||||
def write_pid(self, pid):
|
||||
pid_filepath = os.path.join(self.private_data_dir, 'local.pid')
|
||||
with open(pid_filepath, 'w') as f:
|
||||
f.write(str(pid))
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from django.shortcuts import get_object_or_404
|
||||
|
@ -166,16 +167,58 @@ class CeleryTaskViewSet(
|
|||
i.next_exec_time = now + next_run_at
|
||||
return queryset
|
||||
|
||||
def generate_summary_state(self, execution_qs):
|
||||
model = self.get_queryset().model
|
||||
executions = execution_qs.order_by('-date_published').values('name', 'state')
|
||||
summary_state_dict = defaultdict(
|
||||
lambda: {
|
||||
'states': [], 'state': 'green',
|
||||
'summary': {'total': 0, 'success': 0}
|
||||
}
|
||||
)
|
||||
for execution in executions:
|
||||
name = execution['name']
|
||||
state = execution['state']
|
||||
|
||||
summary = summary_state_dict[name]['summary']
|
||||
|
||||
summary['total'] += 1
|
||||
summary['success'] += 1 if state == 'SUCCESS' else 0
|
||||
|
||||
states = summary_state_dict[name].get('states')
|
||||
if states is not None and len(states) >= 5:
|
||||
color = model.compute_state_color(states)
|
||||
summary_state_dict[name]['state'] = color
|
||||
summary_state_dict[name].pop('states', None)
|
||||
elif isinstance(states, list):
|
||||
states.append(state)
|
||||
|
||||
return summary_state_dict
|
||||
|
||||
def loading_summary_state(self, queryset):
|
||||
if isinstance(queryset, list):
|
||||
names = [i.name for i in queryset]
|
||||
execution_qs = CeleryTaskExecution.objects.filter(name__in=names)
|
||||
else:
|
||||
execution_qs = CeleryTaskExecution.objects.all()
|
||||
summary_state_dict = self.generate_summary_state(execution_qs)
|
||||
for i in queryset:
|
||||
i.summary = summary_state_dict.get(i.name, {}).get('summary', {})
|
||||
i.state = summary_state_dict.get(i.name, {}).get('state', 'green')
|
||||
return queryset
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
|
||||
page = self.paginate_queryset(queryset)
|
||||
if page is not None:
|
||||
page = self.generate_execute_time(page)
|
||||
page = self.loading_summary_state(page)
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
queryset = self.generate_execute_time(queryset)
|
||||
queryset = self.loading_summary_state(queryset)
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from django.conf import settings
|
||||
from django.db import transaction
|
||||
from django.db.models import Count
|
||||
from django.http import Http404
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils._os import safe_join
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
@ -14,9 +16,10 @@ from rest_framework.views import APIView
|
|||
from assets.models import Asset
|
||||
from common.const.http import POST
|
||||
from common.permissions import IsValidUser
|
||||
from ops.celery import app
|
||||
from ops.const import Types
|
||||
from ops.models import Job, JobExecution
|
||||
from ops.serializers.job import JobSerializer, JobExecutionSerializer, FileSerializer
|
||||
from ops.serializers.job import JobSerializer, JobExecutionSerializer, FileSerializer, JobTaskStopSerializer
|
||||
|
||||
__all__ = [
|
||||
'JobViewSet', 'JobExecutionViewSet', 'JobRunVariableHelpAPIView',
|
||||
|
@ -187,6 +190,33 @@ class JobExecutionViewSet(OrgBulkModelViewSet):
|
|||
queryset = queryset.filter(creator=self.request.user)
|
||||
return queryset
|
||||
|
||||
@action(methods=[POST], detail=False, serializer_class=JobTaskStopSerializer, permission_classes=[IsValidUser, ],
|
||||
url_path='stop')
|
||||
def stop(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
if not serializer.is_valid():
|
||||
return Response({'error': serializer.errors}, status=400)
|
||||
task_id = serializer.validated_data['task_id']
|
||||
try:
|
||||
instance = get_object_or_404(JobExecution, task_id=task_id, creator=request.user)
|
||||
except Http404:
|
||||
return Response(
|
||||
{'error': _('The task is being created and cannot be interrupted. Please try again later.')},
|
||||
status=400
|
||||
)
|
||||
|
||||
task = AsyncResult(task_id, app=app)
|
||||
inspect = app.control.inspect()
|
||||
for worker in inspect.registered().keys():
|
||||
if task_id not in [at['id'] for at in inspect.active().get(worker, [])]:
|
||||
# 在队列中未执行使用revoke执行
|
||||
task.revoke(terminate=True)
|
||||
instance.set_error('Job stop by "revoke task {}"'.format(task_id))
|
||||
return Response({'task_id': task_id}, status=200)
|
||||
|
||||
instance.stop()
|
||||
return Response({'task_id': task_id}, status=200)
|
||||
|
||||
|
||||
class JobAssetDetail(APIView):
|
||||
rbac_perms = {
|
||||
|
@ -246,6 +276,6 @@ class UsernameHintsAPI(APIView):
|
|||
.filter(username__icontains=query) \
|
||||
.filter(asset__in=assets) \
|
||||
.values('username') \
|
||||
.annotate(total=Count('username')) \
|
||||
.annotate(total=Count('username', distinct=True)) \
|
||||
.order_by('total', '-username')[:10]
|
||||
return Response(data=top_accounts)
|
||||
|
|
|
@ -15,6 +15,9 @@ class CeleryTask(models.Model):
|
|||
name = models.CharField(max_length=1024, verbose_name=_('Name'))
|
||||
date_last_publish = models.DateTimeField(null=True, verbose_name=_("Date last publish"))
|
||||
|
||||
__summary = None
|
||||
__state = None
|
||||
|
||||
@property
|
||||
def meta(self):
|
||||
task = app.tasks.get(self.name, None)
|
||||
|
@ -25,23 +28,43 @@ class CeleryTask(models.Model):
|
|||
|
||||
@property
|
||||
def summary(self):
|
||||
if self.__summary is not None:
|
||||
return self.__summary
|
||||
executions = CeleryTaskExecution.objects.filter(name=self.name)
|
||||
total = executions.count()
|
||||
success = executions.filter(state='SUCCESS').count()
|
||||
return {'total': total, 'success': success}
|
||||
|
||||
@summary.setter
|
||||
def summary(self, value):
|
||||
self.__summary = value
|
||||
|
||||
@staticmethod
|
||||
def compute_state_color(states: list, default_count=5):
|
||||
color = 'green'
|
||||
states = states[:default_count]
|
||||
if not states:
|
||||
return color
|
||||
if states[0] == 'FAILURE':
|
||||
color = 'red'
|
||||
elif 'FAILURE' in states:
|
||||
color = 'yellow'
|
||||
return color
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
last_five_executions = CeleryTaskExecution.objects.filter(name=self.name).order_by('-date_published')[:5]
|
||||
if self.__state is not None:
|
||||
return self.__state
|
||||
last_five_executions = CeleryTaskExecution.objects.filter(
|
||||
name=self.name
|
||||
).order_by('-date_published').values('state')[:5]
|
||||
states = [i['state'] for i in last_five_executions]
|
||||
color = self.compute_state_color(states)
|
||||
return color
|
||||
|
||||
if len(last_five_executions) > 0:
|
||||
if last_five_executions[0].state == 'FAILURE':
|
||||
return "red"
|
||||
|
||||
for execution in last_five_executions:
|
||||
if execution.state == 'FAILURE':
|
||||
return "yellow"
|
||||
return "green"
|
||||
@state.setter
|
||||
def state(self, value):
|
||||
self.__state = value
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Celery Task")
|
||||
|
|
|
@ -67,6 +67,7 @@ class JMSPermedInventory(JMSInventory):
|
|||
'postgresql': ['postgresql'],
|
||||
'sqlserver': ['sqlserver'],
|
||||
'ssh': ['shell', 'python', 'win_shell', 'raw'],
|
||||
'winrm': ['win_shell', 'shell'],
|
||||
}
|
||||
|
||||
if self.module not in protocol_supported_modules_mapping.get(protocol.name, []):
|
||||
|
@ -553,6 +554,15 @@ class JobExecution(JMSOrgBaseModel):
|
|||
finally:
|
||||
ssh_tunnel.local_gateway_clean(runner)
|
||||
|
||||
def stop(self):
|
||||
with open(os.path.join(self.private_dir, 'local.pid')) as f:
|
||||
try:
|
||||
pid = f.read()
|
||||
os.kill(int(pid), 9)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
self.set_error('Job stop by "kill -9 {}"'.format(pid))
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Job Execution")
|
||||
ordering = ['-date_created']
|
||||
|
|
|
@ -57,6 +57,13 @@ class FileSerializer(serializers.Serializer):
|
|||
ref_name = "JobFileSerializer"
|
||||
|
||||
|
||||
class JobTaskStopSerializer(serializers.Serializer):
|
||||
task_id = serializers.CharField(max_length=128)
|
||||
|
||||
class Meta:
|
||||
ref_name = "JobTaskStopSerializer"
|
||||
|
||||
|
||||
class JobExecutionSerializer(BulkOrgResourceModelSerializer):
|
||||
creator = ReadableHiddenField(default=serializers.CurrentUserDefault())
|
||||
job_type = serializers.ReadOnlyField(label=_("Job type"))
|
||||
|
|
|
@ -173,6 +173,9 @@ class Organization(OrgRoleMixin, JMSBaseModel):
|
|||
def is_default(self):
|
||||
return str(self.id) == self.DEFAULT_ID
|
||||
|
||||
def is_system(self):
|
||||
return str(self.id) == self.SYSTEM_ID
|
||||
|
||||
@property
|
||||
def internal(self):
|
||||
return str(self.id) in self.INTERNAL_IDS
|
||||
|
|
|
@ -87,7 +87,8 @@ class OrgResourceStatisticsRefreshUtil:
|
|||
if not cache_field_name:
|
||||
return
|
||||
org = getattr(instance, 'org', None)
|
||||
cls.refresh_org_fields(((org, cache_field_name),))
|
||||
cache_field_name = tuple(cache_field_name)
|
||||
cls.refresh_org_fields.delay(org_fields=((org, cache_field_name),))
|
||||
|
||||
|
||||
@receiver(post_save)
|
||||
|
|
|
@ -6,6 +6,7 @@ from functools import wraps
|
|||
from inspect import signature
|
||||
|
||||
from werkzeug.local import LocalProxy
|
||||
from django.conf import settings
|
||||
|
||||
from common.local import thread_local
|
||||
from .models import Organization
|
||||
|
@ -14,7 +15,6 @@ from .models import Organization
|
|||
def get_org_from_request(request):
|
||||
# query中优先级最高
|
||||
oid = request.GET.get("oid")
|
||||
|
||||
# 其次header
|
||||
if not oid:
|
||||
oid = request.META.get("HTTP_X_JMS_ORG")
|
||||
|
@ -24,14 +24,33 @@ def get_org_from_request(request):
|
|||
# 其次session
|
||||
if not oid:
|
||||
oid = request.session.get("oid")
|
||||
|
||||
if oid and oid.lower() == 'default':
|
||||
return Organization.default()
|
||||
|
||||
if oid and oid.lower() == 'root':
|
||||
return Organization.root()
|
||||
|
||||
if oid and oid.lower() == 'system':
|
||||
return Organization.system()
|
||||
|
||||
org = Organization.get_instance(oid)
|
||||
|
||||
if org and org.internal:
|
||||
# 内置组织直接返回
|
||||
return org
|
||||
|
||||
if not settings.XPACK_ENABLED:
|
||||
# 社区版用户只能使用默认组织
|
||||
return Organization.default()
|
||||
|
||||
if not org and request.user.is_authenticated:
|
||||
# 企业版用户优先从自己有权限的组织中获取
|
||||
org = request.user.orgs.first()
|
||||
|
||||
if not org:
|
||||
org = Organization.default()
|
||||
|
||||
if not oid:
|
||||
oid = Organization.DEFAULT_ID
|
||||
if oid.lower() == "default":
|
||||
oid = Organization.DEFAULT_ID
|
||||
elif oid.lower() == "root":
|
||||
oid = Organization.ROOT_ID
|
||||
org = Organization.get_instance(oid, default=Organization.default())
|
||||
return org
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import abc
|
||||
|
||||
from django.conf import settings
|
||||
from rest_framework.generics import ListAPIView, RetrieveAPIView
|
||||
|
||||
from assets.api.asset.asset import AssetFilterSet
|
||||
from assets.models import Asset, Node
|
||||
from common.api.mixin import ExtraFilterFieldsMixin
|
||||
from common.utils import get_logger, lazyproperty, is_uuid
|
||||
from orgs.utils import tmp_to_root_org
|
||||
from perms import serializers
|
||||
|
@ -37,8 +39,8 @@ class UserPermedAssetRetrieveApi(SelfOrPKUserMixin, RetrieveAPIView):
|
|||
return asset
|
||||
|
||||
|
||||
class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView):
|
||||
ordering = ('name',)
|
||||
class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ExtraFilterFieldsMixin, ListAPIView):
|
||||
ordering = []
|
||||
search_fields = ('name', 'address', 'comment')
|
||||
ordering_fields = ("name", "address")
|
||||
filterset_class = AssetFilterSet
|
||||
|
@ -47,6 +49,8 @@ class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView):
|
|||
def get_queryset(self):
|
||||
if getattr(self, 'swagger_fake_view', False):
|
||||
return Asset.objects.none()
|
||||
if settings.ASSET_SIZE == 'small':
|
||||
self.ordering = ['name']
|
||||
assets = self.get_assets()
|
||||
assets = self.serializer_class.setup_eager_loading(assets)
|
||||
return assets
|
||||
|
|
|
@ -9,7 +9,7 @@ class PermedAssetsWillExpireUserMsg(UserMessage):
|
|||
def __init__(self, user, assets, day_count=0):
|
||||
super().__init__(user)
|
||||
self.assets = assets
|
||||
self.day_count = _('today') if day_count == 0 else day_count + _('day')
|
||||
self.day_count = _('today') if day_count == 0 else str(day_count) + _('day')
|
||||
|
||||
def get_html_msg(self) -> dict:
|
||||
subject = _("You permed assets is about to expire")
|
||||
|
@ -41,7 +41,7 @@ class AssetPermsWillExpireForOrgAdminMsg(UserMessage):
|
|||
super().__init__(user)
|
||||
self.perms = perms
|
||||
self.org = org
|
||||
self.day_count = _('today') if day_count == 0 else day_count + _('day')
|
||||
self.day_count = _('today') if day_count == 0 else str(day_count) + _('day')
|
||||
|
||||
def get_items_with_url(self):
|
||||
items_with_url = []
|
||||
|
|
|
@ -198,9 +198,9 @@ class AssetPermissionListSerializer(AssetPermissionSerializer):
|
|||
"""Perform necessary eager loading of data."""
|
||||
queryset = queryset \
|
||||
.prefetch_related('labels', 'labels__label') \
|
||||
.annotate(users_amount=Count("users"),
|
||||
user_groups_amount=Count("user_groups"),
|
||||
assets_amount=Count("assets"),
|
||||
nodes_amount=Count("nodes"),
|
||||
.annotate(users_amount=Count("users", distinct=True),
|
||||
user_groups_amount=Count("user_groups", distinct=True),
|
||||
assets_amount=Count("assets", distinct=True),
|
||||
nodes_amount=Count("nodes", distinct=True),
|
||||
)
|
||||
return queryset
|
||||
|
|
|
@ -8,9 +8,9 @@ from rest_framework import serializers
|
|||
from accounts.models import Account
|
||||
from assets.const import Category, AllTypes
|
||||
from assets.models import Node, Asset, Platform
|
||||
from assets.serializers.asset.common import AssetLabelSerializer, AssetProtocolsPermsSerializer
|
||||
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
|
||||
from assets.serializers.asset.common import AssetProtocolsPermsSerializer
|
||||
from common.serializers import ResourceLabelsMixin
|
||||
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
|
||||
from orgs.mixins.serializers import OrgResourceModelSerializerMixin
|
||||
from perms.serializers.permission import ActionChoicesField
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ class AssetPermissionUtil(object):
|
|||
""" 资产授权相关的方法工具 """
|
||||
|
||||
@timeit
|
||||
def get_permissions_for_user(self, user, with_group=True, flat=False):
|
||||
def get_permissions_for_user(self, user, with_group=True, flat=False, with_expired=False):
|
||||
""" 获取用户的授权规则 """
|
||||
perm_ids = set()
|
||||
# user
|
||||
|
@ -25,7 +25,7 @@ class AssetPermissionUtil(object):
|
|||
groups = user.groups.all()
|
||||
group_perm_ids = self.get_permissions_for_user_groups(groups, flat=True)
|
||||
perm_ids.update(group_perm_ids)
|
||||
perms = self.get_permissions(ids=perm_ids)
|
||||
perms = self.get_permissions(ids=perm_ids, with_expired=with_expired)
|
||||
if flat:
|
||||
return perms.values_list('id', flat=True)
|
||||
return perms
|
||||
|
@ -102,6 +102,8 @@ class AssetPermissionUtil(object):
|
|||
return model.objects.filter(id__in=ids)
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(ids):
|
||||
perms = AssetPermission.objects.filter(id__in=ids).valid().order_by('-date_expired')
|
||||
return perms
|
||||
def get_permissions(ids, with_expired=False):
|
||||
perms = AssetPermission.objects.filter(id__in=ids)
|
||||
if not with_expired:
|
||||
perms = perms.valid()
|
||||
return perms.order_by('-date_expired')
|
||||
|
|
|
@ -7,10 +7,10 @@ from django.db.models import Q
|
|||
from rest_framework.utils.encoders import JSONEncoder
|
||||
|
||||
from assets.const import AllTypes
|
||||
from assets.models import FavoriteAsset, Asset
|
||||
from assets.models import FavoriteAsset, Asset, Node
|
||||
from common.utils.common import timeit, get_logger
|
||||
from orgs.utils import current_org, tmp_to_root_org
|
||||
from perms.models import PermNode, UserAssetGrantedTreeNodeRelation
|
||||
from perms.models import PermNode, UserAssetGrantedTreeNodeRelation, AssetPermission
|
||||
from .permission import AssetPermissionUtil
|
||||
|
||||
__all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil']
|
||||
|
@ -21,38 +21,37 @@ logger = get_logger(__name__)
|
|||
class AssetPermissionPermAssetUtil:
|
||||
|
||||
def __init__(self, perm_ids):
|
||||
self.perm_ids = perm_ids
|
||||
self.perm_ids = set(perm_ids)
|
||||
|
||||
def get_all_assets(self):
|
||||
node_assets = self.get_perm_nodes_assets()
|
||||
direct_assets = self.get_direct_assets()
|
||||
# 比原来的查到所有 asset id 再搜索块很多,因为当资产量大的时候,搜索会很慢
|
||||
return (node_assets | direct_assets).distinct()
|
||||
return (node_assets | direct_assets).order_by().distinct()
|
||||
|
||||
@timeit
|
||||
def get_perm_nodes_assets(self, flat=False):
|
||||
""" 获取所有授权节点下的资产 """
|
||||
from assets.models import Node
|
||||
from ..models import AssetPermission
|
||||
def get_perm_nodes(self):
|
||||
""" 获取所有授权节点 """
|
||||
nodes_ids = AssetPermission.objects \
|
||||
.filter(id__in=self.perm_ids) \
|
||||
.values_list('nodes', flat=True)
|
||||
nodes_ids = set(nodes_ids)
|
||||
nodes = Node.objects.filter(id__in=nodes_ids).only('id', 'key')
|
||||
assets = PermNode.get_nodes_all_assets(*nodes)
|
||||
if flat:
|
||||
return set(assets.values_list('id', flat=True))
|
||||
return nodes
|
||||
|
||||
@timeit
|
||||
def get_perm_nodes_assets(self):
|
||||
""" 获取所有授权节点下的资产 """
|
||||
nodes = self.get_perm_nodes()
|
||||
assets = PermNode.get_nodes_all_assets(*nodes, distinct=False)
|
||||
return assets
|
||||
|
||||
@timeit
|
||||
def get_direct_assets(self, flat=False):
|
||||
def get_direct_assets(self):
|
||||
""" 获取直接授权的资产 """
|
||||
from ..models import AssetPermission
|
||||
asset_ids = AssetPermission.objects \
|
||||
.filter(id__in=self.perm_ids) \
|
||||
.values_list('assets', flat=True)
|
||||
assets = Asset.objects.filter(id__in=asset_ids).distinct()
|
||||
if flat:
|
||||
return set(assets.values_list('id', flat=True))
|
||||
asset_ids = AssetPermission.assets.through.objects \
|
||||
.filter(assetpermission_id__in=self.perm_ids) \
|
||||
.values_list('asset_id', flat=True)
|
||||
assets = Asset.objects.filter(id__in=asset_ids)
|
||||
return assets
|
||||
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
|
|||
|
||||
@timeit
|
||||
def refresh_if_need(self, force=False):
|
||||
built_just_now = cache.get(self.cache_key_time)
|
||||
built_just_now = False if settings.ASSET_SIZE == 'small' else cache.get(self.cache_key_time)
|
||||
if built_just_now:
|
||||
logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now))
|
||||
return
|
||||
|
@ -80,12 +80,18 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
|
|||
if not to_refresh_orgs:
|
||||
logger.info('Not have to refresh orgs')
|
||||
return
|
||||
|
||||
logger.info("Delay refresh user orgs: {} {}".format(self.user, [o.name for o in to_refresh_orgs]))
|
||||
refresh_user_orgs_perm_tree(user_orgs=((self.user, tuple(to_refresh_orgs)),))
|
||||
refresh_user_favorite_assets(users=(self.user,))
|
||||
sync = True if settings.ASSET_SIZE == 'small' else False
|
||||
refresh_user_orgs_perm_tree.apply(sync=sync, user_orgs=((self.user, tuple(to_refresh_orgs)),))
|
||||
refresh_user_favorite_assets.apply(sync=sync, users=(self.user,))
|
||||
|
||||
@timeit
|
||||
def refresh_tree_manual(self):
|
||||
"""
|
||||
用来手动 debug
|
||||
:return:
|
||||
"""
|
||||
built_just_now = cache.get(self.cache_key_time)
|
||||
if built_just_now:
|
||||
logger.info('Refresh just now, pass: {}'.format(built_just_now))
|
||||
|
@ -105,8 +111,9 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
|
|||
return
|
||||
|
||||
self._clean_user_perm_tree_for_legacy_org()
|
||||
ttl = settings.PERM_TREE_REGEN_INTERVAL
|
||||
cache.set(self.cache_key_time, int(time.time()), ttl)
|
||||
if settings.ASSET_SIZE != 'small':
|
||||
ttl = settings.PERM_TREE_REGEN_INTERVAL
|
||||
cache.set(self.cache_key_time, int(time.time()), ttl)
|
||||
|
||||
lock = UserGrantedTreeRebuildLock(self.user.id)
|
||||
got = lock.acquire(blocking=False)
|
||||
|
@ -187,13 +194,20 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin):
|
|||
|
||||
@on_transaction_commit
|
||||
def expire_perm_tree_for_users_orgs(self, user_ids, org_ids):
|
||||
user_ids = list(user_ids)
|
||||
org_ids = [str(oid) for oid in org_ids]
|
||||
with self.client.pipeline() as p:
|
||||
for uid in user_ids:
|
||||
cache_key = self.get_cache_key(uid)
|
||||
p.srem(cache_key, *org_ids)
|
||||
p.execute()
|
||||
logger.info('Expire perm tree for users: [{}], orgs: [{}]'.format(user_ids, org_ids))
|
||||
users_display = ','.join([str(i) for i in user_ids[:3]])
|
||||
if len(user_ids) > 3:
|
||||
users_display += '...'
|
||||
orgs_display = ','.join([str(i) for i in org_ids[:3]])
|
||||
if len(org_ids) > 3:
|
||||
orgs_display += '...'
|
||||
logger.info('Expire perm tree for users: [{}], orgs: [{}]'.format(users_display, orgs_display))
|
||||
|
||||
def expire_perm_tree_for_all_user(self):
|
||||
keys = self.client.keys(self.cache_key_all_user)
|
||||
|
|
|
@ -80,9 +80,11 @@ class RoleViewSet(JMSModelViewSet):
|
|||
queryset = Role.objects.filter(id__in=ids).order_by(*self.ordering)
|
||||
org_id = current_org.id
|
||||
q = Q(role__scope=Role.Scope.system) | Q(role__scope=Role.Scope.org, org_id=org_id)
|
||||
role_bindings = RoleBinding.objects.filter(q).values_list('role_id').annotate(user_count=Count('user_id'))
|
||||
role_bindings = RoleBinding.objects.filter(q).values_list('role_id').annotate(
|
||||
user_count=Count('user_id', distinct=True)
|
||||
)
|
||||
role_user_amount_mapper = {role_id: user_count for role_id, user_count in role_bindings}
|
||||
queryset = queryset.annotate(permissions_amount=Count('permissions'))
|
||||
queryset = queryset.annotate(permissions_amount=Count('permissions', distinct=True))
|
||||
queryset = list(queryset)
|
||||
for role in queryset:
|
||||
role.users_amount = role_user_amount_mapper.get(role.id, 0)
|
||||
|
|
|
@ -1,28 +1,16 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
|
||||
import threading
|
||||
|
||||
from django.conf import settings
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import generics
|
||||
from rest_framework.generics import CreateAPIView
|
||||
from rest_framework.views import Response, APIView
|
||||
from rest_framework.views import Response
|
||||
|
||||
from common.api import AsyncApiMixin
|
||||
from common.utils import get_logger
|
||||
from orgs.models import Organization
|
||||
from orgs.utils import current_org
|
||||
from users.models import User
|
||||
from ..models import Setting
|
||||
from ..serializers import (
|
||||
LDAPTestConfigSerializer, LDAPUserSerializer,
|
||||
LDAPTestLoginSerializer
|
||||
)
|
||||
from ..tasks import sync_ldap_user
|
||||
from ..serializers import LDAPUserSerializer
|
||||
from ..utils import (
|
||||
LDAPServerUtil, LDAPCacheUtil, LDAPImportUtil, LDAPSyncUtil,
|
||||
LDAP_USE_CACHE_FLAGS, LDAPTestUtil
|
||||
LDAPServerUtil, LDAPCacheUtil,
|
||||
LDAP_USE_CACHE_FLAGS
|
||||
)
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
@ -100,49 +88,3 @@ class LDAPUserListApi(generics.ListAPIView):
|
|||
else:
|
||||
data = {'msg': _('Users are not synchronized, please click the user synchronization button')}
|
||||
return Response(data=data, status=400)
|
||||
|
||||
|
||||
class LDAPUserImportAPI(APIView):
|
||||
perm_model = Setting
|
||||
rbac_perms = {
|
||||
'POST': 'settings.change_auth'
|
||||
}
|
||||
|
||||
def get_orgs(self):
|
||||
org_ids = self.request.data.get('org_ids')
|
||||
if org_ids:
|
||||
orgs = list(Organization.objects.filter(id__in=org_ids))
|
||||
else:
|
||||
orgs = [current_org]
|
||||
return orgs
|
||||
|
||||
def get_ldap_users(self):
|
||||
username_list = self.request.data.get('username_list', [])
|
||||
cache_police = self.request.query_params.get('cache_police', True)
|
||||
if '*' in username_list:
|
||||
users = LDAPServerUtil().search()
|
||||
elif cache_police in LDAP_USE_CACHE_FLAGS:
|
||||
users = LDAPCacheUtil().search(search_users=username_list)
|
||||
else:
|
||||
users = LDAPServerUtil().search(search_users=username_list)
|
||||
return users
|
||||
|
||||
def post(self, request):
|
||||
try:
|
||||
users = self.get_ldap_users()
|
||||
except Exception as e:
|
||||
return Response({'error': str(e)}, status=400)
|
||||
|
||||
if users is None:
|
||||
return Response({'msg': _('Get ldap users is None')}, status=400)
|
||||
|
||||
orgs = self.get_orgs()
|
||||
new_users, errors = LDAPImportUtil().perform_import(users, orgs)
|
||||
if errors:
|
||||
return Response({'errors': errors}, status=400)
|
||||
|
||||
count = users if users is None else len(users)
|
||||
orgs_name = ', '.join([str(org) for org in orgs])
|
||||
return Response({
|
||||
'msg': _('Imported {} users successfully (Organization: {})').format(count, orgs_name)
|
||||
})
|
||||
|
|
|
@ -3,6 +3,7 @@ from rest_framework import generics
|
|||
from rest_framework.permissions import AllowAny
|
||||
|
||||
from authentication.permissions import IsValidUserOrConnectionToken
|
||||
from common.const.choices import COUNTRY_CALLING_CODES
|
||||
from common.utils import get_logger, lazyproperty
|
||||
from common.utils.timezone import local_now
|
||||
from .. import serializers
|
||||
|
@ -24,7 +25,8 @@ class OpenPublicSettingApi(generics.RetrieveAPIView):
|
|||
def get_object(self):
|
||||
return {
|
||||
"XPACK_ENABLED": settings.XPACK_ENABLED,
|
||||
"INTERFACE": self.interface_setting
|
||||
"INTERFACE": self.interface_setting,
|
||||
"COUNTRY_CALLING_CODES": COUNTRY_CALLING_CODES
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ class OAuth2SettingSerializer(serializers.Serializer):
|
|||
)
|
||||
AUTH_OAUTH2_ACCESS_TOKEN_METHOD = serializers.ChoiceField(
|
||||
default='GET', label=_('Client authentication method'),
|
||||
choices=(('GET', 'GET'), ('POST', 'POST'))
|
||||
choices=(('GET', 'GET'), ('POST', 'POST-DATA'), ('POST_JSON', 'POST-JSON'))
|
||||
)
|
||||
AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT = serializers.CharField(
|
||||
required=True, max_length=1024, label=_('Provider userinfo endpoint')
|
||||
|
|
|
@ -22,6 +22,10 @@ class CleaningSerializer(serializers.Serializer):
|
|||
min_value=MIN_VALUE, max_value=9999,
|
||||
label=_("Operate log keep days (day)"),
|
||||
)
|
||||
PASSWORD_CHANGE_LOG_KEEP_DAYS = serializers.IntegerField(
|
||||
min_value=MIN_VALUE, max_value=9999,
|
||||
label=_("password change log keep days (day)"),
|
||||
)
|
||||
FTP_LOG_KEEP_DAYS = serializers.IntegerField(
|
||||
min_value=MIN_VALUE, max_value=9999,
|
||||
label=_("FTP log keep days (day)"),
|
||||
|
|
|
@ -109,6 +109,7 @@ class TicketSettingSerializer(serializers.Serializer):
|
|||
PREFIX_TITLE = _('Ticket')
|
||||
|
||||
TICKETS_ENABLED = serializers.BooleanField(required=False, default=True, label=_("Enable tickets"))
|
||||
TICKETS_DIRECT_APPROVE = serializers.BooleanField(required=False, default=False, label=_("No login approval"))
|
||||
TICKET_AUTHORIZE_DEFAULT_TIME = serializers.IntegerField(
|
||||
min_value=1, max_value=999999, required=False,
|
||||
label=_("Ticket authorize default time")
|
||||
|
|
|
@ -11,6 +11,7 @@ __all__ = [
|
|||
class PublicSettingSerializer(serializers.Serializer):
|
||||
XPACK_ENABLED = serializers.BooleanField()
|
||||
INTERFACE = serializers.DictField()
|
||||
COUNTRY_CALLING_CODES = serializers.ListField()
|
||||
|
||||
|
||||
class PrivateSettingSerializer(PublicSettingSerializer):
|
||||
|
@ -50,6 +51,7 @@ class PrivateSettingSerializer(PublicSettingSerializer):
|
|||
ANNOUNCEMENT = serializers.DictField()
|
||||
|
||||
TICKETS_ENABLED = serializers.BooleanField()
|
||||
TICKETS_DIRECT_APPROVE = serializers.BooleanField()
|
||||
CONNECTION_TOKEN_REUSABLE = serializers.BooleanField()
|
||||
CACHE_LOGIN_PASSWORD_ENABLED = serializers.BooleanField()
|
||||
VAULT_ENABLED = serializers.BooleanField()
|
||||
|
|
|
@ -14,9 +14,13 @@
|
|||
</ul>
|
||||
<b>{% trans "Synced User" %}:</b>
|
||||
<ul>
|
||||
{% for user in users %}
|
||||
<li>{{ user }}</li>
|
||||
{% endfor %}
|
||||
{% if users %}
|
||||
{% for user in users %}
|
||||
<li>{{ user }}</li>
|
||||
{% endfor %}
|
||||
{% else %}
|
||||
<li>{% trans 'No user synchronization required' %}</li>
|
||||
{% endif %}
|
||||
</ul>
|
||||
{% if errors %}
|
||||
<b>{% trans 'Error' %}:</b>
|
||||
|
|
|
@ -12,7 +12,6 @@ router.register(r'chatai-prompts', api.ChatPromptViewSet, 'chatai-prompt')
|
|||
urlpatterns = [
|
||||
path('mail/testing/', api.MailTestingAPI.as_view(), name='mail-testing'),
|
||||
path('ldap/users/', api.LDAPUserListApi.as_view(), name='ldap-user-list'),
|
||||
path('ldap/users/import/', api.LDAPUserImportAPI.as_view(), name='ldap-user-import'),
|
||||
path('wecom/testing/', api.WeComTestingAPI.as_view(), name='wecom-testing'),
|
||||
path('dingtalk/testing/', api.DingTalkTestingAPI.as_view(), name='dingtalk-testing'),
|
||||
path('feishu/testing/', api.FeiShuTestingAPI.as_view(), name='feishu-testing'),
|
||||
|
|
|
@ -6,6 +6,7 @@ import asyncio
|
|||
from channels.generic.websocket import AsyncJsonWebsocketConsumer
|
||||
from django.core.cache import cache
|
||||
from django.conf import settings
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common.db.utils import close_old_connections
|
||||
from common.utils import get_logger
|
||||
|
@ -13,9 +14,12 @@ from settings.serializers import (
|
|||
LDAPTestConfigSerializer,
|
||||
LDAPTestLoginSerializer
|
||||
)
|
||||
from orgs.models import Organization
|
||||
from orgs.utils import current_org
|
||||
from settings.tasks import sync_ldap_user
|
||||
from settings.utils import (
|
||||
LDAPSyncUtil, LDAPTestUtil
|
||||
LDAPServerUtil, LDAPCacheUtil, LDAPImportUtil, LDAPSyncUtil,
|
||||
LDAP_USE_CACHE_FLAGS, LDAPTestUtil
|
||||
)
|
||||
from .tools import (
|
||||
verbose_ping, verbose_telnet, verbose_nmap,
|
||||
|
@ -27,9 +31,11 @@ logger = get_logger(__name__)
|
|||
CACHE_KEY_LDAP_TEST_CONFIG_MSG = 'CACHE_KEY_LDAP_TEST_CONFIG_MSG'
|
||||
CACHE_KEY_LDAP_TEST_LOGIN_MSG = 'CACHE_KEY_LDAP_TEST_LOGIN_MSG'
|
||||
CACHE_KEY_LDAP_SYNC_USER_MSG = 'CACHE_KEY_LDAP_SYNC_USER_MSG'
|
||||
CACHE_KEY_LDAP_IMPORT_USER_MSG = 'CACHE_KEY_LDAP_IMPORT_USER_MSG'
|
||||
CACHE_KEY_LDAP_TEST_CONFIG_TASK_STATUS = 'CACHE_KEY_LDAP_TEST_CONFIG_TASK_STATUS'
|
||||
CACHE_KEY_LDAP_TEST_LOGIN_TASK_STATUS = 'CACHE_KEY_LDAP_TEST_LOGIN_TASK_STATUS'
|
||||
CACHE_KEY_LDAP_SYNC_USER_TASK_STATUS = 'CACHE_KEY_LDAP_SYNC_USER_TASK_STATUS'
|
||||
CACHE_KEY_LDAP_IMPORT_USER_TASK_STATUS = 'CACHE_KEY_LDAP_IMPORT_USER_TASK_STATUS'
|
||||
TASK_STATUS_IS_RUNNING = 'RUNNING'
|
||||
TASK_STATUS_IS_OVER = 'OVER'
|
||||
|
||||
|
@ -117,6 +123,8 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
|
|||
ok, msg = cache.get(CACHE_KEY_LDAP_TEST_CONFIG_MSG)
|
||||
elif msg_type == 'sync_user':
|
||||
ok, msg = cache.get(CACHE_KEY_LDAP_SYNC_USER_MSG)
|
||||
elif msg_type == 'import_user':
|
||||
ok, msg = cache.get(CACHE_KEY_LDAP_IMPORT_USER_MSG)
|
||||
else:
|
||||
ok, msg = cache.get(CACHE_KEY_LDAP_TEST_LOGIN_MSG)
|
||||
await self.send_msg(ok, msg)
|
||||
|
@ -165,8 +173,8 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
|
|||
cache.set(task_key, TASK_STATUS_IS_OVER, ttl)
|
||||
|
||||
@staticmethod
|
||||
def set_task_msg(task_key, ok, msg):
|
||||
cache.set(task_key, (ok, msg), 120)
|
||||
def set_task_msg(task_key, ok, msg, ttl=120):
|
||||
cache.set(task_key, (ok, msg), ttl)
|
||||
|
||||
def run_testing_config(self, data):
|
||||
while True:
|
||||
|
@ -207,3 +215,53 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
|
|||
ok = False if msg else True
|
||||
self.set_task_status_over(CACHE_KEY_LDAP_SYNC_USER_TASK_STATUS)
|
||||
self.set_task_msg(CACHE_KEY_LDAP_SYNC_USER_MSG, ok, msg)
|
||||
|
||||
def run_import_user(self, data):
|
||||
while True:
|
||||
if self.task_is_over(CACHE_KEY_LDAP_IMPORT_USER_TASK_STATUS):
|
||||
break
|
||||
else:
|
||||
ok, msg = self.import_user(data)
|
||||
self.set_task_status_over(CACHE_KEY_LDAP_IMPORT_USER_TASK_STATUS, 3)
|
||||
self.set_task_msg(CACHE_KEY_LDAP_IMPORT_USER_MSG, ok, msg, 3)
|
||||
|
||||
def import_user(self, data):
|
||||
ok = False
|
||||
org_ids = data.get('org_ids')
|
||||
username_list = data.get('username_list', [])
|
||||
cache_police = data.get('cache_police', True)
|
||||
try:
|
||||
users = self.get_ldap_users(username_list, cache_police)
|
||||
if users is None:
|
||||
msg = _('Get ldap users is None')
|
||||
|
||||
orgs = self.get_orgs(org_ids)
|
||||
new_users, error_msg = LDAPImportUtil().perform_import(users, orgs)
|
||||
if error_msg:
|
||||
msg = error_msg
|
||||
|
||||
count = users if users is None else len(users)
|
||||
orgs_name = ', '.join([str(org) for org in orgs])
|
||||
ok = True
|
||||
msg = _('Imported {} users successfully (Organization: {})').format(count, orgs_name)
|
||||
except Exception as e:
|
||||
msg = str(e)
|
||||
return ok, msg
|
||||
|
||||
@staticmethod
|
||||
def get_orgs(org_ids):
|
||||
if org_ids:
|
||||
orgs = list(Organization.objects.filter(id__in=org_ids))
|
||||
else:
|
||||
orgs = [current_org]
|
||||
return orgs
|
||||
|
||||
@staticmethod
|
||||
def get_ldap_users(username_list, cache_police):
|
||||
if '*' in username_list:
|
||||
users = LDAPServerUtil().search()
|
||||
elif cache_police in LDAP_USE_CACHE_FLAGS:
|
||||
users = LDAPCacheUtil().search(search_users=username_list)
|
||||
else:
|
||||
users = LDAPServerUtil().search(search_users=username_list)
|
||||
return users
|
||||
|
|
|
@ -9,7 +9,7 @@ from django.conf import settings
|
|||
from django.core.files.storage import default_storage
|
||||
from django.http import HttpResponse
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils.translation import gettext as _
|
||||
from django.utils.translation import gettext as _, get_language
|
||||
from rest_framework import viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.request import Request
|
||||
|
@ -19,6 +19,8 @@ from rest_framework.serializers import ValidationError
|
|||
from common.api import JMSBulkModelViewSet
|
||||
from common.serializers import FileSerializer
|
||||
from common.utils import is_uuid
|
||||
from common.utils.http import is_true
|
||||
from common.utils.yml import yaml_load_with_i18n
|
||||
from terminal import serializers
|
||||
from terminal.models import AppletPublication, Applet
|
||||
|
||||
|
@ -106,9 +108,66 @@ class AppletViewSet(DownloadUploadMixin, JMSBulkModelViewSet):
|
|||
def get_object(self):
|
||||
pk = self.kwargs.get('pk')
|
||||
if not is_uuid(pk):
|
||||
return get_object_or_404(Applet, name=pk)
|
||||
obj = get_object_or_404(Applet, name=pk)
|
||||
else:
|
||||
return get_object_or_404(Applet, pk=pk)
|
||||
obj = get_object_or_404(Applet, pk=pk)
|
||||
return self.trans_object(obj)
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = super().get_queryset()
|
||||
queryset = self.trans_queryset(queryset)
|
||||
return queryset
|
||||
|
||||
@staticmethod
|
||||
def read_manifest_with_i18n(obj, lang='zh'):
|
||||
path = os.path.join(obj.path, 'manifest.yml')
|
||||
if os.path.exists(path):
|
||||
with open(path, encoding='utf8') as f:
|
||||
manifest = yaml_load_with_i18n(f, lang)
|
||||
else:
|
||||
manifest = {}
|
||||
return manifest
|
||||
|
||||
def trans_queryset(self, queryset):
|
||||
for obj in queryset:
|
||||
self.trans_object(obj)
|
||||
return queryset
|
||||
|
||||
@staticmethod
|
||||
def readme(obj, lang=''):
|
||||
lang = lang[:2]
|
||||
readme_file = os.path.join(obj.path, f'README_{lang.upper()}.md')
|
||||
if os.path.isfile(readme_file):
|
||||
with open(readme_file, 'r') as f:
|
||||
return f.read()
|
||||
return ''
|
||||
|
||||
def trans_object(self, obj):
|
||||
lang = get_language()
|
||||
manifest = self.read_manifest_with_i18n(obj, lang)
|
||||
obj.display_name = manifest.get('display_name', obj.display_name)
|
||||
obj.comment = manifest.get('comment', obj.comment)
|
||||
obj.readme = self.readme(obj, lang)
|
||||
return obj
|
||||
|
||||
def is_record_found(self, obj, search):
|
||||
combine_fields = ' '.join([getattr(obj, f, '') for f in self.search_fields])
|
||||
return search in combine_fields
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
search = self.request.query_params.get('search')
|
||||
if search:
|
||||
queryset = [i for i in queryset if self.is_record_found(i, search)]
|
||||
|
||||
for field in self.filterset_fields:
|
||||
field_value = self.request.query_params.get(field)
|
||||
if not field_value:
|
||||
continue
|
||||
if field in ['is_active', 'builtin']:
|
||||
field_value = is_true(field_value)
|
||||
queryset = [i for i in queryset if getattr(i, field, '') == field_value]
|
||||
|
||||
return queryset
|
||||
|
||||
def perform_destroy(self, instance):
|
||||
if not instance.name:
|
||||
|
|
|
@ -42,7 +42,7 @@ class SmartEndpointViewMixin:
|
|||
return endpoint
|
||||
|
||||
def match_endpoint_by_label(self):
|
||||
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol)
|
||||
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol, self.request)
|
||||
|
||||
def match_endpoint_by_target_ip(self):
|
||||
target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数,用来方便测试
|
||||
|
|
|
@ -18,10 +18,11 @@ from rest_framework.response import Response
|
|||
|
||||
from audits.const import ActionChoices
|
||||
from common.api import AsyncApiMixin
|
||||
from common.const.http import GET
|
||||
from common.const.http import GET, POST
|
||||
from common.drf.filters import BaseFilterSet
|
||||
from common.drf.filters import DatetimeRangeFilterBackend
|
||||
from common.drf.renders import PassthroughRenderer
|
||||
from common.permissions import IsServiceAccount
|
||||
from common.storage.replay import ReplayStorageHandler
|
||||
from common.utils import data_to_json, is_uuid, i18n_fmt
|
||||
from common.utils import get_logger, get_object_or_none
|
||||
|
@ -33,6 +34,7 @@ from terminal import serializers
|
|||
from terminal.const import TerminalType
|
||||
from terminal.models import Session
|
||||
from terminal.permissions import IsSessionAssignee
|
||||
from terminal.session_lifecycle import lifecycle_events_map, reasons_map
|
||||
from terminal.utils import is_session_approver
|
||||
from users.models import User
|
||||
|
||||
|
@ -79,6 +81,7 @@ class SessionViewSet(RecordViewLogMixin, OrgBulkModelViewSet):
|
|||
serializer_classes = {
|
||||
'default': serializers.SessionSerializer,
|
||||
'display': serializers.SessionDisplaySerializer,
|
||||
'lifecycle_log': serializers.SessionLifecycleLogSerializer,
|
||||
}
|
||||
search_fields = [
|
||||
"user", "asset", "account", "remote_addr",
|
||||
|
@ -168,6 +171,23 @@ class SessionViewSet(RecordViewLogMixin, OrgBulkModelViewSet):
|
|||
count = queryset.count()
|
||||
return Response({'count': count})
|
||||
|
||||
@action(methods=[POST], detail=True, permission_classes=[IsServiceAccount], url_path='lifecycle_log',
|
||||
url_name='lifecycle_log')
|
||||
def lifecycle_log(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
validated_data = serializer.validated_data
|
||||
event = validated_data.pop('event', None)
|
||||
event_class = lifecycle_events_map.get(event, None)
|
||||
if not event_class:
|
||||
return Response({'msg': f'event_name {event} invalid'}, status=400)
|
||||
session = self.get_object()
|
||||
reason = validated_data.pop('reason', None)
|
||||
reason = reasons_map.get(reason, reason)
|
||||
event_obj = event_class(session, reason, **validated_data)
|
||||
activity_log = event_obj.create_activity_log()
|
||||
return Response({'msg': 'ok', 'id': activity_log.id})
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = super().get_queryset() \
|
||||
.prefetch_related('terminal') \
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
## Selenium Version
|
||||
|
||||
- Selenium == 4.4.0
|
||||
- Chrome and ChromeDriver versions must match
|
||||
- Driver [download address](https://chromedriver.chromium.org/downloads)
|
||||
|
||||
## ChangeLog
|
||||
|
||||
Refer to [ChangeLog](./ChangeLog) for some important updates.
|
|
@ -0,0 +1,9 @@
|
|||
## Selenium バージョン
|
||||
|
||||
- Selenium == 4.4.0
|
||||
- Chrome と ChromeDriver のバージョンは一致している必要があります
|
||||
- ドライバ [ダウンロードアドレス](https://chromedriver.chromium.org/downloads)
|
||||
|
||||
## 変更ログ
|
||||
|
||||
重要な更新については、[変更ログ](./ChangeLog) を参照してください
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue