[shardformer/sequence parallel] not support opt of seq-parallel, add warning and fix a bug in gpt2 pp (#4488)

pull/4493/head
Bin Jia 1 year ago committed by GitHub
parent 5545114fd8
commit 351351a36e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -148,7 +148,7 @@ class GPT2PipelineForwards:
if token_type_ids is not None: if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids) token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states) hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)

@ -1,3 +1,4 @@
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List from typing import Callable, Dict, List
@ -39,6 +40,9 @@ class OPTPolicy(Policy):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
policy = {} policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[

Loading…
Cancel
Save