From 97a0e27307a4a4c2b9b408ee3e6594cb3f13961a Mon Sep 17 00:00:00 2001 From: ibuler Date: Fri, 11 Jun 2021 17:11:29 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E4=B8=AD=E5=BF=83=E6=9C=AA=E8=AF=BB=E6=95=B0=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/jumpserver/routing.py | 4 +- apps/jumpserver/settings/base.py | 2 +- apps/jumpserver/urls.py | 2 +- apps/notifications/api/site_msgs.py | 12 ++-- apps/notifications/apps.py | 4 ++ apps/notifications/backends/site_msg.py | 2 +- apps/notifications/migrations/0001_initial.py | 2 +- apps/notifications/signals_handler.py | 43 ++++++++++++ apps/notifications/site_msg.py | 43 ++++++------ .../urls/{notifications.py => api_urls.py} | 0 apps/notifications/urls/ws_urls.py | 9 +++ apps/notifications/ws.py | 70 +++++++++++++++++++ 12 files changed, 161 insertions(+), 32 deletions(-) create mode 100644 apps/notifications/signals_handler.py rename apps/notifications/urls/{notifications.py => api_urls.py} (100%) create mode 100644 apps/notifications/urls/ws_urls.py create mode 100644 apps/notifications/ws.py diff --git a/apps/jumpserver/routing.py b/apps/jumpserver/routing.py index d76f1ccee..5deae804e 100644 --- a/apps/jumpserver/routing.py +++ b/apps/jumpserver/routing.py @@ -2,9 +2,11 @@ from channels.auth import AuthMiddlewareStack from channels.routing import ProtocolTypeRouter, URLRouter from ops.urls.ws_urls import urlpatterns as ops_urlpatterns +from notifications.urls.ws_urls import urlpatterns as notifications_urlpatterns urlpatterns = [] -urlpatterns += ops_urlpatterns +urlpatterns += ops_urlpatterns \ + + notifications_urlpatterns application = ProtocolTypeRouter({ 'websocket': AuthMiddlewareStack( diff --git a/apps/jumpserver/settings/base.py b/apps/jumpserver/settings/base.py index 268bafa44..ef69cec8e 100644 --- a/apps/jumpserver/settings/base.py +++ b/apps/jumpserver/settings/base.py @@ -48,7 +48,7 @@ INSTALLED_APPS = [ 'applications.apps.ApplicationsConfig', 'tickets.apps.TicketsConfig', 'acls.apps.AclsConfig', - 'notifications', + 'notifications.apps.NotificationsConfig', 'common.apps.CommonConfig', 'jms_oidc_rp', 'rest_framework', diff --git a/apps/jumpserver/urls.py b/apps/jumpserver/urls.py index c2ffea6ec..43d0e6cb0 100644 --- a/apps/jumpserver/urls.py +++ b/apps/jumpserver/urls.py @@ -23,7 +23,7 @@ api_v1 = [ path('applications/', include('applications.urls.api_urls', namespace='api-applications')), path('tickets/', include('tickets.urls.api_urls', namespace='api-tickets')), path('acls/', include('acls.urls.api_urls', namespace='api-acls')), - path('notifications/', include('notifications.urls.notifications', namespace='api-notifications')), + path('notifications/', include('notifications.urls.api_urls', namespace='api-notifications')), path('prometheus/metrics/', api.PrometheusMetricsApi.as_view()), ] diff --git a/apps/notifications/api/site_msgs.py b/apps/notifications/api/site_msgs.py index 6ee856922..2f8ba7e15 100644 --- a/apps/notifications/api/site_msgs.py +++ b/apps/notifications/api/site_msgs.py @@ -10,7 +10,7 @@ from ..serializers import ( SiteMessageDetailSerializer, SiteMessageIdsSerializer, SiteMessageSendSerializer, ) -from ..site_msg import SiteMessage +from ..site_msg import SiteMessageUtil from ..filters import SiteMsgFilter __all__ = ('SiteMessageViewSet', ) @@ -30,15 +30,15 @@ class SiteMessageViewSet(ListModelMixin, RetrieveModelMixin, JmsGenericViewSet): has_read = self.request.query_params.get('has_read') if has_read is None: - msgs = SiteMessage.get_user_all_msgs(user.id) + msgs = SiteMessageUtil.get_user_all_msgs(user.id) else: - msgs = SiteMessage.filter_user_msgs(user.id, has_read=is_true(has_read)) + msgs = SiteMessageUtil.filter_user_msgs(user.id, has_read=is_true(has_read)) return msgs @action(methods=[GET], detail=False, url_path='unread-total') def unread_total(self, request, **kwargs): user = request.user - msgs = SiteMessage.filter_user_msgs(user.id, has_read=False) + msgs = SiteMessageUtil.filter_user_msgs(user.id, has_read=False) return Response(data={'total': msgs.count()}) @action(methods=[PATCH], detail=False, url_path='mark-as-read') @@ -47,12 +47,12 @@ class SiteMessageViewSet(ListModelMixin, RetrieveModelMixin, JmsGenericViewSet): seri = self.get_serializer(data=request.data) seri.is_valid(raise_exception=True) ids = seri.validated_data['ids'] - SiteMessage.mark_msgs_as_read(user.id, ids) + SiteMessageUtil.mark_msgs_as_read(user.id, ids) return Response({'detail': 'ok'}) @action(methods=[POST], detail=False) def send(self, request, **kwargs): seri = self.get_serializer(data=request.data) seri.is_valid(raise_exception=True) - SiteMessage.send_msg(**seri.validated_data, sender=request.user) + SiteMessageUtil.send_msg(**seri.validated_data, sender=request.user) return Response({'detail': 'ok'}) diff --git a/apps/notifications/apps.py b/apps/notifications/apps.py index 9c260e0b1..f14a8ebe9 100644 --- a/apps/notifications/apps.py +++ b/apps/notifications/apps.py @@ -3,3 +3,7 @@ from django.apps import AppConfig class NotificationsConfig(AppConfig): name = 'notifications' + + def ready(self): + from . import signals_handler + super().ready() diff --git a/apps/notifications/backends/site_msg.py b/apps/notifications/backends/site_msg.py index 33032843a..0f7468f48 100644 --- a/apps/notifications/backends/site_msg.py +++ b/apps/notifications/backends/site_msg.py @@ -1,4 +1,4 @@ -from notifications.site_msg import SiteMessage as Client +from notifications.site_msg import SiteMessageUtil as Client from .base import BackendBase diff --git a/apps/notifications/migrations/0001_initial.py b/apps/notifications/migrations/0001_initial.py index ebe79f304..8e55bb305 100644 --- a/apps/notifications/migrations/0001_initial.py +++ b/apps/notifications/migrations/0001_initial.py @@ -17,7 +17,7 @@ class Migration(migrations.Migration): operations = [ migrations.CreateModel( - name='SiteMessage', + name='SiteMessageUtil', fields=[ ('created_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Created by')), ('updated_by', models.CharField(blank=True, max_length=32, null=True, verbose_name='Updated by')), diff --git a/apps/notifications/signals_handler.py b/apps/notifications/signals_handler.py new file mode 100644 index 000000000..13ebdc4bc --- /dev/null +++ b/apps/notifications/signals_handler.py @@ -0,0 +1,43 @@ +import json + +from django.utils.functional import LazyObject +from django.db.models.signals import post_save +from django.dispatch import receiver + +from common.utils.connection import RedisPubSub +from common.utils import get_logger +from common.decorator import on_transaction_commit +from .models import SiteMessage + + +logger = get_logger(__name__) + + +def new_site_msg_pub_sub(): + return RedisPubSub('notifications.SiteMessageCome') + + +class NewSiteMsgSubPub(LazyObject): + def _setup(self): + self._wrapped = new_site_msg_pub_sub() + + +new_site_msg_chan = NewSiteMsgSubPub() + + +@receiver(post_save, sender=SiteMessage) +@on_transaction_commit +def on_site_message_create(sender, instance, created, **kwargs): + if not created: + return + logger.debug('New site msg created, publish it') + user_ids = instance.users.all().values_list('id', flat=True) + user_ids = [str(i) for i in user_ids] + data = { + 'id': str(instance.id), + 'subject': instance.subject, + 'message': instance.message, + 'users': user_ids + } + data = json.dumps(data) + new_site_msg_chan.publish(data) diff --git a/apps/notifications/site_msg.py b/apps/notifications/site_msg.py index b78d3c7f4..1a5c9dc23 100644 --- a/apps/notifications/site_msg.py +++ b/apps/notifications/site_msg.py @@ -1,11 +1,12 @@ from django.db.models import F +from django.db import transaction from common.utils.timezone import now from users.models import User from .models import SiteMessage as SiteMessageModel, SiteMessageUsers -class SiteMessage: +class SiteMessageUtil: @classmethod def send_msg(cls, subject, message, user_ids=(), group_ids=(), @@ -13,24 +14,24 @@ class SiteMessage: if not any((user_ids, group_ids, is_broadcast)): raise ValueError('No recipient is specified') - site_msg = SiteMessageModel.objects.create( - subject=subject, message=message, - is_broadcast=is_broadcast, sender=sender, - ) + with transaction.atomic(): + site_msg = SiteMessageModel.objects.create( + subject=subject, message=message, + is_broadcast=is_broadcast, sender=sender, + ) - if is_broadcast: - user_ids = User.objects.all().values_list('id', flat=True) - else: - if group_ids: - site_msg.groups.add(*group_ids) + if is_broadcast: + user_ids = User.objects.all().values_list('id', flat=True) + else: + if group_ids: + site_msg.groups.add(*group_ids) - user_ids_from_group = User.groups.through.objects.filter( - usergroup_id__in=group_ids - ).values_list('user_id', flat=True) + user_ids_from_group = User.groups.through.objects.filter( + usergroup_id__in=group_ids + ).values_list('user_id', flat=True) + user_ids = [*user_ids, *user_ids_from_group] - user_ids = [*user_ids, *user_ids_from_group] - - site_msg.users.add(*user_ids) + site_msg.users.add(*user_ids) @classmethod def get_user_all_msgs(cls, user_id): @@ -72,14 +73,14 @@ class SiteMessage: @classmethod def mark_msgs_as_read(cls, user_id, msg_ids): - sitemsg_users = SiteMessageUsers.objects.filter( + site_msg_users = SiteMessageUsers.objects.filter( user_id=user_id, sitemessage_id__in=msg_ids, has_read=False ) - for sitemsg_user in sitemsg_users: - sitemsg_user.has_read = True - sitemsg_user.read_at = now() + for site_msg_user in site_msg_users: + site_msg_user.has_read = True + site_msg_user.read_at = now() SiteMessageUsers.objects.bulk_update( - sitemsg_users, fields=('has_read', 'read_at')) + site_msg_users, fields=('has_read', 'read_at')) diff --git a/apps/notifications/urls/notifications.py b/apps/notifications/urls/api_urls.py similarity index 100% rename from apps/notifications/urls/notifications.py rename to apps/notifications/urls/api_urls.py diff --git a/apps/notifications/urls/ws_urls.py b/apps/notifications/urls/ws_urls.py new file mode 100644 index 000000000..dfd457e52 --- /dev/null +++ b/apps/notifications/urls/ws_urls.py @@ -0,0 +1,9 @@ +from django.urls import path + +from .. import ws + +app_name = 'notifications' + +urlpatterns = [ + path('ws/notifications/site-msg/', ws.SiteMsgWebsocket, name='site-msg-ws'), +] \ No newline at end of file diff --git a/apps/notifications/ws.py b/apps/notifications/ws.py new file mode 100644 index 000000000..cbbb25d2d --- /dev/null +++ b/apps/notifications/ws.py @@ -0,0 +1,70 @@ +import threading +import json + +from channels.generic.websocket import JsonWebsocketConsumer + +from common.utils import get_logger +from .models import SiteMessage +from .site_msg import SiteMessageUtil +from .signals_handler import new_site_msg_chan + +logger = get_logger(__name__) + + +class SiteMsgWebsocket(JsonWebsocketConsumer): + disconnected = False + refresh_every_seconds = 10 + + def connect(self): + user = self.scope["user"] + if user.is_authenticated: + self.accept() + + thread = threading.Thread(target=self.unread_site_msg_count) + thread.start() + else: + self.close() + + def receive(self, text_data=None, bytes_data=None, **kwargs): + data = json.loads(text_data) + refresh_every_seconds = data.get('refresh_every_seconds') + + try: + refresh_every_seconds = int(refresh_every_seconds) + except Exception as e: + logger.error(e) + return + + if refresh_every_seconds > 0: + self.refresh_every_seconds = refresh_every_seconds + + def send_unread_msg_count(self): + user_id = self.scope["user"].id + unread_count = SiteMessageUtil.get_user_unread_msgs_count(user_id) + logger.debug('Send unread count to user: {} {}'.format(user_id, unread_count)) + self.send_json({'type': 'unread_count', 'unread_count': unread_count}) + + def unread_site_msg_count(self): + user_id = str(self.scope["user"].id) + self.send_unread_msg_count() + + while not self.disconnected: + subscribe = new_site_msg_chan.subscribe() + for message in subscribe.listen(): + if message['type'] != 'message': + continue + try: + msg = json.loads(message['data'].decode()) + logger.debug('New site msg recv, may be mine: {}'.format(msg)) + if not msg: + continue + users = msg.get('users', []) + logger.debug('Message users: {}'.format(users)) + if user_id in users: + self.send_unread_msg_count() + except json.JSONDecoder as e: + logger.debug('Decode json error: ', e) + + def disconnect(self, close_code): + self.disconnected = True + self.close()