[shardformer] fix pipeline forward error if custom layer distribution is used (#5189)

* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution

* Change static methods for t5 layer distribution to member functions

* Change static methods for whisper layer distribution to member functions

* Replace whisper policy usage with self one

* Fix test case to use non-static layer distribution methods

* fix: fix typo

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>
pull/4309/merge
Insu Jang 2024-03-27 01:57:00 -04:00 committed by GitHub
parent e6707a6e8d
commit 00525f7772
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 136 additions and 106 deletions

View File

@ -110,7 +110,7 @@ class MixtralPolicy(Policy):
module = self.model.model module = self.model.model
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls description=method_replacement, policy=policy, target_key=model_cls

View File

@ -197,8 +197,7 @@ class Policy(ABC):
""" """
return [] return []
@staticmethod def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
"""Divide layers into stages""" """Divide layers into stages"""
quotient = num_layers // num_stages quotient = num_layers // num_stages
remainder = num_layers % num_stages remainder = num_layers % num_stages
@ -213,8 +212,8 @@ class Policy(ABC):
layers_per_stage[i] += 1 layers_per_stage[i] += 1
return layers_per_stage return layers_per_stage
@staticmethod
def get_stage_index( def get_stage_index(
self,
layers_per_stage: List[int], layers_per_stage: List[int],
stage: int, stage: int,
num_model_chunks: int = 1, num_model_chunks: int = 1,

View File

@ -84,17 +84,26 @@ class BertPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.query", suffix="attention.self.query",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.key", suffix="attention.self.key",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.value", suffix="attention.self.value",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.dropout", suffix="attention.self.dropout",
@ -112,7 +121,10 @@ class BertPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="intermediate.dense", suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.dense", suffix="output.dense",
@ -214,7 +226,9 @@ class BertPolicy(Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
), ),
policy=base_policy, policy=base_policy,
target_key=BertLMPredictionHead, target_key=BertLMPredictionHead,
@ -241,7 +255,9 @@ class BertPolicy(Policy):
"_load_from_state_dict": col_nn.ParallelModule._load_from_state_dict, "_load_from_state_dict": col_nn.ParallelModule._load_from_state_dict,
} }
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description=method_replacement, policy=base_policy, target_key=BertLMPredictionHead description=method_replacement,
policy=base_policy,
target_key=BertLMPredictionHead,
) )
return base_policy return base_policy
@ -264,24 +280,32 @@ class BertPolicy(Policy):
if stage_manager.is_interleave: if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers( layers_per_stage = self.distribute_layers(
len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks len(module.encoder.layer),
stage_manager.num_stages * stage_manager.num_model_chunks,
) )
stage_manager.stage_indices = Policy.get_stage_index( stage_manager.stage_indices = self.get_stage_index(
layers_per_stage, layers_per_stage,
stage_manager.stage, stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks, num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages, num_stages=stage_manager.num_stages,
) )
method_replacement = { method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) "forward": partial(
new_forward,
stage_manager=stage_manager,
shard_config=self.shard_config,
)
} }
else: else:
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config,
) )
} }
@ -301,9 +325,10 @@ class BertPolicy(Policy):
if stage_manager.is_interleave: if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers( layers_per_stage = self.distribute_layers(
len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks len(module.encoder.layer),
stage_manager.num_stages * stage_manager.num_model_chunks,
) )
stage_indices = Policy.get_stage_index( stage_indices = self.get_stage_index(
layers_per_stage, layers_per_stage,
stage_manager.stage, stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks, num_model_chunks=stage_manager.num_model_chunks,
@ -320,7 +345,7 @@ class BertPolicy(Policy):
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.embeddings) held_layers.append(module.embeddings)
start_idx, end_idx = Policy.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx]) held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
held_layers.append(module.pooler) held_layers.append(module.pooler)
@ -336,7 +361,9 @@ class BertModelPolicy(BertPolicy):
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=BertModel, new_forward=BertPipelineForwards.bert_model_forward, policy=policy model_cls=BertModel,
new_forward=BertPipelineForwards.bert_model_forward,
policy=policy,
) )
return policy return policy
@ -399,7 +426,9 @@ class BertLMHeadModelPolicy(BertPolicy):
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=BertLMHeadModel, new_forward=BertPipelineForwards.bert_lm_head_model_forward, policy=policy model_cls=BertLMHeadModel,
new_forward=BertPipelineForwards.bert_lm_head_model_forward,
policy=policy,
) )
return policy return policy
@ -437,7 +466,9 @@ class BertForMaskedLMPolicy(BertPolicy):
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=BertForMaskedLM, new_forward=BertPipelineForwards.bert_for_masked_lm_forward, policy=policy model_cls=BertForMaskedLM,
new_forward=BertPipelineForwards.bert_for_masked_lm_forward,
policy=policy,
) )
return policy return policy

