mirror of https://github.com/hpcaitech/ColossalAI
fix
parent
1f703e0ef4
commit
53823118f2
|
@ -1278,31 +1278,31 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
|
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
|
||||||
use_fp8=self.use_fp8,
|
use_fp8=self.use_fp8,
|
||||||
)
|
)
|
||||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||||
if zero_stage == 0:
|
if zero_stage == 0:
|
||||||
is_zero = False
|
is_zero = False
|
||||||
if self.precision in ["fp16", "bf16"]:
|
if self.precision in ["fp16", "bf16"]:
|
||||||
optimizer = HybridParallelAMPOptimizer(
|
optimizer = HybridParallelAMPOptimizer(
|
||||||
optimizer,
|
optimizer,
|
||||||
model,
|
model,
|
||||||
use_pipeline=self.enable_pipeline_parallelism,
|
use_pipeline=self.enable_pipeline_parallelism,
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
precision=self.precision,
|
precision=self.precision,
|
||||||
max_norm=self.max_norm,
|
max_norm=self.max_norm,
|
||||||
pp_process_group=self.pp_group,
|
pp_process_group=self.pp_group,
|
||||||
tp_process_group=self.tp_group,
|
tp_process_group=self.tp_group,
|
||||||
**self.amp_config,
|
**self.amp_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optimizer = HybridParallelNaiveOptimizer(
|
optimizer = HybridParallelNaiveOptimizer(
|
||||||
optimizer,
|
optimizer,
|
||||||
model,
|
model,
|
||||||
use_pipeline=self.enable_pipeline_parallelism,
|
use_pipeline=self.enable_pipeline_parallelism,
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
max_norm=self.max_norm,
|
max_norm=self.max_norm,
|
||||||
pp_process_group=self.pp_group,
|
pp_process_group=self.pp_group,
|
||||||
tp_process_group=self.tp_group,
|
tp_process_group=self.tp_group,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
is_zero = self.dp_size > 1
|
is_zero = self.dp_size > 1
|
||||||
if self.dp_size == 1:
|
if self.dp_size == 1:
|
||||||
|
|
|
@ -90,13 +90,14 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||||
if self.dtype is not None and cast_inputs:
|
if self.dtype is not None and cast_inputs:
|
||||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||||
if overlap_allgather:
|
if overlap_allgather:
|
||||||
self.op_hook = ZeroOpHook()
|
self.op_hooks.append(ZeroOpHook())
|
||||||
|
if use_fp8:
|
||||||
|
self.op_hooks.append(FP8Hook())
|
||||||
|
if overlap_allgather or use_fp8:
|
||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
if p.requires_grad and type(p) is not ColoParameter:
|
if p.requires_grad and type(p) is not ColoParameter:
|
||||||
p.__class__ = ColoParameter
|
p.__class__ = ColoParameter
|
||||||
p.__init__(p, requires_grad=True)
|
p.__init__(p, requires_grad=True)
|
||||||
if use_fp8:
|
|
||||||
self.op_hooks.append(FP8Hook())
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.convert_fn is not None:
|
if self.convert_fn is not None:
|
||||||
|
@ -348,9 +349,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
cpu_offload: bool = False,
|
cpu_offload: bool = False,
|
||||||
master_weights: bool = True,
|
master_weights: bool = True,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
cast_inputs: bool = True,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
cast_inputs: bool = True,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||||
|
@ -497,8 +498,8 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
model,
|
model,
|
||||||
self.precision,
|
self.precision,
|
||||||
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
|
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
|
||||||
use_fp8=self.use_fp8,
|
|
||||||
cast_inputs=self.cast_inputs,
|
cast_inputs=self.cast_inputs,
|
||||||
|
use_fp8=self.use_fp8,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Support Galore + ZeRO
|
# TODO: Support Galore + ZeRO
|
||||||
|
|
|
@ -1082,8 +1082,8 @@ def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, f
|
||||||
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
|
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
|
||||||
|
|
||||||
|
|
||||||
def reduce_forward(input_, process_group, fp8_communication=False):
|
def reduce_forward(input_, process_group, grad_scale=None, fp8_communication=False):
|
||||||
return _ReduceForward.apply(input_, process_group, fp8_communication)
|
return _ReduceForward.apply(input_, process_group, grad_scale, fp8_communication)
|
||||||
|
|
||||||
|
|
||||||
def reduce_backward(input_, process_group, fp8_communication=False):
|
def reduce_backward(input_, process_group, fp8_communication=False):
|
||||||
|
|
|
@ -698,7 +698,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
|
||||||
|
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len)
|
||||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
attn_kwargs: dict = ColoAttention.prepare_attn_kwargs(
|
||||||
mask_shape,
|
mask_shape,
|
||||||
inputs_embeds.dtype,
|
inputs_embeds.dtype,
|
||||||
inputs_embeds.device,
|
inputs_embeds.device,
|
||||||
|
|
|
@ -144,10 +144,14 @@ class MixtralPolicy(Policy):
|
||||||
description=SubModuleReplacementDescription(
|
description=SubModuleReplacementDescription(
|
||||||
suffix="embed_tokens",
|
suffix="embed_tokens",
|
||||||
target_module=embedding_cls,
|
target_module=embedding_cls,
|
||||||
kwargs={
|
kwargs=(
|
||||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
{
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||||
},
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
}
|
||||||
|
if self.shard_config.enable_tensor_parallelism
|
||||||
|
else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
|
||||||
|
),
|
||||||
),
|
),
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=MixtralModel,
|
target_key=MixtralModel,
|
||||||
|
@ -164,7 +168,6 @@ class MixtralPolicy(Policy):
|
||||||
"ep_group": self.shard_config.ep_group,
|
"ep_group": self.shard_config.ep_group,
|
||||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
|
Loading…
Reference in New Issue