|
|
|
@ -119,6 +119,7 @@ class Qwen2Policy(Policy):
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
suffix="self_attn.q_proj",
|
|
|
|
|
target_module=Linear1D_Col,
|
|
|
|
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
|
|
|
|
),
|
|
|
|
|
SubModuleReplacementDescription(
|
|
|
|
|
suffix="self_attn.k_proj",
|
|
|
|
@ -319,7 +320,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
|
|
|
|
setattr(self.shard_config, "causal_lm", True)
|
|
|
|
|
|
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
|
|
|
# add a new item for causal lm
|
|
|
|
|
# add a new item for casual lm
|
|
|
|
|
new_item = {
|
|
|
|
|
Qwen2ForCausalLM: ModulePolicyDescription(
|
|
|
|
|
sub_module_replacement=[
|
|
|
|
|