[fix] fix hybridparall use_fp8 config

pull/6107/head
duanjunwen 2024-11-01 05:27:11 +00:00
parent 3b5c314bea
commit 5b5fbcff09
1 changed files with 0 additions and 3 deletions

View File

@ -78,7 +78,6 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
self.require_grad_sync = True self.require_grad_sync = True
self.overlap_allgather = overlap_allgather self.overlap_allgather = overlap_allgather
self.use_fp8 = use_fp8 self.use_fp8 = use_fp8
self.use_fp8 = use_fp8
shardformer = ShardFormer(shard_config) shardformer = ShardFormer(shard_config)
if custom_policy is not None: if custom_policy is not None:
@ -1099,7 +1098,6 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_jit_fused = enable_jit_fused self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism self.enable_sequence_parallelism = enable_sequence_parallelism
self.use_fp8 = use_fp8 self.use_fp8 = use_fp8
self.use_fp8 = use_fp8
if dp_outside: if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
@ -1325,7 +1323,6 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy=self.custom_policy, custom_policy=self.custom_policy,
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
use_fp8=self.use_fp8, use_fp8=self.use_fp8,
use_fp8=self.use_fp8,
) )
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0: if zero_stage == 0: