From 53823118f2fc91539b46442f216cd203fe7b5f60 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 20 Aug 2024 03:20:13 +0000 Subject: [PATCH] fix --- .../booster/plugin/hybrid_parallel_plugin.py | 50 +++++++++---------- .../booster/plugin/low_level_zero_plugin.py | 11 ++-- colossalai/shardformer/layer/_operation.py | 4 +- colossalai/shardformer/modeling/llama.py | 2 +- colossalai/shardformer/policies/mixtral.py | 13 +++-- 5 files changed, 42 insertions(+), 38 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bd970878f..a92371485 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1278,31 +1278,31 @@ class HybridParallelPlugin(PipelinePluginBase): overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), use_fp8=self.use_fp8, ) - if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): - if zero_stage == 0: - is_zero = False - if self.precision in ["fp16", "bf16"]: - optimizer = HybridParallelAMPOptimizer( - optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - precision=self.precision, - max_norm=self.max_norm, - pp_process_group=self.pp_group, - tp_process_group=self.tp_group, - **self.amp_config, - ) - else: - optimizer = HybridParallelNaiveOptimizer( - optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - max_norm=self.max_norm, - pp_process_group=self.pp_group, - tp_process_group=self.tp_group, - ) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if zero_stage == 0: + is_zero = False + if self.precision in ["fp16", "bf16"]: + optimizer = HybridParallelAMPOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + precision=self.precision, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, + **self.amp_config, + ) + else: + optimizer = HybridParallelNaiveOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, + ) else: is_zero = self.dp_size > 1 if self.dp_size == 1: diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 448fb9e21..088fa1daa 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -90,13 +90,14 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): if self.dtype is not None and cast_inputs: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) if overlap_allgather: - self.op_hook = ZeroOpHook() + self.op_hooks.append(ZeroOpHook()) + if use_fp8: + self.op_hooks.append(FP8Hook()) + if overlap_allgather or use_fp8: for p in module.parameters(): if p.requires_grad and type(p) is not ColoParameter: p.__class__ = ColoParameter p.__init__(p, requires_grad=True) - if use_fp8: - self.op_hooks.append(FP8Hook()) def forward(self, *args, **kwargs): if self.convert_fn is not None: @@ -348,9 +349,9 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, + cast_inputs: bool = True, fp8_communication: bool = False, use_fp8: bool = False, - cast_inputs: bool = True, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" @@ -497,8 +498,8 @@ class LowLevelZeroPlugin(DPPluginBase): model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], - use_fp8=self.use_fp8, cast_inputs=self.cast_inputs, + use_fp8=self.use_fp8, ) # TODO: Support Galore + ZeRO diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index aed9d8351..bfe408065 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1082,8 +1082,8 @@ def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, f return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication) -def reduce_forward(input_, process_group, fp8_communication=False): - return _ReduceForward.apply(input_, process_group, fp8_communication) +def reduce_forward(input_, process_group, grad_scale=None, fp8_communication=False): + return _ReduceForward.apply(input_, process_group, grad_scale, fp8_communication) def reduce_backward(input_, process_group, fp8_communication=False): diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 71d8daa35..3f05d0428 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -698,7 +698,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N if shard_config.enable_flash_attention: mask_shape = (inputs_embeds.shape[0], 1, past_seen_tokens + seq_len, past_seen_tokens + seq_len) - attention_mask = ColoAttention.prepare_attn_kwargs( + attn_kwargs: dict = ColoAttention.prepare_attn_kwargs( mask_shape, inputs_embeds.dtype, inputs_embeds.device, diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 4bdca78cb..3a373889c 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -144,10 +144,14 @@ class MixtralPolicy(Policy): description=SubModuleReplacementDescription( suffix="embed_tokens", target_module=embedding_cls, - kwargs={ - "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, - "fp8_communication": self.shard_config.fp8_communication, - }, + kwargs=( + { + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + } + if self.shard_config.enable_tensor_parallelism + else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by} + ), ), policy=policy, target_key=MixtralModel, @@ -164,7 +168,6 @@ class MixtralPolicy(Policy): "ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, - "fp8_communication": self.shard_config.fp8_communication, }, ) ],