mirror of https://github.com/hpcaitech/ColossalAI
[fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009)
parent
20722a8c93
commit
3f09a6145f
|
@ -215,6 +215,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
overlap_p2p: bool = True,
|
overlap_p2p: bool = True,
|
||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
|
use_fp8: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if overlap_communication or zero_stage == 2:
|
if overlap_communication or zero_stage == 2:
|
||||||
overlap_communication = False
|
overlap_communication = False
|
||||||
|
@ -324,7 +325,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||||
else:
|
else:
|
||||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||||
|
self.use_fp8 = use_fp8
|
||||||
self.shard_config = ShardConfig(
|
self.shard_config = ShardConfig(
|
||||||
tensor_parallel_process_group=self.tp_group,
|
tensor_parallel_process_group=self.tp_group,
|
||||||
sequence_parallel_process_group=self.sp_group,
|
sequence_parallel_process_group=self.sp_group,
|
||||||
|
@ -428,6 +429,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
use_ddp=use_ddp,
|
use_ddp=use_ddp,
|
||||||
ddp_config=self.ddp_config,
|
ddp_config=self.ddp_config,
|
||||||
custom_policy=self.custom_policy,
|
custom_policy=self.custom_policy,
|
||||||
|
use_fp8=self.use_fp8,
|
||||||
)
|
)
|
||||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
|
|
Loading…
Reference in New Issue