From 4c69e2dc9161a3aece27b806d8bfeded2a89a285 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Tue, 9 Apr 2024 11:50:35 +0800 Subject: [PATCH] support qwen model --- colossalai/shardformer/modeling/qwen2.py | 39 ++++++++------------- colossalai/shardformer/policies/qwen2.py | 16 ++++----- tests/test_shardformer/test_model/_utils.py | 6 ---- 3 files changed, 23 insertions(+), 38 deletions(-) diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 0abe4fe03..3641eedfc 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -44,7 +44,7 @@ class Qwen2PipelineForwards: hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, - ): + )-> Union[Tuple, BaseModelOutputWithPast]: logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -82,14 +82,18 @@ class Qwen2PipelineForwards: if output_hidden_states: logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False - if use_cache: - logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") - use_cache = False assert past_key_values is None, "past_key_values is not supported for Qwen2 models at the moment." past_key_values_length = 0 + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) @@ -123,18 +127,11 @@ class Qwen2PipelineForwards: attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), - inputs_embeds, + hidden_states, past_key_values_length, sliding_window=self.config.sliding_window, ) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -148,20 +145,14 @@ class Qwen2PipelineForwards: past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, attention_mask, position_ids, - None, + past_key_values, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( @@ -315,7 +306,7 @@ class Qwen2PipelineForwards: else: hidden_states = outputs.get("hidden_states") return {"hidden_states": hidden_states} - + @staticmethod def qwen2_for_sequence_classification_forward( self: Qwen2ForSequenceClassification, diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index f01518ef7..933223ba7 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -151,10 +151,10 @@ class Qwen2Policy(Policy): module = self.model.model if stage_manager.is_interleave: - layers_per_stage = self.distribute_layers( + layers_per_stage = stage_manager.distribute_layers( len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_manager.stage_indices = Policy.get_stage_index( + stage_manager.stage_indices = stage_manager.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -165,8 +165,8 @@ class Qwen2Policy(Policy): } else: - layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = stage_manager.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config @@ -192,10 +192,10 @@ class Qwen2Policy(Policy): held_layers = [] if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None - layers_per_stage = self.distribute_layers( + layers_per_stage = stage_manager.distribute_layers( len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_indices = Policy.get_stage_index( + stage_indices = stage_manager.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -209,10 +209,10 @@ class Qwen2Policy(Policy): held_layers.append(module.norm) else: - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers), stage_manager.num_stages) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage, stage_manager.stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index f5fc21b4c..e3e5045de 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -182,12 +182,6 @@ def run_forward_backward_with_hybrid_plugin( data_iter = iter([data]) sharded_output = booster.execute_pipeline( data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True - data_iter, - sharded_model, - _criterion, - sharded_optimizer, - return_loss=True, - return_outputs=True, ) sharded_loss = sharded_output["loss"] else: