[fix] fix stage_indices;

pull/6069/head
duanjunwen 2 months ago
parent 7e6f793c51
commit fc8b016887

@ -430,7 +430,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
# fwd calculate # fwd calculate
internal_inputs = {} if input_obj is None else input_obj internal_inputs = {} if input_obj is None else input_obj
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs) output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
# last layer in model # last layer in model
@ -480,22 +480,26 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# For chunk 0 stage 0, use micro_batch as input_obj_ # For chunk 0 stage 0, use micro_batch as input_obj_
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj_, _ = tree_flatten(micro_batch) input_obj_, _ = tree_flatten({k: v for k, v in micro_batch.items() if isinstance(v, torch.Tensor)})
output_obj_, _ = tree_flatten(output_obj) # y output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy output_obj_grad_, _ = tree_flatten(
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)}
) # dy
# For loss backward; output_obj is loss; output_obj_grad should be None # For loss backward; output_obj is loss; output_obj_grad should be None
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
assert output_obj_grad is None assert output_obj_grad is None
input_obj_, _ = tree_flatten(input_obj) input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)})
output_obj_.append(output_obj) # LOSS output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(output_obj_grad) # None output_obj_grad_.append(output_obj_grad) # None
# For other chunk stage, use input_obj as input_obj_; # For other chunk stage, use input_obj as input_obj_;
else: else:
input_obj_, _ = tree_flatten(input_obj) input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)})
output_obj_, _ = tree_flatten(output_obj) # y output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy output_obj_grad_, _ = tree_flatten(
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)}
) # dy
optimizer.backward_by_grad( optimizer.backward_by_grad(
tensor=output_obj_, tensor=output_obj_,
@ -547,8 +551,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_.append(output_obj) # LOSS output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(None) # None output_obj_grad_.append(None) # None
else: else:
output_obj_, _ = tree_flatten(output_obj) # y output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy output_obj_grad_, _ = tree_flatten(
{k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)}
) # dy
optimizer.backward_by_grad( optimizer.backward_by_grad(
tensor=output_obj_, tensor=output_obj_,

@ -39,6 +39,7 @@ def pp_linear_fwd(
forward, forward,
data: torch.Tensor = None, data: torch.Tensor = None,
hidden_states: torch.Tensor = None, hidden_states: torch.Tensor = None,
stage_index=None,
stage_mgr: PipelineStageManager = None, stage_mgr: PipelineStageManager = None,
model_chunk_id: int = None, model_chunk_id: int = None,
): ):
@ -605,6 +606,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
# input_base = [t.clone() for t in data_iter] # input_base = [t.clone() for t in data_iter]
input_base = {k: v.clone() for k, v in data_iter.items()} input_base = {k: v.clone() for k, v in data_iter.items()}
model_base = deepcopy(model) model_base = deepcopy(model)
layers_per_stage = stage_manager.distribute_layers(len(model.layers))
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
if rank == 0: if rank == 0:
# layer 0 & 7 to chunk 0 on rank0 # layer 0 & 7 to chunk 0 on rank0

Loading…
Cancel
Save