diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 0a97c466a..92d214bad 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -691,7 +691,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss=accum_loss, outputs=outputs, ) - # print(f"stage {self.stage_manager.stage}; chunk {model_chunk_id}; output_obj {output_obj}") # Step3: # 3-1:detach output; detach output for send fwd; @@ -896,7 +895,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) for it in range(len(schedule)): scheduled_node = schedule[it] - # print(f"stage {self.stage_manager.stage} {scheduled_node.type}") if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication communication_func = self.communication_map[scheduled_node.type] @@ -925,6 +923,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk_id=scheduled_node.chunk, optimizer=optimizer, ) + # wait here to ensure all communication is done + for h in self.wait_handles: + for hh in h: + hh.wait() # return loss & output if outputs is not None: outputs = merge_batch(outputs) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b18aa933c..d962057b1 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -506,25 +506,24 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): ) } policy.update(new_item) - # TODO: test lora bug here - # # enable tp, replace layer to LinearWithGradAccum - # else: - # # add a new item for sequence classification - # new_item = { - # LlamaForSequenceClassification: ModulePolicyDescription( - # sub_module_replacement=[ - # SubModuleReplacementDescription( - # suffix="score", - # target_module=LinearWithGradAccum, - # kwargs=dict( - # fp8_communication=self.shard_config.fp8_communication, - # use_zbv=use_zbv, - # ), - # ) - # ] - # ) - # } - # policy.update(new_item) + # enable tp, replace layer to LinearWithGradAccum + else: + # add a new item for sequence classification + new_item = { + LlamaForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", + target_module=LinearWithGradAccum, + kwargs=dict( + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), + ) + ] + ) + } + policy.update(new_item) # to be confirmed if self.pipeline_stage_manager: diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index b630d30b1..ba6e82e88 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -36,24 +36,6 @@ NUM_HEADS = 4 TOP_K = 1 -def register_hooks(module: torch.nn.Module): - - def fwd_hook(module, input, output): - torch.cuda.synchronize() - name = module._name if hasattr(module, "_name") else module - print(f"Fwd hook {name} \n output {output}") - - def bwd_hook(module, grad_input, grad_output): - torch.cuda.synchronize() - - def bwd_pre_hook(module, grad_output): - torch.cuda.synchronize() - - module.register_forward_hook(fwd_hook) - # module.register_backward_hook(bwd_hook) - # module.register_full_backward_pre_hook(bwd_pre_hook) - - class MlpModel(nn.Module): def __init__( self, @@ -1068,7 +1050,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): all_inputs = [input_embeddings.clone() for _ in range(dp_size)] dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) torch_output_sum = 0 - # torch_model.apply(register_hooks) # register hook for base model for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output.backward()