You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
jumpserver/apps/terminal/api/applet/applet.py

188 lines
6.9 KiB

import os
import os.path
import re
import shutil
import zipfile
from typing import Callable
from django.conf import settings
from django.core.files.storage import default_storage
from django.http import HttpResponse
from django.shortcuts import get_object_or_404
from django.utils.translation import gettext as _, get_language
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import ValidationError
from common.api import JMSBulkModelViewSet
from common.serializers import FileSerializer
from common.utils import is_uuid
from common.utils.http import is_true
from common.utils.yml import yaml_load_with_i18n
from terminal import serializers
from terminal.models import AppletPublication, Applet
__all__ = ['AppletViewSet', 'AppletPublicationViewSet']
class DownloadUploadMixin:
get_serializer: Callable
request: Request
get_object: Callable
def extract_and_check_file(self, request):
serializer = self.get_serializer(data=self.request.data)
serializer.is_valid(raise_exception=True)
file = serializer.validated_data['file']
save_to = 'applets/{}'.format(file.name + '.tmp.zip')
if default_storage.exists(save_to):
default_storage.delete(save_to)
rel_path = default_storage.save(save_to, file)
path = default_storage.path(rel_path)
extract_to = default_storage.path('applets/{}.tmp'.format(file.name))
if os.path.exists(extract_to):
shutil.rmtree(extract_to)
try:
with zipfile.ZipFile(path) as zp:
if zp.testzip() is not None:
raise ValidationError({'error': _('Invalid zip file')})
zp.extractall(extract_to)
except RuntimeError as e:
raise ValidationError({'error': _('Invalid zip file') + ': {}'.format(e)})
tmp_dir = os.path.join(extract_to, file.name.replace('.zip', ''))
if not os.path.exists(tmp_dir):
name = file.name
name = re.match(r"(\w+)", name).group()
tmp_dir = os.path.join(extract_to, name)
manifest = Applet.validate_pkg(tmp_dir)
return manifest, tmp_dir
@action(detail=False, methods=['post'], serializer_class=FileSerializer)
def upload(self, request, *args, **kwargs):
manifest, tmp_dir = self.extract_and_check_file(request)
name = manifest['name']
update = request.query_params.get('update')
is_enterprise = manifest.get('edition') == Applet.Edition.enterprise
if is_enterprise and not settings.XPACK_LICENSE_IS_VALID:
raise ValidationError({'error': _('This is enterprise edition applet')})
instance = Applet.objects.filter(name=name).first()
if instance and not update:
return Response({'error': 'Applet already exists: {}'.format(name)}, status=400)
applet, serializer = Applet.install_from_dir(tmp_dir, builtin=False)
return Response(serializer.data, status=201)
@action(detail=True, methods=['get'])
def download(self, request, *args, **kwargs):
instance = self.get_object()
if instance.builtin:
path = os.path.join(settings.APPS_DIR, 'terminal', 'applets', instance.name)
else:
path = default_storage.path('applets/{}'.format(instance.name))
if not os.path.exists(path):
raise ValidationError({'error': _('Applet not found in path: {}').format(path)})
zip_path = shutil.make_archive(path, 'zip', path)
with open(zip_path, 'rb') as f:
response = HttpResponse(f.read(), status=200, content_type='application/octet-stream')
response['Content-Disposition'] = 'attachment; filename*=UTF-8\'\'{}.zip'.format(instance.name)
os.unlink(zip_path)
return response
class AppletViewSet(DownloadUploadMixin, JMSBulkModelViewSet):
queryset = Applet.objects.all()
serializer_class = serializers.AppletSerializer
filterset_fields = ['name', 'version', 'builtin', 'is_active']
search_fields = ['name', 'display_name', 'author']
rbac_perms = {
'upload': 'terminal.add_applet',
'download': 'terminal.view_applet',
}
def get_object(self):
pk = self.kwargs.get('pk')
if not is_uuid(pk):
obj = get_object_or_404(Applet, name=pk)
else:
obj = get_object_or_404(Applet, pk=pk)
return self.trans_object(obj)
def get_queryset(self):
queryset = super().get_queryset()
queryset = self.trans_queryset(queryset)
return queryset
@staticmethod
def read_manifest_with_i18n(obj, lang='zh'):
path = os.path.join(obj.path, 'manifest.yml')
if os.path.exists(path):
with open(path, encoding='utf8') as f:
manifest = yaml_load_with_i18n(f, lang)
else:
manifest = {}
return manifest
def trans_queryset(self, queryset):
for obj in queryset:
self.trans_object(obj)
return queryset
@staticmethod
def readme(obj, lang=''):
lang = lang[:2]
readme_file = os.path.join(obj.path, f'README_{lang.upper()}.md')
if os.path.isfile(readme_file):
with open(readme_file, 'r') as f:
return f.read()
return ''
def trans_object(self, obj):
lang = get_language()
manifest = self.read_manifest_with_i18n(obj, lang)
obj.display_name = manifest.get('display_name', obj.display_name)
obj.comment = manifest.get('comment', obj.comment)
obj.readme = self.readme(obj, lang)
return obj
def is_record_found(self, obj, search):
combine_fields = ' '.join([getattr(obj, f, '') for f in self.search_fields])
return search in combine_fields
def filter_queryset(self, queryset):
search = self.request.query_params.get('search')
if search:
queryset = [i for i in queryset if self.is_record_found(i, search)]
for field in self.filterset_fields:
field_value = self.request.query_params.get(field)
if not field_value:
continue
if field in ['is_active', 'builtin']:
field_value = is_true(field_value)
queryset = [i for i in queryset if getattr(i, field, '') == field_value]
return queryset
def perform_destroy(self, instance):
if not instance.name:
raise ValidationError('Applet is not null')
path = default_storage.path('applets/{}'.format(instance.name))
if os.path.exists(path):
shutil.rmtree(path)
instance.delete()
class AppletPublicationViewSet(viewsets.ModelViewSet):
queryset = AppletPublication.objects.all()
serializer_class = serializers.AppletPublicationSerializer
filterset_fields = ['host', 'applet', 'status']
search_fields = ['applet__name', 'applet__display_name', 'host__name']