mirror of https://github.com/jumpserver/jumpserver
perf: 优化订阅处理,形成框架
parent
de006bc664
commit
c85249be36
|
@ -10,7 +10,6 @@ from django.dispatch import receiver
|
||||||
from django.utils.functional import LazyObject
|
from django.utils.functional import LazyObject
|
||||||
|
|
||||||
from common.signals import django_ready
|
from common.signals import django_ready
|
||||||
from common.db.utils import close_old_connections
|
|
||||||
from common.utils.connection import RedisPubSub
|
from common.utils.connection import RedisPubSub
|
||||||
from common.utils import get_logger
|
from common.utils import get_logger
|
||||||
from assets.models import Asset, Node
|
from assets.models import Asset, Node
|
||||||
|
@ -78,31 +77,17 @@ def on_node_asset_change(sender, instance, **kwargs):
|
||||||
def subscribe_node_assets_mapping_expire(sender, **kwargs):
|
def subscribe_node_assets_mapping_expire(sender, **kwargs):
|
||||||
logger.debug("Start subscribe for expire node assets id mapping from memory")
|
logger.debug("Start subscribe for expire node assets id mapping from memory")
|
||||||
|
|
||||||
|
def handle_node_relation_change(org_id):
|
||||||
|
root_org_id = Organization.ROOT_ID
|
||||||
|
Node.expire_node_all_asset_ids_mapping_from_memory(org_id)
|
||||||
|
Node.expire_node_all_asset_ids_mapping_from_memory(root_org_id)
|
||||||
|
logger.debug(
|
||||||
|
"Expire node assets id mapping from memory of org={}, pid={}"
|
||||||
|
"".format(str(org_id), os.getpid())
|
||||||
|
)
|
||||||
|
|
||||||
def keep_subscribe_node_assets_relation():
|
def keep_subscribe_node_assets_relation():
|
||||||
while True:
|
node_assets_mapping_for_memory_pub_sub.keep_handle_msg(handle_node_relation_change)
|
||||||
try:
|
|
||||||
subscribe = node_assets_mapping_for_memory_pub_sub.subscribe()
|
|
||||||
msgs = subscribe.listen()
|
|
||||||
# 开始之前关闭连接,因为server端可能关闭了连接,而 client 还在 CONN_MAX_AGE 中
|
|
||||||
for message in msgs:
|
|
||||||
if message["type"] != "message":
|
|
||||||
continue
|
|
||||||
close_old_connections()
|
|
||||||
org_id = message['data'].decode()
|
|
||||||
root_org_id = Organization.ROOT_ID
|
|
||||||
Node.expire_node_all_asset_ids_mapping_from_memory(org_id)
|
|
||||||
Node.expire_node_all_asset_ids_mapping_from_memory(root_org_id)
|
|
||||||
logger.debug(
|
|
||||||
"Expire node assets id mapping from memory of org={}, pid={}"
|
|
||||||
"".format(str(org_id), os.getpid())
|
|
||||||
)
|
|
||||||
close_old_connections()
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f'subscribe_node_assets_mapping_expire: {e}')
|
|
||||||
Node.expire_all_orgs_node_all_asset_ids_mapping_from_memory()
|
|
||||||
finally:
|
|
||||||
# 请求结束,关闭连接
|
|
||||||
close_old_connections()
|
|
||||||
|
|
||||||
t = threading.Thread(target=keep_subscribe_node_assets_relation)
|
t = threading.Thread(target=keep_subscribe_node_assets_relation)
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
from common.utils import get_logger
|
from contextlib import contextmanager
|
||||||
|
|
||||||
from django.db import connections
|
from django.db import connections
|
||||||
|
|
||||||
|
from common.utils import get_logger
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,3 +47,10 @@ def get_objects(model, pks):
|
||||||
def close_old_connections():
|
def close_old_connections():
|
||||||
for conn in connections.all():
|
for conn in connections.all():
|
||||||
conn.close_if_unusable_or_obsolete()
|
conn.close_if_unusable_or_obsolete()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def safe_db_connection():
|
||||||
|
close_old_connections()
|
||||||
|
yield
|
||||||
|
close_old_connections()
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
|
import json
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
|
||||||
|
from common.db.utils import safe_db_connection
|
||||||
|
from common.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_redis_client(db):
|
def get_redis_client(db):
|
||||||
rc = redis.StrictRedis(
|
rc = redis.StrictRedis(
|
||||||
|
@ -23,5 +30,38 @@ class RedisPubSub:
|
||||||
return ps
|
return ps
|
||||||
|
|
||||||
def publish(self, data):
|
def publish(self, data):
|
||||||
self.redis.publish(self.ch, data)
|
data_json = json.dumps(data)
|
||||||
|
self.redis.publish(self.ch, data_json)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def keep_handle_msg(self, handle):
|
||||||
|
"""
|
||||||
|
handle arg is the pub published
|
||||||
|
|
||||||
|
:param handle: lambda item: do_something
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
sub = self.subscribe()
|
||||||
|
msgs = sub.listen()
|
||||||
|
|
||||||
|
try:
|
||||||
|
for msg in msgs:
|
||||||
|
if msg["type"] != "message":
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
item_json = msg['data'].decode()
|
||||||
|
item = json.loads(item_json)
|
||||||
|
|
||||||
|
with safe_db_connection():
|
||||||
|
handle(item)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('Subscribe handler handle msg error: ', e)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('Consume msg error: ', e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
sub.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Redis observer close error: ", e)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils.translation import gettext_lazy as _
|
|
||||||
|
|
||||||
from common.db.models import JMSModel
|
from common.db.models import JMSModel
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,6 @@ def on_site_message_create(sender, instance, created, **kwargs):
|
||||||
'message': instance.message,
|
'message': instance.message,
|
||||||
'users': user_ids
|
'users': user_ids
|
||||||
}
|
}
|
||||||
data = json.dumps(data)
|
|
||||||
new_site_msg_chan.publish(data)
|
new_site_msg_chan.publish(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
import threading
|
import threading
|
||||||
import json
|
import json
|
||||||
from redis.exceptions import ConnectionError
|
|
||||||
from channels.generic.websocket import JsonWebsocketConsumer
|
from channels.generic.websocket import JsonWebsocketConsumer
|
||||||
|
|
||||||
from common.db.utils import close_old_connections
|
|
||||||
from common.utils import get_logger
|
from common.utils import get_logger
|
||||||
|
from common.db.utils import safe_db_connection
|
||||||
from .site_msg import SiteMessageUtil
|
from .site_msg import SiteMessageUtil
|
||||||
from .signals_handler import new_site_msg_chan
|
from .signals_handler import new_site_msg_chan
|
||||||
|
|
||||||
|
@ -13,15 +12,13 @@ logger = get_logger(__name__)
|
||||||
|
|
||||||
class SiteMsgWebsocket(JsonWebsocketConsumer):
|
class SiteMsgWebsocket(JsonWebsocketConsumer):
|
||||||
refresh_every_seconds = 10
|
refresh_every_seconds = 10
|
||||||
chan = None
|
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
user = self.scope["user"]
|
user = self.scope["user"]
|
||||||
if user.is_authenticated:
|
if user.is_authenticated:
|
||||||
self.accept()
|
self.accept()
|
||||||
self.chan = new_site_msg_chan.subscribe()
|
|
||||||
|
|
||||||
thread = threading.Thread(target=self.unread_site_msg_count)
|
thread = threading.Thread(target=self.watch_recv_new_site_msg)
|
||||||
thread.start()
|
thread.start()
|
||||||
else:
|
else:
|
||||||
self.close()
|
self.close()
|
||||||
|
@ -45,45 +42,18 @@ class SiteMsgWebsocket(JsonWebsocketConsumer):
|
||||||
logger.debug('Send unread count to user: {} {}'.format(user_id, unread_count))
|
logger.debug('Send unread count to user: {} {}'.format(user_id, unread_count))
|
||||||
self.send_json({'type': 'unread_count', 'unread_count': unread_count})
|
self.send_json({'type': 'unread_count', 'unread_count': unread_count})
|
||||||
|
|
||||||
def unread_site_msg_count(self):
|
def watch_recv_new_site_msg(self):
|
||||||
|
ws = self
|
||||||
user_id = str(self.scope["user"].id)
|
user_id = str(self.scope["user"].id)
|
||||||
self.send_unread_msg_count()
|
|
||||||
|
|
||||||
try:
|
|
||||||
msgs = self.chan.listen()
|
|
||||||
# 开始之前关闭连接,因为server端可能关闭了连接,而 client 还在 CONN_MAX_AGE 中
|
|
||||||
for message in msgs:
|
|
||||||
if message['type'] != 'message':
|
|
||||||
continue
|
|
||||||
close_old_connections()
|
|
||||||
try:
|
|
||||||
msg = json.loads(message['data'].decode())
|
|
||||||
except json.JSONDecoder as e:
|
|
||||||
logger.debug('Decode json error: ', e)
|
|
||||||
continue
|
|
||||||
if not msg:
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.debug('New site msg recv, may be mine: {}'.format(msg))
|
|
||||||
users = msg.get('users', [])
|
|
||||||
logger.debug('Message users: {}'.format(users))
|
|
||||||
if user_id in users:
|
|
||||||
self.send_unread_msg_count()
|
|
||||||
close_old_connections()
|
|
||||||
except ConnectionError:
|
|
||||||
logger.error('Redis chan closed')
|
|
||||||
finally:
|
|
||||||
logger.info('Notification ws thread end')
|
|
||||||
close_old_connections()
|
|
||||||
|
|
||||||
def disconnect(self, close_code):
|
|
||||||
try:
|
|
||||||
if self.chan is not None:
|
|
||||||
self.chan.close()
|
|
||||||
self.close()
|
|
||||||
finally:
|
|
||||||
close_old_connections()
|
|
||||||
logger.info('Notification websocket disconnect')
|
|
||||||
|
|
||||||
|
# 先发一个消息再说
|
||||||
|
with safe_db_connection():
|
||||||
|
self.send_unread_msg_count()
|
||||||
|
|
||||||
|
def handle_new_site_msg_recv(msg):
|
||||||
|
users = msg.get('users', [])
|
||||||
|
logger.debug('New site msg recv, message users: {}'.format(users))
|
||||||
|
if user_id in users:
|
||||||
|
ws.send_unread_msg_count()
|
||||||
|
|
||||||
|
new_site_msg_chan.keep_handle_msg(handle_new_site_msg_recv)
|
||||||
|
|
|
@ -6,9 +6,8 @@ from functools import partial
|
||||||
|
|
||||||
from django.dispatch import receiver
|
from django.dispatch import receiver
|
||||||
from django.utils.functional import LazyObject
|
from django.utils.functional import LazyObject
|
||||||
from common.db.utils import close_old_connections
|
|
||||||
from django.db.models.signals import m2m_changed
|
from django.db.models.signals import m2m_changed
|
||||||
from django.db.models.signals import post_save, post_delete, pre_delete
|
from django.db.models.signals import post_save, pre_delete
|
||||||
|
|
||||||
from orgs.utils import tmp_to_org
|
from orgs.utils import tmp_to_org
|
||||||
from orgs.models import Organization, OrganizationMember
|
from orgs.models import Organization, OrganizationMember
|
||||||
|
@ -47,25 +46,9 @@ def subscribe_orgs_mapping_expire(sender, **kwargs):
|
||||||
logger.debug("Start subscribe for expire orgs mapping from memory")
|
logger.debug("Start subscribe for expire orgs mapping from memory")
|
||||||
|
|
||||||
def keep_subscribe_org_mapping():
|
def keep_subscribe_org_mapping():
|
||||||
while True:
|
orgs_mapping_for_memory_pub_sub.keep_handle_msg(
|
||||||
try:
|
lambda org_id: Organization.expire_orgs_mapping()
|
||||||
subscribe = orgs_mapping_for_memory_pub_sub.subscribe()
|
)
|
||||||
msgs = subscribe.listen()
|
|
||||||
# 开始之前关闭连接,因为server端可能关闭了连接,而 client 还在 CONN_MAX_AGE 中
|
|
||||||
close_old_connections()
|
|
||||||
for message in msgs:
|
|
||||||
if message['type'] != 'message':
|
|
||||||
continue
|
|
||||||
if message['data'] == b'error':
|
|
||||||
raise ValueError
|
|
||||||
Organization.expire_orgs_mapping()
|
|
||||||
logger.debug('Expire orgs mapping: ' + str(message['data']))
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f'subscribe_orgs_mapping_expire: {e}')
|
|
||||||
Organization.expire_orgs_mapping()
|
|
||||||
finally:
|
|
||||||
# 结束收关闭连接
|
|
||||||
close_old_connections()
|
|
||||||
|
|
||||||
t = threading.Thread(target=keep_subscribe_org_mapping)
|
t = threading.Thread(target=keep_subscribe_org_mapping)
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
|
|
|
@ -11,7 +11,6 @@ from jumpserver.utils import current_request
|
||||||
from common.decorator import on_transaction_commit
|
from common.decorator import on_transaction_commit
|
||||||
from common.utils import get_logger, ssh_key_gen
|
from common.utils import get_logger, ssh_key_gen
|
||||||
from common.utils.connection import RedisPubSub
|
from common.utils.connection import RedisPubSub
|
||||||
from common.db.utils import close_old_connections
|
|
||||||
from common.signals import django_ready
|
from common.signals import django_ready
|
||||||
from .models import Setting
|
from .models import Setting
|
||||||
|
|
||||||
|
@ -81,23 +80,9 @@ def subscribe_settings_change(sender, **kwargs):
|
||||||
logger.debug("Start subscribe setting change")
|
logger.debug("Start subscribe setting change")
|
||||||
|
|
||||||
def keep_subscribe_settings_change():
|
def keep_subscribe_settings_change():
|
||||||
while True:
|
setting_pub_sub.keep_handle_msg(
|
||||||
try:
|
lambda name: Setting.refresh_item(name)
|
||||||
sub = setting_pub_sub.subscribe()
|
)
|
||||||
msgs = sub.listen()
|
|
||||||
# 开始之前关闭连接,因为server端可能关闭了连接,而 client 还在 CONN_MAX_AGE 中
|
|
||||||
for msg in msgs:
|
|
||||||
if msg["type"] != "message":
|
|
||||||
continue
|
|
||||||
close_old_connections()
|
|
||||||
item = msg['data'].decode()
|
|
||||||
logger.debug("Found setting change: {}".format(str(item)))
|
|
||||||
Setting.refresh_item(item)
|
|
||||||
close_old_connections()
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f'subscribe_settings_change: {e}')
|
|
||||||
Setting.refresh_all_settings()
|
|
||||||
close_old_connections()
|
|
||||||
|
|
||||||
t = threading.Thread(target=keep_subscribe_settings_change)
|
t = threading.Thread(target=keep_subscribe_settings_change)
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
|
|
Loading…
Reference in New Issue