From 79d7c392a6b4fdce3ace3a9a22ad7f3362a3edc9 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Thu, 28 Sep 2023 13:44:14 +0800 Subject: [PATCH] fix bugs for pipeline --- internlm/core/scheduler/pipeline_scheduler.py | 15 ++++++++------- internlm/train/utils.py | 7 ++++++- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 4ef8c86..6f2558c 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -154,13 +154,7 @@ class PipelineScheduler(BaseScheduler): self._tensor_shape = tensor_shape def pre_processing(self, engine): - types = set() - - for param in engine.model.parameters(): - types.add(param.dtype) - assert len(types) == 1, f"Mixed types of parameter detected, {types}" - - self.dtype = types.pop() + self.dtype = gpc.config.model.get("dtype", torch.half) @staticmethod def _call_engine(engine, data): # pylint: disable=W0237 @@ -430,6 +424,7 @@ class PipelineScheduler(BaseScheduler): comm.send_obj_meta(output_obj) need_forward_meta = False # send only once. # Send the forward computation output to the next stage + assert output_obj.dtype == self.dtype comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) @@ -537,6 +532,7 @@ class PipelineScheduler(BaseScheduler): # Send the output of forward computation of this pipeline stage to the next pipeline stage as input for # forward computation if not gpc.is_last_rank(ParallelMode.PIPELINE): + assert output_obj.dtype == self.dtype comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) input_objs.append(input_obj) @@ -572,6 +568,7 @@ class PipelineScheduler(BaseScheduler): if gpc.is_last_rank(ParallelMode.PIPELINE): output_obj_grad = None else: + assert output_obj.dtype == self.dtype output_obj_grad = comm.send_forward_recv_backward( output_obj, backward_recv_shapes, @@ -984,6 +981,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): # in this iteration; receive tensors for next iteration). if k != (num_warmup_microsteps - 1) or not receive_extra_backward: # Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage + assert output_obj.dtype == self.dtype input_obj = comm.send_forward_recv_forward( output_obj, input_shape, @@ -995,6 +993,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): if self._communication_overlap: # In this case, we should handle forward and backward communication separately, consistent with the # overlap version of the 1F1B stage + assert output_obj.dtype == self.dtype input_obj = comm.send_forward_recv_forward( output_obj, input_shape, @@ -1011,6 +1010,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): else: # In this case, we should handle forward and backward communication together, consistent with the # non-overlap version of the 1F1B stage + assert output_obj.dtype == self.dtype input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( output_obj, None, # no backward grad to send @@ -1203,6 +1203,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None # Communicate objs. + assert output_obj.dtype == self.dtype input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( output_obj, input_obj_grad, diff --git a/internlm/train/utils.py b/internlm/train/utils.py index a05a6b2..14874f3 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -1,7 +1,9 @@ from typing import Dict, Tuple import torch +import torch.distributed as dist +from internlm.core.context.parallel_context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param @@ -74,7 +76,10 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) pgroup["params"] = origin_params for _, g in new_groups.items(): - if g["params"]: + # remove empty group, especially for fp32 group + is_empty = torch.tensor(bool(g["params"]), device=torch.cuda.current_device()) + dist.all_reduce(is_empty, group=gpc.get_group(ParallelMode.MODEL)) + if is_empty: param_groups.append(g) return tuple(param_groups)