[fix] add & fix llama test

pull/6083/head
duanjunwen 2024-10-16 03:58:50 +00:00
parent e76308c6e6
commit 705b18e1e7
2 changed files with 22 additions and 33 deletions

View File

@ -82,7 +82,7 @@ class LlamaPipelineForwards:
elif input_ids is not None: elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2] batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None: elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape[:2] batch_size, seq_length = inputs_embeds.shape[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:

View File

@ -924,9 +924,9 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
"config", "config",
[ [
(0, 4, 1, 1), (0, 4, 1, 1),
# (1, 2, 2, 1), (1, 2, 2, 1),
# (1, 2, 1, 2), (1, 2, 1, 2),
# (1, 1, 2, 2), (1, 1, 2, 2),
], ],
) )
def run_with_booster_hybridplugin(config: Tuple[int, ...]): def run_with_booster_hybridplugin(config: Tuple[int, ...]):
@ -1010,27 +1010,22 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
torch_model.train() torch_model.train()
parallel_model.train() parallel_model.train()
for i in range(2): for _ in range(2):
# gen random input # gen random input
# input = torch.rand( input_embeddings = torch.rand(
# NUM_BATCH, NUM_TOK_PER_BATCH, NUM_HEADS, HIDDEN_SIZE_PER_HEAD, requires_grad=True NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
# ).cuda() ).cuda()
input_ids = torch.randint(0, torch_model.vocab_size, (NUM_BATCH, config.max_position_embeddings)).cuda() dist.all_reduce(
attention_mask = torch.ones_like(input_ids).cuda() input_embeddings, group=plugin.pp_group
input_ids.clone().cuda() ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
input_data = {"input_ids": input_ids, "attention_mask": attention_mask}
# dist.all_reduce( dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
# input, group=plugin.pp_group dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
# ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
# dist.all_reduce(input, group=plugin.tp_group) # tp group duplicate input
# dist.all_reduce(input, group=plugin.sp_group) # sp group duplicate input
# run the model with hybrid parallel # run the model with hybrid parallel
if booster.plugin.stage_manager is not None: if booster.plugin.stage_manager is not None:
# for test with pp # for test with pp
data_iter = iter([input_data]) data_iter = iter([{"inputs_embeds": input_embeddings}])
sharded_output = booster.execute_pipeline( sharded_output = booster.execute_pipeline(
data_iter, data_iter,
parallel_model, parallel_model,
@ -1053,10 +1048,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
else: else:
# for test without pp # for test without pp
parallel_output = parallel_model( parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
input_ids=input_data["input_ids"],
attention_mask=input_data["attention_mask"],
).last_hidden_state.mean()
parallel_optimizer.backward(parallel_output) parallel_optimizer.backward(parallel_output)
parallel_optimizer.step() parallel_optimizer.step()
parallel_optimizer.zero_grad() parallel_optimizer.zero_grad()
@ -1064,14 +1056,11 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
# =================================================================================== # ===================================================================================
# run normal model with all dp(different) inputs # run normal model with all dp(different) inputs
all_inputs = [input_data for _ in range(dp_size)] all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
# dist.all_gather(all_inputs, input, group=plugin.dp_group) dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
torch_output_sum = 0 torch_output_sum = 0
for input_data_ in all_inputs: for input_data_ in all_inputs:
torch_output = torch_model( torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
input_ids=input_data_["input_ids"],
attention_mask=input_data_["attention_mask"],
).last_hidden_state.mean()
torch_output.backward() torch_output.backward()
torch_output_sum += torch_output.detach() torch_output_sum += torch_output.detach()
# print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}") # print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
@ -1082,9 +1071,9 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
torch_optimizer.step() torch_optimizer.step()
torch_optimizer.zero_grad() torch_optimizer.zero_grad()
print(f"loop {i} rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}") # print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
# assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
# print(f"rank {dist.get_rank()} config {test_config} test passed") print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
clear_layout_converter() clear_layout_converter()
Randomizer.reset_index() Randomizer.reset_index()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -1094,7 +1083,7 @@ def run_dist(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_with_booster_moehybridplugin() run_with_booster_moehybridplugin()
# run_with_booster_hybridplugin() run_with_booster_hybridplugin()
@pytest.mark.dist @pytest.mark.dist