@ -44,8 +44,8 @@ class MessageType(type):
@shared_task ( verbose_name = _ ( ' Publish the station message ' ) )
def publish_task ( msg ) :
msg . publish ( )
def publish_task ( receive_user_ids , backends_ msg_mapper ) :
Message . send_msg ( receive_user_ids , backends_msg_mapper )
class Message ( metaclass = MessageType ) :
@ -65,27 +65,35 @@ class Message(metaclass=MessageType):
return cls . __name__
def publish_async ( self ) :
return publish_task . delay ( self )
self . publish ( is_async = True )
@classmethod
def gen_test_msg ( cls ) :
raise NotImplementedError
def publish ( self ) :
def publish ( self , is_async = False ) :
raise NotImplementedError
def send_msg ( self , users : Iterable , backends : Iterable = BACKEND ) :
def get_backend_msg_mapper ( self , backends ) :
backends = set ( backends )
backends . add ( BACKEND . SITE_MSG ) # 站内信必须发
backends_msg_mapper = { }
for backend in backends :
try :
backend = BACKEND ( backend )
if not backend . is_enable :
continue
get_msg_method = getattr ( self , f ' get_ { backend } _msg ' , self . get_common_msg )
msg = get_msg_method ( )
backends_msg_mapper [ backend ] = msg
return backends_msg_mapper
@staticmethod
def send_msg ( receive_user_ids , backends_msg_mapper ) :
for backend , msg in backends_msg_mapper . items ( ) :
try :
backend = BACKEND ( backend )
client = backend . client ( )
users = User . objects . filter ( id__in = receive_user_ids ) . all ( )
client . send_msg ( users , * * msg )
except NotImplementedError :
continue
@ -238,7 +246,7 @@ class Message(metaclass=MessageType):
class SystemMessage ( Message ) :
def publish ( self ) :
def publish ( self , is_async = False ) :
subscription = SystemMsgSubscription . objects . get (
message_type = self . get_message_type ( )
)
@ -251,7 +259,13 @@ class SystemMessage(Message):
* subscription . users . all ( ) ,
* chain ( * [ g . users . all ( ) for g in subscription . groups . all ( ) ] )
]
self . send_msg ( users , receive_backends )
receive_user_ids = [ u . id for u in users ]
backends_msg_mapper = self . get_backend_msg_mapper ( receive_backends )
if is_async :
publish_task . delay ( receive_user_ids , backends_msg_mapper )
else :
self . send_msg ( receive_user_ids , backends_msg_mapper )
@classmethod
def post_insert_to_db ( cls , subscription : SystemMsgSubscription ) :
@ -268,12 +282,17 @@ class UserMessage(Message):
def __init__ ( self , user ) :
self . user = user
def publish ( self ) :
def publish ( self , is_async = False ) :
"""
发送消息到每个用户配置的接收方式上
"""
sub = UserMsgSubscription . objects . get ( user = self . user )
self . send_msg ( [ self . user ] , sub . receive_backends )
backends_msg_mapper = self . get_backend_msg_mapper ( sub . receive_backends )
receive_user_ids = [ self . user . id ]
if is_async :
publish_task . delay ( receive_user_ids , backends_msg_mapper )
else :
self . send_msg ( receive_user_ids , backends_msg_mapper )
@classmethod
def get_test_user ( cls ) :