mirror of https://github.com/hpcaitech/ColossalAI
fix gptj (#5652)
parent
1b387ca9fe
commit
8b7d535977
|
@ -54,7 +54,6 @@ class GPTJPolicy(Policy):
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
self.shard_config.enable_sequence_parallelism = False
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
|
||||||
|
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
@ -78,7 +77,6 @@ class GPTJPolicy(Policy):
|
||||||
suffix="attn.k_proj",
|
suffix="attn.k_proj",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel": use_sequence_parallel,
|
|
||||||
"overlap": overlap,
|
"overlap": overlap,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@ -86,7 +84,6 @@ class GPTJPolicy(Policy):
|
||||||
suffix="attn.q_proj",
|
suffix="attn.q_proj",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel": use_sequence_parallel,
|
|
||||||
"overlap": overlap,
|
"overlap": overlap,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@ -94,24 +91,20 @@ class GPTJPolicy(Policy):
|
||||||
suffix="attn.v_proj",
|
suffix="attn.v_proj",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel": use_sequence_parallel,
|
|
||||||
"overlap": overlap,
|
"overlap": overlap,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attn.out_proj",
|
suffix="attn.out_proj",
|
||||||
target_module=col_nn.Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel},
|
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.fc_in",
|
suffix="mlp.fc_in",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel},
|
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.fc_out",
|
suffix="mlp.fc_out",
|
||||||
target_module=col_nn.Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel},
|
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attn.attn_dropout",
|
suffix="attn.attn_dropout",
|
||||||
|
|
Loading…
Reference in New Issue