Browse Source

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
pull/5842/head
pre-commit-ci[bot] 5 months ago
parent
commit
351a1c269b
  1. 2
      colossalai/shardformer/layer/loss.py
  2. 9
      colossalai/shardformer/modeling/llama.py

2
colossalai/shardformer/layer/loss.py

@ -131,4 +131,4 @@ def cross_entropy_1d(
vocab_size: int = None,
dtype: torch.dtype = None,
) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)

9
colossalai/shardformer/modeling/llama.py

@ -24,10 +24,7 @@ from transformers.models.llama.modeling_llama import (
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import (
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d
@ -566,7 +563,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
#attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
# attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -826,4 +823,4 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
attentions=outputs.attentions,
)
return forward
return forward

Loading…
Cancel
Save