diff --git a/apps/terminal/api/component/connect_methods.py b/apps/terminal/api/component/connect_methods.py index d928c7329..4891ec5bd 100644 --- a/apps/terminal/api/component/connect_methods.py +++ b/apps/terminal/api/component/connect_methods.py @@ -17,28 +17,15 @@ class ConnectMethodListApi(generics.ListAPIView): serializer_class = serializers.ConnectMethodSerializer permission_classes = [IsValidUser] - def filter_user_connect_methods(self, d): - from acls.models import ConnectMethodACL - # 这里要根据用户来了,受 acl 影响 - acls = ConnectMethodACL.get_user_acls(self.request.user) - disabled_connect_methods = acls.values_list('connect_methods', flat=True) - disabled_connect_methods = set(itertools.chain.from_iterable(disabled_connect_methods)) - new_queryset = {} - for protocol, methods in d.items(): - new_queryset[protocol] = [x for x in methods if x['value'] not in disabled_connect_methods] - return new_queryset - def get_queryset(self): os = self.request.query_params.get('os') or get_request_os(self.request) - queryset = ConnectMethodUtil.get_filtered_protocols_connect_methods(os) flat = self.request.query_params.get('flat') - - # 先这么处理, 这里不用过滤包含的事所有 if is_true(flat): + queryset = ConnectMethodUtil.get_filtered_protocols_connect_methods(os) queryset = itertools.chain.from_iterable(queryset.values()) queryset = distinct(queryset, key=lambda x: x['value']) else: - queryset = self.filter_queryset(queryset) + queryset = ConnectMethodUtil.get_user_allowed_connect_methods(os, self.request.user) return queryset def list(self, request, *args, **kwargs): diff --git a/apps/terminal/connect_methods.py b/apps/terminal/connect_methods.py index aefdc7675..0d1571489 100644 --- a/apps/terminal/connect_methods.py +++ b/apps/terminal/connect_methods.py @@ -227,6 +227,19 @@ class ConnectMethodUtil: methods = cls._filter_disable_protocols_connect_methods(methods) return methods + @classmethod + def get_user_allowed_connect_methods(cls, os, user): + from acls.models import ConnectMethodACL + methods = cls.get_filtered_protocols_connect_methods(os) + acls = ConnectMethodACL.get_user_acls(user) + disabled_connect_methods = acls.values_list('connect_methods', flat=True) + disabled_connect_methods = set(itertools.chain.from_iterable(disabled_connect_methods)) + + new_queryset = {} + for protocol, methods in methods.items(): + new_queryset[protocol] = [x for x in methods if x['value'] not in disabled_connect_methods] + return new_queryset + @classmethod def _filter_disable_components_connect_methods(cls, methods): component_setting = {