[fix] rm debug info; update llama policy; update wait handle

pull/6114/head
duanjunwen 2024-11-15 09:47:05 +00:00
parent cf86c1b1c5
commit 0fb500c7d4
3 changed files with 22 additions and 40 deletions

View File

@ -691,7 +691,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss=accum_loss, accum_loss=accum_loss,
outputs=outputs, outputs=outputs,
) )
# print(f"stage {self.stage_manager.stage}; chunk {model_chunk_id}; output_obj {output_obj}")
# Step3: # Step3:
# 3-1:detach output; detach output for send fwd; # 3-1:detach output; detach output for send fwd;
@ -896,7 +895,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
for it in range(len(schedule)): for it in range(len(schedule)):
scheduled_node = schedule[it] scheduled_node = schedule[it]
# print(f"stage {self.stage_manager.stage} {scheduled_node.type}")
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication # communication
communication_func = self.communication_map[scheduled_node.type] communication_func = self.communication_map[scheduled_node.type]
@ -925,6 +923,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk, model_chunk_id=scheduled_node.chunk,
optimizer=optimizer, optimizer=optimizer,
) )
# wait here to ensure all communication is done
for h in self.wait_handles:
for hh in h:
hh.wait()
# return loss & output # return loss & output
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) outputs = merge_batch(outputs)

View File

@ -506,25 +506,24 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
) )
} }
policy.update(new_item) policy.update(new_item)
# TODO: test lora bug here # enable tp, replace layer to LinearWithGradAccum
# # enable tp, replace layer to LinearWithGradAccum else:
# else: # add a new item for sequence classification
# # add a new item for sequence classification new_item = {
# new_item = { LlamaForSequenceClassification: ModulePolicyDescription(
# LlamaForSequenceClassification: ModulePolicyDescription( sub_module_replacement=[
# sub_module_replacement=[ SubModuleReplacementDescription(
# SubModuleReplacementDescription( suffix="score",
# suffix="score", target_module=LinearWithGradAccum,
# target_module=LinearWithGradAccum, kwargs=dict(
# kwargs=dict( fp8_communication=self.shard_config.fp8_communication,
# fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv,
# use_zbv=use_zbv, ),
# ), )
# ) ]
# ] )
# ) }
# } policy.update(new_item)
# policy.update(new_item)
# to be confirmed # to be confirmed
if self.pipeline_stage_manager: if self.pipeline_stage_manager:

View File

@ -36,24 +36,6 @@ NUM_HEADS = 4
TOP_K = 1 TOP_K = 1
def register_hooks(module: torch.nn.Module):
def fwd_hook(module, input, output):
torch.cuda.synchronize()
name = module._name if hasattr(module, "_name") else module
print(f"Fwd hook {name} \n output {output}")
def bwd_hook(module, grad_input, grad_output):
torch.cuda.synchronize()
def bwd_pre_hook(module, grad_output):
torch.cuda.synchronize()
module.register_forward_hook(fwd_hook)
# module.register_backward_hook(bwd_hook)
# module.register_full_backward_pre_hook(bwd_pre_hook)
class MlpModel(nn.Module): class MlpModel(nn.Module):
def __init__( def __init__(
self, self,
@ -1068,7 +1050,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
all_inputs = [input_embeddings.clone() for _ in range(dp_size)] all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
torch_output_sum = 0 torch_output_sum = 0
# torch_model.apply(register_hooks) # register hook for base model
for input_data_ in all_inputs: for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward() torch_output.backward()