From 351a1c269b4c6d45195c605c66fdebe71384c44a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Jun 2024 06:50:39 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/layer/loss.py | 2 +- colossalai/shardformer/modeling/llama.py | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 2af18d677..a6d19edf5 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -131,4 +131,4 @@ def cross_entropy_1d( vocab_size: int = None, dtype: torch.dtype = None, ) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype) \ No newline at end of file + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 6502457f2..a57d29815 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -24,10 +24,7 @@ from transformers.models.llama.modeling_llama import ( from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer._operation import ( - gather_forward_split_backward, - split_forward_gather_backward, -) +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, cross_entropy_1d @@ -566,7 +563,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - #attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + # attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -826,4 +823,4 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): attentions=outputs.attentions, ) - return forward \ No newline at end of file + return forward