From 606b0891ed8de51d517c9ac0436438def687614c Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 25 Jul 2024 06:34:22 +0000 Subject: [PATCH] [chore] change moe_pg_mesh to private --- colossalai/booster/plugin/moe_hybrid_parallel_plugin.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 9b6eae0d0..7f6608086 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -214,14 +214,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) self.moe_dp_axis, self.ep_axis = 0, 1 - self.moe_pg_mesh = ProcessGroupMesh( + self.__moe_pg_mesh = ProcessGroupMesh( self.moe_dp_size, self.ep_size, self.pp_size, self.tp_size, self.sp_size ) else: self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.moe_dp_axis, self.ep_axis = 1, 2 - self.moe_pg_mesh = ProcessGroupMesh( + self.__moe_pg_mesh = ProcessGroupMesh( self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size ) @@ -269,8 +269,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) - self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis) - self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis) + self.moe_dp_group = self.__moe_pg_mesh.get_group_along_axis(self.moe_dp_axis) + self.ep_group = self.__moe_pg_mesh.get_group_along_axis(self.ep_axis) if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) else: