mirror of https://github.com/hpcaitech/ColossalAI
[ShardFormer] fix qwen2 sp (#5903)
parent
45c49dde96
commit
1c961b20f3
|
@ -1,3 +1,4 @@
|
||||||
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -513,7 +514,6 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
# sp: all-to-all comminucation when introducing sequence parallel
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
query_states = all_to_all_comm(query_states, sp_group)
|
query_states = all_to_all_comm(query_states, sp_group)
|
||||||
|
@ -698,9 +698,9 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
||||||
next_decoder_cache = None
|
next_decoder_cache = None
|
||||||
|
|
||||||
if sp_mode in ["ring", "split_gather"]:
|
if sp_mode in ["ring", "split_gather"]:
|
||||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
|
||||||
elif sp_mode == "all_to_all":
|
elif sp_mode == "all_to_all":
|
||||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
for decoder_layer in self.layers:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
|
|
@ -135,51 +135,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"tp_size": 1,
|
|
||||||
"pp_size": 2,
|
|
||||||
"num_microbatches": 2,
|
|
||||||
"enable_all_optimization": True,
|
|
||||||
"use_lazy_init": True,
|
|
||||||
"zero_stage": 1,
|
|
||||||
"precision": "fp16",
|
|
||||||
"initial_scale": 1,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def run_qwen2_test(test_config):
|
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
|
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
|
||||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
|
||||||
|
|
||||||
clear_layout_converter()
|
|
||||||
Randomizer.reset_index()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@parameterize(
|
|
||||||
"test_config",
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"tp_size": 2,
|
|
||||||
"pp_size": 2,
|
|
||||||
"num_microbatches": 4,
|
|
||||||
"enable_all_optimization": False,
|
|
||||||
"use_lazy_init": False,
|
|
||||||
"precision": "fp32",
|
|
||||||
"initial_scale": 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tp_size": 2,
|
|
||||||
"pp_size": 2,
|
|
||||||
"num_microbatches": 4,
|
|
||||||
"enable_all_optimization": False,
|
|
||||||
"use_lazy_init": False,
|
|
||||||
"precision": "fp16",
|
|
||||||
"zero_stage": 1,
|
|
||||||
"initial_scale": 1,
|
|
||||||
},
|
|
||||||
{ # Ulysess + Flash attention
|
{ # Ulysess + Flash attention
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
|
@ -242,6 +197,54 @@ def run_qwen2_test(test_config):
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 2,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_qwen2_test(test_config):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
try:
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed config: {test_config}")
|
||||||
|
raise e
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp32",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"zero_stage": 1,
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
|
@ -259,7 +262,11 @@ def run_qwen2_3d_test(test_config):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
try:
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed config: {test_config}")
|
||||||
|
raise e
|
||||||
|
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
Randomizer.reset_index()
|
Randomizer.reset_index()
|
||||||
|
|
Loading…
Reference in New Issue