mirror of https://github.com/hpcaitech/ColossalAI
[fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009)
parent
20722a8c93
commit
3f09a6145f
|
@ -215,6 +215,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
overlap_p2p: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
if overlap_communication or zero_stage == 2:
|
||||
overlap_communication = False
|
||||
|
@ -324,7 +325,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
else:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||
|
||||
self.use_fp8 = use_fp8
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
sequence_parallel_process_group=self.sp_group,
|
||||
|
@ -428,6 +429,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
use_ddp=use_ddp,
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
use_fp8=self.use_fp8,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.ep_size > 1:
|
||||
|
|
Loading…
Reference in New Issue