[fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009)

pull/6024/head
Wang Binluo 2024-08-16 10:12:50 +08:00 committed by GitHub
parent 20722a8c93
commit 3f09a6145f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 1 deletions

View File

@ -215,6 +215,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_p2p: bool = True, overlap_p2p: bool = True,
overlap_allgather: bool = False, overlap_allgather: bool = False,
fp8_communication: bool = False, fp8_communication: bool = False,
use_fp8: bool = False,
) -> None: ) -> None:
if overlap_communication or zero_stage == 2: if overlap_communication or zero_stage == 2:
overlap_communication = False overlap_communication = False
@ -324,7 +325,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else: else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
self.use_fp8 = use_fp8
self.shard_config = ShardConfig( self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group, tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group, sequence_parallel_process_group=self.sp_group,
@ -428,6 +429,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
use_ddp=use_ddp, use_ddp=use_ddp,
ddp_config=self.ddp_config, ddp_config=self.ddp_config,
custom_policy=self.custom_policy, custom_policy=self.custom_policy,
use_fp8=self.use_fp8,
) )
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.ep_size > 1: if self.ep_size > 1: