mirror of https://github.com/hpcaitech/ColossalAI
[fix] add & fix llama test
parent
e76308c6e6
commit
705b18e1e7
|
@ -82,7 +82,7 @@ class LlamaPipelineForwards:
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
batch_size, seq_length = input_ids.shape[:2]
|
batch_size, seq_length = input_ids.shape[:2]
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
|
|
|
@ -924,9 +924,9 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
(0, 4, 1, 1),
|
(0, 4, 1, 1),
|
||||||
# (1, 2, 2, 1),
|
(1, 2, 2, 1),
|
||||||
# (1, 2, 1, 2),
|
(1, 2, 1, 2),
|
||||||
# (1, 1, 2, 2),
|
(1, 1, 2, 2),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
|
@ -1010,27 +1010,22 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
|
|
||||||
torch_model.train()
|
torch_model.train()
|
||||||
parallel_model.train()
|
parallel_model.train()
|
||||||
for i in range(2):
|
for _ in range(2):
|
||||||
# gen random input
|
# gen random input
|
||||||
# input = torch.rand(
|
input_embeddings = torch.rand(
|
||||||
# NUM_BATCH, NUM_TOK_PER_BATCH, NUM_HEADS, HIDDEN_SIZE_PER_HEAD, requires_grad=True
|
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||||
# ).cuda()
|
).cuda()
|
||||||
input_ids = torch.randint(0, torch_model.vocab_size, (NUM_BATCH, config.max_position_embeddings)).cuda()
|
dist.all_reduce(
|
||||||
attention_mask = torch.ones_like(input_ids).cuda()
|
input_embeddings, group=plugin.pp_group
|
||||||
input_ids.clone().cuda()
|
) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
|
||||||
input_data = {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
||||||
|
|
||||||
# dist.all_reduce(
|
dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
|
||||||
# input, group=plugin.pp_group
|
dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
|
||||||
# ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
|
|
||||||
|
|
||||||
# dist.all_reduce(input, group=plugin.tp_group) # tp group duplicate input
|
|
||||||
# dist.all_reduce(input, group=plugin.sp_group) # sp group duplicate input
|
|
||||||
|
|
||||||
# run the model with hybrid parallel
|
# run the model with hybrid parallel
|
||||||
if booster.plugin.stage_manager is not None:
|
if booster.plugin.stage_manager is not None:
|
||||||
# for test with pp
|
# for test with pp
|
||||||
data_iter = iter([input_data])
|
data_iter = iter([{"inputs_embeds": input_embeddings}])
|
||||||
sharded_output = booster.execute_pipeline(
|
sharded_output = booster.execute_pipeline(
|
||||||
data_iter,
|
data_iter,
|
||||||
parallel_model,
|
parallel_model,
|
||||||
|
@ -1053,10 +1048,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# for test without pp
|
# for test without pp
|
||||||
parallel_output = parallel_model(
|
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
|
||||||
input_ids=input_data["input_ids"],
|
|
||||||
attention_mask=input_data["attention_mask"],
|
|
||||||
).last_hidden_state.mean()
|
|
||||||
parallel_optimizer.backward(parallel_output)
|
parallel_optimizer.backward(parallel_output)
|
||||||
parallel_optimizer.step()
|
parallel_optimizer.step()
|
||||||
parallel_optimizer.zero_grad()
|
parallel_optimizer.zero_grad()
|
||||||
|
@ -1064,14 +1056,11 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
|
|
||||||
# ===================================================================================
|
# ===================================================================================
|
||||||
# run normal model with all dp(different) inputs
|
# run normal model with all dp(different) inputs
|
||||||
all_inputs = [input_data for _ in range(dp_size)]
|
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
||||||
# dist.all_gather(all_inputs, input, group=plugin.dp_group)
|
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
||||||
torch_output_sum = 0
|
torch_output_sum = 0
|
||||||
for input_data_ in all_inputs:
|
for input_data_ in all_inputs:
|
||||||
torch_output = torch_model(
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||||
input_ids=input_data_["input_ids"],
|
|
||||||
attention_mask=input_data_["attention_mask"],
|
|
||||||
).last_hidden_state.mean()
|
|
||||||
torch_output.backward()
|
torch_output.backward()
|
||||||
torch_output_sum += torch_output.detach()
|
torch_output_sum += torch_output.detach()
|
||||||
# print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
# print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
||||||
|
@ -1082,9 +1071,9 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
torch_optimizer.zero_grad()
|
torch_optimizer.zero_grad()
|
||||||
|
|
||||||
print(f"loop {i} rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
# print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
||||||
# assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
# print(f"rank {dist.get_rank()} config {test_config} test passed")
|
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
Randomizer.reset_index()
|
Randomizer.reset_index()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -1094,7 +1083,7 @@ def run_dist(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
run_with_booster_moehybridplugin()
|
run_with_booster_moehybridplugin()
|
||||||
# run_with_booster_hybridplugin()
|
run_with_booster_hybridplugin()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
|
Loading…
Reference in New Issue