diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 47c17e749..7a04c5451 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -82,7 +82,7 @@ class LlamaPipelineForwards: elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape[:2] + batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index bdc539043..ffeaf6bd8 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -924,9 +924,9 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): "config", [ (0, 4, 1, 1), - # (1, 2, 2, 1), - # (1, 2, 1, 2), - # (1, 1, 2, 2), + (1, 2, 2, 1), + (1, 2, 1, 2), + (1, 1, 2, 2), ], ) def run_with_booster_hybridplugin(config: Tuple[int, ...]): @@ -1010,27 +1010,22 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): torch_model.train() parallel_model.train() - for i in range(2): + for _ in range(2): # gen random input - # input = torch.rand( - # NUM_BATCH, NUM_TOK_PER_BATCH, NUM_HEADS, HIDDEN_SIZE_PER_HEAD, requires_grad=True - # ).cuda() - input_ids = torch.randint(0, torch_model.vocab_size, (NUM_BATCH, config.max_position_embeddings)).cuda() - attention_mask = torch.ones_like(input_ids).cuda() - input_ids.clone().cuda() - input_data = {"input_ids": input_ids, "attention_mask": attention_mask} + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check - # dist.all_reduce( - # input, group=plugin.pp_group - # ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check - - # dist.all_reduce(input, group=plugin.tp_group) # tp group duplicate input - # dist.all_reduce(input, group=plugin.sp_group) # sp group duplicate input + dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input + dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input # run the model with hybrid parallel if booster.plugin.stage_manager is not None: # for test with pp - data_iter = iter([input_data]) + data_iter = iter([{"inputs_embeds": input_embeddings}]) sharded_output = booster.execute_pipeline( data_iter, parallel_model, @@ -1053,10 +1048,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): else: # for test without pp - parallel_output = parallel_model( - input_ids=input_data["input_ids"], - attention_mask=input_data["attention_mask"], - ).last_hidden_state.mean() + parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean() parallel_optimizer.backward(parallel_output) parallel_optimizer.step() parallel_optimizer.zero_grad() @@ -1064,14 +1056,11 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): # =================================================================================== # run normal model with all dp(different) inputs - all_inputs = [input_data for _ in range(dp_size)] - # dist.all_gather(all_inputs, input, group=plugin.dp_group) + all_inputs = [input_embeddings.clone() for _ in range(dp_size)] + dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) torch_output_sum = 0 for input_data_ in all_inputs: - torch_output = torch_model( - input_ids=input_data_["input_ids"], - attention_mask=input_data_["attention_mask"], - ).last_hidden_state.mean() + torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output.backward() torch_output_sum += torch_output.detach() # print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}") @@ -1082,9 +1071,9 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): torch_optimizer.step() torch_optimizer.zero_grad() - print(f"loop {i} rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}") - # assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) - # print(f"rank {dist.get_rank()} config {test_config} test passed") + # print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}") + assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) + print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed") clear_layout_converter() Randomizer.reset_index() torch.cuda.empty_cache() @@ -1094,7 +1083,7 @@ def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_with_booster_moehybridplugin() - # run_with_booster_hybridplugin() + run_with_booster_hybridplugin() @pytest.mark.dist