mirror of https://github.com/hpcaitech/ColossalAI
[fix] add & fix llama test
parent
e76308c6e6
commit
705b18e1e7
|
@ -82,7 +82,7 @@ class LlamaPipelineForwards:
|
|||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
||||
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
if inputs_embeds is None:
|
||||
|
|
|
@ -924,9 +924,9 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||
"config",
|
||||
[
|
||||
(0, 4, 1, 1),
|
||||
# (1, 2, 2, 1),
|
||||
# (1, 2, 1, 2),
|
||||
# (1, 1, 2, 2),
|
||||
(1, 2, 2, 1),
|
||||
(1, 2, 1, 2),
|
||||
(1, 1, 2, 2),
|
||||
],
|
||||
)
|
||||
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||
|
@ -1010,27 +1010,22 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||
|
||||
torch_model.train()
|
||||
parallel_model.train()
|
||||
for i in range(2):
|
||||
for _ in range(2):
|
||||
# gen random input
|
||||
# input = torch.rand(
|
||||
# NUM_BATCH, NUM_TOK_PER_BATCH, NUM_HEADS, HIDDEN_SIZE_PER_HEAD, requires_grad=True
|
||||
# ).cuda()
|
||||
input_ids = torch.randint(0, torch_model.vocab_size, (NUM_BATCH, config.max_position_embeddings)).cuda()
|
||||
attention_mask = torch.ones_like(input_ids).cuda()
|
||||
input_ids.clone().cuda()
|
||||
input_data = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
input_embeddings = torch.rand(
|
||||
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||
).cuda()
|
||||
dist.all_reduce(
|
||||
input_embeddings, group=plugin.pp_group
|
||||
) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
|
||||
|
||||
# dist.all_reduce(
|
||||
# input, group=plugin.pp_group
|
||||
# ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
|
||||
|
||||
# dist.all_reduce(input, group=plugin.tp_group) # tp group duplicate input
|
||||
# dist.all_reduce(input, group=plugin.sp_group) # sp group duplicate input
|
||||
dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
|
||||
dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
|
||||
|
||||
# run the model with hybrid parallel
|
||||
if booster.plugin.stage_manager is not None:
|
||||
# for test with pp
|
||||
data_iter = iter([input_data])
|
||||
data_iter = iter([{"inputs_embeds": input_embeddings}])
|
||||
sharded_output = booster.execute_pipeline(
|
||||
data_iter,
|
||||
parallel_model,
|
||||
|
@ -1053,10 +1048,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||
|
||||
else:
|
||||
# for test without pp
|
||||
parallel_output = parallel_model(
|
||||
input_ids=input_data["input_ids"],
|
||||
attention_mask=input_data["attention_mask"],
|
||||
).last_hidden_state.mean()
|
||||
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
|
||||
parallel_optimizer.backward(parallel_output)
|
||||
parallel_optimizer.step()
|
||||
parallel_optimizer.zero_grad()
|
||||
|
@ -1064,14 +1056,11 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||
|
||||
# ===================================================================================
|
||||
# run normal model with all dp(different) inputs
|
||||
all_inputs = [input_data for _ in range(dp_size)]
|
||||
# dist.all_gather(all_inputs, input, group=plugin.dp_group)
|
||||
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
||||
torch_output_sum = 0
|
||||
for input_data_ in all_inputs:
|
||||
torch_output = torch_model(
|
||||
input_ids=input_data_["input_ids"],
|
||||
attention_mask=input_data_["attention_mask"],
|
||||
).last_hidden_state.mean()
|
||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||
torch_output.backward()
|
||||
torch_output_sum += torch_output.detach()
|
||||
# print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
||||
|
@ -1082,9 +1071,9 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||
torch_optimizer.step()
|
||||
torch_optimizer.zero_grad()
|
||||
|
||||
print(f"loop {i} rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
||||
# assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||
# print(f"rank {dist.get_rank()} config {test_config} test passed")
|
||||
# print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -1094,7 +1083,7 @@ def run_dist(rank, world_size, port):
|
|||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_with_booster_moehybridplugin()
|
||||
# run_with_booster_hybridplugin()
|
||||
run_with_booster_hybridplugin()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue