mirror of https://github.com/jumpserver/jumpserver
119 lines
4.0 KiB
Python
119 lines
4.0 KiB
Python
import datetime
|
|
from threading import Thread
|
|
|
|
from channels.generic.websocket import JsonWebsocketConsumer
|
|
from django.utils import timezone
|
|
from rest_framework.renderers import JSONRenderer
|
|
|
|
from common.db.utils import safe_db_connection
|
|
from common.utils import get_logger
|
|
from common.utils.connection import Subscription
|
|
from terminal.const import TaskNameType
|
|
from terminal.models import Session, Terminal
|
|
from terminal.serializers import TaskSerializer, StatSerializer
|
|
from .mixin import LokiMixin
|
|
from .signal_handlers import component_event_chan
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class TerminalTaskWebsocket(JsonWebsocketConsumer):
|
|
sub: Subscription = None
|
|
terminal: Terminal = None
|
|
|
|
def connect(self):
|
|
user = self.scope["user"]
|
|
if user.is_authenticated and user.terminal:
|
|
self.accept()
|
|
self.terminal = user.terminal
|
|
self.sub = self.watch_component_event()
|
|
else:
|
|
self.close()
|
|
|
|
def receive_json(self, content, **kwargs):
|
|
req_type = content.get('type')
|
|
if req_type == "status":
|
|
payload = content.get('payload')
|
|
self.handle_status(payload)
|
|
|
|
def handle_status(self, content):
|
|
serializer = StatSerializer(data=content)
|
|
if not serializer.is_valid():
|
|
logger.error('Invalid status data: {}'.format(serializer.errors))
|
|
return
|
|
serializer.validated_data["terminal"] = self.terminal
|
|
session_ids = serializer.validated_data.pop('sessions', [])
|
|
Session.set_sessions_active(session_ids)
|
|
with safe_db_connection():
|
|
serializer.save()
|
|
|
|
def send_tasks_msg(self, task_id=None):
|
|
content = self.get_terminal_tasks(task_id)
|
|
self.send(bytes_data=content)
|
|
|
|
def get_terminal_tasks(self, task_id=None):
|
|
with safe_db_connection():
|
|
critical_time = timezone.now() - datetime.timedelta(minutes=10)
|
|
tasks = self.terminal.task_set.filter(is_finished=False, date_created__gte=critical_time)
|
|
if task_id:
|
|
tasks = tasks.filter(id=task_id)
|
|
serializer = TaskSerializer(tasks, many=True)
|
|
return JSONRenderer().render(serializer.data)
|
|
|
|
def watch_component_event(self):
|
|
# 先发一次已有的任务
|
|
self.send_tasks_msg()
|
|
|
|
ws = self
|
|
|
|
def handle_task_msg_recv(msg):
|
|
logger.debug('New component task msg recv: {}'.format(msg))
|
|
msg_type = msg.get('type')
|
|
payload = msg.get('payload')
|
|
if msg_type in TaskNameType.names:
|
|
ws.send_tasks_msg(payload.get('id'))
|
|
|
|
return component_event_chan.subscribe(handle_task_msg_recv)
|
|
|
|
def disconnect(self, code):
|
|
if self.sub is None:
|
|
return
|
|
self.sub.unsubscribe()
|
|
|
|
|
|
class LokiTailWebsocket(JsonWebsocketConsumer, LokiMixin):
|
|
loki_tail_ws = None
|
|
|
|
def connect(self):
|
|
user = self.scope["user"]
|
|
if user.is_authenticated and user.is_superuser:
|
|
self.accept()
|
|
logger.info('Loki tail websocket connected')
|
|
else:
|
|
self.close()
|
|
|
|
def receive_json(self, content, **kwargs):
|
|
if not content:
|
|
return
|
|
components = content.get('components')
|
|
search = content.get('search', '')
|
|
query = self.create_loki_query(components, search)
|
|
self.handle_query(query)
|
|
|
|
def send_tail_msg(self, tail_ws):
|
|
for message in tail_ws.messages():
|
|
self.send(text_data=message)
|
|
logger.info('Loki tail thread finished')
|
|
|
|
def handle_query(self, query):
|
|
loki_client = self.get_loki_client()
|
|
self.loki_tail_ws = loki_client.create_tail_ws(query)
|
|
threader = Thread(target=self.send_tail_msg, args=(self.loki_tail_ws,))
|
|
threader.start()
|
|
logger.debug('Start loki tail thread')
|
|
|
|
def disconnect(self, close_code):
|
|
if self.loki_tail_ws:
|
|
self.loki_tail_ws.close()
|
|
logger.info('Loki tail websocket client closed')
|