diff --git a/backend/application/asgi.py b/backend/application/asgi.py index 5a5c987..586b729 100644 --- a/backend/application/asgi.py +++ b/backend/application/asgi.py @@ -8,12 +8,84 @@ https://docs.djangoproject.com/en/3.2/howto/deployment/asgi/ """ import os + +import jwt +from channels.db import database_sync_to_async +from channels.middleware import BaseMiddleware +from django.contrib.auth import get_user_model +from django.contrib.auth.models import AnonymousUser from django.core.asgi import get_asgi_application from channels.auth import AuthMiddlewareStack from channels.routing import ProtocolTypeRouter, URLRouter +from django.db import close_old_connections +from rest_framework_simplejwt.authentication import AUTH_HEADER_TYPE_BYTES +from rest_framework_simplejwt.exceptions import InvalidToken, TokenError +from rest_framework_simplejwt.tokens import UntypedToken +from dvadmin.system.models import Users +from application import settings os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'application.settings') os.environ["DJANGO_ALLOW_ASYNC_UNSAFE"] = "true" +@database_sync_to_async +def get_user(validated_token): + try: + user = get_user_model().objects.get(id=validated_token["user_id"]) + # return get_user_model().objects.get(id=toke_id) + return user + + except Users.DoesNotExist: + return AnonymousUser() + + +class JwtAuthMiddleware(BaseMiddleware): + def __init__(self, inner): + self.inner = inner + + async def __call__(self, scope, receive, send): + # Close old database connections to prevent usage of timed out connections + close_old_connections() + parts = dict(scope['headers']).get(b'authorization', b'').split() + print("parts",scope) + if len(parts) == 0: + # Empty AUTHORIZATION header sent + return None + + if parts[0] not in AUTH_HEADER_TYPE_BYTES: + # Assume the header does not contain a JSON web token + return None + + if len(parts) != 2: + raise None + + token = parts[1] + # Get the token + # Try to authenticate the user + try: + # This will automatically validate the token and raise an error if token is invalid + UntypedToken(token) + except (InvalidToken, TokenError) as e: + # Token is invalid + print(e) + return None + else: + # Then token is valid, decode it + decoded_data = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) + print(decoded_data) + # Will return a dictionary like - + # { + # "token_type": "access", + # "exp": 1568770772, + # "jti": "5c15e80d65b04c20ad34d77b6703251b", + # "user_id": 6 + # } + + # Get the user using ID + scope["user"] = await get_user(validated_token=decoded_data) + return await super().__call__(scope, receive, send) + + +def JwtAuthMiddlewareStack(inner): + return JwtAuthMiddleware(AuthMiddlewareStack(inner)) http_application = get_asgi_application() diff --git a/backend/application/settings.py b/backend/application/settings.py index 2e6552b..fe99aef 100644 --- a/backend/application/settings.py +++ b/backend/application/settings.py @@ -170,19 +170,19 @@ CORS_ALLOW_CREDENTIALS = True # 指明在跨域访问中,后端是否支持 # ********************* channels配置 ******************* # # ================================================= # ASGI_APPLICATION = 'application.asgi.application' -CHANNEL_LAYERS = { - "default": { - "BACKEND": "channels.layers.InMemoryChannelLayer" - } -} # CHANNEL_LAYERS = { -# 'default': { -# 'BACKEND': 'channels_redis.core.RedisChannelLayer', -# 'CONFIG': { -# "hosts": [('127.0.0.1', 6379)], #需修改 -# }, -# }, +# "default": { +# "BACKEND": "channels.layers.InMemoryChannelLayer" +# } # } +CHANNEL_LAYERS = { + 'default': { + 'BACKEND': 'channels_redis.core.RedisChannelLayer', + 'CONFIG': { + "hosts": [('127.0.0.1', 6379)], #需修改 + }, + }, +} # ================================================= # diff --git a/backend/application/websocketConfig.py b/backend/application/websocketConfig.py index db520e8..414a2a4 100644 --- a/backend/application/websocketConfig.py +++ b/backend/application/websocketConfig.py @@ -16,12 +16,12 @@ send_dict = {} # 发送消息结构体 -def set_message(sender, msg_type, msg, unread=0): +def set_message(sender, msg_type, msg, refresh_unread=False): text = { 'sender': sender, 'contentType': msg_type, 'content': msg, - 'unread': unread + 'refresh_unread': refresh_unread } return text @@ -59,10 +59,14 @@ class DvadminWebSocket(AsyncJsonWebsocketConsumer): decoded_result = jwt.decode(self.service_uid, settings.SECRET_KEY, algorithms=["HS256"]) if decoded_result: self.user_id = decoded_result.get('user_id') - self.chat_group_name = "user_" + str(self.user_id) + self.room_name = "user_" + str(self.user_id) # 收到连接时候处理, await self.channel_layer.group_add( - self.chat_group_name, + "dvadmin", + self.channel_name + ) + await self.channel_layer.group_add( + self.room_name, self.channel_name ) await self.accept() @@ -74,13 +78,14 @@ class DvadminWebSocket(AsyncJsonWebsocketConsumer): else: await self.send_json( set_message('system', 'SYSTEM', "请查看您的未读消息~", - unread=unread_count)) + refresh_unread=True)) except InvalidSignatureError: await self.disconnect(None) async def disconnect(self, close_code): # Leave room group - await self.channel_layer.group_discard(self.chat_group_name, self.channel_name) + await self.channel_layer.group_discard(self.room_name, self.channel_name) + await self.channel_layer.group_discard("dvadmin", self.channel_name) print("连接关闭") try: await self.close(close_code) @@ -96,27 +101,35 @@ class MegCenter(DvadminWebSocket): async def receive(self, text_data): # 接受客户端的信息,你处理的函数 text_data_json = json.loads(text_data) - message_id = text_data_json.get('message_id', None) - user_list = await _get_message_center_instance(message_id) - for send_user in user_list: - await self.channel_layer.group_send( - "user_" + str(send_user), - {'type': 'push.message', 'json': text_data_json} - ) + # message_id = text_data_json.get('message_id', None) + # user_list = await _get_message_center_instance(message_id) + # for send_user in user_list: + # await self.channel_layer.group_send( + # "user_" + str(send_user), + # {'type': 'push.message', 'json': text_data_json} + # ) async def push_message(self, event): """消息发送""" message = event['json'] - print("进入消息发送") + print("进入消息发送",event) await self.send(text_data=json.dumps(message)) -def websocket_push(user_id,message): - username = "user_" + str(user_id) +def websocket_push(room_name,message): channel_layer = get_channel_layer() + print(channel_layer.__dict__) + # async_to_sync(channel_layer.group_send)( + # "dvadmin", + # { + # "type": "push.message", + # "json": message + # } + # ) + print("进入推送了") async_to_sync(channel_layer.group_send)( - username, + room_name, { "type": "push.message", "json": message diff --git a/backend/dvadmin/system/views/message_center.py b/backend/dvadmin/system/views/message_center.py index f9491ed..67e547b 100644 --- a/backend/dvadmin/system/views/message_center.py +++ b/backend/dvadmin/system/views/message_center.py @@ -116,19 +116,21 @@ class MessageCenterCreateSerializer(CustomModelSerializer): users = Users.objects.filter(dept__id__in=target_dept).values_list('id', flat=True) if target_type in [3]: # 系统通知 users = Users.objects.values_list('id', flat=True) + websocket_push("dvadmin", message={"sender": 'system', "contentType": 'SYSTEM', + "content": '您有一条新消息~', "refresh_unread": True}) targetuser_data = [] for user in users: targetuser_data.append({ "messagecenter": data.id, "users": user }) + if target_type in [1,2]: + room_name = f"user_{user}" + websocket_push(room_name, message={"sender": 'system', "contentType": 'SYSTEM', + "content": '您有一条新消息~', "refresh_unread": True}) targetuser_instance = MessageCenterTargetUserSerializer(data=targetuser_data, many=True, request=self.request) targetuser_instance.is_valid(raise_exception=True) targetuser_instance.save() - for user in users: - unread_count = MessageCenterTargetUser.objects.filter(users__id=user, is_read=False).count() - websocket_push(user, message={"sender": 'system', "contentType": 'SYSTEM', - "content": '您有一条新消息~', "unread": unread_count}) return data class Meta: @@ -169,9 +171,9 @@ class MessageCenterViewSet(CustomModelViewSet): instance = self.get_object() serializer = self.get_serializer(instance) # 主动推送消息 - unread_count = MessageCenterTargetUser.objects.filter(users__id=user_id, is_read=False).count() - websocket_push(user_id, message={"sender": 'system', "contentType": 'TEXT', - "content": '您查看了一条消息~', "unread": unread_count}) + room_name = f"user_{user_id}" + websocket_push(room_name, message={"sender": 'system', "contentType": 'TEXT', + "content": '您查看了一条消息~', "refresh_unread": True}) return DetailResponse(data=serializer.data, msg="获取成功") @action(methods=['GET'], detail=False, permission_classes=[IsAuthenticated]) @@ -203,3 +205,10 @@ class MessageCenterViewSet(CustomModelViewSet): serializer = MessageCenterTargetUserListSerializer(queryset.messagecenter, many=False, request=request) data = serializer.data return DetailResponse(data=data, msg="获取成功") + + @action(methods=['GET'], detail=False, permission_classes=[IsAuthenticated]) + def get_unread_msg(self, request): + """获取未读消息数量""" + self_user_id = self.request.user.id + count = MessageCenterTargetUser.objects.filter(users__id=self_user_id,is_read=False).count() + return DetailResponse(data={"count":count}, msg="获取成功") \ No newline at end of file diff --git a/web/src/api/websocket.js b/web/src/api/websocket.js index 83a7e9d..efc5c9c 100644 --- a/web/src/api/websocket.js +++ b/web/src/api/websocket.js @@ -29,8 +29,10 @@ function webSocketOnError (e) { */ function webSocketOnMessage (e) { const data = JSON.parse(e.data) - const { unread } = data - store.dispatch('d2admin/messagecenter/setUnread', unread || 0) + const { refresh_unread } = data + if (refresh_unread) { + store.dispatch('d2admin/messagecenter/setUnread') + } if (data.contentType === 'SYSTEM') { ElementUI.Notification({ title: '系统消息', diff --git a/web/src/store/modules/d2admin/modules/messagecenter.js b/web/src/store/modules/d2admin/modules/messagecenter.js index 21856d1..02482c3 100644 --- a/web/src/store/modules/d2admin/modules/messagecenter.js +++ b/web/src/store/modules/d2admin/modules/messagecenter.js @@ -1,3 +1,5 @@ +import { request } from '@/api/service' +import { urlPrefix } from '@/views/system/messageCenter/api' export default { namespaced: true, @@ -18,8 +20,22 @@ export default { * @param {String} param type {String} 类型 * @param {Object} payload meta {Object} 附带的信息 */ - async setUnread ({ state, commit }, number) { - commit('set', number) + async setUnread ({ + state, + commit + }, number) { + if (number) { + commit('set', number) + } else { + request({ + url: '/api/system/message_center/get_unread_msg/', + method: 'get', + params: {} + }).then(res => { + const { data } = res + commit('set', data.count) + }) + } } }, mutations: { diff --git a/web/src/views/system/messageCenter/index.vue b/web/src/views/system/messageCenter/index.vue index 8832409..408c888 100644 --- a/web/src/views/system/messageCenter/index.vue +++ b/web/src/views/system/messageCenter/index.vue @@ -56,14 +56,7 @@ export default { return GetObj(query) }, addRequest (row) { - return AddObj(row).then(res => { - const message = { - message_id: res.data.id, - contentType: 'INFO', - content: '您有新的消息,请到消息中心查看~' - } - this.$websocket.webSocketSend(message) - }) + return AddObj(row) }, updateRequest (row) { return UpdateObj(row)