mirror of https://github.com/hpcaitech/ColossalAI
[bug] fix: somehow logger hangs the program
parent
067e18f7e9
commit
96d0fbc531
|
@ -27,7 +27,6 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.interface.optimizer import DistributedOptim
|
from colossalai.interface.optimizer import DistributedOptim
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
@ -1020,8 +1019,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.logger = get_dist_logger(type(self).__name__)
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||||
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||||
|
@ -1070,10 +1067,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.stage_manager = None
|
self.stage_manager = None
|
||||||
self.schedule = None
|
self.schedule = None
|
||||||
self.custom_policy = custom_policy
|
self.custom_policy = custom_policy
|
||||||
|
@ -1123,10 +1116,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
else:
|
else:
|
||||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
|
|
||||||
ranks=[0],
|
|
||||||
)
|
|
||||||
self.shard_config = ShardConfig(
|
self.shard_config = ShardConfig(
|
||||||
tensor_parallel_process_group=self.tp_group,
|
tensor_parallel_process_group=self.tp_group,
|
||||||
sequence_parallel_process_group=self.sp_group,
|
sequence_parallel_process_group=self.sp_group,
|
||||||
|
|
|
@ -226,12 +226,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size"
|
f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}\n"
|
|
||||||
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}",
|
|
||||||
ranks=[0],
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||||
return MoECheckpointIO(
|
return MoECheckpointIO(
|
||||||
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
||||||
|
|
Loading…
Reference in New Issue