mirror of https://github.com/hpcaitech/ColossalAI
[chore] change moe_pg_mesh to private
parent
5b4c12381b
commit
606b0891ed
|
@ -214,14 +214,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||||
self.moe_dp_axis, self.ep_axis = 0, 1
|
self.moe_dp_axis, self.ep_axis = 0, 1
|
||||||
self.moe_pg_mesh = ProcessGroupMesh(
|
self.__moe_pg_mesh = ProcessGroupMesh(
|
||||||
self.moe_dp_size, self.ep_size, self.pp_size, self.tp_size, self.sp_size
|
self.moe_dp_size, self.ep_size, self.pp_size, self.tp_size, self.sp_size
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||||
self.moe_dp_axis, self.ep_axis = 1, 2
|
self.moe_dp_axis, self.ep_axis = 1, 2
|
||||||
self.moe_pg_mesh = ProcessGroupMesh(
|
self.__moe_pg_mesh = ProcessGroupMesh(
|
||||||
self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size
|
self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -269,8 +269,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||||
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
|
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
|
||||||
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
|
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
|
||||||
self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
|
self.moe_dp_group = self.__moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
|
||||||
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
|
self.ep_group = self.__moe_pg_mesh.get_group_along_axis(self.ep_axis)
|
||||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
|
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue