|
|
|
@ -69,13 +69,13 @@ class WhisperPolicy(Policy):
|
|
|
|
|
if self.shard_config.enable_sequence_parallelism:
|
|
|
|
|
self.shard_config.enable_sequence_parallelism = False
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
|
|
|
|
|
"Whisper doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# TODO using the jit fused add_and_dropout affect the accuracy
|
|
|
|
|
if self.shard_config.enable_jit_fused:
|
|
|
|
|
self.shard_config.enable_jit_fused = False
|
|
|
|
|
warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.")
|
|
|
|
|
warnings.warn("Whisper doesn't support jit fused operator now, will ignore the jit fused operator flag.")
|
|
|
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
|
policy[WhisperEncoderLayer] = ModulePolicyDescription(
|
|
|
|
@ -302,7 +302,7 @@ class WhisperPolicy(Policy):
|
|
|
|
|
if num_decoder_layers == 0:
|
|
|
|
|
return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
|
|
|
|
|
|
|
|
|
|
# the number of stages distributed between encoder and decoder is optmized in this way:
|
|
|
|
|
# the number of stages distributed between encoder and decoder is optimized in this way:
|
|
|
|
|
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
|
|
|
|
|
# s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1
|
|
|
|
|
def objective(num_encoder_stages):
|
|
|
|
|