mirror of https://github.com/hpcaitech/ColossalAI
fix
parent
1f703e0ef4
commit
53823118f2
|
@ -1278,31 +1278,31 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
|
||||
use_fp8=self.use_fp8,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if zero_stage == 0:
|
||||
is_zero = False
|
||||
if self.precision in ["fp16", "bf16"]:
|
||||
optimizer = HybridParallelAMPOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
pp_process_group=self.pp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
**self.amp_config,
|
||||
)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
max_norm=self.max_norm,
|
||||
pp_process_group=self.pp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if zero_stage == 0:
|
||||
is_zero = False
|
||||
if self.precision in ["fp16", "bf16"]:
|
||||
optimizer = HybridParallelAMPOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
pp_process_group=self.pp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
**self.amp_config,
|
||||
)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
max_norm=self.max_norm,
|
||||
pp_process_group=self.pp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
)
|
||||
else:
|
||||
is_zero = self.dp_size > 1
|
||||
if self.dp_size == 1:
|
||||
|
|
|
@ -90,13 +90,14 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|||
if self.dtype is not None and cast_inputs:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
if overlap_allgather:
|
||||
self.op_hook = ZeroOpHook()
|
||||
self.op_hooks.append(ZeroOpHook())
|
||||
if use_fp8:
|
||||
self.op_hooks.append(FP8Hook())
|
||||
if overlap_allgather or use_fp8:
|
||||
for p in module.parameters():
|
||||
if p.requires_grad and type(p) is not ColoParameter:
|
||||
p.__class__ = ColoParameter
|
||||
p.__init__(p, requires_grad=True)
|
||||
if use_fp8:
|
||||
self.op_hooks.append(FP8Hook())
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.convert_fn is not None:
|
||||
|
@ -348,9 +349,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
cpu_offload: bool = False,
|
||||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
cast_inputs: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
cast_inputs: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||
|
@ -497,8 +498,8 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
model,
|
||||
self.precision,
|
||||
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
|
||||
use_fp8=self.use_fp8,
|
||||
cast_inputs=self.cast_inputs,
|
||||
use_fp8=self.use_fp8,
|
||||
)
|
||||
|
||||
# TODO: Support Galore + ZeRO
|
||||
|
|
|
@ -1082,8 +1082,8 @@ def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, f
|
|||
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
|
||||
|
||||
|
||||
def reduce_forward(input_, process_group, fp8_communication=False):
|
||||
return _ReduceForward.apply(input_, process_group, fp8_communication)
|
||||
def reduce_forward(input_, process_group, grad_scale=None, fp8_communication=False):
|
||||
return _ReduceForward.apply(input_, process_group, grad_scale, fp8_communication)
|
||||
|
||||
|
||||
def reduce_backward(input_, process_group, fp8_communication=False):
|
||||
|
|
|
@ -698,7 +698,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
|
|||
|
||||
if shard_config.enable_flash_attention:
|
||||
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
inputs_embeds.dtype,
|
||||
inputs_embeds.device,
|
||||
|
|
|
@ -144,10 +144,14 @@ class MixtralPolicy(Policy):
|
|||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
kwargs=(
|
||||
{
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
}
|
||||
if self.shard_config.enable_tensor_parallelism
|
||||
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=MixtralModel,
|
||||
|
@ -164,7 +168,6 @@ class MixtralPolicy(Policy):
|
|||
"ep_group": self.shard_config.ep_group,
|
||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
|
Loading…
Reference in New Issue