merge: with dev

pull/12643/head
ibuler 2024-02-05 09:49:43 +08:00
commit b284bb60f5
103 changed files with 1967 additions and 1189 deletions

View File

@ -1,11 +1,12 @@
from django.db.models import Q
from rest_framework.generics import CreateAPIView
from accounts import serializers
from accounts.models import Account
from accounts.permissions import AccountTaskActionPermission
from accounts.tasks import (
remove_accounts_task, verify_accounts_connectivity_task, push_accounts_to_assets_task
)
from assets.exceptions import NotSupportedTemporarilyError
from authentication.permissions import UserConfirmation, ConfirmType
__all__ = [
@ -26,25 +27,35 @@ class AccountsTaskCreateAPI(CreateAPIView):
]
return super().get_permissions()
def perform_create(self, serializer):
data = serializer.validated_data
accounts = data.get('accounts', [])
params = data.get('params')
@staticmethod
def get_account_ids(data, action):
account_type = 'gather_accounts' if action == 'remove' else 'accounts'
accounts = data.get(account_type, [])
account_ids = [str(a.id) for a in accounts]
if data['action'] == 'push':
task = push_accounts_to_assets_task.delay(account_ids, params)
elif data['action'] == 'remove':
gather_accounts = data.get('gather_accounts', [])
gather_account_ids = [str(a.id) for a in gather_accounts]
task = remove_accounts_task.delay(gather_account_ids)
if action == 'remove':
return account_ids
assets = data.get('assets', [])
asset_ids = [str(a.id) for a in assets]
ids = Account.objects.filter(
Q(id__in=account_ids) | Q(asset_id__in=asset_ids)
).distinct().values_list('id', flat=True)
return [str(_id) for _id in ids]
def perform_create(self, serializer):
data = serializer.validated_data
action = data['action']
ids = self.get_account_ids(data, action)
if action == 'push':
task = push_accounts_to_assets_task.delay(ids, data.get('params'))
elif action == 'remove':
task = remove_accounts_task.delay(ids)
elif action == 'verify':
task = verify_accounts_connectivity_task.delay(ids)
else:
account = accounts[0]
asset = account.asset
if not asset.auto_config['ansible_enabled'] or \
not asset.auto_config['ping_enabled']:
raise NotSupportedTemporarilyError()
task = verify_accounts_connectivity_task.delay(account_ids)
raise ValueError(f"Invalid action: {action}")
data = getattr(serializer, '_data', {})
data["task"] = task.id

View File

@ -145,9 +145,9 @@ class AccountBackupHandler:
wb = Workbook(filename)
for sheet, data in data_map.items():
ws = wb.add_worksheet(str(sheet))
for row in data:
for col, _data in enumerate(row):
ws.write_string(0, col, _data)
for row_index, row_data in enumerate(data):
for col_index, col_data in enumerate(row_data):
ws.write_string(row_index, col_index, col_data)
wb.close()
files.append(filename)
timedelta = round((time.time() - time_start), 2)

View File

@ -39,3 +39,4 @@
login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}"
login_database: "{{ jms_asset.spec_info.db_name }}"
mode: "{{ account.mode }}"

View File

@ -4,6 +4,7 @@ from copy import deepcopy
from django.conf import settings
from django.utils import timezone
from django.utils.translation import gettext_lazy as _
from xlsxwriter import Workbook
from accounts.const import AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy
@ -161,7 +162,8 @@ class ChangeSecretManager(AccountBasePlaybookManager):
print("Account not found, deleted ?")
return
account.secret = recorder.new_secret
account.save(update_fields=['secret'])
account.date_updated = timezone.now()
account.save(update_fields=['secret', 'date_updated'])
def on_host_error(self, host, error, result):
recorder = self.name_recorder_mapper.get(host)
@ -182,17 +184,33 @@ class ChangeSecretManager(AccountBasePlaybookManager):
return False
return True
@staticmethod
def get_summary(recorders):
total, succeed, failed = 0, 0, 0
for recorder in recorders:
if recorder.status == 'success':
succeed += 1
else:
failed += 1
total += 1
summary = _('Success: %s, Failed: %s, Total: %s') % (succeed, failed, total)
return summary
def run(self, *args, **kwargs):
if self.secret_type and not self.check_secret():
return
super().run(*args, **kwargs)
recorders = list(self.name_recorder_mapper.values())
summary = self.get_summary(recorders)
print(summary, end='')
if self.record_id:
return
recorders = self.name_recorder_mapper.values()
recorders = list(recorders)
self.send_recorder_mail(recorders)
def send_recorder_mail(self, recorders):
self.send_recorder_mail(recorders, summary)
def send_recorder_mail(self, recorders, summary):
recipients = self.execution.recipients
if not recorders or not recipients:
return
@ -212,7 +230,7 @@ class ChangeSecretManager(AccountBasePlaybookManager):
attachment = os.path.join(path, f'{name}-{local_now_filename()}-{time.time()}.zip')
encrypt_and_compress_zip_file(attachment, password, [filename])
attachments = [attachment]
ChangeSecretExecutionTaskMsg(name, user).publish(attachments)
ChangeSecretExecutionTaskMsg(name, user, summary).publish(attachments)
os.remove(filename)
@staticmethod
@ -228,8 +246,8 @@ class ChangeSecretManager(AccountBasePlaybookManager):
rows.insert(0, header)
wb = Workbook(filename)
ws = wb.add_worksheet('Sheet1')
for row in rows:
for col, data in enumerate(row):
ws.write_string(0, col, data)
for row_index, row_data in enumerate(rows):
for col_index, col_data in enumerate(row_data):
ws.write_string(row_index, col_index, col_data)
wb.close()
return True

View File

@ -1,9 +1,10 @@
- hosts: demo
gather_facts: no
tasks:
- name: Gather posix account
- name: Gather windows account
ansible.builtin.win_shell: net user
register: result
ignore_errors: true
- name: Define info by set_fact
ansible.builtin.set_fact:

View File

@ -39,3 +39,4 @@
login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}"
login_database: "{{ jms_asset.spec_info.db_name }}"
mode: "{{ account.mode }}"

View File

