[chore] change moe_pg_mesh to private

colossalchat
hxwang 2024-07-25 06:34:22 +00:00 committed by Hongxin Liu
parent 5b4c12381b
commit 606b0891ed
1 changed files with 4 additions and 4 deletions

View File

@ -214,14 +214,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
self.moe_dp_axis, self.ep_axis = 0, 1 self.moe_dp_axis, self.ep_axis = 0, 1
self.moe_pg_mesh = ProcessGroupMesh( self.__moe_pg_mesh = ProcessGroupMesh(
self.moe_dp_size, self.ep_size, self.pp_size, self.tp_size, self.sp_size self.moe_dp_size, self.ep_size, self.pp_size, self.tp_size, self.sp_size
) )
else: else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.moe_dp_axis, self.ep_axis = 1, 2 self.moe_dp_axis, self.ep_axis = 1, 2
self.moe_pg_mesh = ProcessGroupMesh( self.__moe_pg_mesh = ProcessGroupMesh(
self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size
) )
@ -269,8 +269,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis) self.moe_dp_group = self.__moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis) self.ep_group = self.__moe_pg_mesh.get_group_along_axis(self.ep_axis)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else: else: