mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'
parent
b6616f544e
commit
1739df423c
|
@ -429,18 +429,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# Only attention_mask from micro_batch is used
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
# fwd calculate
|
||||
if isinstance(model_chunk, ModuleList):
|
||||
# fwd for ModuleList model
|
||||
if input_obj is None:
|
||||
output_obj = model_chunk[model_chunk_id](**micro_batch)
|
||||
else:
|
||||
output_obj = model_chunk[model_chunk_id](**input_obj)
|
||||
else:
|
||||
# fwd for shardformer
|
||||
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
|
||||
|
||||
# last layer in model
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
|
|
|
@ -48,7 +48,7 @@ def pp_linear_fwd(
|
|||
return forward(hidden_states)
|
||||
# fwd start
|
||||
elif stage_mgr.is_first_stage() and model_chunk_id == 0:
|
||||
return {"hidden_states": forward(hidden_states)}
|
||||
return {"hidden_states": forward(data)}
|
||||
# fwd middle
|
||||
else:
|
||||
return {"hidden_states": forward(hidden_states)}
|
||||
|
@ -601,7 +601,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||
data_iter = {"hidden_states": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
|
||||
data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
|
||||
# input_base = [t.clone() for t in data_iter]
|
||||
input_base = {k: v.clone() for k, v in data_iter.items()}
|
||||
model_base = deepcopy(model)
|
||||
|
@ -694,7 +694,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
# Fwd bwd for base
|
||||
##########################
|
||||
# fwd & bwd
|
||||
output_base = model_base(input_base["hidden_states"])
|
||||
output_base = model_base(input_base["data"])
|
||||
loss_base = criterion_base(output_base)
|
||||
loss_base.backward()
|
||||
optimizer_base.step()
|
||||
|
|
Loading…
Reference in New Issue