pull/5661/head
flybird11111 2024-04-26 11:52:27 +08:00 committed by GitHub
parent 1b387ca9fe
commit 8b7d535977
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 0 additions and 7 deletions

View File

@ -54,7 +54,6 @@ class GPTJPolicy(Policy):
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
@ -78,7 +77,6 @@ class GPTJPolicy(Policy):
suffix="attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
@ -86,7 +84,6 @@ class GPTJPolicy(Policy):
suffix="attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
@ -94,24 +91,20 @@ class GPTJPolicy(Policy):
suffix="attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="mlp.fc_in",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="mlp.fc_out",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",