pull/6023/head
wangbluo 2024-08-19 09:02:16 +00:00
parent 0d8e82a024
commit 12b44012d9
1 changed files with 25 additions and 24 deletions

View File

@ -1278,30 +1278,31 @@ class HybridParallelPlugin(PipelinePluginBase):
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
use_fp8=self.use_fp8, use_fp8=self.use_fp8,
) )
if zero_stage == 0: if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
is_zero = False if zero_stage == 0:
if self.precision in ["fp16", "bf16"]: is_zero = False
optimizer = HybridParallelAMPOptimizer( if self.precision in ["fp16", "bf16"]:
optimizer, optimizer = HybridParallelAMPOptimizer(
model, optimizer,
use_pipeline=self.enable_pipeline_parallelism, model,
param_info=param_info, use_pipeline=self.enable_pipeline_parallelism,
precision=self.precision, param_info=param_info,
max_norm=self.max_norm, precision=self.precision,
pp_process_group=self.pp_group, max_norm=self.max_norm,
tp_process_group=self.tp_group, pp_process_group=self.pp_group,
**self.amp_config, tp_process_group=self.tp_group,
) **self.amp_config,
else: )
optimizer = HybridParallelNaiveOptimizer( else:
optimizer, optimizer = HybridParallelNaiveOptimizer(
model, optimizer,
use_pipeline=self.enable_pipeline_parallelism, model,
param_info=param_info, use_pipeline=self.enable_pipeline_parallelism,
max_norm=self.max_norm, param_info=param_info,
pp_process_group=self.pp_group, max_norm=self.max_norm,
tp_process_group=self.tp_group, pp_process_group=self.pp_group,
) tp_process_group=self.tp_group,
)
else: else:
is_zero = self.dp_size > 1 is_zero = self.dp_size > 1
if self.dp_size == 1: if self.dp_size == 1: