U 优化websocket连接

pull/517/head
vapao 2022-07-04 11:23:59 +08:00
parent a5a5970001
commit 3e2357ae50
8 changed files with 89 additions and 129 deletions

View File

@ -1,43 +1,18 @@
# Copyright: (c) OpenSpug Organization. https://github.com/openspug/spug
# Copyright: (c) <spug.dev@gmail.com>
# Released under the AGPL-3.0 License.
from channels.generic.websocket import WebsocketConsumer
from django.conf import settings
from django_redis import get_redis_connection
from asgiref.sync import async_to_sync
from apps.host.models import Host
from consumer.utils import BaseConsumer
from apps.account.utils import has_host_perm
from threading import Thread
import time
import json
class ExecConsumer(WebsocketConsumer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.token = self.scope['url_route']['kwargs']['token']
self.rds = get_redis_connection()
def connect(self):
self.accept()
def disconnect(self, code):
self.rds.close()
def get_response(self):
response = self.rds.brpop(self.token, timeout=5)
return response[1] if response else None
def receive(self, **kwargs):
response = self.get_response()
while response:
data = response.decode()
self.send(text_data=data)
response = self.get_response()
self.send(text_data='pong')
class ComConsumer(WebsocketConsumer):
class ComConsumer(BaseConsumer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
token = self.scope['url_route']['kwargs']['token']
@ -52,9 +27,6 @@ class ComConsumer(WebsocketConsumer):
raise TypeError(f'unknown module for {module}')
self.rds = get_redis_connection()
def connect(self):
self.accept()
def disconnect(self, code):
self.rds.close()
@ -78,18 +50,17 @@ class ComConsumer(WebsocketConsumer):
self.send(text_data='pong')
class SSHConsumer(WebsocketConsumer):
class SSHConsumer(BaseConsumer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.user = self.scope['user']
self.id = self.scope['url_route']['kwargs']['id']
self.chan = None
self.ssh = None
def loop_read(self):
is_ready = False
while True:
data = self.chan.recv(32 * 1024)
# print('read: {!r}'.format(data))
if not data:
self.close(3333)
break
@ -97,6 +68,9 @@ class SSHConsumer(WebsocketConsumer):
text = data.decode()
except UnicodeDecodeError:
text = data.decode(encoding='GBK', errors='ignore')
if not is_ready:
self.send(text_data='\033[2J\033[3J\033[1;1H')
is_ready = True
self.send(text_data=text)
def receive(self, text_data=None, bytes_data=None):
@ -116,34 +90,28 @@ class SSHConsumer(WebsocketConsumer):
if self.ssh:
self.ssh.close()
def connect(self):
def init(self):
if has_host_perm(self.user, self.id):
self.accept()
self._init()
self.send(text_data='\r\n正在连接至主机 ...')
host = Host.objects.filter(pk=self.id).first()
if not host:
return self.close_with_message('未找到指定主机,请刷新页面重试。')
try:
self.ssh = host.get_ssh().get_client()
except Exception as e:
return self.close_with_message(f'连接主机失败: {e}')
self.chan = self.ssh.invoke_shell(term='xterm')
self.chan.transport.set_keepalive(30)
Thread(target=self.loop_read).start()
else:
self.close()
def _init(self):
self.send(text_data='\r\33[KConnecting ...\r')
host = Host.objects.filter(pk=self.id).first()
if not host:
self.send(text_data='Unknown host\r\n')
self.close()
try:
self.ssh = host.get_ssh().get_client()
except Exception as e:
self.send(text_data=f'Exception: {e}\r\n'.encode())
self.close()
return
self.chan = self.ssh.invoke_shell(term='xterm')
self.chan.transport.set_keepalive(30)
Thread(target=self.loop_read).start()
self.close_with_message('你当前无权限操作该主机,请联系管理员授权。')
class NotifyConsumer(WebsocketConsumer):
def connect(self):
class NotifyConsumer(BaseConsumer):
def init(self):
async_to_sync(self.channel_layer.group_add)('notify', self.channel_name)
self.accept()
def disconnect(self, code):
async_to_sync(self.channel_layer.group_discard)('notify', self.channel_name)
@ -155,7 +123,7 @@ class NotifyConsumer(WebsocketConsumer):
self.send(text_data=json.dumps(event))
class PubSubConsumer(WebsocketConsumer):
class PubSubConsumer(BaseConsumer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.token = self.scope['url_route']['kwargs']['token']
@ -163,9 +131,6 @@ class PubSubConsumer(WebsocketConsumer):
self.p = self.rds.pubsub(ignore_subscribe_messages=True)
self.p.subscribe(self.token)
def connect(self):
self.accept()
def disconnect(self, code):
self.p.close()
self.rds.close()

View File

@ -1,50 +0,0 @@
# Copyright: (c) OpenSpug Organization. https://github.com/openspug/spug
# Copyright: (c) <spug.dev@gmail.com>
# Released under the AGPL-3.0 License.
from django.db import close_old_connections
from channels.security.websocket import WebsocketDenier
from apps.account.models import User
from apps.setting.utils import AppSetting
from libs.utils import get_request_real_ip
from urllib.parse import parse_qs
import time
class AuthMiddleware:
def __init__(self, application):
self.application = application
def __call__(self, scope):
# Make sure the scope is of type websocket
if scope["type"] != "websocket":
raise ValueError(
"You cannot use AuthMiddleware on a non-WebSocket connection"
)
headers = dict(scope.get('headers', []))
is_ok, message = self.verify_user(scope, headers)
if is_ok:
return self.application(scope)
else:
print(message)
return WebsocketDenier(scope)
def get_real_ip(self, headers):
decode_headers = {
'x-forwarded-for': headers.get(b'x-forwarded-for', b'').decode(),
'x-real-ip': headers.get(b'x-real-ip', b'').decode()
}
return get_request_real_ip(decode_headers)
def verify_user(self, scope, headers):
close_old_connections()
query_string = scope['query_string'].decode()
x_real_ip = self.get_real_ip(headers)
token = parse_qs(query_string).get('x-token', [''])[0]
if token and len(token) == 32:
user = User.objects.filter(access_token=token).first()
if user and user.token_expired >= time.time() and user.is_active:
if x_real_ip == user.last_ip or AppSetting.get_default('bind_ip') is False:
scope['user'] = user
return True, None
return False, f'Verify failed: {x_real_ip} <> {user.last_ip if user else None}'
return False, 'Token is invalid'

View File

@ -3,15 +3,11 @@
# Released under the AGPL-3.0 License.
from django.urls import path
from channels.routing import URLRouter
from consumer.middleware import AuthMiddleware
from consumer.consumers import *
ws_router = AuthMiddleware(
URLRouter([
path('ws/exec/<str:token>/', ExecConsumer),
path('ws/ssh/<int:id>/', SSHConsumer),
path('ws/subscribe/<str:token>/', PubSubConsumer),
path('ws/<str:module>/<str:token>/', ComConsumer),
path('ws/notify/', NotifyConsumer),
])
)
ws_router = URLRouter([
path('ws/ssh/<int:id>/', SSHConsumer),
path('ws/subscribe/<str:token>/', PubSubConsumer),
path('ws/<str:module>/<str:token>/', ComConsumer),
path('ws/notify/', NotifyConsumer),
])

View File

@ -0,0 +1,42 @@
# Copyright: (c) OpenSpug Organization. https://github.com/openspug/spug
# Copyright: (c) <spug.dev@gmail.com>
# Released under the AGPL-3.0 License.
from django.db import close_old_connections
from channels.generic.websocket import WebsocketConsumer
from apps.account.models import User
from apps.setting.utils import AppSetting
from libs.utils import get_request_real_ip
from urllib.parse import parse_qs
import time
def get_real_ip(headers):
decode_headers = {k.decode(): v.decode() for k, v in headers}
return get_request_real_ip(decode_headers)
class BaseConsumer(WebsocketConsumer):
def __init__(self, *args, **kwargs):
super(BaseConsumer, self).__init__(*args, **kwargs)
self.user = None
def close_with_message(self, content):
self.send(text_data=f'\r\n\x1b[31m{content}\x1b[0m\r\n')
self.close()
def connect(self):
self.accept()
close_old_connections()
query_string = self.scope['query_string'].decode()
x_real_ip = get_real_ip(self.scope['headers'])
token = parse_qs(query_string).get('x-token', [''])[0]
if token and len(token) == 32:
user = User.objects.filter(access_token=token).first()
if user and user.token_expired >= time.time() and user.is_active:
if x_real_ip == user.last_ip or AppSetting.get_default('bind_ip') is False:
self.user = user
if hasattr(self, 'init'):
self.init()
return None
self.close_with_message('触发登录IP绑定安全策略请在系统设置/安全设置中查看配置。')
self.close_with_message('用户身份验证失败,请重新登录或刷新页面。')

View File

@ -70,11 +70,15 @@ export default function () {
ws.onmessage = e => {
if (e.data !== 'pong') {
fetch();
const {title, content} = JSON.parse(e.data);
const key = `open${Date.now()}`;
const description = <div style={{whiteSpace: 'pre-wrap'}}>{content}</div>;
const btn = <Button type="primary" size="small" onClick={() => notification.close(key)}>知道了</Button>;
notification.warning({message: title, description, btn, key, top: 64, duration: null})
try {
const {title, content} = JSON.parse(e.data);
const key = `open${Date.now()}`;
const description = <div style={{whiteSpace: 'pre-wrap'}}>{content}</div>;
const btn = <Button type="primary" size="small" onClick={() => notification.close(key)}>知道了</Button>;
notification.warning({message: title, description, btn, key, top: 64, duration: null})
} catch (e) {
}
}
}
}

View File

@ -16,7 +16,7 @@ import {
import { FitAddon } from 'xterm-addon-fit';
import { Terminal } from 'xterm';
import style from './index.module.less';
import { X_TOKEN } from 'libs';
import { http, X_TOKEN } from 'libs';
import store from './store';
import gStore from 'gStore';
@ -55,7 +55,7 @@ function OutView(props) {
useEffect(() => {
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const socket = new WebSocket(`${protocol}//${window.location.host}/api/ws/exec/${store.token}/?x-token=${X_TOKEN}`);
const socket = new WebSocket(`${protocol}//${window.location.host}/api/ws/subscribe/${store.token}/?x-token=${X_TOKEN}`);
socket.onopen = () => {
const message = '\r\x1b[K\x1b[36m### Waiting for scheduling ...\x1b[0m'
for (let key of Object.keys(store.outputs)) {
@ -64,6 +64,7 @@ function OutView(props) {
term.write(message)
socket.send('ok');
fitPlugin.fit()
http.patch('/api/exec/do/', {token: store.token})
}
socket.onmessage = e => {
if (e.data === 'pong') {

View File

@ -15,6 +15,7 @@ import Output from './Output';
import { http, cleanCommand } from 'libs';
import moment from 'moment';
import store from './store';
import gStore from 'gStore';
import style from './index.module.less';
function TaskIndex() {
@ -28,7 +29,7 @@ function TaskIndex() {
useEffect(() => {
if (!loading) {
http.get('/api/exec/history/')
http.get('/api/exec/do/')
.then(res => setHistories(res))
}
}, [loading])
@ -40,6 +41,7 @@ function TaskIndex() {
}, [command])
useEffect(() => {
gStore.fetchUserSettings()
return () => {
store.host_ids = []
if (store.showConsole) {

View File

@ -174,7 +174,7 @@ class FileManager extends React.Component {
_updatePercent = token => {
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
this.socket = new WebSocket(`${protocol}//${window.location.host}/api/ws/exec/${token}/?x-token=${X_TOKEN}`);
this.socket = new WebSocket(`${protocol}//${window.location.host}/api/ws/subscribe/${token}/?x-token=${X_TOKEN}`);
this.socket.onopen = () => this.socket.send('ok');
this.socket.onmessage = e => {
if (e.data === 'pong') {