diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 6c5e2c2ea..429eec52f 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -391,7 +391,12 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): LlamaForSequenceClassification: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + suffix="score", + target_module=Linear1D_Col, + kwargs=dict( + gather_output=True, + fp8_communication=self.shard_config.fp8_communication, + ), ) ] )