@ -54,20 +54,23 @@ class AccountBackupByObjStorageExecutionTaskMsg(object):
class ChangeSecretExecutionTaskMsg(object):
subject = _('Notification of implementation result of encryption change plan')
def __init__(self, name: str, user: User):
def __init__(self, name: str, user: User, summary):
self.name = name
self.user = user
self.summary = summary
@property
def message(self):
name = self.name
if self.user.secret_key:
return _('{} - The encryption change task has been completed. '
'See the attachment for details').format(name)
default_message = _('{} - The encryption change task has been completed. '
'See the attachment for details').format(name)
else:
return _("{} - The encryption change task has been completed: the encryption "
"password has not been set - please go to personal information -> "
"file encryption password to set the encryption password").format(name)
default_message = _("{} - The encryption change task has been completed: the encryption "
"password has not been set - please go to personal information -> "
"set encryption password in preferences").format(name)
return self.summary + '\n' + default_message
def publish(self, attachments=None):
send_mail_attachment_async(

View File

@ -60,7 +60,7 @@ class AccountCreateUpdateSerializerMixin(serializers.Serializer):
for data in initial_data:
if not data.get('asset') and not self.instance:
raise serializers.ValidationError({'asset': UniqueTogetherValidator.missing_message})
asset = data.get('asset') or self.instance.asset
asset = data.get('asset') or getattr(self.instance, 'asset', None)
self.from_template_if_need(data)
self.set_uniq_name_if_need(data, asset)
@ -457,12 +457,14 @@ class AccountHistorySerializer(serializers.ModelSerializer):
class AccountTaskSerializer(serializers.Serializer):
ACTION_CHOICES = (
('test', 'test'),
('verify', 'verify'),
('push', 'push'),
('remove', 'remove'),
)
action = serializers.ChoiceField(choices=ACTION_CHOICES, write_only=True)
assets = serializers.PrimaryKeyRelatedField(
queryset=Asset.objects, required=False, allow_empty=True, many=True
)
accounts = serializers.PrimaryKeyRelatedField(
queryset=Account.objects, required=False, allow_empty=True, many=True
)

View File

@ -21,7 +21,8 @@ def on_account_pre_save(sender, instance, **kwargs):
if instance.version == 0:
instance.version = 1
else:
instance.version = instance.history.count()
history_account = instance.history.first()
instance.version = history_account.version + 1 if history_account else 0
@merge_delay_run(ttl=5)
@ -62,7 +63,7 @@ def create_accounts_activities(account, action='create'):
def on_account_create_by_template(sender, instance, created=False, **kwargs):
if not created or instance.source != 'template':
return
push_accounts_if_need(accounts=(instance,))
push_accounts_if_need.delay(accounts=(instance,))
create_accounts_activities(instance, action='create')

View File

@ -55,7 +55,7 @@ def clean_historical_accounts():
history_model = Account.history.model
history_id_mapper = defaultdict(list)
ids = history_model.objects.values('id').annotate(count=Count('id')) \
ids = history_model.objects.values('id').annotate(count=Count('id', distinct=True)) \
.filter(count__gte=limit).values_list('id', flat=True)
if not ids:

View File

@ -92,6 +92,7 @@ class AssetViewSet(SuggestionMixin, OrgBulkModelViewSet):
model = Asset
filterset_class = AssetFilterSet
search_fields = ("name", "address", "comment")
ordering = ('name',)
ordering_fields = ('name', 'address', 'connectivity', 'platform', 'date_updated', 'date_created')
serializer_classes = (
("default", serializers.AssetSerializer),

View File

@ -12,6 +12,6 @@ class Migration(migrations.Migration):
operations = [
migrations.AlterModelOptions(
name='asset',
options={'ordering': ['name'], 'permissions': [('refresh_assethardwareinfo', 'Can refresh asset hardware info'), ('test_assetconnectivity', 'Can test asset connectivity'), ('match_asset', 'Can match asset'), ('change_assetnodes', 'Can change asset nodes')], 'verbose_name': 'Asset'},
options={'ordering': [], 'permissions': [('refresh_assethardwareinfo', 'Can refresh asset hardware info'), ('test_assetconnectivity', 'Can test asset connectivity'), ('match_asset', 'Can match asset'), ('change_assetnodes', 'Can change asset nodes')], 'verbose_name': 'Asset'},
),
]

View File

@ -348,7 +348,7 @@ class Asset(NodesRelationMixin, LabeledMixin, AbsConnectivity, JSONFilterMixin,
class Meta:
unique_together = [('org_id', 'name')]
verbose_name = _("Asset")
ordering = ["name", ]
ordering = []
permissions = [
('refresh_assethardwareinfo', _('Can refresh asset hardware info')),
('test_assetconnectivity', _('Can test asset connectivity')),

View File

@ -429,7 +429,7 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
@classmethod
@timeit
def get_nodes_all_assets(cls, *nodes):
def get_nodes_all_assets(cls, *nodes, distinct=True):
from .asset import Asset
node_ids = set()
descendant_node_query = Q()
@ -439,7 +439,10 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
if descendant_node_query:
_ids = Node.objects.order_by().filter(descendant_node_query).values_list('id', flat=True)
node_ids.update(_ids)
return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct()
assets = Asset.objects.order_by().filter(nodes__id__in=node_ids)
if distinct:
assets = assets.distinct()
return assets
def get_all_asset_ids(self):
asset_ids = self.get_all_asset_ids_by_node_key(org_id=self.org_id, node_key=self.key)

View File

@ -1,8 +1,8 @@
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.request import Request
from common.utils import get_logger
from assets.models import Node
from common.utils import get_logger
logger = get_logger(__name__)
@ -28,6 +28,7 @@ class AssetPaginationBase(LimitOffsetPagination):
'key', 'all', 'show_current_asset',
'cache_policy', 'display', 'draw',
'order', 'node', 'node_id', 'fields_size',
'asset'
}
for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None:

View File

@ -206,9 +206,12 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
""" Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('domain', 'nodes', 'protocols', ) \
.prefetch_related('platform', 'platform__automation') \
.prefetch_related('labels', 'labels__label') \
.annotate(category=F("platform__category")) \
.annotate(type=F("platform__type"))
if queryset.model is Asset:
queryset = queryset.prefetch_related('labels__label', 'labels')
else:
queryset = queryset.prefetch_related('asset_ptr__labels__label', 'asset_ptr__labels')
return queryset
@staticmethod

View File

@ -56,7 +56,14 @@ class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
class DomainListSerializer(DomainSerializer):
class Meta(DomainSerializer.Meta):
fields = list(set(DomainSerializer.Meta.fields) - {'assets'})
fields = list(set(DomainSerializer.Meta.fields + ['assets_amount']) - {'assets'})
@classmethod
def setup_eager_loading(cls, queryset):
queryset = queryset.annotate(
assets_amount=Count('assets', distinct=True),
)
return queryset
class DomainWithGatewaySerializer(serializers.ModelSerializer):

View File

@ -191,7 +191,6 @@ class PlatformSerializer(ResourceLabelsMixin, WritableNestedModelSerializer):
def add_type_choices(self, name, label):
tp = self.fields['type']
tp.choices[name] = label
tp.choice_mapper[name] = label
tp.choice_strings_to_values[name] = label
@lazyproperty

View File

@ -63,13 +63,13 @@ def on_asset_create(sender, instance=None, created=False, **kwargs):
return
logger.info("Asset create signal recv: {}".format(instance))
ensure_asset_has_node(assets=(instance,))
ensure_asset_has_node.delay(assets=(instance,))
# 获取资产硬件信息
auto_config = instance.auto_config
if auto_config.get('ping_enabled'):
logger.debug('Asset {} ping enabled, test connectivity'.format(instance.name))
test_assets_connectivity_handler(assets=(instance,))
test_assets_connectivity_handler.delay(assets=(instance,))
if auto_config.get('gather_facts_enabled'):
logger.debug('Asset {} gather facts enabled, gather facts'.format(instance.name))
gather_assets_facts_handler(assets=(instance,))

View File

@ -2,14 +2,16 @@
#
from operator import add, sub
from django.conf import settings
from django.db.models.signals import m2m_changed
from django.dispatch import receiver
from assets.models import Asset, Node
from common.const.signals import PRE_CLEAR, POST_ADD, PRE_REMOVE
from common.decorators import on_transaction_commit, merge_delay_run
from common.signals import django_ready
from common.utils import get_logger
from orgs.utils import tmp_to_org
from orgs.utils import tmp_to_org, tmp_to_root_org
from ..tasks import check_node_assets_amount_task
logger = get_logger(__file__)
@ -34,7 +36,7 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
node_ids = [instance.id]
else:
node_ids = list(pk_set)
update_nodes_assets_amount(node_ids=node_ids)
update_nodes_assets_amount.delay(node_ids=node_ids)
@merge_delay_run(ttl=30)
@ -52,3 +54,18 @@ def update_nodes_assets_amount(node_ids=()):
node.assets_amount = node.get_assets_amount()
Node.objects.bulk_update(nodes, ['assets_amount'])
@receiver(django_ready)
def set_assets_size_to_setting(sender, **kwargs):
from assets.models import Asset
try:
with tmp_to_root_org():
amount = Asset.objects.order_by().count()
except:
amount = 0
if amount > 20000:
settings.ASSET_SIZE = 'large'
elif amount > 2000:
settings.ASSET_SIZE = 'medium'

View File

@ -44,18 +44,18 @@ def on_node_post_create(sender, instance, created, update_fields, **kwargs):
need_expire = False
if need_expire:
expire_node_assets_mapping(org_ids=(instance.org_id,))
expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
@receiver(post_delete, sender=Node)
def on_node_post_delete(sender, instance, **kwargs):
expire_node_assets_mapping(org_ids=(instance.org_id,))
expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, instance, action='pre_remove', **kwargs):
if action.startswith('post'):
expire_node_assets_mapping(org_ids=(instance.org_id,))
expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
@receiver(django_ready)

View File

@ -2,6 +2,7 @@
from django.urls import path
from rest_framework_bulk.routes import BulkRouter
from labels.api import LabelViewSet
from .. import api
app_name = 'assets'
@ -22,6 +23,7 @@ router.register(r'domains', api.DomainViewSet, 'domain')
router.register(r'gateways', api.GatewayViewSet, 'gateway')
router.register(r'favorite-assets', api.FavoriteAssetViewSet, 'favorite-asset')
router.register(r'protocol-settings', api.PlatformProtocolViewSet, 'protocol-setting')
router.register(r'labels', LabelViewSet, 'label')
urlpatterns = [
# path('assets/<uuid:pk>/gateways/', api.AssetGatewayListApi.as_view(), name='asset-gateway-list'),

View File

@ -4,7 +4,6 @@ from urllib.parse import urlencode, urlparse
from kubernetes import client
from kubernetes.client import api_client
from kubernetes.client.api import core_v1_api
from kubernetes.client.exceptions import ApiException
from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError
from common.utils import get_logger
@ -88,8 +87,9 @@ class KubernetesClient:
if hasattr(self, func_name):
try:
data = getattr(self, func_name)(*args)
except ApiException as e:
logger.error(e.reason)
except Exception as e:
logger.error(e)
raise e
if self.server:
self.server.stop()

View File

@ -20,6 +20,7 @@ from common.const.http import GET, POST
from common.drf.filters import DatetimeRangeFilterBackend
from common.permissions import IsServiceAccount
from common.plugins.es import QuerySet as ESQuerySet
from common.sessions.cache import user_session_manager
from common.storage.ftp_file import FTPFileStorageHandler
from common.utils import is_uuid, get_logger, lazyproperty
from orgs.mixins.api import OrgReadonlyModelViewSet, OrgModelViewSet
@ -289,8 +290,7 @@ class UserSessionViewSet(CommonApiMixin, viewsets.ModelViewSet):
return Response(status=status.HTTP_200_OK)
keys = queryset.values_list('key', flat=True)
session_store_cls = import_module(settings.SESSION_ENGINE).SessionStore
for key in keys:
session_store_cls(key).delete()
user_session_manager.decrement_or_remove(key)
queryset.delete()
return Response(status=status.HTTP_200_OK)

View File

@ -1,10 +1,9 @@
from django.core.cache import cache
from django_filters import rest_framework as drf_filters
from rest_framework import filters
from rest_framework.compat import coreapi, coreschema
from common.drf.filters import BaseFilterSet
from notifications.ws import WS_SESSION_KEY
from common.sessions.cache import user_session_manager
from orgs.utils import current_org
from .models import UserSession
@ -41,13 +40,11 @@ class UserSessionFilterSet(BaseFilterSet):
@staticmethod
def filter_is_active(queryset, name, is_active):
redis_client = cache.client.get_client()
members = redis_client.smembers(WS_SESSION_KEY)
members = [member.decode('utf-8') for member in members]
keys = user_session_manager.get_active_keys()
if is_active:
queryset = queryset.filter(key__in=members)
queryset = queryset.filter(key__in=keys)
else:
queryset = queryset.exclude(key__in=members)
queryset = queryset.exclude(key__in=keys)
return queryset
class Meta:

View File

@ -4,15 +4,15 @@ from datetime import timedelta
from importlib import import_module
from django.conf import settings
from django.core.cache import caches, cache
from django.core.cache import caches
from django.db import models
from django.db.models import Q
from django.utils import timezone
from django.utils.translation import gettext, gettext_lazy as _
from common.db.encoder import ModelJSONFieldEncoder
from common.sessions.cache import user_session_manager
from common.utils import lazyproperty, i18n_trans
from notifications.ws import WS_SESSION_KEY
from ops.models import JobExecution
from orgs.mixins.models import OrgModelMixin, Organization
from orgs.utils import current_org
@ -278,8 +278,7 @@ class UserSession(models.Model):
@property
def is_active(self):
redis_client = cache.client.get_client()
return redis_client.sismember(WS_SESSION_KEY, self.key)
return user_session_manager.check_active(self.key)
@property
def date_expired(self):

View File

@ -205,7 +205,7 @@ class RDPFileClientProtocolURLMixin:
return data
def get_smart_endpoint(self, protocol, asset=None):
endpoint = Endpoint.match_by_instance_label(asset, protocol)
endpoint = Endpoint.match_by_instance_label(asset, protocol, self.request)
if not endpoint:
target_ip = asset.get_target_ip() if asset else ''
endpoint = EndpointRule.match_endpoint(

View File

@ -90,6 +90,6 @@ class MFAChallengeVerifyApi(AuthMixin, CreateAPIView):
return Response({'msg': 'ok'})
except errors.AuthFailedError as e:
data = {"error": e.error, "msg": e.msg}
raise ValidationError(data)
return Response(data, status=401)
except errors.NeedMoreInfoError as e:
return Response(e.as_data(), status=200)

View File

@ -10,6 +10,7 @@ from rest_framework import authentication, exceptions
from common.auth import signature
from common.decorators import merge_delay_run
from common.utils import get_object_or_none, get_request_ip_or_data, contains_ip
from users.models import User
from ..models import AccessKey, PrivateToken
@ -19,22 +20,23 @@ def date_more_than(d, seconds):
@merge_delay_run(ttl=60)
def update_token_last_used(tokens=()):
for token in tokens:
token.date_last_used = timezone.now()
token.save(update_fields=['date_last_used'])
access_keys_ids = [token.id for token in tokens if isinstance(token, AccessKey)]
private_token_keys = [token.key for token in tokens if isinstance(token, PrivateToken)]
if len(access_keys_ids) > 0:
AccessKey.objects.filter(id__in=access_keys_ids).update(date_last_used=timezone.now())
if len(private_token_keys) > 0:
PrivateToken.objects.filter(key__in=private_token_keys).update(date_last_used=timezone.now())
@merge_delay_run(ttl=60)
def update_user_last_used(users=()):
for user in users:
user.date_api_key_last_used = timezone.now()
user.save(update_fields=['date_api_key_last_used'])
User.objects.filter(id__in=users).update(date_api_key_last_used=timezone.now())
def after_authenticate_update_date(user, token=None):
update_user_last_used(users=(user,))
update_user_last_used.delay(users=(user.id,))
if token:
update_token_last_used(tokens=(token,))
update_token_last_used.delay(tokens=(token,))
class AccessTokenAuthentication(authentication.BaseAuthentication):

View File

@ -98,16 +98,19 @@ class OAuth2Backend(JMSModelBackend):
access_token_url = '{url}{separator}{query}'.format(
url=settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT, separator=separator, query=urlencode(query_dict)
)
# token_method -> get, post(post_data), post_json
token_method = settings.AUTH_OAUTH2_ACCESS_TOKEN_METHOD.lower()
requests_func = getattr(requests, token_method, requests.get)
logger.debug(log_prompt.format('Call the access token endpoint[method: %s]' % token_method))
headers = {
'Accept': 'application/json'
}
if token_method == 'post':
access_token_response = requests_func(access_token_url, headers=headers, data=query_dict)
if token_method.startswith('post'):
body_key = 'json' if token_method.endswith('json') else 'data'
access_token_response = requests.post(
access_token_url, headers=headers, **{body_key: query_dict}
)
else:
access_token_response = requests_func(access_token_url, headers=headers)
access_token_response = requests.get(access_token_url, headers=headers)
try:
access_token_response.raise_for_status()
access_token_response_data = access_token_response.json()

View File

@ -18,7 +18,7 @@ class EncryptedField(forms.CharField):
class UserLoginForm(forms.Form):
days_auto_login = int(settings.SESSION_COOKIE_AGE / 3600 / 24)
disable_days_auto_login = settings.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE \
disable_days_auto_login = settings.SESSION_EXPIRE_AT_BROWSER_CLOSE \
or days_auto_login < 1
username = forms.CharField(

View File

@ -142,23 +142,7 @@ class SessionCookieMiddleware(MiddlewareMixin):
return response
response.set_cookie(key, value)
@staticmethod
def set_cookie_session_expire(request, response):
if not request.session.get('auth_session_expiration_required'):
return
value = 'age'
if settings.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE or \
not request.session.get('auto_login', False):
value = 'close'
age = request.session.get_expiry_age()
expire_timestamp = request.session.get_expiry_date().timestamp()
response.set_cookie('jms_session_expire_timestamp', expire_timestamp)
response.set_cookie('jms_session_expire', value, max_age=age)
request.session.pop('auth_session_expiration_required', None)
def process_response(self, request, response: HttpResponse):
self.set_cookie_session_prefix(request, response)
self.set_cookie_public_key(request, response)
self.set_cookie_session_expire(request, response)
return response

View File

@ -37,9 +37,6 @@ def on_user_auth_login_success(sender, user, request, **kwargs):
UserSession.objects.filter(key=session_key).delete()
cache.set(lock_key, request.session.session_key, None)
# 标记登录,设置 cookie前端可以控制刷新, Middleware 会拦截这个生成 cookie
request.session['auth_session_expiration_required'] = 1
@receiver(cas_user_authenticated)
def on_cas_user_login_success(sender, request, user, **kwargs):

View File

@ -70,11 +70,12 @@ class DingTalkQRMixin(DingTalkBaseMixin, View):
self.request.session[DINGTALK_STATE_SESSION_KEY] = state
params = {
'appid': settings.DINGTALK_APPKEY,
'client_id': settings.DINGTALK_APPKEY,
'response_type': 'code',
'scope': 'snsapi_login',
'scope': 'openid',
'state': state,
'redirect_uri': redirect_uri,
'prompt': 'consent'
}
url = URL.QR_CONNECT + '?' + urlencode(params)
return url

View File

@ -104,9 +104,11 @@ class QuerySetMixin:
page = super().paginate_queryset(queryset)
serializer_class = self.get_serializer_class()
if page and serializer_class and hasattr(serializer_class, 'setup_eager_loading'):
ids = [i.id for i in page]
ids = [str(obj.id) for obj in page]
page = self.get_queryset().filter(id__in=ids)
page = serializer_class.setup_eager_loading(page)
page_mapper = {str(obj.id): obj for obj in page}
page = [page_mapper.get(_id) for _id in ids if _id in page_mapper]
return page

View File

@ -19,3 +19,17 @@ class Status(models.TextChoices):
failed = 'failed', _("Failed")
error = 'error', _("Error")
canceled = 'canceled', _("Canceled")
COUNTRY_CALLING_CODES = [
{'name': 'China(中国)', 'value': '+86'},
{'name': 'HongKong(中国香港)', 'value': '+852'},
{'name': 'Macao(中国澳门)', 'value': '+853'},
{'name': 'Taiwan(中国台湾)', 'value': '+886'},
{'name': 'America(America)', 'value': '+1'}, {'name': 'Russia(Россия)', 'value': '+7'},
{'name': 'France(français)', 'value': '+33'},
{'name': 'Britain(Britain)', 'value': '+44'},
{'name': 'Germany(Deutschland)', 'value': '+49'},
{'name': 'Japan(日本)', 'value': '+81'}, {'name': 'Korea(한국)', 'value': '+82'},
{'name': 'India(भारत)', 'value': '+91'}
]

View File

@ -362,11 +362,15 @@ class RelatedManager:
if name is None or val is None:
continue
if custom_attr_filter:
custom_filter_q = None
spec_attr_filter = getattr(to_model, "get_{}_filter_attr_q".format(name), None)
if spec_attr_filter:
custom_filter_q = spec_attr_filter(val, match)
elif custom_attr_filter:
custom_filter_q = custom_attr_filter(name, val, match)
if custom_filter_q:
filters.append(custom_filter_q)
continue
if custom_filter_q:
filters.append(custom_filter_q)
continue
if match == 'ip_in':
q = cls.get_ip_in_q(name, val)
@ -464,11 +468,15 @@ class JSONManyToManyDescriptor:
rule_value = rule.get('value', '')
rule_match = rule.get('match', 'exact')
if custom_attr_filter:
q = custom_attr_filter(rule['name'], rule_value, rule_match)
if q:
custom_q &= q
continue
custom_filter_q = None
spec_attr_filter = getattr(to_model, "get_filter_{}_attr_q".format(rule['name']), None)
if spec_attr_filter:
custom_filter_q = spec_attr_filter(rule_value, rule_match)
elif custom_attr_filter:
custom_filter_q = custom_attr_filter(rule['name'], rule_value, rule_match)
if custom_filter_q:
custom_q &= custom_filter_q
continue
if rule_match == 'in':
res &= value in rule_value or '*' in rule_value
@ -517,7 +525,6 @@ class JSONManyToManyDescriptor:
res &= rule_value.issubset(value)
else:
res &= bool(value & rule_value)
else:
logging.error("unknown match: {}".format(rule['match']))
res &= False

View File

@ -3,6 +3,7 @@
import asyncio
import functools
import inspect
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor
@ -101,7 +102,11 @@ def run_debouncer_func(cache_key, org, ttl, func, *args, **kwargs):
first_run_time = current
if current - first_run_time > ttl:
_loop_debouncer_func_args_cache.pop(cache_key, None)
_loop_debouncer_func_task_time_cache.pop(cache_key, None)
executor.submit(run_func_partial, *args, **kwargs)
logger.debug('pid {} executor submit run {}'.format(
os.getpid(), func.__name__, ))
return
loop = _loop_thread.get_loop()
@ -133,13 +138,26 @@ class Debouncer(object):
return await self.loop.run_in_executor(self.executor, func)
ignore_err_exceptions = (
"(3101, 'Plugin instructed the server to rollback the current transaction.')",
)
def _run_func_with_org(key, org, func, *args, **kwargs):
from orgs.utils import set_current_org
try:
set_current_org(org)
func(*args, **kwargs)
with transaction.atomic():
set_current_org(org)
func(*args, **kwargs)
except Exception as e:
logger.error('delay run error: %s' % e)
msg = str(e)
log_func = logger.error
if msg in ignore_err_exceptions:
log_func = logger.info
pid = os.getpid()
thread_name = threading.current_thread()
log_func('pid {} thread {} delay run {} error: {}'.format(
pid, thread_name, func.__name__, msg))
_loop_debouncer_func_task_cache.pop(key, None)
_loop_debouncer_func_args_cache.pop(key, None)
_loop_debouncer_func_task_time_cache.pop(key, None)
@ -181,6 +199,32 @@ def merge_delay_run(ttl=5, key=None):
:return:
"""
def delay(func, *args, **kwargs):
from orgs.utils import get_current_org
suffix_key_func = key if key else default_suffix_key
org = get_current_org()
func_name = f'{func.__module__}_{func.__name__}'
key_suffix = suffix_key_func(*args, **kwargs)
cache_key = f'MERGE_DELAY_RUN_{func_name}_{key_suffix}'
cache_kwargs = _loop_debouncer_func_args_cache.get(cache_key, {})
for k, v in kwargs.items():
if not isinstance(v, (tuple, list, set)):
raise ValueError('func kwargs value must be list or tuple: %s %s' % (func.__name__, v))
v = set(v)
if k not in cache_kwargs:
cache_kwargs[k] = v
else:
cache_kwargs[k] = cache_kwargs[k].union(v)
_loop_debouncer_func_args_cache[cache_key] = cache_kwargs
run_debouncer_func(cache_key, org, ttl, func, *args, **cache_kwargs)
def apply(func, sync=False, *args, **kwargs):
if sync:
return func(*args, **kwargs)
else:
return delay(func, *args, **kwargs)
def inner(func):
sigs = inspect.signature(func)
if len(sigs.parameters) != 1:
@ -188,27 +232,12 @@ def merge_delay_run(ttl=5, key=None):
param = list(sigs.parameters.values())[0]
if not isinstance(param.default, tuple):
raise ValueError('func default must be tuple: %s' % param.default)
suffix_key_func = key if key else default_suffix_key
func.delay = functools.partial(delay, func)
func.apply = functools.partial(apply, func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
from orgs.utils import get_current_org
org = get_current_org()
func_name = f'{func.__module__}_{func.__name__}'
key_suffix = suffix_key_func(*args, **kwargs)
cache_key = f'MERGE_DELAY_RUN_{func_name}_{key_suffix}'
cache_kwargs = _loop_debouncer_func_args_cache.get(cache_key, {})
for k, v in kwargs.items():
if not isinstance(v, (tuple, list, set)):
raise ValueError('func kwargs value must be list or tuple: %s %s' % (func.__name__, v))
v = set(v)
if k not in cache_kwargs:
cache_kwargs[k] = v
else:
cache_kwargs[k] = cache_kwargs[k].union(v)
_loop_debouncer_func_args_cache[cache_key] = cache_kwargs
run_debouncer_func(cache_key, org, ttl, func, *args, **cache_kwargs)
return func(*args, **kwargs)
return wrapper

View File

@ -6,7 +6,7 @@ import logging
from django.core.cache import cache
from django.core.exceptions import ImproperlyConfigured
from django.db.models import Q, Count
from django.db.models import Q
from django_filters import rest_framework as drf_filters
from rest_framework import filters
from rest_framework.compat import coreapi, coreschema
@ -180,36 +180,30 @@ class LabelFilterBackend(filters.BaseFilterBackend):
]
@staticmethod
def filter_resources(resources, labels_id):
def parse_label_ids(labels_id):
from labels.models import Label
label_ids = [i.strip() for i in labels_id.split(',')]
cleaned = []
args = []
for label_id in label_ids:
kwargs = {}
if ':' in label_id:
k, v = label_id.split(':', 1)
kwargs['label__name'] = k.strip()
kwargs['name'] = k.strip()
if v != '*':
kwargs['label__value'] = v.strip()
kwargs['value'] = v.strip()
args.append(kwargs)
else:
kwargs['label_id'] = label_id
args.append(kwargs)
cleaned.append(label_id)
if len(args) == 1:
resources = resources.filter(**args[0])
return resources
q = Q()
for kwarg in args:
q |= Q(**kwarg)
resources = resources.filter(q) \
.values('res_id') \
.order_by('res_id') \
.annotate(count=Count('res_id')) \
.values('res_id', 'count') \
.filter(count=len(args))
return resources
if len(args) != 0:
q = Q()
for kwarg in args:
q |= Q(**kwarg)
ids = Label.objects.filter(q).values_list('id', flat=True)
cleaned.extend(list(ids))
return cleaned
def filter_queryset(self, request, queryset, view):
labels_id = request.query_params.get('labels')
@ -223,14 +217,15 @@ class LabelFilterBackend(filters.BaseFilterBackend):
return queryset
model = queryset.model.label_model()
labeled_resource_cls = model._labels.field.related_model
labeled_resource_cls = model.labels.field.related_model
app_label = model._meta.app_label
model_name = model._meta.model_name
resources = labeled_resource_cls.objects.filter(
res_type__app_label=app_label, res_type__model=model_name,
)
resources = self.filter_resources(resources, labels_id)
label_ids = self.parse_label_ids(labels_id)
resources = model.filter_resources_by_labels(resources, label_ids)
res_ids = resources.values_list('res_id', flat=True)
queryset = queryset.filter(id__in=set(res_ids))
return queryset

View File

@ -14,6 +14,7 @@ class CeleryBaseService(BaseService):
print('\n- Start Celery as Distributed Task Queue: {}'.format(self.queue.capitalize()))
ansible_config_path = os.path.join(settings.APPS_DIR, 'ops', 'ansible', 'ansible.cfg')
ansible_modules_path = os.path.join(settings.APPS_DIR, 'ops', 'ansible', 'modules')
os.environ.setdefault('LC_ALL', 'C.UTF-8')
os.environ.setdefault('PYTHONOPTIMIZE', '1')
os.environ.setdefault('ANSIBLE_FORCE_COLOR', 'True')
os.environ.setdefault('ANSIBLE_CONFIG', ansible_config_path)

View File

@ -28,9 +28,10 @@ class ErrorCode:
class URL:
QR_CONNECT = 'https://oapi.dingtalk.com/connect/qrconnect'
QR_CONNECT = 'https://login.dingtalk.com/oauth2/auth'
OAUTH_CONNECT = 'https://oapi.dingtalk.com/connect/oauth2/sns_authorize'
GET_USER_INFO_BY_CODE = 'https://oapi.dingtalk.com/sns/getuserinfo_bycode'
GET_USER_ACCESSTOKEN = 'https://api.dingtalk.com/v1.0/oauth2/userAccessToken'
GET_USER_INFO = 'https://api.dingtalk.com/v1.0/contact/users/me'
GET_TOKEN = 'https://oapi.dingtalk.com/gettoken'
SEND_MESSAGE_BY_TEMPLATE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/sendbytemplate'
SEND_MESSAGE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/asyncsend_v2'
@ -72,8 +73,9 @@ class DingTalkRequests(BaseRequest):
def get(self, url, params=None,
with_token=False, with_sign=False,
check_errcode_is_0=True,
**kwargs):
**kwargs) -> dict:
pass
get = as_request(get)
def post(self, url, json=None, params=None,
@ -81,6 +83,7 @@ class DingTalkRequests(BaseRequest):
check_errcode_is_0=True,
**kwargs) -> dict:
pass
post = as_request(post)
def _add_sign(self, kwargs: dict):
@ -123,17 +126,22 @@ class DingTalk:
)
def get_userinfo_bycode(self, code):
# https://developers.dingtalk.com/document/app/obtain-the-user-information-based-on-the-sns-temporary-authorization?spm=ding_open_doc.document.0.0.3a256573y8Y7yg#topic-1995619
body = {
"tmp_auth_code": code
'clientId': self._appid,
'clientSecret': self._appsecret,
'code': code,
'grantType': 'authorization_code'
}
data = self._request.post(URL.GET_USER_ACCESSTOKEN, json=body, check_errcode_is_0=False)
token = data['accessToken']
data = self._request.post(URL.GET_USER_INFO_BY_CODE, json=body, with_sign=True)
return data['user_info']
user = self._request.get(URL.GET_USER_INFO,
headers={'x-acs-dingtalk-access-token': token}, check_errcode_is_0=False)
return user
def get_user_id_by_code(self, code):
user_info = self.get_userinfo_bycode(code)
unionid = user_info['unionid']
unionid = user_info['unionId']
userid = self.get_userid_by_unionid(unionid)
return userid, None

View File

@ -394,20 +394,20 @@ class CommonBulkModelSerializer(CommonBulkSerializerMixin, serializers.ModelSeri
class ResourceLabelsMixin(serializers.Serializer):
labels = LabelRelatedField(many=True, label=_('Labels'), required=False, allow_null=True)
labels = LabelRelatedField(many=True, label=_('Labels'), required=False, allow_null=True, source='res_labels')
def update(self, instance, validated_data):
labels = validated_data.pop('labels', None)
labels = validated_data.pop('res_labels', None)
res = super().update(instance, validated_data)
if labels is not None:
instance.labels.set(labels, bulk=False)
instance.res_labels.set(labels, bulk=False)
return res
def create(self, validated_data):
labels = validated_data.pop('labels', None)
labels = validated_data.pop('res_labels', None)
instance = super().create(validated_data)
if labels is not None:
instance.labels.set(labels, bulk=False)
instance.res_labels.set(labels, bulk=False)
return instance
@classmethod

View File

View File

@ -0,0 +1,56 @@
import re
from django.contrib.sessions.backends.cache import (
SessionStore as DjangoSessionStore
)
from django.core.cache import cache
from jumpserver.utils import get_current_request
class SessionStore(DjangoSessionStore):
ignore_urls = [
r'^/api/v1/users/profile/'
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ignore_pattern = re.compile('|'.join(self.ignore_urls))
def save(self, *args, **kwargs):
request = get_current_request()
if request is None or not self.ignore_pattern.match(request.path):
super().save(*args, **kwargs)
class RedisUserSessionManager:
JMS_SESSION_KEY = 'jms_session_key'
def __init__(self):
self.client = cache.client.get_client()
def add_or_increment(self, session_key):
self.client.hincrby(self.JMS_SESSION_KEY, session_key, 1)
def decrement_or_remove(self, session_key):
new_count = self.client.hincrby(self.JMS_SESSION_KEY, session_key, -1)
if new_count <= 0:
self.client.hdel(self.JMS_SESSION_KEY, session_key)
def check_active(self, session_key):
count = self.client.hget(self.JMS_SESSION_KEY, session_key)
count = 0 if count is None else int(count.decode('utf-8'))
return count > 0
def get_active_keys(self):
session_keys = []
for k, v in self.client.hgetall(self.JMS_SESSION_KEY).items():
count = int(v.decode('utf-8'))
if count <= 0:
continue
key = k.decode('utf-8')
session_keys.append(key)
return session_keys
user_session_manager = RedisUserSessionManager()

View File

@ -69,7 +69,7 @@ def digest_sql_query():
for query in queries:
sql = query['sql']
print(" # {}: {}".format(query['time'], sql[:1000]))
print(" # {}: {}".format(query['time'], sql[:1000]))
if len(queries) < 3:
continue
print("- Table: {}".format(table_name))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -282,6 +282,7 @@ class Config(dict):
'AUTH_LDAP_SYNC_INTERVAL': None,
'AUTH_LDAP_SYNC_CRONTAB': None,
'AUTH_LDAP_SYNC_ORG_IDS': ['00000000-0000-0000-0000-000000000002'],
'AUTH_LDAP_SYNC_RECEIVERS': [],
'AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS': False,
'AUTH_LDAP_OPTIONS_OPT_REFERRALS': -1,
@ -546,7 +547,6 @@ class Config(dict):
'REFERER_CHECK_ENABLED': False,
'SESSION_ENGINE': 'cache',
'SESSION_SAVE_EVERY_REQUEST': True,
'SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE': False,
'SERVER_REPLAY_STORAGE': {},
'SECURITY_DATA_CRYPTO_ALGO': None,
'GMSSL_ENABLED': False,
@ -605,7 +605,9 @@ class Config(dict):
'GPT_MODEL': 'gpt-3.5-turbo',
'VIRTUAL_APP_ENABLED': False,
'FILE_UPLOAD_SIZE_LIMIT_MB': 200
'FILE_UPLOAD_SIZE_LIMIT_MB': 200,
'TICKET_APPLY_ASSET_SCOPE': 'all'
}
old_config_map = {

View File

@ -66,11 +66,6 @@ class RequestMiddleware:
def __call__(self, request):
set_current_request(request)
response = self.get_response(request)
is_request_api = request.path.startswith('/api')
if not settings.SESSION_EXPIRE_AT_BROWSER_CLOSE and \
not is_request_api:
age = request.session.get_expiry_age()
request.session.set_expiry(age)
return response

View File

@ -3,6 +3,7 @@
path_perms_map = {
'xpack': '*',
'settings': '*',
'img': '*',
'replay': 'default',
'applets': 'terminal.view_applet',
'virtual_apps': 'terminal.view_virtualapp',

View File

@ -0,0 +1,14 @@
from private_storage.servers import NginxXAccelRedirectServer, DjangoServer
class StaticFileServer(object):
@staticmethod
def serve(private_file):
full_path = private_file.full_path
# todo: gzip 文件录像 nginx 处理后,浏览器无法正常解析内容
# 造成在线播放失败,暂时仅使用 nginx 处理 mp4 录像文件
if full_path.endswith('.mp4'):
return NginxXAccelRedirectServer.serve(private_file)
else:
return DjangoServer.serve(private_file)

View File

@ -50,6 +50,7 @@ AUTH_LDAP_SYNC_IS_PERIODIC = CONFIG.AUTH_LDAP_SYNC_IS_PERIODIC
AUTH_LDAP_SYNC_INTERVAL = CONFIG.AUTH_LDAP_SYNC_INTERVAL
AUTH_LDAP_SYNC_CRONTAB = CONFIG.AUTH_LDAP_SYNC_CRONTAB
AUTH_LDAP_SYNC_ORG_IDS = CONFIG.AUTH_LDAP_SYNC_ORG_IDS
AUTH_LDAP_SYNC_RECEIVERS = CONFIG.AUTH_LDAP_SYNC_RECEIVERS
AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS = CONFIG.AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS
# ==============================================================================

View File

@ -234,11 +234,9 @@ CSRF_COOKIE_NAME = '{}csrftoken'.format(SESSION_COOKIE_NAME_PREFIX)
SESSION_COOKIE_NAME = '{}sessionid'.format(SESSION_COOKIE_NAME_PREFIX)
SESSION_COOKIE_AGE = CONFIG.SESSION_COOKIE_AGE
SESSION_EXPIRE_AT_BROWSER_CLOSE = True
# 自定义的配置SESSION_EXPIRE_AT_BROWSER_CLOSE 始终为 True, 下面这个来控制是否强制关闭后过期 cookie
SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE = CONFIG.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE
SESSION_SAVE_EVERY_REQUEST = CONFIG.SESSION_SAVE_EVERY_REQUEST
SESSION_ENGINE = "django.contrib.sessions.backends.{}".format(CONFIG.SESSION_ENGINE)
SESSION_EXPIRE_AT_BROWSER_CLOSE = CONFIG.SESSION_EXPIRE_AT_BROWSER_CLOSE
SESSION_ENGINE = "common.sessions.{}".format(CONFIG.SESSION_ENGINE)
MESSAGE_STORAGE = 'django.contrib.messages.storage.cookie.CookieStorage'
# Database
@ -319,9 +317,7 @@ MEDIA_ROOT = os.path.join(PROJECT_DIR, 'data', 'media').replace('\\', '/') + '/'
PRIVATE_STORAGE_ROOT = MEDIA_ROOT
PRIVATE_STORAGE_AUTH_FUNCTION = 'jumpserver.rewriting.storage.permissions.allow_access'
PRIVATE_STORAGE_INTERNAL_URL = '/private-media/'
PRIVATE_STORAGE_SERVER = 'nginx'
if DEBUG_DEV:
PRIVATE_STORAGE_SERVER = 'django'
PRIVATE_STORAGE_SERVER = 'jumpserver.rewriting.storage.servers.StaticFileServer'
# Use django-bootstrap-form to format template, input max width arg
# BOOTSTRAP_COLUMN_COUNT = 11

View File

@ -214,6 +214,9 @@ PERM_TREE_REGEN_INTERVAL = CONFIG.PERM_TREE_REGEN_INTERVAL
MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS
LIMIT_SUPER_PRIV = CONFIG.LIMIT_SUPER_PRIV
# Asset account may be too many
ASSET_SIZE = 'small'
# Chat AI
CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED
GPT_API_KEY = CONFIG.GPT_API_KEY
@ -224,3 +227,5 @@ GPT_MODEL = CONFIG.GPT_MODEL
VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED
FILE_UPLOAD_SIZE_LIMIT_MB = CONFIG.FILE_UPLOAD_SIZE_LIMIT_MB
TICKET_APPLY_ASSET_SCOPE = CONFIG.TICKET_APPLY_ASSET_SCOPE

View File

@ -1,14 +1,15 @@
from django.contrib.contenttypes.fields import GenericRelation
from django.db import models
from django.db.models import OneToOneField
from django.db.models import OneToOneField, Count
from common.utils import lazyproperty
from .models import LabeledResource
__all__ = ['LabeledMixin']
class LabeledMixin(models.Model):
_labels = GenericRelation(LabeledResource, object_id_field='res_id', content_type_field='res_type')
labels = GenericRelation(LabeledResource, object_id_field='res_id', content_type_field='res_type')
class Meta:
abstract = True
@ -21,7 +22,7 @@ class LabeledMixin(models.Model):
model = pk_field.related_model
return model
@property
@lazyproperty
def real(self):
pk_field = self._meta.pk
if isinstance(pk_field, OneToOneField):
@ -29,9 +30,43 @@ class LabeledMixin(models.Model):
return self
@property
def labels(self):
return self.real._labels
def res_labels(self):
return self.real.labels
@labels.setter
def labels(self, value):
self.real._labels.set(value, bulk=False)
@res_labels.setter
def res_labels(self, value):
self.real.labels.set(value, bulk=False)
@classmethod
def filter_resources_by_labels(cls, resources, label_ids):
return cls._get_filter_res_by_labels_m2m_all(resources, label_ids)
@classmethod
def _get_filter_res_by_labels_m2m_in(cls, resources, label_ids):
return resources.filter(label_id__in=label_ids)
@classmethod
def _get_filter_res_by_labels_m2m_all(cls, resources, label_ids):
if len(label_ids) == 1:
return cls._get_filter_res_by_labels_m2m_in(resources, label_ids)
resources = resources.filter(label_id__in=label_ids) \
.values('res_id') \
.order_by('res_id') \
.annotate(count=Count('res_id', distinct=True)) \
.values('res_id', 'count') \
.filter(count=len(label_ids))
return resources
@classmethod
def get_labels_filter_attr_q(cls, value, match):
resources = LabeledResource.objects.all()
if not value:
return None
if match != 'm2m_all':
resources = cls._get_filter_res_by_labels_m2m_in(resources, value)
else:
resources = cls._get_filter_res_by_labels_m2m_all(resources, value)
res_ids = set(resources.values_list('res_id', flat=True))
return models.Q(id__in=res_ids)

View File

@ -34,7 +34,7 @@ class LabelSerializer(BulkOrgResourceModelSerializer):
@classmethod
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
queryset = queryset.annotate(res_count=Count('labeled_resources'))
queryset = queryset.annotate(res_count=Count('labeled_resources', distinct=True))
return queryset

View File

@ -1,28 +1,32 @@
import json
import time
from threading import Thread
from channels.generic.websocket import JsonWebsocketConsumer
from django.core.cache import cache
from django.conf import settings
from common.db.utils import safe_db_connection
from common.sessions.cache import user_session_manager
from common.utils import get_logger
from .signal_handlers import new_site_msg_chan
from .site_msg import SiteMessageUtil
logger = get_logger(__name__)
WS_SESSION_KEY = 'ws_session_key'
class SiteMsgWebsocket(JsonWebsocketConsumer):
sub = None
refresh_every_seconds = 10
@property
def session(self):
return self.scope['session']
def connect(self):
user = self.scope["user"]
if user.is_authenticated:
self.accept()
session = self.scope['session']
redis_client = cache.client.get_client()
redis_client.sadd(WS_SESSION_KEY, session.session_key)
user_session_manager.add_or_increment(self.session.session_key)
self.sub = self.watch_recv_new_site_msg()
else:
self.close()
@ -66,6 +70,32 @@ class SiteMsgWebsocket(JsonWebsocketConsumer):
if not self.sub:
return
self.sub.unsubscribe()
session = self.scope['session']
redis_client = cache.client.get_client()
redis_client.srem(WS_SESSION_KEY, session.session_key)
user_session_manager.decrement_or_remove(self.session.session_key)
if self.should_delete_session():
thread = Thread(target=self.delay_delete_session)
thread.start()
def should_delete_session(self):
return (self.session.modified or settings.SESSION_SAVE_EVERY_REQUEST) and \
not self.session.is_empty() and \
self.session.get_expire_at_browser_close() and \
not user_session_manager.check_active(self.session.session_key)
def delay_delete_session(self):
timeout = 3
check_interval = 0.5
start_time = time.time()
while time.time() - start_time < timeout:
time.sleep(check_interval)
if user_session_manager.check_active(self.session.session_key):
return
self.delete_session()
def delete_session(self):
try:
self.session.delete()
except Exception as e:
logger.info(f'delete session error: {e}')

View File

@ -4,6 +4,21 @@ import time
import paramiko
from sshtunnel import SSHTunnelForwarder
from packaging import version
if version.parse(paramiko.__version__) > version.parse("2.8.1"):
_preferred_pubkeys = (
"ssh-ed25519",
"ecdsa-sha2-nistp256",
"ecdsa-sha2-nistp384",
"ecdsa-sha2-nistp521",
"ssh-rsa",
"rsa-sha2-256",
"rsa-sha2-512",
"ssh-dss",
)
paramiko.transport.Transport._preferred_pubkeys = _preferred_pubkeys
def common_argument_spec():
options = dict(

View File

@ -2,6 +2,7 @@
#
import os
import re
from collections import defaultdict
from celery.result import AsyncResult
from django.shortcuts import get_object_or_404
@ -166,16 +167,58 @@ class CeleryTaskViewSet(
i.next_exec_time = now + next_run_at
return queryset
def generate_summary_state(self, execution_qs):
model = self.get_queryset().model
executions = execution_qs.order_by('-date_published').values('name', 'state')
summary_state_dict = defaultdict(
lambda: {
'states': [], 'state': 'green',
'summary': {'total': 0, 'success': 0}
}
)
for execution in executions:
name = execution['name']
state = execution['state']
summary = summary_state_dict[name]['summary']
summary['total'] += 1
summary['success'] += 1 if state == 'SUCCESS' else 0
states = summary_state_dict[name].get('states')
if states is not None and len(states) >= 5:
color = model.compute_state_color(states)
summary_state_dict[name]['state'] = color
summary_state_dict[name].pop('states', None)
elif isinstance(states, list):
states.append(state)
return summary_state_dict
def loading_summary_state(self, queryset):
if isinstance(queryset, list):
names = [i.name for i in queryset]
execution_qs = CeleryTaskExecution.objects.filter(name__in=names)
else:
execution_qs = CeleryTaskExecution.objects.all()
summary_state_dict = self.generate_summary_state(execution_qs)
for i in queryset:
i.summary = summary_state_dict.get(i.name, {}).get('summary', {})
i.state = summary_state_dict.get(i.name, {}).get('state', 'green')
return queryset
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
page = self.generate_execute_time(page)
page = self.loading_summary_state(page)
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
queryset = self.generate_execute_time(queryset)
queryset = self.loading_summary_state(queryset)
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)

View File

@ -246,6 +246,6 @@ class UsernameHintsAPI(APIView):
.filter(username__icontains=query) \
.filter(asset__in=assets) \
.values('username') \
.annotate(total=Count('username')) \
.annotate(total=Count('username', distinct=True)) \
.order_by('total', '-username')[:10]
return Response(data=top_accounts)

View File

@ -15,6 +15,9 @@ class CeleryTask(models.Model):
name = models.CharField(max_length=1024, verbose_name=_('Name'))
date_last_publish = models.DateTimeField(null=True, verbose_name=_("Date last publish"))
__summary = None
__state = None
@property
def meta(self):
task = app.tasks.get(self.name, None)
@ -25,25 +28,43 @@ class CeleryTask(models.Model):
@property
def summary(self):
if self.__summary is not None:
return self.__summary
executions = CeleryTaskExecution.objects.filter(name=self.name)
total = executions.count()
success = executions.filter(state='SUCCESS').count()
return {'total': total, 'success': success}
@summary.setter
def summary(self, value):
self.__summary = value
@staticmethod
def compute_state_color(states: list, default_count=5):
color = 'green'
states = states[:default_count]
if not states:
return color
if states[0] == 'FAILURE':
color = 'red'
elif 'FAILURE' in states:
color = 'yellow'
return color
@property
def state(self):
last_five_executions = CeleryTaskExecution.objects \
.filter(name=self.name) \
.order_by('-date_published')[:5]
if self.__state is not None:
return self.__state
last_five_executions = CeleryTaskExecution.objects.filter(
name=self.name
).order_by('-date_published').values('state')[:5]
states = [i['state'] for i in last_five_executions]
color = self.compute_state_color(states)
return color
if len(last_five_executions) > 0:
if last_five_executions[0].state == 'FAILURE':
return "red"
for execution in last_five_executions:
if execution.state == 'FAILURE':
return "yellow"
return "green"
@state.setter
def state(self, value):
self.__state = value
class Meta:
verbose_name = _("Celery Task")

View File

@ -67,6 +67,7 @@ class JMSPermedInventory(JMSInventory):
'postgresql': ['postgresql'],
'sqlserver': ['sqlserver'],
'ssh': ['shell', 'python', 'win_shell', 'raw'],
'winrm': ['win_shell', 'shell'],
}
if self.module not in protocol_supported_modules_mapping.get(protocol.name, []):

View File

@ -87,7 +87,8 @@ class OrgResourceStatisticsRefreshUtil:
if not cache_field_name:
return
org = getattr(instance, 'org', None)
cls.refresh_org_fields(((org, cache_field_name),))
cache_field_name = tuple(cache_field_name)
cls.refresh_org_fields.delay(org_fields=((org, cache_field_name),))
@receiver(post_save)

View File

@ -1,5 +1,6 @@
import abc
from django.conf import settings
from rest_framework.generics import ListAPIView, RetrieveAPIView
from assets.api.asset.asset import AssetFilterSet
@ -7,8 +8,7 @@ from assets.models import Asset, Node
from common.utils import get_logger, lazyproperty, is_uuid
from orgs.utils import tmp_to_root_org
from perms import serializers
from perms.pagination import AllPermedAssetPagination
from perms.pagination import NodePermedAssetPagination
from perms.pagination import NodePermedAssetPagination, AllPermedAssetPagination
from perms.utils import UserPermAssetUtil, PermAssetDetailUtil
from .mixin import (
SelfOrPKUserMixin
@ -39,7 +39,7 @@ class UserPermedAssetRetrieveApi(SelfOrPKUserMixin, RetrieveAPIView):
class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView):
ordering = ('name',)
ordering = []
search_fields = ('name', 'address', 'comment')
ordering_fields = ("name", "address")
filterset_class = AssetFilterSet
@ -48,6 +48,8 @@ class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView):
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
if settings.ASSET_SIZE == 'small':
self.ordering = ['name']
assets = self.get_assets()
assets = self.serializer_class.setup_eager_loading(assets)
return assets

View File

@ -14,6 +14,7 @@ from assets.api import SerializeToTreeNodeMixin
from assets.models import Asset
from assets.utils import KubernetesTree
from authentication.models import ConnectionToken
from common.exceptions import JMSException
from common.utils import get_object_or_none, lazyproperty
from common.utils.common import timeit
from perms.hands import Node
@ -181,6 +182,8 @@ class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi(BaseUserNodeWithAssetAsT
return self.query_asset_util.get_all_assets()
def _get_tree_nodes_async(self):
if self.request.query_params.get('lv') == '0':
return [], []
if not self.tp or not all(self.tp):
nodes = UserPermAssetUtil.get_type_nodes_tree_or_cached(self.user)
return nodes, []
@ -262,5 +265,8 @@ class UserGrantedK8sAsTreeApi(SelfOrPKUserMixin, ListAPIView):
if not any([namespace, pod]) and not key:
asset_node = k8s_tree_instance.as_asset_tree_node()
tree.append(asset_node)
tree.extend(k8s_tree_instance.async_tree_node(namespace, pod))
return Response(data=tree)
try:
tree.extend(k8s_tree_instance.async_tree_node(namespace, pod))
return Response(data=tree)
except Exception as e:
raise JMSException(e)

View File

@ -130,7 +130,7 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel):
qs1_ids = User.objects.filter(id__in=user_ids).distinct().values_list('id', flat=True)
qs2_ids = User.objects.filter(groups__id__in=group_ids).distinct().values_list('id', flat=True)
qs_ids = list(qs1_ids) + list(qs2_ids)
qs = User.objects.filter(id__in=qs_ids)
qs = User.objects.filter(id__in=qs_ids, is_service_account=False)
return qs
def get_all_assets(self, flat=False):

View File

@ -9,7 +9,7 @@ class PermedAssetsWillExpireUserMsg(UserMessage):
def __init__(self, user, assets, day_count=0):
super().__init__(user)
self.assets = assets
self.day_count = _('today') if day_count == 0 else day_count + _('day')
self.day_count = _('today') if day_count == 0 else str(day_count) + _('day')
def get_html_msg(self) -> dict:
subject = _("You permed assets is about to expire")
@ -41,7 +41,7 @@ class AssetPermsWillExpireForOrgAdminMsg(UserMessage):
super().__init__(user)
self.perms = perms
self.org = org
self.day_count = _('today') if day_count == 0 else day_count + _('day')
self.day_count = _('today') if day_count == 0 else str(day_count) + _('day')
def get_items_with_url(self):
items_with_url = []

View File

@ -197,9 +197,9 @@ class AssetPermissionListSerializer(AssetPermissionSerializer):
"""Perform necessary eager loading of data."""
queryset = queryset \
.prefetch_related('labels', 'labels__label') \
.annotate(users_amount=Count("users"),
user_groups_amount=Count("user_groups"),
assets_amount=Count("assets"),
nodes_amount=Count("nodes"),
.annotate(users_amount=Count("users", distinct=True),
user_groups_amount=Count("user_groups", distinct=True),
assets_amount=Count("assets", distinct=True),
nodes_amount=Count("nodes", distinct=True),
)
return queryset

View File

@ -8,9 +8,9 @@ from rest_framework import serializers
from accounts.models import Account
from assets.const import Category, AllTypes
from assets.models import Node, Asset, Platform
from assets.serializers.asset.common import AssetLabelSerializer, AssetProtocolsPermsSerializer
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from assets.serializers.asset.common import AssetProtocolsPermsSerializer
from common.serializers import ResourceLabelsMixin
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from orgs.mixins.serializers import OrgResourceModelSerializerMixin
from perms.serializers.permission import ActionChoicesField

View File

@ -13,7 +13,7 @@ class AssetPermissionUtil(object):
""" 资产授权相关的方法工具 """
@timeit
def get_permissions_for_user(self, user, with_group=True, flat=False):
def get_permissions_for_user(self, user, with_group=True, flat=False, with_expired=False):
""" 获取用户的授权规则 """
perm_ids = set()
# user
@ -25,7 +25,7 @@ class AssetPermissionUtil(object):
groups = user.groups.all()
group_perm_ids = self.get_permissions_for_user_groups(groups, flat=True)
perm_ids.update(group_perm_ids)
perms = self.get_permissions(ids=perm_ids)
perms = self.get_permissions(ids=perm_ids, with_expired=with_expired)
if flat:
return perms.values_list('id', flat=True)
return perms
@ -102,6 +102,8 @@ class AssetPermissionUtil(object):
return model.objects.filter(id__in=ids)
@staticmethod
def get_permissions(ids):
perms = AssetPermission.objects.filter(id__in=ids).valid().order_by('-date_expired')
return perms
def get_permissions(ids, with_expired=False):
perms = AssetPermission.objects.filter(id__in=ids)
if not with_expired:
perms = perms.valid()
return perms.order_by('-date_expired')

View File

@ -7,10 +7,10 @@ from django.db.models import Q
from rest_framework.utils.encoders import JSONEncoder
from assets.const import AllTypes
from assets.models import FavoriteAsset, Asset
from assets.models import FavoriteAsset, Asset, Node
from common.utils.common import timeit, get_logger
from orgs.utils import current_org, tmp_to_root_org
from perms.models import PermNode, UserAssetGrantedTreeNodeRelation
from perms.models import PermNode, UserAssetGrantedTreeNodeRelation, AssetPermission
from .permission import AssetPermissionUtil
__all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil']
@ -21,36 +21,37 @@ logger = get_logger(__name__)
class AssetPermissionPermAssetUtil:
def __init__(self, perm_ids):
self.perm_ids = perm_ids
self.perm_ids = set(perm_ids)
def get_all_assets(self):
""" 获取所有授权的资产 """
node_assets = self.get_perm_nodes_assets()
direct_assets = self.get_direct_assets()
# 比原来的查到所有 asset id 再搜索块很多,因为当资产量大的时候,搜索会很慢
return (node_assets | direct_assets).distinct()
return (node_assets | direct_assets).order_by().distinct()
def get_perm_nodes(self):
""" 获取所有授权节点 """
nodes_ids = AssetPermission.objects \
.filter(id__in=self.perm_ids) \
.values_list('nodes', flat=True)
nodes_ids = set(nodes_ids)
nodes = Node.objects.filter(id__in=nodes_ids).only('id', 'key')
return nodes
@timeit
def get_perm_nodes_assets(self, flat=False):
def get_perm_nodes_assets(self):
""" 获取所有授权节点下的资产 """
from assets.models import Node
nodes = Node.objects \
.prefetch_related('granted_by_permissions') \
.filter(granted_by_permissions__in=self.perm_ids) \
.only('id', 'key')
assets = PermNode.get_nodes_all_assets(*nodes)
if flat:
return set(assets.values_list('id', flat=True))
nodes = self.get_perm_nodes()
assets = PermNode.get_nodes_all_assets(*nodes, distinct=False)
return assets
@timeit
def get_direct_assets(self, flat=False):
def get_direct_assets(self):
""" 获取直接授权的资产 """
assets = Asset.objects.order_by() \
.filter(granted_by_permissions__id__in=self.perm_ids) \
.distinct()
if flat:
return set(assets.values_list('id', flat=True))
asset_ids = AssetPermission.assets.through.objects \
.filter(assetpermission_id__in=self.perm_ids) \
.values_list('asset_id', flat=True)
assets = Asset.objects.filter(id__in=asset_ids)
return assets
@ -152,6 +153,7 @@ class UserPermAssetUtil(AssetPermissionPermAssetUtil):
assets = assets.filter(nodes__id=node.id).order_by().distinct()
return assets
@timeit
def _get_indirect_perm_node_all_assets(self, node):
""" 获取间接授权节点下的所有资产
此算法依据 `UserAssetGrantedTreeNodeRelation` 的数据查询

View File

@ -72,7 +72,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
@timeit
def refresh_if_need(self, force=False):
built_just_now = cache.get(self.cache_key_time)
built_just_now = False if settings.ASSET_SIZE == 'small' else cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now))
return
@ -80,12 +80,18 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
if not to_refresh_orgs:
logger.info('Not have to refresh orgs')
return
logger.info("Delay refresh user orgs: {} {}".format(self.user, [o.name for o in to_refresh_orgs]))
refresh_user_orgs_perm_tree(user_orgs=((self.user, tuple(to_refresh_orgs)),))
refresh_user_favorite_assets(users=(self.user,))
sync = True if settings.ASSET_SIZE == 'small' else False
refresh_user_orgs_perm_tree.apply(sync=sync, user_orgs=((self.user, tuple(to_refresh_orgs)),))
refresh_user_favorite_assets.apply(sync=sync, users=(self.user,))
@timeit
def refresh_tree_manual(self):
"""
用来手动 debug
:return:
"""
built_just_now = cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh just now, pass: {}'.format(built_just_now))
@ -105,8 +111,9 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
return
self._clean_user_perm_tree_for_legacy_org()
ttl = settings.PERM_TREE_REGEN_INTERVAL
cache.set(self.cache_key_time, int(time.time()), ttl)
if settings.ASSET_SIZE != 'small':
ttl = settings.PERM_TREE_REGEN_INTERVAL
cache.set(self.cache_key_time, int(time.time()), ttl)
lock = UserGrantedTreeRebuildLock(self.user.id)
got = lock.acquire(blocking=False)
@ -193,7 +200,13 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin):
cache_key = self.get_cache_key(uid)
p.srem(cache_key, *org_ids)
p.execute()
logger.info('Expire perm tree for users: [{}], orgs: [{}]'.format(user_ids, org_ids))
users_display = ','.join([str(i) for i in user_ids[:3]])
if len(user_ids) > 3:
users_display += '...'
orgs_display = ','.join([str(i) for i in org_ids[:3]])
if len(org_ids) > 3:
orgs_display += '...'
logger.info('Expire perm tree for users: [{}], orgs: [{}]'.format(users_display, orgs_display))
def expire_perm_tree_for_all_user(self):
keys = self.client.keys(self.cache_key_all_user)

View File

@ -80,9 +80,11 @@ class RoleViewSet(JMSModelViewSet):
queryset = Role.objects.filter(id__in=ids).order_by(*self.ordering)
org_id = current_org.id
q = Q(role__scope=Role.Scope.system) | Q(role__scope=Role.Scope.org, org_id=org_id)
role_bindings = RoleBinding.objects.filter(q).values_list('role_id').annotate(user_count=Count('user_id'))
role_bindings = RoleBinding.objects.filter(q).values_list('role_id').annotate(
user_count=Count('user_id', distinct=True)
)
role_user_amount_mapper = {role_id: user_count for role_id, user_count in role_bindings}
queryset = queryset.annotate(permissions_amount=Count('permissions'))
queryset = queryset.annotate(permissions_amount=Count('permissions', distinct=True))
queryset = list(queryset)
for role in queryset:
role.users_amount = role_user_amount_mapper.get(role.id, 0)

View File

@ -137,7 +137,7 @@ class LDAPUserImportAPI(APIView):
return Response({'msg': _('Get ldap users is None')}, status=400)
orgs = self.get_orgs()
errors = LDAPImportUtil().perform_import(users, orgs)
new_users, errors = LDAPImportUtil().perform_import(users, orgs)
if errors:
return Response({'errors': errors}, status=400)

View File

@ -3,6 +3,7 @@ from rest_framework import generics
from rest_framework.permissions import AllowAny
from authentication.permissions import IsValidUserOrConnectionToken
from common.const.choices import COUNTRY_CALLING_CODES
from common.utils import get_logger, lazyproperty
from common.utils.timezone import local_now
from .. import serializers
@ -24,7 +25,8 @@ class OpenPublicSettingApi(generics.RetrieveAPIView):
def get_object(self):
return {
"XPACK_ENABLED": settings.XPACK_ENABLED,
"INTERFACE": self.interface_setting
"INTERFACE": self.interface_setting,
"COUNTRY_CALLING_CODES": COUNTRY_CALLING_CODES
}

View File

@ -0,0 +1,36 @@
from django.template.loader import render_to_string
from django.utils.translation import gettext as _
from common.utils import get_logger
from common.utils.timezone import local_now_display
from notifications.notifications import UserMessage
logger = get_logger(__file__)
class LDAPImportMessage(UserMessage):
def __init__(self, user, extra_kwargs):
super().__init__(user)
self.orgs = extra_kwargs.pop('orgs', [])
self.end_time = extra_kwargs.pop('end_time', '')
self.start_time = extra_kwargs.pop('start_time', '')
self.time_start_display = extra_kwargs.pop('time_start_display', '')
self.new_users = extra_kwargs.pop('new_users', [])
self.errors = extra_kwargs.pop('errors', [])
self.cost_time = extra_kwargs.pop('cost_time', '')
def get_html_msg(self) -> dict:
subject = _('Notification of Synchronized LDAP User Task Results')
context = {
'orgs': self.orgs,
'start_time': self.time_start_display,
'end_time': local_now_display(),
'cost_time': self.cost_time,
'users': self.new_users,
'errors': self.errors
}
message = render_to_string('ldap/_msg_import_ldap_user.html', context)
return {
'subject': subject,
'message': message
}

View File

@ -77,6 +77,9 @@ class LDAPSettingSerializer(serializers.Serializer):
required=False, label=_('Connect timeout (s)'),
)
AUTH_LDAP_SEARCH_PAGED_SIZE = serializers.IntegerField(required=False, label=_('Search paged size (piece)'))
AUTH_LDAP_SYNC_RECEIVERS = serializers.ListField(
required=False, label=_('Recipient'), max_length=36
)
AUTH_LDAP = serializers.BooleanField(required=False, label=_('Enable LDAP auth'))

View File

@ -43,7 +43,7 @@ class OAuth2SettingSerializer(serializers.Serializer):
)
AUTH_OAUTH2_ACCESS_TOKEN_METHOD = serializers.ChoiceField(
default='GET', label=_('Client authentication method'),
choices=(('GET', 'GET'), ('POST', 'POST'))
choices=(('GET', 'GET'), ('POST', 'POST-DATA'), ('POST_JSON', 'POST-JSON'))
)
AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT = serializers.CharField(
required=True, max_length=1024, label=_('Provider userinfo endpoint')

View File

@ -11,6 +11,7 @@ __all__ = [
class PublicSettingSerializer(serializers.Serializer):
XPACK_ENABLED = serializers.BooleanField()
INTERFACE = serializers.DictField()
COUNTRY_CALLING_CODES = serializers.ListField()
class PrivateSettingSerializer(PublicSettingSerializer):

View File

@ -1,15 +1,19 @@
# coding: utf-8
#
import time
from celery import shared_task
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from common.utils import get_logger
from common.utils.timezone import local_now_display
from ops.celery.decorator import after_app_ready_start
from ops.celery.utils import (
create_or_update_celery_periodic_tasks, disable_celery_periodic_task
)
from orgs.models import Organization
from settings.notifications import LDAPImportMessage
from users.models import User
from ..utils import LDAPSyncUtil, LDAPServerUtil, LDAPImportUtil
__all__ = ['sync_ldap_user', 'import_ldap_user_periodic', 'import_ldap_user']
@ -23,6 +27,8 @@ def sync_ldap_user():
@shared_task(verbose_name=_('Periodic import ldap user'))
def import_ldap_user():
start_time = time.time()
time_start_display = local_now_display()
logger.info("Start import ldap user task")
util_server = LDAPServerUtil()
util_import = LDAPImportUtil()
@ -35,11 +41,26 @@ def import_ldap_user():
org_ids = [Organization.DEFAULT_ID]
default_org = Organization.default()
orgs = list(set([Organization.get_instance(org_id, default=default_org) for org_id in org_ids]))
errors = util_import.perform_import(users, orgs)
new_users, errors = util_import.perform_import(users, orgs)
if errors:
logger.error("Imported LDAP users errors: {}".format(errors))
else:
logger.info('Imported {} users successfully'.format(len(users)))
if settings.AUTH_LDAP_SYNC_RECEIVERS:
user_ids = settings.AUTH_LDAP_SYNC_RECEIVERS
recipient_list = User.objects.filter(id__in=list(user_ids))
end_time = time.time()
extra_kwargs = {
'orgs': orgs,
'end_time': end_time,
'start_time': start_time,
'time_start_display': time_start_display,
'new_users': new_users,
'errors': errors,
'cost_time': end_time - start_time,
}
for user in recipient_list:
LDAPImportMessage(user, extra_kwargs).publish()
@shared_task(verbose_name=_('Registration periodic import ldap user task'))

View File

@ -0,0 +1,34 @@
{% load i18n %}
<p>{% trans "Sync task Finish" %}</p>
<b>{% trans 'Time' %}:</b>
<ul>
<li>{% trans 'Date start' %}: {{ start_time }}</li>
<li>{% trans 'Date end' %}: {{ end_time }}</li>
<li>{% trans 'Time cost' %}: {{ cost_time| floatformat:0 }}s</li>
</ul>
<b>{% trans "Synced Organization" %}:</b>
<ul>
{% for org in orgs %}
<li>{{ org }}</li>
{% endfor %}
</ul>
<b>{% trans "Synced User" %}:</b>
<ul>
{% if users %}
{% for user in users %}
<li>{{ user }}</li>
{% endfor %}
{% else %}
<li>{% trans 'No user synchronization required' %}</li>
{% endif %}
</ul>
{% if errors %}
<b>{% trans 'Error' %}:</b>
<ul>
{% for error in errors %}
<li>{{ error }}</li>
{% endfor %}
</ul>
{% endif %}

View File

@ -400,11 +400,14 @@ class LDAPImportUtil(object):
logger.info('Start perform import ldap users, count: {}'.format(len(users)))
errors = []
objs = []
new_users = []
group_users_mapper = defaultdict(set)
for user in users:
groups = user.pop('groups', [])
try:
obj, created = self.update_or_create(user)
if created:
new_users.append(obj)
objs.append(obj)
except Exception as e:
errors.append({user['username']: str(e)})
@ -421,7 +424,7 @@ class LDAPImportUtil(object):
for org in orgs:
self.bind_org(org, objs, group_users_mapper)
logger.info('End perform import ldap users')
return errors
return new_users, errors
def exit_user_group(self, user_groups_mapper):
# 通过对比查询本次导入用户需要移除的用户组

View File

@ -42,7 +42,7 @@ class SmartEndpointViewMixin:
return endpoint
def match_endpoint_by_label(self):
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol)
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol, self.request)
def match_endpoint_by_target_ip(self):
target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数用来方便测试

View File

@ -75,7 +75,20 @@ class Endpoint(JMSBaseModel):
return endpoint
@classmethod
def match_by_instance_label(cls, instance, protocol):
def handle_endpoint_host(cls, endpoint, request=None):
if not endpoint.host and request:
# 动态添加 current request host
host_port = request.get_host()
# IPv6
if host_port.startswith('['):
host = host_port.split(']:')[0].rstrip(']') + ']'
else:
host = host_port.split(':')[0]
endpoint.host = host
return endpoint
@classmethod
def match_by_instance_label(cls, instance, protocol, request=None):
from assets.models import Asset
from terminal.models import Session
if isinstance(instance, Session):
@ -88,6 +101,7 @@ class Endpoint(JMSBaseModel):
endpoints = cls.objects.filter(name__in=list(values)).order_by('-date_updated')
for endpoint in endpoints:
if endpoint.is_valid_for(instance, protocol):
endpoint = cls.handle_endpoint_host(endpoint, request)
return endpoint
@ -130,13 +144,5 @@ class EndpointRule(JMSBaseModel):
endpoint = endpoint_rule.endpoint
else:
endpoint = Endpoint.get_or_create_default(request)
if not endpoint.host and request:
# 动态添加 current request host
host_port = request.get_host()
# IPv6
if host_port.startswith('['):
host = host_port.split(']:')[0].rstrip(']') + ']'
else:
host = host_port.split(':')[0]
endpoint.host = host
endpoint = Endpoint.handle_endpoint_host(endpoint, request)
return endpoint

View File

@ -5,3 +5,4 @@ from .ticket import *
from .comment import *
from .relation import *
from .super_ticket import *
from .perms import *

66
apps/tickets/api/perms.py Normal file
View File

@ -0,0 +1,66 @@
from django.conf import settings
from assets.models import Asset, Node
from assets.serializers.asset.common import MiniAssetSerializer
from assets.serializers.node import NodeSerializer
from common.api import SuggestionMixin
from orgs.mixins.api import OrgReadonlyModelViewSet
from perms.utils import AssetPermissionPermAssetUtil
from perms.utils.permission import AssetPermissionUtil
from tickets.const import TicketApplyAssetScope
__all__ = ['ApplyAssetsViewSet', 'ApplyNodesViewSet']
class ApplyAssetsViewSet(OrgReadonlyModelViewSet, SuggestionMixin):
model = Asset
serializer_class = MiniAssetSerializer
rbac_perms = (
("match", "assets.match_asset"),
)
search_fields = ("name", "address", "comment")
def get_queryset(self):
if TicketApplyAssetScope.is_permed():
queryset = self.get_assets(with_expired=True)
elif TicketApplyAssetScope.is_permed_valid():
queryset = self.get_assets()
else:
queryset = super().get_queryset()
return queryset
def get_assets(self, with_expired=False):
perms = AssetPermissionUtil().get_permissions_for_user(
self.request.user, flat=True, with_expired=with_expired
)
util = AssetPermissionPermAssetUtil(perms)
assets = util.get_all_assets()
return assets
class ApplyNodesViewSet(OrgReadonlyModelViewSet, SuggestionMixin):
model = Node
serializer_class = NodeSerializer
rbac_perms = (
("match", "assets.match_node"),
)
search_fields = ('full_value',)
def get_queryset(self):
if TicketApplyAssetScope.is_permed():
queryset = self.get_nodes(with_expired=True)
elif TicketApplyAssetScope.is_permed_valid():
queryset = self.get_nodes()
else:
queryset = super().get_queryset()
return queryset
def get_nodes(self, with_expired=False):
perms = AssetPermissionUtil().get_permissions_for_user(
self.request.user, flat=True, with_expired=with_expired
)
util = AssetPermissionPermAssetUtil(perms)
nodes = util.get_perm_nodes()
return nodes

View File

@ -4,6 +4,7 @@ from django.utils.translation import gettext_lazy as _
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.exceptions import MethodNotAllowed
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from audits.handler import create_or_update_operate_log
@ -41,7 +42,6 @@ class TicketViewSet(CommonApiMixin, viewsets.ModelViewSet):
ordering = ('-date_created',)
rbac_perms = {
'open': 'tickets.view_ticket',
'bulk': 'tickets.change_ticket',
}
def retrieve(self, request, *args, **kwargs):
@ -122,7 +122,7 @@ class TicketViewSet(CommonApiMixin, viewsets.ModelViewSet):
self._record_operate_log(instance, TicketAction.close)
return Response('ok')
@action(detail=False, methods=[PUT], permission_classes=[RBACPermission, ])
@action(detail=False, methods=[PUT], permission_classes=[IsAuthenticated, ])
def bulk(self, request, *args, **kwargs):
self.ticket_not_allowed()

View File

@ -1,3 +1,4 @@
from django.conf import settings
from django.db.models import TextChoices, IntegerChoices
from django.utils.translation import gettext_lazy as _
@ -56,3 +57,21 @@ class TicketApprovalStrategy(TextChoices):
custom_user = 'custom_user', _("Custom user")
super_admin = 'super_admin', _("Super admin")
super_org_admin = 'super_org_admin', _("Super admin and org admin")
class TicketApplyAssetScope(TextChoices):
all = 'all', _("All assets")
permed = 'permed', _("Permed assets")
permed_valid = 'permed_valid', _('Permed valid assets')
@classmethod
def get_scope(cls):
return settings.TICKET_APPLY_ASSET_SCOPE.lower()
@classmethod
def is_permed(cls):
return cls.get_scope() == cls.permed
@classmethod
def is_permed_valid(cls):
return cls.get_scope() == cls.permed_valid

View File

@ -57,7 +57,7 @@ class TicketStep(JMSBaseModel):
assignees.update(state=state)
self.status = StepStatus.closed
self.state = state
self.save(update_fields=['state', 'status'])
self.save(update_fields=['state', 'status', 'date_updated'])
def set_active(self):
self.status = StepStatus.active

View File

@ -16,6 +16,8 @@ router.register('apply-login-tickets', api.ApplyLoginTicketViewSet, 'apply-login
router.register('apply-command-tickets', api.ApplyCommandTicketViewSet, 'apply-command-ticket')
router.register('apply-login-asset-tickets', api.ApplyLoginAssetTicketViewSet, 'apply-login-asset-ticket')
router.register('ticket-session-relation', api.TicketSessionRelationViewSet, 'ticket-session-relation')
router.register('apply-assets', api.ApplyAssetsViewSet, 'ticket-session-relation')
router.register('apply-nodes', api.ApplyNodesViewSet, 'ticket-session-relation')
urlpatterns = [
path('tickets/<uuid:ticket_id>/session/', api.TicketSessionApi.as_view(), name='ticket-session'),

View File

@ -729,7 +729,7 @@ class JSONFilterMixin:
bindings = RoleBinding.objects.filter(**kwargs, role__in=value)
if match == 'm2m_all':
user_id = bindings.values('user_id').annotate(count=Count('user_id')) \
user_id = bindings.values('user_id').annotate(count=Count('user_id', distinct=True)) \
.filter(count=len(value)).values_list('user_id', flat=True)
else:
user_id = bindings.values_list('user_id', flat=True)

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
#
from django.db.models import Count
from django.db.models import Count, Q
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
@ -46,7 +46,7 @@ class UserGroupSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('labels', 'labels__label') \
.annotate(users_amount=Count('users'))
.annotate(users_amount=Count('users', distinct=True, filter=Q(users__is_service_account=False)))
return queryset

View File

@ -163,9 +163,9 @@ def on_openid_create_or_update_user(sender, request, user, created, name, userna
user.save()
@shared_task(verbose_name=_('Clean audits session task log'))
@shared_task(verbose_name=_('Clean up expired user sessions'))
@register_as_period_task(crontab=CRONTAB_AT_PM_TWO)
def clean_audits_log_period():
def clean_expired_user_session_period():
UserSession.clear_expired_sessions()

View File

@ -86,7 +86,7 @@ def check_user_expired_periodic():
@tmp_to_root_org()
def check_unused_users():
uncommon_users_ttl = settings.SECURITY_UNCOMMON_USERS_TTL
if not uncommon_users_ttl or not uncommon_users_ttl.isdigit():
if not uncommon_users_ttl:
return
uncommon_users_ttl = int(uncommon_users_ttl)

View File

@ -7,6 +7,7 @@
.margin-bottom {
margin-bottom: 15px;
}
.input-style {
width: 100%;
display: inline-block;
@ -22,6 +23,19 @@
height: 100%;
vertical-align: top;
}
.scrollable-menu {
height: auto;
max-height: 18rem;
overflow-x: hidden;
}
.input-group {
.input-group-btn .btn-secondary {
color: #464a4c;
background-color: #eceeef;
}
}
</style>
{% endblock %}
{% block html_title %}{% trans 'Forgot password' %}{% endblock %}
@ -57,9 +71,26 @@
placeholder="{% trans 'Email account' %}" value="{{ email }}">
</div>
<div id="validate-sms" class="validate-field margin-bottom">
<input type="tel" id="sms" name="sms" class="form-control input-style"
placeholder="{% trans 'Mobile number' %}" value="{{ sms }}">
<small style="color: #999; margin-left: 5px">{{ form.sms.help_text }}</small>
<div class="input-group">
<div class="input-group-btn">
<button type="button" class="btn btn-secondary dropdown-toggle" data-toggle="dropdown"
aria-haspopup="true" aria-expanded="false">
<span class="country-code-value">+86</span>
</button>
<ul class="dropdown-menu scrollable-menu">
{% for country in countries %}
<li>
<a href="#" class="dropdown-item d-flex justify-content-between">
<span class="country-name text-left">{{ country.name }}</span>
<span class="country-code">{{ country.value }}</span>
</a>
</li>
{% endfor %}
</ul>
</div>
<input type="tel" id="sms" name="sms" class="form-control input-style"
placeholder="{% trans 'Mobile number' %}" value="{{ sms }}">
</div>
</div>
<div class="margin-bottom challenge-required">
<input type="text" id="code" name="code" class="form-control input-style"
@ -76,7 +107,7 @@
</div>
</div>
<script>
$(function (){
$(function () {
const validateSelectRef = $('#validate-backend-select')
const formType = $('input[name="form_type"]').val()
validateSelectRef.val(formType)
@ -84,19 +115,31 @@
selectChange(formType);
}
})
$(".dropdown-menu li a").click(function (evt) {
const inputGroup = $('.input-group');
const inputGroupAddon = inputGroup.find('.country-code-value');
const selectedCountry = $(evt.target).closest('li');
const selectedCountryCode = selectedCountry.find('.country-code').html();
inputGroupAddon.html(selectedCountryCode)
});
function getQueryString(name) {
const reg = new RegExp("(^|&)"+ name +"=([^&]*)(&|$)");
const reg = new RegExp("(^|&)" + name + "=([^&]*)(&|$)");
const r = window.location.search.substr(1).match(reg);
if(r !== null)
if (r !== null)
return unescape(r[2])
return null
}
function selectChange(name) {
$('.validate-field').hide()
$('#validate-' + name).show()
$('#validate-' + name + '-tip').show()
$('input[name="form_type"]').attr('value', name)
}
function sendChallengeCode(currentBtn) {
let time = 60;
const token = getQueryString('token')
@ -104,7 +147,7 @@
const formType = $('input[name="form_type"]').val()
const email = $('#email').val()
const sms = $('#sms').val()
let sms = $('#sms').val();
const errMsg = "{% trans 'The {} cannot be empty' %}"
if (formType === 'sms') {
@ -118,10 +161,11 @@
return
}
}
sms = $(".input-group .country-code-value").html() + sms
const data = {
form_type: formType, email: email, sms: sms,
}
function onSuccess() {
const originBtnText = currentBtn.innerHTML;
currentBtn.disabled = true

View File

@ -14,22 +14,24 @@
</strong>
</p>
<div>
<img src="{% static 'img/authenticator_android.png' %}" width="128" height="128" alt="">
<img src="{{ authenticator_android_url }}" width="128" height="128" alt="">
<p>{% trans 'Android downloads' %}</p>
</div>
<div>
<img src="{% static 'img/authenticator_iphone.png' %}" width="128" height="128" alt="">
<img src="{{ authenticator_iphone_url }}" width="128" height="128" alt="">
<p>{% trans 'iPhone downloads' %}</p>
</div>
<p style="margin: 20px auto;"><strong style="color: #000000">{% trans 'After installation, click the next step to enter the binding page (if installed, go to the next step directly).' %}</strong></p>
<p style="margin: 20px auto;"><strong
style="color: #000000">{% trans 'After installation, click the next step to enter the binding page (if installed, go to the next step directly).' %}</strong>
</p>
</div>
<a href="{% url 'authentication:user-otp-enable-bind' %}" class="next">{% trans 'Next' %}</a>
<script>
$(function(){
$(function () {
$('.change-color li:eq(1) i').css('color', '{{ INTERFACE.primary_color }}')
})
</script>

View File

@ -1,10 +1,14 @@
# ~*~ coding: utf-8 ~*~
import os
from django.conf import settings
from django.contrib.auth import logout as auth_logout
from django.http.response import HttpResponseRedirect
from django.shortcuts import redirect
from django.templatetags.static import static
from django.urls import reverse
from django.utils.translation import gettext as _
from django.utils._os import safe_join
from django.views.generic.base import TemplateView
from django.views.generic.edit import FormView
@ -45,9 +49,26 @@ class UserOtpEnableStartView(AuthMixin, TemplateView):
class UserOtpEnableInstallAppView(TemplateView):
template_name = 'users/user_otp_enable_install_app.html'
@staticmethod
def replace_authenticator_png(platform):
media_url = settings.MEDIA_URL
base_path = f'img/authenticator_{platform}.png'
authenticator_media_path = safe_join(settings.MEDIA_ROOT, base_path)
if os.path.exists(authenticator_media_path):
authenticator_url = f'{media_url}{base_path}'
else:
authenticator_url = static(base_path)
return authenticator_url
def get_context_data(self, **kwargs):
user = get_user_or_pre_auth_user(self.request)
context = {'user': user}
authenticator_android_url = self.replace_authenticator_png('android')
authenticator_iphone_url = self.replace_authenticator_png('iphone')
context = {
'user': user,
'authenticator_android_url': authenticator_android_url,
'authenticator_iphone_url': authenticator_iphone_url
}
kwargs.update(context)
return super().get_context_data(**kwargs)

View File

@ -13,6 +13,7 @@ from django.views.generic import FormView, RedirectView
from authentication.errors import IntervalTooShort
from authentication.utils import check_user_property_is_correct
from common.const.choices import COUNTRY_CALLING_CODES
from common.utils import FlashMessageUtil, get_object_or_none, random_string
from common.utils.verify_code import SendAndVerifyCodeUtil
from users.notifications import ResetPasswordSuccessMsg
@ -108,7 +109,7 @@ class UserForgotPasswordView(FormView):
for k, v in cleaned_data.items():
if v:
context[k] = v
context['countries'] = COUNTRY_CALLING_CODES
context['form_type'] = 'email'
context['XPACK_ENABLED'] = settings.XPACK_ENABLED
validate_backends = self.get_validate_backends_context(has_phone)

View File

@ -85,7 +85,7 @@ REDIS_PORT: 6379
# SECURITY_WATERMARK_ENABLED: False
# 浏览器关闭页面后,会话过期
# SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE: False
# SESSION_EXPIRE_AT_BROWSER_CLOSE: False
# 每次api请求session续期
# SESSION_SAVE_EVERY_REQUEST: True

Some files were not shown because too many files have changed in this diff Show More