[fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'

pull/6065/head
duanjunwen 2024-09-20 07:34:43 +00:00
parent b6616f544e
commit 1739df423c
2 changed files with 6 additions and 15 deletions

View File

@ -429,18 +429,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
# fwd calculate
if isinstance(model_chunk, ModuleList):
# fwd for ModuleList model
if input_obj is None:
output_obj = model_chunk[model_chunk_id](**micro_batch)
else:
output_obj = model_chunk[model_chunk_id](**input_obj)
else:
# fwd for shardformer
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
internal_inputs = {} if input_obj is None else input_obj
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
internal_inputs = {} if input_obj is None else input_obj
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):

View File

@ -48,7 +48,7 @@ def pp_linear_fwd(
return forward(hidden_states)
# fwd start
elif stage_mgr.is_first_stage() and model_chunk_id == 0:
return {"hidden_states": forward(hidden_states)}
return {"hidden_states": forward(data)}
# fwd middle
else:
return {"hidden_states": forward(hidden_states)}
@ -601,7 +601,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
data_iter = {"hidden_states": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
# input_base = [t.clone() for t in data_iter]
input_base = {k: v.clone() for k, v in data_iter.items()}
model_base = deepcopy(model)
@ -694,7 +694,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
# Fwd bwd for base
##########################
# fwd & bwd
output_base = model_base(input_base["hidden_states"])
output_base = model_base(input_base["data"])
loss_base = criterion_base(output_base)
loss_base.backward()
optimizer_base.step()