From 6aface9316e77349eeb23362eb562f07357d7252 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 21 Aug 2024 03:51:25 +0000 Subject: [PATCH] fix --- colossalai/shardformer/policies/llama.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 6c5e2c2ea..429eec52f 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -391,7 +391,12 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): LlamaForSequenceClassification: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="score", + target_module=Linear1D_Col, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + ), ) ] )