From 351351a36eb9d11e5bdb3610b0d3705055d90e7d Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Tue, 22 Aug 2023 17:35:35 +0800 Subject: [PATCH] [shardformer/sequence parallel] not support opt of seq-parallel, add warning and fix a bug in gpt2 pp (#4488) --- colossalai/shardformer/modeling/gpt2.py | 2 +- colossalai/shardformer/policies/opt.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 722f0f523..8ed367b25 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -148,7 +148,7 @@ class GPT2PipelineForwards: if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds - hidden_states = self.drop(hidden_states) + hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index ba6036bd0..58663553b 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Callable, Dict, List @@ -39,6 +40,9 @@ class OPTPolicy(Policy): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer policy = {} + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[