[fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009)

pull/6024/head
Wang Binluo 2024-08-16 10:12:50 +08:00 committed by GitHub
parent 20722a8c93
commit 3f09a6145f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 1 deletions

View File

@ -215,6 +215,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_p2p: bool = True,
overlap_allgather: bool = False,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
if overlap_communication or zero_stage == 2:
overlap_communication = False
@ -324,7 +325,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
self.use_fp8 = use_fp8
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group,
@ -428,6 +429,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
use_fp8=self.use_fp8,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.ep_size > 1: