diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index da78dfc0b..55822b150 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -1,3 +1,4 @@ +import math from typing import List, Optional, Tuple, Union import torch @@ -513,7 +514,6 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": query_states = all_to_all_comm(query_states, sp_group) @@ -698,9 +698,9 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No next_decoder_cache = None if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size) for decoder_layer in self.layers: if output_hidden_states: diff --git a/tests/test_shardformer/test_model/test_shard_qwen2.py b/tests/test_shardformer/test_model/test_shard_qwen2.py index 160f9c53b..c87415b75 100644 --- a/tests/test_shardformer/test_model/test_shard_qwen2.py +++ b/tests/test_shardformer/test_model/test_shard_qwen2.py @@ -135,51 +135,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { - "tp_size": 1, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": True, - "use_lazy_init": True, - "zero_stage": 1, - "precision": "fp16", - "initial_scale": 1, - }, - ], -) -def run_qwen2_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2") - - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) - - clear_layout_converter() - Randomizer.reset_index() - torch.cuda.empty_cache() - - -@parameterize( - "test_config", - [ - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 4, - "enable_all_optimization": False, - "use_lazy_init": False, - "precision": "fp32", - "initial_scale": 1, - }, - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 4, - "enable_all_optimization": False, - "use_lazy_init": False, - "precision": "fp16", - "zero_stage": 1, - "initial_scale": 1, - }, { # Ulysess + Flash attention "tp_size": 1, "pp_size": 2, @@ -242,6 +197,54 @@ def run_qwen2_test(test_config): "precision": "fp16", "initial_scale": 1, }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) +def run_qwen2_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config}") + raise e + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2, @@ -259,7 +262,11 @@ def run_qwen2_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed config: {test_config}") + raise e clear_layout_converter() Randomizer.reset_index()