mirror of https://github.com/hpcaitech/ColossalAI
[fix] rm debug info; update llama policy; update wait handle
parent
cf86c1b1c5
commit
0fb500c7d4
|
@ -691,7 +691,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
# print(f"stage {self.stage_manager.stage}; chunk {model_chunk_id}; output_obj {output_obj}")
|
||||
|
||||
# Step3:
|
||||
# 3-1:detach output; detach output for send fwd;
|
||||
|
@ -896,7 +895,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
||||
for it in range(len(schedule)):
|
||||
scheduled_node = schedule[it]
|
||||
# print(f"stage {self.stage_manager.stage} {scheduled_node.type}")
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
|
@ -925,6 +923,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
model_chunk_id=scheduled_node.chunk,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
# wait here to ensure all communication is done
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
# return loss & output
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
|
|
|
@ -506,25 +506,24 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
|||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
# TODO: test lora bug here
|
||||
# # enable tp, replace layer to LinearWithGradAccum
|
||||
# else:
|
||||
# # add a new item for sequence classification
|
||||
# new_item = {
|
||||
# LlamaForSequenceClassification: ModulePolicyDescription(
|
||||
# sub_module_replacement=[
|
||||
# SubModuleReplacementDescription(
|
||||
# suffix="score",
|
||||
# target_module=LinearWithGradAccum,
|
||||
# kwargs=dict(
|
||||
# fp8_communication=self.shard_config.fp8_communication,
|
||||
# use_zbv=use_zbv,
|
||||
# ),
|
||||
# )
|
||||
# ]
|
||||
# )
|
||||
# }
|
||||
# policy.update(new_item)
|
||||
# enable tp, replace layer to LinearWithGradAccum
|
||||
else:
|
||||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
LlamaForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
# to be confirmed
|
||||
if self.pipeline_stage_manager:
|
||||
|
|
|
@ -36,24 +36,6 @@ NUM_HEADS = 4
|
|||
TOP_K = 1
|
||||
|
||||
|
||||
def register_hooks(module: torch.nn.Module):
|
||||
|
||||
def fwd_hook(module, input, output):
|
||||
torch.cuda.synchronize()
|
||||
name = module._name if hasattr(module, "_name") else module
|
||||
print(f"Fwd hook {name} \n output {output}")
|
||||
|
||||
def bwd_hook(module, grad_input, grad_output):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def bwd_pre_hook(module, grad_output):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
module.register_forward_hook(fwd_hook)
|
||||
# module.register_backward_hook(bwd_hook)
|
||||
# module.register_full_backward_pre_hook(bwd_pre_hook)
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1068,7 +1050,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
||||
torch_output_sum = 0
|
||||
# torch_model.apply(register_hooks) # register hook for base model
|
||||
for input_data_ in all_inputs:
|
||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||
torch_output.backward()
|
||||
|
|
Loading…
Reference in New Issue