mirror of https://github.com/hpcaitech/ColossalAI
fix gptj (#5652)
parent
1b387ca9fe
commit
8b7d535977
|
@ -54,7 +54,6 @@ class GPTJPolicy(Policy):
|
|||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
|
@ -78,7 +77,6 @@ class GPTJPolicy(Policy):
|
|||
suffix="attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"seq_parallel": use_sequence_parallel,
|
||||
"overlap": overlap,
|
||||
},
|
||||
),
|
||||
|
@ -86,7 +84,6 @@ class GPTJPolicy(Policy):
|
|||
suffix="attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"seq_parallel": use_sequence_parallel,
|
||||
"overlap": overlap,
|
||||
},
|
||||
),
|
||||
|
@ -94,24 +91,20 @@ class GPTJPolicy(Policy):
|
|||
suffix="attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"seq_parallel": use_sequence_parallel,
|
||||
"overlap": overlap,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc_in",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc_out",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.attn_dropout",
|
||||
|
|
Loading…
Reference in New Issue