[bug] fix: somehow logger hangs the program

colossalchat
botbw 2024-07-23 06:17:51 +00:00 committed by Hongxin Liu
parent 067e18f7e9
commit 96d0fbc531
2 changed files with 0 additions and 17 deletions

View File

@ -27,7 +27,6 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -1020,8 +1019,6 @@ class HybridParallelPlugin(PipelinePluginBase):
) -> None: ) -> None:
super().__init__() super().__init__()
self.logger = get_dist_logger(type(self).__name__)
assert ( assert (
dist.get_world_size() % (tp_size * pp_size) == 0 dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
@ -1070,10 +1067,6 @@ class HybridParallelPlugin(PipelinePluginBase):
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.logger.info(
f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0]
)
self.stage_manager = None self.stage_manager = None
self.schedule = None self.schedule = None
self.custom_policy = custom_policy self.custom_policy = custom_policy
@ -1123,10 +1116,6 @@ class HybridParallelPlugin(PipelinePluginBase):
else: else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
self.logger.info(
f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
ranks=[0],
)
self.shard_config = ShardConfig( self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group, tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group, sequence_parallel_process_group=self.sp_group,

View File

@ -226,12 +226,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size" f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size"
) )
self.logger.info(
f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}\n"
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
ranks=[0],
)
def get_checkpoint_io(self) -> MoECheckpointIO: def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO( return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage