mirror of https://github.com/InternLM/InternLM
fix bugs for pipeline
parent
e34e7307c9
commit
79d7c392a6
|
@ -154,13 +154,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
self._tensor_shape = tensor_shape
|
||||
|
||||
def pre_processing(self, engine):
|
||||
types = set()
|
||||
|
||||
for param in engine.model.parameters():
|
||||
types.add(param.dtype)
|
||||
assert len(types) == 1, f"Mixed types of parameter detected, {types}"
|
||||
|
||||
self.dtype = types.pop()
|
||||
self.dtype = gpc.config.model.get("dtype", torch.half)
|
||||
|
||||
@staticmethod
|
||||
def _call_engine(engine, data): # pylint: disable=W0237
|
||||
|
@ -430,6 +424,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
comm.send_obj_meta(output_obj)
|
||||
need_forward_meta = False # send only once.
|
||||
# Send the forward computation output to the next stage
|
||||
assert output_obj.dtype == self.dtype
|
||||
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
||||
|
@ -537,6 +532,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
# Send the output of forward computation of this pipeline stage to the next pipeline stage as input for
|
||||
# forward computation
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
assert output_obj.dtype == self.dtype
|
||||
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
input_objs.append(input_obj)
|
||||
|
@ -572,6 +568,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
output_obj_grad = None
|
||||
else:
|
||||
assert output_obj.dtype == self.dtype
|
||||
output_obj_grad = comm.send_forward_recv_backward(
|
||||
output_obj,
|
||||
backward_recv_shapes,
|
||||
|
@ -984,6 +981,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
# in this iteration; receive tensors for next iteration).
|
||||
if k != (num_warmup_microsteps - 1) or not receive_extra_backward:
|
||||
# Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage
|
||||
assert output_obj.dtype == self.dtype
|
||||
input_obj = comm.send_forward_recv_forward(
|
||||
output_obj,
|
||||
input_shape,
|
||||
|
@ -995,6 +993,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
if self._communication_overlap:
|
||||
# In this case, we should handle forward and backward communication separately, consistent with the
|
||||
# overlap version of the 1F1B stage
|
||||
assert output_obj.dtype == self.dtype
|
||||
input_obj = comm.send_forward_recv_forward(
|
||||
output_obj,
|
||||
input_shape,
|
||||
|
@ -1011,6 +1010,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
else:
|
||||
# In this case, we should handle forward and backward communication together, consistent with the
|
||||
# non-overlap version of the 1F1B stage
|
||||
assert output_obj.dtype == self.dtype
|
||||
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
||||
output_obj,
|
||||
None, # no backward grad to send
|
||||
|
@ -1203,6 +1203,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None
|
||||
|
||||
# Communicate objs.
|
||||
assert output_obj.dtype == self.dtype
|
||||
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
||||
output_obj,
|
||||
input_obj_grad,
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from internlm.core.context.parallel_context import ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.model.utils import is_gate_param, is_moe_param, is_norm_param
|
||||
|
||||
|
@ -74,7 +76,10 @@ def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict])
|
|||
pgroup["params"] = origin_params
|
||||
|
||||
for _, g in new_groups.items():
|
||||
if g["params"]:
|
||||
# remove empty group, especially for fp32 group
|
||||
is_empty = torch.tensor(bool(g["params"]), device=torch.cuda.current_device())
|
||||
dist.all_reduce(is_empty, group=gpc.get_group(ParallelMode.MODEL))
|
||||
if is_empty:
|
||||
param_groups.append(g)
|
||||
|
||||
return tuple(param_groups)
|
||||
|
|
Loading…
Reference in New Issue