From 00525f77724428e7d883893d07bbfbf4dd1ad35e Mon Sep 17 00:00:00 2001 From: Insu Jang Date: Wed, 27 Mar 2024 01:57:00 -0400 Subject: [PATCH] [shardformer] fix pipeline forward error if custom layer distribution is used (#5189) * Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution * Change static methods for t5 layer distribution to member functions * Change static methods for whisper layer distribution to member functions * Replace whisper policy usage with self one * Fix test case to use non-static layer distribution methods * fix: fix typo --------- Co-authored-by: Wenhao Chen --- .../colossal_moe/models/mixtral_policy.py | 2 +- .../shardformer/policies/base_policy.py | 7 +- colossalai/shardformer/policies/bert.py | 67 ++++++++++++++----- colossalai/shardformer/policies/bloom.py | 4 +- colossalai/shardformer/policies/chatglm2.py | 4 +- colossalai/shardformer/policies/falcon.py | 4 +- colossalai/shardformer/policies/gpt2.py | 8 +-- colossalai/shardformer/policies/gptj.py | 4 +- colossalai/shardformer/policies/llama.py | 8 +-- colossalai/shardformer/policies/opt.py | 4 +- colossalai/shardformer/policies/t5.py | 32 +++++---- colossalai/shardformer/policies/vit.py | 4 +- colossalai/shardformer/policies/whisper.py | 30 ++++----- .../language/openmoe/model/openmoe_policy.py | 17 ++--- .../test_plugin/test_3d_plugin.py | 8 +-- .../test_t5_pipeline_utils.py | 18 +++-- .../test_whisper_pipeline_utils.py | 16 +++-- .../test_layer/test_dist_crossentropy.py | 5 +- 18 files changed, 136 insertions(+), 106 deletions(-) diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py index 218b05b27..23ffbf5d3 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py @@ -110,7 +110,7 @@ class MixtralPolicy(Policy): module = self.model.model layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 9a49b1ba6..762e75481 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -197,8 +197,7 @@ class Policy(ABC): """ return [] - @staticmethod - def distribute_layers(num_layers: int, num_stages: int) -> List[int]: + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: """Divide layers into stages""" quotient = num_layers // num_stages remainder = num_layers % num_stages @@ -213,8 +212,8 @@ class Policy(ABC): layers_per_stage[i] += 1 return layers_per_stage - @staticmethod def get_stage_index( + self, layers_per_stage: List[int], stage: int, num_model_chunks: int = 1, @@ -242,4 +241,4 @@ class Policy(ABC): end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] stage_indices.append([start_idx, end_idx]) - return stage_indices[0] if num_model_chunks == 1 else stage_indices \ No newline at end of file + return stage_indices[0] if num_model_chunks == 1 else stage_indices diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 0ab63b765..4d50a3c99 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -84,17 +84,26 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="attention.self.query", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attention.self.key", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attention.self.value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="attention.self.dropout", @@ -112,7 +121,10 @@ class BertPolicy(Policy): SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + kwargs={ + "seq_parallel": use_sequence_parallel, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="output.dense", @@ -214,7 +226,9 @@ class BertPolicy(Policy): if self.shard_config.enable_tensor_parallelism: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( - suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + suffix="decoder", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, ), policy=base_policy, target_key=BertLMPredictionHead, @@ -241,7 +255,9 @@ class BertPolicy(Policy): "_load_from_state_dict": col_nn.ParallelModule._load_from_state_dict, } self.append_or_create_method_replacement( - description=method_replacement, policy=base_policy, target_key=BertLMPredictionHead + description=method_replacement, + policy=base_policy, + target_key=BertLMPredictionHead, ) return base_policy @@ -264,24 +280,32 @@ class BertPolicy(Policy): if stage_manager.is_interleave: layers_per_stage = self.distribute_layers( - len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks + len(module.encoder.layer), + stage_manager.num_stages * stage_manager.num_model_chunks, ) - stage_manager.stage_indices = Policy.get_stage_index( + stage_manager.stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, num_stages=stage_manager.num_stages, ) method_replacement = { - "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + "forward": partial( + new_forward, + stage_manager=stage_manager, + shard_config=self.shard_config, + ) } else: - layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config, ) } @@ -301,9 +325,10 @@ class BertPolicy(Policy): if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = self.distribute_layers( - len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks + len(module.encoder.layer), + stage_manager.num_stages * stage_manager.num_model_chunks, ) - stage_indices = Policy.get_stage_index( + stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -320,7 +345,7 @@ class BertPolicy(Policy): layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) if stage_manager.is_first_stage(): held_layers.append(module.embeddings) - start_idx, end_idx = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) held_layers.extend(module.encoder.layer[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.pooler) @@ -336,7 +361,9 @@ class BertModelPolicy(BertPolicy): if self.pipeline_stage_manager: self.set_pipeline_forward( - model_cls=BertModel, new_forward=BertPipelineForwards.bert_model_forward, policy=policy + model_cls=BertModel, + new_forward=BertPipelineForwards.bert_model_forward, + policy=policy, ) return policy @@ -399,7 +426,9 @@ class BertLMHeadModelPolicy(BertPolicy): if self.pipeline_stage_manager: self.set_pipeline_forward( - model_cls=BertLMHeadModel, new_forward=BertPipelineForwards.bert_lm_head_model_forward, policy=policy + model_cls=BertLMHeadModel, + new_forward=BertPipelineForwards.bert_lm_head_model_forward, + policy=policy, ) return policy @@ -437,7 +466,9 @@ class BertForMaskedLMPolicy(BertPolicy): if self.pipeline_stage_manager: self.set_pipeline_forward( - model_cls=BertForMaskedLM, new_forward=BertPipelineForwards.bert_for_masked_lm_forward, policy=policy + model_cls=BertForMaskedLM, + new_forward=BertPipelineForwards.bert_for_masked_lm_forward, + policy=policy, ) return policy diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index eddfafdcb..e4714c8c1 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -203,8 +203,8 @@ class BloomPolicy(Policy): else: module = self.model.transformer - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index d1ad9f914..cbe6254d1 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -204,8 +204,8 @@ class ChatGLMPolicy(Policy): else: module = self.model.transformer - layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 5c148880f..16bbc3f23 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -161,8 +161,8 @@ class FalconPolicy(Policy): else: module = self.model.transformer - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 5b43ecaed..d1a8c9dce 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -188,7 +188,7 @@ class GPT2Policy(Policy): layers_per_stage = self.distribute_layers( len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_indices = Policy.get_stage_index( + stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -229,7 +229,7 @@ class GPT2Policy(Policy): layers_per_stage = self.distribute_layers( len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_manager.stage_indices = Policy.get_stage_index( + stage_manager.stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -243,8 +243,8 @@ class GPT2Policy(Policy): ) } else: - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index b001a2009..b24443298 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -200,8 +200,8 @@ class GPTJPolicy(Policy): else: module = self.model.transformer - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index db8468713..daa7708c8 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -167,7 +167,7 @@ class LlamaPolicy(Policy): layers_per_stage = self.distribute_layers( len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_manager.stage_indices = Policy.get_stage_index( + stage_manager.stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, @@ -178,8 +178,8 @@ class LlamaPolicy(Policy): } else: - layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config @@ -207,7 +207,7 @@ class LlamaPolicy(Policy): layers_per_stage = self.distribute_layers( len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks ) - stage_indices = Policy.get_stage_index( + stage_indices = self.get_stage_index( layers_per_stage, stage_manager.stage, num_model_chunks=stage_manager.num_model_chunks, diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 9a74da0b8..683f3a9d5 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -208,8 +208,8 @@ class OPTPolicy(Policy): else: module = self.model.model.decoder - layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index e183b0632..f5f701dc0 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from functools import partial from typing import Callable, Dict, List, Tuple @@ -241,9 +243,8 @@ class T5BasePolicy(Policy): def postprocess(self): return self.model - @staticmethod def distribute_t5_layers( - num_encoder_layers: int, num_decoder_layers: int, num_stages: int + self, num_encoder_layers: int, num_decoder_layers: int, num_stages: int ) -> Tuple[List[int], int]: """ Distribute t5 layers into stages when pipeline parallel is used. @@ -261,7 +262,7 @@ class T5BasePolicy(Policy): # in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist if num_decoder_layers == 0: - return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages + return self.distribute_layers(num_encoder_layers, num_stages), num_stages # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) @@ -272,22 +273,21 @@ class T5BasePolicy(Policy): num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages - encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) - decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) + encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) return encoder_distribution + decoder_distribution, num_encoder_stages - @staticmethod def get_t5_stage_index( - layers_per_stage: List[int], stage: int, decoder_starting_stage: int + self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int ) -> Tuple[bool, int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. Return the starting/ending idx of layers in encoder/decoder """ if stage < decoder_starting_stage: - return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" @@ -302,12 +302,10 @@ class T5BasePolicy(Policy): num_decoder_layers = len(decoder.block) if decoder else 0 held_layers = [] - layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + layers_per_stage, decoder_starting_stage = self.distribute_t5_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) - start_idx, end_idx = T5BasePolicy.get_t5_stage_index( - layers_per_stage, stage_manager.stage, decoder_starting_stage - ) + start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) if stage_manager.stage < decoder_starting_stage: # current stage is in t5's encoder @@ -343,10 +341,10 @@ class T5BasePolicy(Policy): num_encoder_layers = len(encoder.block) num_decoder_layers = len(decoder.block) if decoder else 0 - layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + layers_per_stage, decoder_starting_stage = self.distribute_t5_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) - stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) + stage_index = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) method_replacement = { "forward": partial( @@ -386,7 +384,7 @@ class T5ModelPolicy(T5BasePolicy): module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + _, decoder_starting_stage = self.distribute_t5_layers( len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages ) @@ -434,7 +432,7 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy): module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + _, decoder_starting_stage = self.distribute_t5_layers( len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages ) diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 584d4e265..b0f224e22 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -149,8 +149,8 @@ class ViTPolicy(Policy): else: module = self.model.vit - layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 14e1e3e0f..480a4beea 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -292,9 +292,8 @@ class WhisperPolicy(Policy): def postprocess(self): return self.model - @staticmethod def distribute_whisper_layers( - num_encoder_layers: int, num_decoder_layers: int, num_stages: int + self, num_encoder_layers: int, num_decoder_layers: int, num_stages: int ) -> Tuple[List[int], int]: """ Distribute whisper layers into stages when pipeline parallel is used. @@ -312,7 +311,7 @@ class WhisperPolicy(Policy): # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist if num_decoder_layers == 0: - return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages + return self.distribute_layers(num_encoder_layers, num_stages), num_stages # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) @@ -323,22 +322,21 @@ class WhisperPolicy(Policy): num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages - encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) - decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) + encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) return encoder_distribution + decoder_distribution, num_encoder_stages - @staticmethod def get_whisper_stage_index( - layers_per_stage: List[int], stage: int, decoder_starting_stage: int + self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int ) -> Tuple[bool, int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. Return the starting/ending idx of layers in encoder/decoder """ if stage < decoder_starting_stage: - return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return Policy.get_stage_index( + return self.get_stage_index( layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage, ) @@ -369,12 +367,10 @@ class WhisperPolicy(Policy): num_decoder_layers = 0 held_layers = [] - layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) - start_idx, end_idx = WhisperPolicy.get_whisper_stage_index( - layers_per_stage, stage_manager.stage, decoder_starting_stage - ) + start_idx, end_idx = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) if stage_manager.stage < decoder_starting_stage: # current stage is in whisper's encoder @@ -424,12 +420,10 @@ class WhisperPolicy(Policy): else: num_decoder_layers = 0 - layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) - stage_index = WhisperPolicy.get_whisper_stage_index( - layers_per_stage, stage_manager.stage, decoder_starting_stage - ) + stage_index = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) method_replacement = { "forward": partial( @@ -511,7 +505,7 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy): stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + _, decoder_starting_stage = self.distribute_whisper_layers( num_encoder_layers, num_decoder_layers, stage_manager.num_stages ) shared_params = [] diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index 17e7aa46c..66a42e017 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -98,11 +98,11 @@ class OpenMoePolicy(Policy): module = self.model.model layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=model_cls) + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) return @@ -126,12 +126,9 @@ class OpenMoePolicy(Policy): held_layers.append(module.norm) return held_layers - - @staticmethod - def distribute_layers(num_layers: int, num_stages: int) -> List[int]: - """Divide layers into stages - """ + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + """Divide layers into stages""" if num_layers == 24 and num_stages == 4: return [7, 7, 7, 3] elif num_layers == 24 and num_stages == 2: @@ -142,7 +139,7 @@ class OpenMoePolicy(Policy): return [8, 4] else: print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy") - return Policy.distribute_layers(num_layers, num_stages) + return super().distribute_layers(num_layers, num_stages) class OpenMoeModelPolicy(OpenMoePolicy): diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index 61558c003..52cb8c46e 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -83,7 +83,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ @parameterize("init_method", ["none", "lazy"]) def check_3d_plugin(init_method: str = "none", early_stop: bool = True): - """check gemini plugin over model zoo + """check hybrid plugin over model zoo Args: early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. @@ -260,7 +260,7 @@ def run_grad_acc_test(test_args): origin_model, origin_optimizer, dataloader=dataloader ) for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()): - assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) + assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2) def run_dist(rank, world_size, port, early_stop: bool = True): @@ -271,9 +271,9 @@ def run_dist(rank, world_size, port, early_stop: bool = True): @rerun_if_address_is_in_use() -def test_gemini_plugin(early_stop: bool = True): +def test_3d_plugin(early_stop: bool = True): spawn(run_dist, 4, early_stop=early_stop) if __name__ == "__main__": - test_gemini_plugin(early_stop=False) \ No newline at end of file + test_3d_plugin(early_stop=False) diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index 3723c9c10..4ba67225f 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -10,9 +10,12 @@ def test_t5_pipeline_distribution(): "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } + policy = T5BasePolicy() for i in range(num_test_cases): - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( - test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + _, decoder_starting_stage = policy.distribute_t5_layers( + test_dict["num_encoder_layers"][i], + test_dict["num_decoder_layers"][i], + test_dict["num_stages"][i], ) assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage @@ -32,14 +35,15 @@ def test_t5_pipeline_layers(): } for i in range(num_test_cases): - layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( - test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + policy = T5BasePolicy() + layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers( + test_dict["num_encoder_layers"][i], + test_dict["num_decoder_layers"][i], + test_dict["num_stages"][i], ) for stage in range(test_dict["num_stages"][i]): start_idx, end_idx = test_dict["layers_per_stage"][i][stage] - predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index( - layers_per_stage, stage, decoder_starting_stage - ) + predicted_start, predicted_end = policy.get_t5_stage_index(layers_per_stage, stage, decoder_starting_stage) assert start_idx == predicted_start assert end_idx == predicted_end diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index f6be8f6fe..0500e46e8 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -10,9 +10,12 @@ def test_whisper_pipeline_distribution(): "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } + policy = WhisperPolicy() for i in range(num_test_cases): - _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( - test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + _, decoder_starting_stage = policy.distribute_whisper_layers( + test_dict["num_encoder_layers"][i], + test_dict["num_decoder_layers"][i], + test_dict["num_stages"][i], ) assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage @@ -31,14 +34,17 @@ def test_whisper_pipeline_layers(): ], } + policy = WhisperPolicy() for i in range(num_test_cases): - layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( - test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers( + test_dict["num_encoder_layers"][i], + test_dict["num_decoder_layers"][i], + test_dict["num_stages"][i], ) for stage in range(test_dict["num_stages"][i]): start_idx, end_idx = test_dict["layers_per_stage"][i][stage] - predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index( + predicted_start, predicted_end = policy.get_whisper_stage_index( layers_per_stage, stage, decoder_starting_stage ) assert start_idx == predicted_start diff --git a/tests/test_shardformer/test_layer/test_dist_crossentropy.py b/tests/test_shardformer/test_layer/test_dist_crossentropy.py index f594a80a4..414157c22 100644 --- a/tests/test_shardformer/test_layer/test_dist_crossentropy.py +++ b/tests/test_shardformer/test_layer/test_dist_crossentropy.py @@ -38,9 +38,10 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index): org_loss, dist_loss, atol=1e-5 ), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" - target_grad = torch.chunk(pred.grad, world_size, dim=-1)[rank] - assert torch.allclose(target_grad, dist_pred.grad), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}" + assert torch.allclose( + target_grad, dist_pred.grad + ), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}" @pytest.mark.dist