diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 4ef8c86..6f2558c 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -154,13 +154,7 @@ class PipelineScheduler(BaseScheduler): self._tensor_shape = tensor_shape def pre_processing(self, engine): - types = set() - - for param in engine.model.parameters(): - types.add(param.dtype) - assert len(types) == 1, f"Mixed types of parameter detected, {types}" - - self.dtype = types.pop() + self.dtype = gpc.config.model.get("dtype", torch.half) @staticmethod def _call_engine(engine, data): # pylint: disable=W0237 @@ -430,6 +424,7 @@ class PipelineScheduler(BaseScheduler): comm.send_obj_meta(output_obj) need_forward_meta = False # send only once. # Send the forward computation output to the next stage + assert output_obj.dtype == self.dtype comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) @@ -537,6 +532,7 @@ class PipelineScheduler(BaseScheduler): # Send the output of forward computation of this pipeline stage to the next pipeline stage as input for # forward computation if not gpc.is_last_rank(ParallelMode.PIPELINE): + assert output_obj.dtype == self.dtype comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) input_objs.append(input_obj) @@ -572,6 +568,7 @@ class PipelineScheduler(BaseScheduler): if gpc.is_last_rank(ParallelMode.PIPELINE): output_obj_grad = None else: + assert output_obj.dtype == self.dtype output_obj_grad = comm.send_forward_recv_backward( output_obj, backward_recv_shapes, @@ -984,6 +981,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): # in this iteration; receive tensors for next iteration). if k != (num_warmup_microsteps - 1) or not receive_extra_backward: # Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage + assert output_obj.dtype == self.dtype input_obj = comm.send_forward_recv_forward( output_obj, input_shape, @@ -995,6 +993,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): if self._communication_overlap: # In this case, we should handle forward and backward communication separately, consistent with the # overlap version of the 1F1B stage + assert output_obj.dtype == self.dtype input_obj = comm.send_forward_recv_forward( output_obj, input_shape, @@ -1011,6 +1010,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): else: # In this case, we should handle forward and backward communication together, consistent with the # non-overlap version of the 1F1B stage + assert output_obj.dtype == self.dtype input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( output_obj, None, # no backward grad to send @@ -1203,6 +1203,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None # Communicate objs. + assert output_obj.dtype == self.dtype input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( output_obj, input_obj_grad, diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 7abca14..6894945 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -115,7 +115,6 @@ class HybridZeroOptimizer(BaseOptimizer): super().__init__(optim=optimizer) - self._dtype = self.optim.param_groups[0]["params"][0].dtype self._cpu_offload = cpu_offload self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1) self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1) @@ -157,8 +156,8 @@ class HybridZeroOptimizer(BaseOptimizer): # need to record the rank in which parameter groups are not assigned parameters. self.param_group_has_params = [] self.param_group_no_params_ranks = [] - self.padding_grad = torch.zeros([32], dtype=self._dtype, device=get_current_device()) - self.padding_tensor = torch.zeros([32], dtype=self._dtype, device=get_current_device()) + self.padding_grad = torch.zeros([32], dtype=gpc.config.model.dtype, device=get_current_device()) + self.padding_tensor = torch.zeros([32], dtype=gpc.config.model.dtype, device=get_current_device()) self.rank_unique_id = ( f"gpus-{gpc.get_world_size(ParallelMode.GLOBAL)}_" @@ -177,6 +176,9 @@ class HybridZeroOptimizer(BaseOptimizer): for group_id, param_group in enumerate(self.optim.param_groups): group_params = param_group["params"] + # set the dtype for each param group + param_group["dtype"] = group_params[0].dtype if len(group_params) != 0 else None + # add the fp16 params to fp16_param_groups for bookkeeping self._fp16_param_groups[group_id] = group_params @@ -253,10 +255,6 @@ class HybridZeroOptimizer(BaseOptimizer): def zero_world_size(self): return self._zero_world_size - @property - def dtype(self): - return self._dtype - @property def loss_scale(self): return self.grad_scaler.scale @@ -528,8 +526,9 @@ class HybridZeroOptimizer(BaseOptimizer): # compute norm for gradients that have been reduced params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket) if len(params) == 0: - grads = [self.padding_grad] - params = [self.padding_tensor] + dtype = self.param_groups[group_id]["dtype"] + grads = [self.padding_grad.to(dtype)] + params = [self.padding_tensor.to(dtype)] norm = 0 if self._clip_grad_norm > 0: diff --git a/internlm/train/utils.py b/internlm/train/utils.py index a05a6b2..0e249fe 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -73,9 +73,8 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) # bf16 param group, which is the first group in the param groups pgroup["params"] = origin_params - for _, g in new_groups.items(): - if g["params"]: - param_groups.append(g) + # param groups may contain empty groups, such as fp32 + param_groups.extend(new_groups.values()) return tuple(param_groups)