View File

@ -203,8 +203,8 @@ class BloomPolicy(Policy):
else: else:
module = self.model.transformer module = self.model.transformer
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config

View File

@ -204,8 +204,8 @@ class ChatGLMPolicy(Policy):
else: else:
module = self.model.transformer module = self.model.transformer
layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config

View File

@ -161,8 +161,8 @@ class FalconPolicy(Policy):
else: else:
module = self.model.transformer module = self.model.transformer
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config

View File

@ -188,7 +188,7 @@ class GPT2Policy(Policy):
layers_per_stage = self.distribute_layers( layers_per_stage = self.distribute_layers(
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
) )
stage_indices = Policy.get_stage_index( stage_indices = self.get_stage_index(
layers_per_stage, layers_per_stage,
stage_manager.stage, stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks, num_model_chunks=stage_manager.num_model_chunks,
@ -229,7 +229,7 @@ class GPT2Policy(Policy):
layers_per_stage = self.distribute_layers( layers_per_stage = self.distribute_layers(
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
) )
stage_manager.stage_indices = Policy.get_stage_index( stage_manager.stage_indices = self.get_stage_index(
layers_per_stage, layers_per_stage,
stage_manager.stage, stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks, num_model_chunks=stage_manager.num_model_chunks,
@ -243,8 +243,8 @@ class GPT2Policy(Policy):
) )
} }
else: else:
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, new_forward,

View File

@ -200,8 +200,8 @@ class GPTJPolicy(Policy):
else: else:
module = self.model.transformer module = self.model.transformer
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, new_forward,

View File

@ -167,7 +167,7 @@ class LlamaPolicy(Policy):
layers_per_stage = self.distribute_layers( layers_per_stage = self.distribute_layers(
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
) )
stage_manager.stage_indices = Policy.get_stage_index( stage_manager.stage_indices = self.get_stage_index(
layers_per_stage, layers_per_stage,
stage_manager.stage, stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks, num_model_chunks=stage_manager.num_model_chunks,
@ -178,8 +178,8 @@ class LlamaPolicy(Policy):
} }
else: else:
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
@ -207,7 +207,7 @@ class LlamaPolicy(Policy):
layers_per_stage = self.distribute_layers( layers_per_stage = self.distribute_layers(
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
) )
stage_indices = Policy.get_stage_index( stage_indices = self.get_stage_index(
layers_per_stage, layers_per_stage,
stage_manager.stage, stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks, num_model_chunks=stage_manager.num_model_chunks,

View File

@ -208,8 +208,8 @@ class OPTPolicy(Policy):
else: else:
module = self.model.model.decoder module = self.model.model.decoder
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, new_forward,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import warnings import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List, Tuple from typing import Callable, Dict, List, Tuple
@ -241,9 +243,8 @@ class T5BasePolicy(Policy):
def postprocess(self): def postprocess(self):
return self.model return self.model
@staticmethod
def distribute_t5_layers( def distribute_t5_layers(
num_encoder_layers: int, num_decoder_layers: int, num_stages: int self, num_encoder_layers: int, num_decoder_layers: int, num_stages: int
) -> Tuple[List[int], int]: ) -> Tuple[List[int], int]:
""" """
Distribute t5 layers into stages when pipeline parallel is used. Distribute t5 layers into stages when pipeline parallel is used.
@ -261,7 +262,7 @@ class T5BasePolicy(Policy):
# in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist # in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0: if num_decoder_layers == 0:
return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages return self.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optimized in this way: # the number of stages distributed between encoder and decoder is optimized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
@ -272,22 +273,21 @@ class T5BasePolicy(Policy):
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages return encoder_distribution + decoder_distribution, num_encoder_stages
@staticmethod
def get_t5_stage_index( def get_t5_stage_index(
layers_per_stage: List[int], stage: int, decoder_starting_stage: int self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
) -> Tuple[bool, int, int]: ) -> Tuple[bool, int, int]:
""" """
Input the distribution of layers among stages, the current stage and the first stage of decoder. Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder Return the starting/ending idx of layers in encoder/decoder
""" """
if stage < decoder_starting_stage: if stage < decoder_starting_stage:
return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else: else:
return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
def get_held_layers(self) -> List[nn.Module]: def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
@ -302,12 +302,10 @@ class T5BasePolicy(Policy):
num_decoder_layers = len(decoder.block) if decoder else 0 num_decoder_layers = len(decoder.block) if decoder else 0
held_layers = [] held_layers = []
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( layers_per_stage, decoder_starting_stage = self.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages num_encoder_layers, num_decoder_layers, stage_manager.num_stages
) )
start_idx, end_idx = T5BasePolicy.get_t5_stage_index( start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
layers_per_stage, stage_manager.stage, decoder_starting_stage
)
if stage_manager.stage < decoder_starting_stage: if stage_manager.stage < decoder_starting_stage:
# current stage is in t5's encoder # current stage is in t5's encoder
@ -343,10 +341,10 @@ class T5BasePolicy(Policy):
num_encoder_layers = len(encoder.block) num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0 num_decoder_layers = len(decoder.block) if decoder else 0
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( layers_per_stage, decoder_starting_stage = self.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages num_encoder_layers, num_decoder_layers, stage_manager.num_stages
) )
stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) stage_index = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
@ -386,7 +384,7 @@ class T5ModelPolicy(T5BasePolicy):
module = self.model module = self.model
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1: if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( _, decoder_starting_stage = self.distribute_t5_layers(
len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages
) )
@ -434,7 +432,7 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
module = self.model module = self.model
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1: if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( _, decoder_starting_stage = self.distribute_t5_layers(
len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages
) )

