diff --git a/apps/notifications/ws.py b/apps/notifications/ws.py index cbbb25d2d..45cbb6d00 100644 --- a/apps/notifications/ws.py +++ b/apps/notifications/ws.py @@ -1,6 +1,6 @@ import threading import json - +from redis.exceptions import ConnectionError from channels.generic.websocket import JsonWebsocketConsumer from common.utils import get_logger @@ -12,13 +12,14 @@ logger = get_logger(__name__) class SiteMsgWebsocket(JsonWebsocketConsumer): - disconnected = False refresh_every_seconds = 10 + chan = None def connect(self): user = self.scope["user"] if user.is_authenticated: self.accept() + self.chan = new_site_msg_chan.subscribe() thread = threading.Thread(target=self.unread_site_msg_count) thread.start() @@ -48,9 +49,8 @@ class SiteMsgWebsocket(JsonWebsocketConsumer): user_id = str(self.scope["user"].id) self.send_unread_msg_count() - while not self.disconnected: - subscribe = new_site_msg_chan.subscribe() - for message in subscribe.listen(): + try: + for message in self.chan.listen(): if message['type'] != 'message': continue try: @@ -64,7 +64,10 @@ class SiteMsgWebsocket(JsonWebsocketConsumer): self.send_unread_msg_count() except json.JSONDecoder as e: logger.debug('Decode json error: ', e) + except ConnectionError: + logger.debug('Redis chan closed') def disconnect(self, close_code): - self.disconnected = True + if self.chan is not None: + self.chan.close() self.close()