[fix] fix fp8 args in HybridParallel

pull/6107/head
duanjunwen 2024-11-01 03:54:08 +00:00
parent c82c75a9b4
commit 3b5c314bea
1 changed files with 0 additions and 5 deletions

View File

@ -722,8 +722,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
backward_context=model._hook_context,
fp8_communication=fp8_communication,
backward_context=model._hook_context,
)
def sync_dp_grads(self):
@ -1162,7 +1160,6 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_metadata_cache=enable_metadata_cache,
overlap_p2p=overlap_p2p,
fp8_communication=fp8_communication,
fp8_communication=fp8_communication,
)
elif pp_style == "1f1b":
self.scheduler = OneForwardOneBackwardSchedule(
@ -1213,7 +1210,6 @@ class HybridParallelPlugin(PipelinePluginBase):
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size,
pg_mesh=self.pg_mesh,
sp_axis=self.sp_axis,
@ -1247,7 +1243,6 @@ class HybridParallelPlugin(PipelinePluginBase):
forced_dtype=PRECISION_TORCH_TYPE[precision],
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
fp8_communication=fp8_communication,
)
self.max_norm = max_norm