View File

@ -149,8 +149,8 @@ class ViTPolicy(Policy):
else: else:
module = self.model.vit module = self.model.vit
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls description=method_replacement, policy=policy, target_key=model_cls

View File

@ -292,9 +292,8 @@ class WhisperPolicy(Policy):
def postprocess(self): def postprocess(self):
return self.model return self.model
@staticmethod
def distribute_whisper_layers( def distribute_whisper_layers(
num_encoder_layers: int, num_decoder_layers: int, num_stages: int self, num_encoder_layers: int, num_decoder_layers: int, num_stages: int
) -> Tuple[List[int], int]: ) -> Tuple[List[int], int]:
""" """
Distribute whisper layers into stages when pipeline parallel is used. Distribute whisper layers into stages when pipeline parallel is used.
@ -312,7 +311,7 @@ class WhisperPolicy(Policy):
# in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0: if num_decoder_layers == 0:
return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages return self.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optimized in this way: # the number of stages distributed between encoder and decoder is optimized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
@ -323,22 +322,21 @@ class WhisperPolicy(Policy):
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages return encoder_distribution + decoder_distribution, num_encoder_stages
@staticmethod
def get_whisper_stage_index( def get_whisper_stage_index(
layers_per_stage: List[int], stage: int, decoder_starting_stage: int self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
) -> Tuple[bool, int, int]: ) -> Tuple[bool, int, int]:
""" """
Input the distribution of layers among stages, the current stage and the first stage of decoder. Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder Return the starting/ending idx of layers in encoder/decoder
""" """
if stage < decoder_starting_stage: if stage < decoder_starting_stage:
return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else: else:
return Policy.get_stage_index( return self.get_stage_index(
layers_per_stage[decoder_starting_stage:], layers_per_stage[decoder_starting_stage:],
stage - decoder_starting_stage, stage - decoder_starting_stage,
) )
@ -369,12 +367,10 @@ class WhisperPolicy(Policy):
num_decoder_layers = 0 num_decoder_layers = 0
held_layers = [] held_layers = []
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages num_encoder_layers, num_decoder_layers, stage_manager.num_stages
) )
start_idx, end_idx = WhisperPolicy.get_whisper_stage_index( start_idx, end_idx = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
layers_per_stage, stage_manager.stage, decoder_starting_stage
)
if stage_manager.stage < decoder_starting_stage: if stage_manager.stage < decoder_starting_stage:
# current stage is in whisper's encoder # current stage is in whisper's encoder
@ -424,12 +420,10 @@ class WhisperPolicy(Policy):
else: else:
num_decoder_layers = 0 num_decoder_layers = 0
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages num_encoder_layers, num_decoder_layers, stage_manager.num_stages
) )
stage_index = WhisperPolicy.get_whisper_stage_index( stage_index = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
layers_per_stage, stage_manager.stage, decoder_starting_stage
)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
@ -511,7 +505,7 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1: if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( _, decoder_starting_stage = self.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages num_encoder_layers, num_decoder_layers, stage_manager.num_stages
) )
shared_params = [] shared_params = []

View File

@ -98,11 +98,11 @@ class OpenMoePolicy(Policy):
module = self.model.model module = self.model.model
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(description=method_replacement, self.append_or_create_method_replacement(
policy=policy, description=method_replacement, policy=policy, target_key=model_cls
target_key=model_cls) )
return return
@ -127,11 +127,8 @@ class OpenMoePolicy(Policy):
return held_layers return held_layers
@staticmethod def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
def distribute_layers(num_layers: int, num_stages: int) -> List[int]: """Divide layers into stages"""
"""Divide layers into stages
"""
if num_layers == 24 and num_stages == 4: if num_layers == 24 and num_stages == 4:
return [7, 7, 7, 3] return [7, 7, 7, 3]
elif num_layers == 24 and num_stages == 2: elif num_layers == 24 and num_stages == 2:
@ -142,7 +139,7 @@ class OpenMoePolicy(Policy):
return [8, 4] return [8, 4]
else: else:
print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy") print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy")
return Policy.distribute_layers(num_layers, num_stages) return super().distribute_layers(num_layers, num_stages)
class OpenMoeModelPolicy(OpenMoePolicy): class OpenMoeModelPolicy(OpenMoePolicy):

View File

@ -83,7 +83,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
@parameterize("init_method", ["none", "lazy"]) @parameterize("init_method", ["none", "lazy"])
def check_3d_plugin(init_method: str = "none", early_stop: bool = True): def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
"""check gemini plugin over model zoo """check hybrid plugin over model zoo
Args: Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
@ -271,9 +271,9 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True): def test_3d_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop) spawn(run_dist, 4, early_stop=early_stop)
if __name__ == "__main__": if __name__ == "__main__":
test_gemini_plugin(early_stop=False) test_3d_plugin(early_stop=False)

View File

@ -10,9 +10,12 @@ def test_t5_pipeline_distribution():
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
} }
policy = T5BasePolicy()
for i in range(num_test_cases): for i in range(num_test_cases):
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( _, decoder_starting_stage = policy.distribute_t5_layers(
test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] test_dict["num_encoder_layers"][i],
test_dict["num_decoder_layers"][i],
test_dict["num_stages"][i],
) )
assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage
@ -32,14 +35,15 @@ def test_t5_pipeline_layers():
} }
for i in range(num_test_cases): for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( policy = T5BasePolicy()
test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers(
test_dict["num_encoder_layers"][i],
test_dict["num_decoder_layers"][i],
test_dict["num_stages"][i],
) )
for stage in range(test_dict["num_stages"][i]): for stage in range(test_dict["num_stages"][i]):
start_idx, end_idx = test_dict["layers_per_stage"][i][stage] start_idx, end_idx = test_dict["layers_per_stage"][i][stage]
predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index( predicted_start, predicted_end = policy.get_t5_stage_index(layers_per_stage, stage, decoder_starting_stage)
layers_per_stage, stage, decoder_starting_stage
)
assert start_idx == predicted_start assert start_idx == predicted_start
assert end_idx == predicted_end assert end_idx == predicted_end

View File

@ -10,9 +10,12 @@ def test_whisper_pipeline_distribution():
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
} }
policy = WhisperPolicy()
for i in range(num_test_cases): for i in range(num_test_cases):
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( _, decoder_starting_stage = policy.distribute_whisper_layers(
test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] test_dict["num_encoder_layers"][i],
test_dict["num_decoder_layers"][i],
test_dict["num_stages"][i],
) )
assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage
@ -31,14 +34,17 @@ def test_whisper_pipeline_layers():
], ],
} }
policy = WhisperPolicy()
for i in range(num_test_cases): for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers(
test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] test_dict["num_encoder_layers"][i],
test_dict["num_decoder_layers"][i],
test_dict["num_stages"][i],
) )
for stage in range(test_dict["num_stages"][i]): for stage in range(test_dict["num_stages"][i]):
start_idx, end_idx = test_dict["layers_per_stage"][i][stage] start_idx, end_idx = test_dict["layers_per_stage"][i][stage]
predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index( predicted_start, predicted_end = policy.get_whisper_stage_index(
layers_per_stage, stage, decoder_starting_stage layers_per_stage, stage, decoder_starting_stage
) )
assert start_idx == predicted_start assert start_idx == predicted_start

View File

@ -38,9 +38,10 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index):
org_loss, dist_loss, atol=1e-5 org_loss, dist_loss, atol=1e-5
), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" ), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"
target_grad = torch.chunk(pred.grad, world_size, dim=-1)[rank] target_grad = torch.chunk(pred.grad, world_size, dim=-1)[rank]
assert torch.allclose(target_grad, dist_pred.grad), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}" assert torch.allclose(
target_grad, dist_pred.grad
), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}"
@pytest.mark.dist @pytest.mark.dist