mirror of https://github.com/InternLM/InternLM
fix(pipeline): fix bugs for pipeline when enable mixed precision (#382)
* fix bugs for pipeline * restore logic for empty fp32 group * move optim.dtype to each param grouppull/408/head
parent
9aef11e89c
commit
916647c0a1
|
@ -154,13 +154,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
self._tensor_shape = tensor_shape
|
self._tensor_shape = tensor_shape
|
||||||
|
|
||||||
def pre_processing(self, engine):
|
def pre_processing(self, engine):
|
||||||
types = set()
|
self.dtype = gpc.config.model.get("dtype", torch.half)
|
||||||
|
|
||||||
for param in engine.model.parameters():
|
|
||||||
types.add(param.dtype)
|
|
||||||
assert len(types) == 1, f"Mixed types of parameter detected, {types}"
|
|
||||||
|
|
||||||
self.dtype = types.pop()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _call_engine(engine, data): # pylint: disable=W0237
|
def _call_engine(engine, data): # pylint: disable=W0237
|
||||||
|
@ -430,6 +424,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
comm.send_obj_meta(output_obj)
|
comm.send_obj_meta(output_obj)
|
||||||
need_forward_meta = False # send only once.
|
need_forward_meta = False # send only once.
|
||||||
# Send the forward computation output to the next stage
|
# Send the forward computation output to the next stage
|
||||||
|
assert output_obj.dtype == self.dtype
|
||||||
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
||||||
|
@ -537,6 +532,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
# Send the output of forward computation of this pipeline stage to the next pipeline stage as input for
|
# Send the output of forward computation of this pipeline stage to the next pipeline stage as input for
|
||||||
# forward computation
|
# forward computation
|
||||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
|
assert output_obj.dtype == self.dtype
|
||||||
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
input_objs.append(input_obj)
|
input_objs.append(input_obj)
|
||||||
|
@ -572,6 +568,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
output_obj_grad = None
|
output_obj_grad = None
|
||||||
else:
|
else:
|
||||||
|
assert output_obj.dtype == self.dtype
|
||||||
output_obj_grad = comm.send_forward_recv_backward(
|
output_obj_grad = comm.send_forward_recv_backward(
|
||||||
output_obj,
|
output_obj,
|
||||||
backward_recv_shapes,
|
backward_recv_shapes,
|
||||||
|
@ -984,6 +981,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
# in this iteration; receive tensors for next iteration).
|
# in this iteration; receive tensors for next iteration).
|
||||||
if k != (num_warmup_microsteps - 1) or not receive_extra_backward:
|
if k != (num_warmup_microsteps - 1) or not receive_extra_backward:
|
||||||
# Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage
|
# Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage
|
||||||
|
assert output_obj.dtype == self.dtype
|
||||||
input_obj = comm.send_forward_recv_forward(
|
input_obj = comm.send_forward_recv_forward(
|
||||||
output_obj,
|
output_obj,
|
||||||
input_shape,
|
input_shape,
|
||||||
|
@ -995,6 +993,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
if self._communication_overlap:
|
if self._communication_overlap:
|
||||||
# In this case, we should handle forward and backward communication separately, consistent with the
|
# In this case, we should handle forward and backward communication separately, consistent with the
|
||||||
# overlap version of the 1F1B stage
|
# overlap version of the 1F1B stage
|
||||||
|
assert output_obj.dtype == self.dtype
|
||||||
input_obj = comm.send_forward_recv_forward(
|
input_obj = comm.send_forward_recv_forward(
|
||||||
output_obj,
|
output_obj,
|
||||||
input_shape,
|
input_shape,
|
||||||
|
@ -1011,6 +1010,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
else:
|
else:
|
||||||
# In this case, we should handle forward and backward communication together, consistent with the
|
# In this case, we should handle forward and backward communication together, consistent with the
|
||||||
# non-overlap version of the 1F1B stage
|
# non-overlap version of the 1F1B stage
|
||||||
|
assert output_obj.dtype == self.dtype
|
||||||
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
||||||
output_obj,
|
output_obj,
|
||||||
None, # no backward grad to send
|
None, # no backward grad to send
|
||||||
|
@ -1203,6 +1203,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None
|
output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None
|
||||||
|
|
||||||
# Communicate objs.
|
# Communicate objs.
|
||||||
|
assert output_obj.dtype == self.dtype
|
||||||
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
||||||
output_obj,
|
output_obj,
|
||||||
input_obj_grad,
|
input_obj_grad,
|
||||||
|
|
|
@ -115,7 +115,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
super().__init__(optim=optimizer)
|
super().__init__(optim=optimizer)
|
||||||
|
|
||||||
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
|
||||||
self._cpu_offload = cpu_offload
|
self._cpu_offload = cpu_offload
|
||||||
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||||
self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1)
|
self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||||
|
@ -157,8 +156,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# need to record the rank in which parameter groups are not assigned parameters.
|
# need to record the rank in which parameter groups are not assigned parameters.
|
||||||
self.param_group_has_params = []
|
self.param_group_has_params = []
|
||||||
self.param_group_no_params_ranks = []
|
self.param_group_no_params_ranks = []
|
||||||
self.padding_grad = torch.zeros([32], dtype=self._dtype, device=get_current_device())
|
self.padding_grad = torch.zeros([32], dtype=gpc.config.model.dtype, device=get_current_device())
|
||||||
self.padding_tensor = torch.zeros([32], dtype=self._dtype, device=get_current_device())
|
self.padding_tensor = torch.zeros([32], dtype=gpc.config.model.dtype, device=get_current_device())
|
||||||
|
|
||||||
self.rank_unique_id = (
|
self.rank_unique_id = (
|
||||||
f"gpus-{gpc.get_world_size(ParallelMode.GLOBAL)}_"
|
f"gpus-{gpc.get_world_size(ParallelMode.GLOBAL)}_"
|
||||||
|
@ -177,6 +176,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
for group_id, param_group in enumerate(self.optim.param_groups):
|
for group_id, param_group in enumerate(self.optim.param_groups):
|
||||||
group_params = param_group["params"]
|
group_params = param_group["params"]
|
||||||
|
|
||||||
|
# set the dtype for each param group
|
||||||
|
param_group["dtype"] = group_params[0].dtype if len(group_params) != 0 else None
|
||||||
|
|
||||||
# add the fp16 params to fp16_param_groups for bookkeeping
|
# add the fp16 params to fp16_param_groups for bookkeeping
|
||||||
self._fp16_param_groups[group_id] = group_params
|
self._fp16_param_groups[group_id] = group_params
|
||||||
|
|
||||||
|
@ -253,10 +255,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
def zero_world_size(self):
|
def zero_world_size(self):
|
||||||
return self._zero_world_size
|
return self._zero_world_size
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self._dtype
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loss_scale(self):
|
def loss_scale(self):
|
||||||
return self.grad_scaler.scale
|
return self.grad_scaler.scale
|
||||||
|
@ -528,8 +526,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# compute norm for gradients that have been reduced
|
# compute norm for gradients that have been reduced
|
||||||
params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket)
|
params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket)
|
||||||
if len(params) == 0:
|
if len(params) == 0:
|
||||||
grads = [self.padding_grad]
|
dtype = self.param_groups[group_id]["dtype"]
|
||||||
params = [self.padding_tensor]
|
grads = [self.padding_grad.to(dtype)]
|
||||||
|
params = [self.padding_tensor.to(dtype)]
|
||||||
|
|
||||||
norm = 0
|
norm = 0
|
||||||
if self._clip_grad_norm > 0:
|
if self._clip_grad_norm > 0:
|
||||||
|
|
|
@ -73,9 +73,8 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
||||||
# bf16 param group, which is the first group in the param groups
|
# bf16 param group, which is the first group in the param groups
|
||||||
pgroup["params"] = origin_params
|
pgroup["params"] = origin_params
|
||||||
|
|
||||||
for _, g in new_groups.items():
|
# param groups may contain empty groups, such as fp32
|
||||||
if g["params"]:
|
param_groups.extend(new_groups.values())
|
||||||
param_groups.append(g)
|
|
||||||
|
|
||||||
return tuple(param_groups)
|
return tuple(param_groups)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue