|
|
|
@ -24,10 +24,7 @@ from transformers.models.llama.modeling_llama import (
|
|
|
|
|
from transformers.utils import logging |
|
|
|
|
|
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager |
|
|
|
|
from colossalai.shardformer.layer._operation import ( |
|
|
|
|
gather_forward_split_backward, |
|
|
|
|
split_forward_gather_backward, |
|
|
|
|
) |
|
|
|
|
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward |
|
|
|
|
from colossalai.shardformer.shard import ShardConfig |
|
|
|
|
|
|
|
|
|
from ..layer import ColoAttention, cross_entropy_1d |
|
|
|
@ -566,7 +563,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
|
|
|
|
|
# sp: all-to-all comminucation when introducing sequence parallel |
|
|
|
|
if sp_mode == "all_to_all": |
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) |
|
|
|
|
#attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) |
|
|
|
|
# attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) |
|
|
|
|
else: |
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
|
|
|
|
|
|
|
|
@ -826,4 +823,4 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|
|
|
|
attentions=outputs.attentions, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
return forward |
|
|
|
|
return forward |
|
|
|
|