mirror of https://github.com/hpcaitech/ColossalAI
[fix] rm debug info; update llama policy; update wait handle
parent
cf86c1b1c5
commit
0fb500c7d4
|
@ -691,7 +691,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
accum_loss=accum_loss,
|
accum_loss=accum_loss,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
# print(f"stage {self.stage_manager.stage}; chunk {model_chunk_id}; output_obj {output_obj}")
|
|
||||||
|
|
||||||
# Step3:
|
# Step3:
|
||||||
# 3-1:detach output; detach output for send fwd;
|
# 3-1:detach output; detach output for send fwd;
|
||||||
|
@ -896,7 +895,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
||||||
for it in range(len(schedule)):
|
for it in range(len(schedule)):
|
||||||
scheduled_node = schedule[it]
|
scheduled_node = schedule[it]
|
||||||
# print(f"stage {self.stage_manager.stage} {scheduled_node.type}")
|
|
||||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||||
# communication
|
# communication
|
||||||
communication_func = self.communication_map[scheduled_node.type]
|
communication_func = self.communication_map[scheduled_node.type]
|
||||||
|
@ -925,6 +923,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
model_chunk_id=scheduled_node.chunk,
|
model_chunk_id=scheduled_node.chunk,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
|
# wait here to ensure all communication is done
|
||||||
|
for h in self.wait_handles:
|
||||||
|
for hh in h:
|
||||||
|
hh.wait()
|
||||||
# return loss & output
|
# return loss & output
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
outputs = merge_batch(outputs)
|
outputs = merge_batch(outputs)
|
||||||
|
|
|
@ -506,25 +506,24 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
policy.update(new_item)
|
policy.update(new_item)
|
||||||
# TODO: test lora bug here
|
# enable tp, replace layer to LinearWithGradAccum
|
||||||
# # enable tp, replace layer to LinearWithGradAccum
|
else:
|
||||||
# else:
|
# add a new item for sequence classification
|
||||||
# # add a new item for sequence classification
|
new_item = {
|
||||||
# new_item = {
|
LlamaForSequenceClassification: ModulePolicyDescription(
|
||||||
# LlamaForSequenceClassification: ModulePolicyDescription(
|
sub_module_replacement=[
|
||||||
# sub_module_replacement=[
|
SubModuleReplacementDescription(
|
||||||
# SubModuleReplacementDescription(
|
suffix="score",
|
||||||
# suffix="score",
|
target_module=LinearWithGradAccum,
|
||||||
# target_module=LinearWithGradAccum,
|
kwargs=dict(
|
||||||
# kwargs=dict(
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
# fp8_communication=self.shard_config.fp8_communication,
|
use_zbv=use_zbv,
|
||||||
# use_zbv=use_zbv,
|
),
|
||||||
# ),
|
)
|
||||||
# )
|
]
|
||||||
# ]
|
)
|
||||||
# )
|
}
|
||||||
# }
|
policy.update(new_item)
|
||||||
# policy.update(new_item)
|
|
||||||
|
|
||||||
# to be confirmed
|
# to be confirmed
|
||||||
if self.pipeline_stage_manager:
|
if self.pipeline_stage_manager:
|
||||||
|
|
|
@ -36,24 +36,6 @@ NUM_HEADS = 4
|
||||||
TOP_K = 1
|
TOP_K = 1
|
||||||
|
|
||||||
|
|
||||||
def register_hooks(module: torch.nn.Module):
|
|
||||||
|
|
||||||
def fwd_hook(module, input, output):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
name = module._name if hasattr(module, "_name") else module
|
|
||||||
print(f"Fwd hook {name} \n output {output}")
|
|
||||||
|
|
||||||
def bwd_hook(module, grad_input, grad_output):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
def bwd_pre_hook(module, grad_output):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
module.register_forward_hook(fwd_hook)
|
|
||||||
# module.register_backward_hook(bwd_hook)
|
|
||||||
# module.register_full_backward_pre_hook(bwd_pre_hook)
|
|
||||||
|
|
||||||
|
|
||||||
class MlpModel(nn.Module):
|
class MlpModel(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -1068,7 +1050,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
||||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
||||||
torch_output_sum = 0
|
torch_output_sum = 0
|
||||||
# torch_model.apply(register_hooks) # register hook for base model
|
|
||||||
for input_data_ in all_inputs:
|
for input_data_ in all_inputs:
|
||||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||||
torch_output.backward()
|
torch_output.backward()
|
||||||
|
|
Loading…
Reference in New Issue