pull/6023/head
wangbluo 2024-08-20 03:20:13 +00:00
parent 1f703e0ef4
commit 53823118f2
5 changed files with 42 additions and 38 deletions

View File

@ -1278,31 +1278,31 @@ class HybridParallelPlugin(PipelinePluginBase):
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
use_fp8=self.use_fp8,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0:
is_zero = False
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
pp_process_group=self.pp_group,
tp_process_group=self.tp_group,
**self.amp_config,
)
else:
optimizer = HybridParallelNaiveOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
max_norm=self.max_norm,
pp_process_group=self.pp_group,
tp_process_group=self.tp_group,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0:
is_zero = False
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
pp_process_group=self.pp_group,
tp_process_group=self.tp_group,
**self.amp_config,
)
else:
optimizer = HybridParallelNaiveOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
max_norm=self.max_norm,
pp_process_group=self.pp_group,
tp_process_group=self.tp_group,
)
else:
is_zero = self.dp_size > 1
if self.dp_size == 1:

View File

@ -90,13 +90,14 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
if self.dtype is not None and cast_inputs:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
if overlap_allgather:
self.op_hook = ZeroOpHook()
self.op_hooks.append(ZeroOpHook())
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather or use_fp8:
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
p.__init__(p, requires_grad=True)
if use_fp8:
self.op_hooks.append(FP8Hook())
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
@ -348,9 +349,9 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
cast_inputs: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
cast_inputs: bool = True,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
@ -497,8 +498,8 @@ class LowLevelZeroPlugin(DPPluginBase):
model,
self.precision,
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
use_fp8=self.use_fp8,
cast_inputs=self.cast_inputs,
use_fp8=self.use_fp8,
)
# TODO: Support Galore + ZeRO

View File

@ -1082,8 +1082,8 @@ def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, f
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
def reduce_forward(input_, process_group, fp8_communication=False):
return _ReduceForward.apply(input_, process_group, fp8_communication)
def reduce_forward(input_, process_group, grad_scale=None, fp8_communication=False):
return _ReduceForward.apply(input_, process_group, grad_scale, fp8_communication)
def reduce_backward(input_, process_group, fp8_communication=False):

View File

@ -698,7 +698,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
if shard_config.enable_flash_attention:
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
attention_mask = ColoAttention.prepare_attn_kwargs(
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
mask_shape,
inputs_embeds.dtype,
inputs_embeds.device,

View File

@ -144,10 +144,14 @@ class MixtralPolicy(Policy):
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
},
kwargs=(
{
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
"fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
),
policy=policy,
target_key=MixtralModel,
@ -164,7 +168,6 @@ class MixtralPolicy(Policy):
"ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group,
"fp8_communication": self.shard_config.fp8_communication,
},
)
],