pull/6023/head
wangbluo 2024-08-20 03:20:13 +00:00
parent 1f703e0ef4
commit 53823118f2
5 changed files with 42 additions and 38 deletions

View File

@ -1278,31 +1278,31 @@ class HybridParallelPlugin(PipelinePluginBase):
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
use_fp8=self.use_fp8, use_fp8=self.use_fp8,
) )
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0: if zero_stage == 0:
is_zero = False is_zero = False
if self.precision in ["fp16", "bf16"]: if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer( optimizer = HybridParallelAMPOptimizer(
optimizer, optimizer,
model, model,
use_pipeline=self.enable_pipeline_parallelism, use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info, param_info=param_info,
precision=self.precision, precision=self.precision,
max_norm=self.max_norm, max_norm=self.max_norm,
pp_process_group=self.pp_group, pp_process_group=self.pp_group,
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
**self.amp_config, **self.amp_config,
) )
else: else:
optimizer = HybridParallelNaiveOptimizer( optimizer = HybridParallelNaiveOptimizer(
optimizer, optimizer,
model, model,
use_pipeline=self.enable_pipeline_parallelism, use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info, param_info=param_info,
max_norm=self.max_norm, max_norm=self.max_norm,
pp_process_group=self.pp_group, pp_process_group=self.pp_group,
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
) )
else: else:
is_zero = self.dp_size > 1 is_zero = self.dp_size > 1
if self.dp_size == 1: if self.dp_size == 1:

View File

@ -90,13 +90,14 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
if self.dtype is not None and cast_inputs: if self.dtype is not None and cast_inputs:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
if overlap_allgather: if overlap_allgather:
self.op_hook = ZeroOpHook() self.op_hooks.append(ZeroOpHook())
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather or use_fp8:
for p in module.parameters(): for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter: if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter p.__class__ = ColoParameter
p.__init__(p, requires_grad=True) p.__init__(p, requires_grad=True)
if use_fp8:
self.op_hooks.append(FP8Hook())
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
if self.convert_fn is not None: if self.convert_fn is not None:
@ -348,9 +349,9 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload: bool = False, cpu_offload: bool = False,
master_weights: bool = True, master_weights: bool = True,
verbose: bool = False, verbose: bool = False,
cast_inputs: bool = True,
fp8_communication: bool = False, fp8_communication: bool = False,
use_fp8: bool = False, use_fp8: bool = False,
cast_inputs: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
@ -497,8 +498,8 @@ class LowLevelZeroPlugin(DPPluginBase):
model, model,
self.precision, self.precision,
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
use_fp8=self.use_fp8,
cast_inputs=self.cast_inputs, cast_inputs=self.cast_inputs,
use_fp8=self.use_fp8,
) )
# TODO: Support Galore + ZeRO # TODO: Support Galore + ZeRO

View File

@ -1082,8 +1082,8 @@ def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, f
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication) return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
def reduce_forward(input_, process_group, fp8_communication=False): def reduce_forward(input_, process_group, grad_scale=None, fp8_communication=False):
return _ReduceForward.apply(input_, process_group, fp8_communication) return _ReduceForward.apply(input_, process_group, grad_scale, fp8_communication)
def reduce_backward(input_, process_group, fp8_communication=False): def reduce_backward(input_, process_group, fp8_communication=False):

View File

@ -698,7 +698,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
attention_mask = ColoAttention.prepare_attn_kwargs( attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
mask_shape, mask_shape,
inputs_embeds.dtype, inputs_embeds.dtype,
inputs_embeds.device, inputs_embeds.device,

View File

@ -144,10 +144,14 @@ class MixtralPolicy(Policy):
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="embed_tokens", suffix="embed_tokens",
target_module=embedding_cls, target_module=embedding_cls,
kwargs={ kwargs=(
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, {
"fp8_communication": self.shard_config.fp8_communication, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
}, "fp8_communication": self.shard_config.fp8_communication,
}
if self.shard_config.enable_tensor_parallelism
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
),
), ),
policy=policy, policy=policy,
target_key=MixtralModel, target_key=MixtralModel,
@ -164,7 +168,6 @@ class MixtralPolicy(Policy):
"ep_group": self.shard_config.ep_group, "ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group, "tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group, "moe_dp_group": self.shard_config.moe_dp_group,
"fp8_communication": self.shard_config.fp8_communication,
}, },
) )
], ],