|
|
|
@ -126,37 +126,65 @@ class LlamaPolicy(Policy):
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
suffix="self_attn.q_proj",
|
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
|
|
|
|
kwargs=dict(
|
|
|
|
|
seq_parallel_mode=sp_mode,
|
|
|
|
|
fp8_communication=self.shard_config.fp8_communication,
|
|
|
|
|
use_zbv=self.shard_config.use_zbv,
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
suffix="self_attn.k_proj",
|
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
|
|
|
|
kwargs=dict(
|
|
|
|
|
seq_parallel_mode=sp_mode,
|
|
|
|
|
fp8_communication=self.shard_config.fp8_communication,
|
|
|
|
|
use_zbv=self.shard_config.use_zbv,
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
suffix="self_attn.v_proj",
|
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
|
|
|
|
kwargs=dict(
|
|
|
|
|
seq_parallel_mode=sp_mode,
|
|
|
|
|
fp8_communication=self.shard_config.fp8_communication,
|
|
|
|
|
use_zbv=self.shard_config.use_zbv,
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
suffix="self_attn.o_proj",
|
|
|
|
|
target_module=Linear1D_Row,
|
|
|
|
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
|
|
|
|
kwargs=dict(
|
|
|
|
|
seq_parallel_mode=sp_mode,
|
|
|
|
|
fp8_communication=self.shard_config.fp8_communication,
|
|
|
|
|
use_zbv=self.shard_config.use_zbv,
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
suffix="mlp.gate_proj",
|
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
|
|
|
|
kwargs=dict(
|
|
|
|
|
seq_parallel_mode=sp_mode,
|
|
|
|
|
fp8_communication=self.shard_config.fp8_communication,
|
|
|
|
|
use_zbv=self.shard_config.use_zbv,
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
suffix="mlp.up_proj",
|
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
|
|
|
|
kwargs=dict(
|
|
|
|
|
seq_parallel_mode=sp_mode,
|
|
|
|
|
fp8_communication=self.shard_config.fp8_communication,
|
|
|
|
|
use_zbv=self.shard_config.use_zbv,
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
suffix="mlp.down_proj",
|
|
|
|
|
target_module=Linear1D_Row,
|
|
|
|
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
|
|
|
|
kwargs=dict(
|
|
|
|
|
seq_parallel_mode=sp_mode,
|
|
|
|
|
fp8_communication=self.shard_config.fp8_communication,
|
|
|
|
|
use_zbv=self.shard_config.use_zbv,
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|