mirror of https://github.com/jumpserver/jumpserver
				
				
				
			
		
			
				
	
	
		
			80 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			80 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
import datetime
 | 
						|
 | 
						|
from channels.generic.websocket import JsonWebsocketConsumer
 | 
						|
from django.utils import timezone
 | 
						|
from rest_framework.renderers import JSONRenderer
 | 
						|
 | 
						|
from common.db.utils import safe_db_connection
 | 
						|
from common.utils import get_logger
 | 
						|
from common.utils.connection import Subscription
 | 
						|
from terminal.const import TaskNameType
 | 
						|
from terminal.models import Session, Terminal
 | 
						|
from terminal.serializers import TaskSerializer, StatSerializer
 | 
						|
from .signal_handlers import component_event_chan
 | 
						|
 | 
						|
logger = get_logger(__name__)
 | 
						|
 | 
						|
 | 
						|
class TerminalTaskWebsocket(JsonWebsocketConsumer):
 | 
						|
    sub: Subscription = None
 | 
						|
    terminal: Terminal = None
 | 
						|
 | 
						|
    def connect(self):
 | 
						|
        user = self.scope["user"]
 | 
						|
        if user.is_authenticated and user.terminal:
 | 
						|
            self.accept()
 | 
						|
            self.terminal = user.terminal
 | 
						|
            self.sub = self.watch_component_event()
 | 
						|
        else:
 | 
						|
            self.close()
 | 
						|
 | 
						|
    def receive_json(self, content, **kwargs):
 | 
						|
        req_type = content.get('type')
 | 
						|
        if req_type == "status":
 | 
						|
            payload = content.get('payload')
 | 
						|
            self.handle_status(payload)
 | 
						|
 | 
						|
    def handle_status(self, content):
 | 
						|
        serializer = StatSerializer(data=content)
 | 
						|
        if not serializer.is_valid():
 | 
						|
            logger.error('Invalid status data: {}'.format(serializer.errors))
 | 
						|
            return
 | 
						|
        serializer.validated_data["terminal"] = self.terminal
 | 
						|
        session_ids = serializer.validated_data.pop('sessions', [])
 | 
						|
        Session.set_sessions_active(session_ids)
 | 
						|
        with safe_db_connection():
 | 
						|
            serializer.save()
 | 
						|
 | 
						|
    def send_tasks_msg(self, task_id=None):
 | 
						|
        content = self.get_terminal_tasks(task_id)
 | 
						|
        self.send(bytes_data=content)
 | 
						|
 | 
						|
    def get_terminal_tasks(self, task_id=None):
 | 
						|
        with safe_db_connection():
 | 
						|
            critical_time = timezone.now() - datetime.timedelta(minutes=10)
 | 
						|
            tasks = self.terminal.task_set.filter(is_finished=False, date_created__gte=critical_time)
 | 
						|
            if task_id:
 | 
						|
                tasks = tasks.filter(id=task_id)
 | 
						|
            serializer = TaskSerializer(tasks, many=True)
 | 
						|
            return JSONRenderer().render(serializer.data)
 | 
						|
 | 
						|
    def watch_component_event(self):
 | 
						|
        # 先发一次已有的任务
 | 
						|
        self.send_tasks_msg()
 | 
						|
 | 
						|
        ws = self
 | 
						|
 | 
						|
        def handle_task_msg_recv(msg):
 | 
						|
            logger.debug('New component task msg recv: {}'.format(msg))
 | 
						|
            msg_type = msg.get('type')
 | 
						|
            payload = msg.get('payload')
 | 
						|
            if msg_type in TaskNameType.names:
 | 
						|
                ws.send_tasks_msg(payload.get('id'))
 | 
						|
 | 
						|
        return component_event_chan.subscribe(handle_task_msg_recv)
 | 
						|
 | 
						|
    def disconnect(self, code):
 | 
						|
        if self.sub is None:
 | 
						|
            return
 | 
						|
        self.sub.unsubscribe()
 |