Browse Source

[fix] rm debug info; update llama policy; update wait handle

pull/6114/head
duanjunwen 7 days ago
parent
commit
0fb500c7d4
  1. 6
      colossalai/pipeline/schedule/zero_bubble_pp.py
  2. 37
      colossalai/shardformer/policies/llama.py
  3. 19
      tests/test_pipeline/test_schedule/test_zerobubble_pp.py

6
colossalai/pipeline/schedule/zero_bubble_pp.py

@ -691,7 +691,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss=accum_loss,
outputs=outputs,
)
# print(f"stage {self.stage_manager.stage}; chunk {model_chunk_id}; output_obj {output_obj}")
# Step3:
# 3-1:detach output; detach output for send fwd;
@ -896,7 +895,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
for it in range(len(schedule)):
scheduled_node = schedule[it]
# print(f"stage {self.stage_manager.stage} {scheduled_node.type}")
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
communication_func = self.communication_map[scheduled_node.type]
@ -925,6 +923,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
# wait here to ensure all communication is done
for h in self.wait_handles:
for hh in h:
hh.wait()
# return loss & output
if outputs is not None:
outputs = merge_batch(outputs)

37
colossalai/shardformer/policies/llama.py

@ -506,25 +506,24 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
)
}
policy.update(new_item)
# TODO: test lora bug here
# # enable tp, replace layer to LinearWithGradAccum
# else:
# # add a new item for sequence classification
# new_item = {
# LlamaForSequenceClassification: ModulePolicyDescription(
# sub_module_replacement=[
# SubModuleReplacementDescription(
# suffix="score",
# target_module=LinearWithGradAccum,
# kwargs=dict(
# fp8_communication=self.shard_config.fp8_communication,
# use_zbv=use_zbv,
# ),
# )
# ]
# )
# }
# policy.update(new_item)
# enable tp, replace layer to LinearWithGradAccum
else:
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score",
target_module=LinearWithGradAccum,
kwargs=dict(
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]
)
}
policy.update(new_item)
# to be confirmed
if self.pipeline_stage_manager:

19
tests/test_pipeline/test_schedule/test_zerobubble_pp.py

@ -36,24 +36,6 @@ NUM_HEADS = 4
TOP_K = 1
def register_hooks(module: torch.nn.Module):
def fwd_hook(module, input, output):
torch.cuda.synchronize()
name = module._name if hasattr(module, "_name") else module
print(f"Fwd hook {name} \n output {output}")
def bwd_hook(module, grad_input, grad_output):
torch.cuda.synchronize()
def bwd_pre_hook(module, grad_output):
torch.cuda.synchronize()
module.register_forward_hook(fwd_hook)
# module.register_backward_hook(bwd_hook)
# module.register_full_backward_pre_hook(bwd_pre_hook)
class MlpModel(nn.Module):
def __init__(
self,
@ -1068,7 +1050,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
torch_output_sum = 0
# torch_model.apply(register_hooks) # register hook for base model
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()

Loading…
Cancel
Save