mirror of https://github.com/jumpserver/jumpserver
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
80 lines
2.7 KiB
80 lines
2.7 KiB
import datetime
|
|
|
|
from channels.generic.websocket import JsonWebsocketConsumer
|
|
from django.utils import timezone
|
|
from rest_framework.renderers import JSONRenderer
|
|
|
|
from common.db.utils import safe_db_connection
|
|
from common.utils import get_logger
|
|
from common.utils.connection import Subscription
|
|
from terminal.models import Terminal
|
|
from terminal.models import Session
|
|
from terminal.serializers import TaskSerializer, StatSerializer
|
|
from .signal_handlers import component_event_chan
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class TerminalTaskWebsocket(JsonWebsocketConsumer):
|
|
sub: Subscription = None
|
|
terminal: Terminal = None
|
|
|
|
def connect(self):
|
|
user = self.scope["user"]
|
|
if user.is_authenticated and user.terminal:
|
|
self.accept()
|
|
self.terminal = user.terminal
|
|
self.sub = self.watch_component_event()
|
|
else:
|
|
self.close()
|
|
|
|
def receive_json(self, content, **kwargs):
|
|
req_type = content.get('type')
|
|
if req_type == "status":
|
|
payload = content.get('payload')
|
|
self.handle_status(payload)
|
|
|
|
def handle_status(self, content):
|
|
serializer = StatSerializer(data=content)
|
|
if not serializer.is_valid():
|
|
logger.error('Invalid status data: {}'.format(serializer.errors))
|
|
return
|
|
serializer.validated_data["terminal"] = self.terminal
|
|
session_ids = serializer.validated_data.pop('sessions', [])
|
|
Session.set_sessions_active(session_ids)
|
|
with safe_db_connection():
|
|
serializer.save()
|
|
|
|
def send_kill_tasks_msg(self, task_id=None):
|
|
content = self.get_terminal_tasks(task_id)
|
|
self.send(bytes_data=content)
|
|
|
|
def get_terminal_tasks(self, task_id=None):
|
|
with safe_db_connection():
|
|
critical_time = timezone.now() - datetime.timedelta(minutes=10)
|
|
tasks = self.terminal.task_set.filter(is_finished=False, date_created__gte=critical_time)
|
|
if task_id:
|
|
tasks = tasks.filter(id=task_id)
|
|
serializer = TaskSerializer(tasks, many=True)
|
|
return JSONRenderer().render(serializer.data)
|
|
|
|
def watch_component_event(self):
|
|
# 先发一次已有的任务
|
|
self.send_kill_tasks_msg()
|
|
|
|
ws = self
|
|
|
|
def handle_task_msg_recv(msg):
|
|
logger.debug('New component task msg recv: {}'.format(msg))
|
|
msg_type = msg.get('type')
|
|
payload = msg.get('payload')
|
|
if msg_type == "kill_session":
|
|
ws.send_kill_tasks_msg(payload.get('id'))
|
|
|
|
return component_event_chan.subscribe(handle_task_msg_recv)
|
|
|
|
def disconnect(self, code):
|
|
if self.sub is None:
|
|
return
|
|
self.sub.unsubscribe()
|