[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
pull/6023/head
pre-commit-ci[bot] 2024-08-17 09:37:37 +00:00
parent 4cf79fa275
commit 81272e9d00
17 changed files with 39 additions and 26 deletions

View File

@ -64,7 +64,12 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin): class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__( def __init__(
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True, use_fp8: bool = False self,
module: nn.Module,
precision: str,
overlap_allgather: bool = False,
cast_inputs: bool = True,
use_fp8: bool = False,
) -> None: ) -> None:
super().__init__(module) super().__init__(module)
self.dtype = None self.dtype = None

View File

@ -3,6 +3,7 @@ import torch.distributed as dist
from torch.autograd import Function from torch.autograd import Function
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from colossalai.shardformer.layer._operation import reduce_forward from colossalai.shardformer.layer._operation import reduce_forward
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig

View File

@ -16,11 +16,6 @@ from colossalai.shardformer.layer._operation import (
gather_forward_split_backward, gather_forward_split_backward,
split_forward_gather_backward, split_forward_gather_backward,
) )
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
def get_flash_core_attention_forward(): def get_flash_core_attention_forward():

View File

@ -24,7 +24,7 @@ from colossalai.shardformer.layer._operation import (
) )
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, dist_cross_entropy, cross_entropy_1d from ..layer import ColoAttention, dist_cross_entropy
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]

View File

@ -145,7 +145,9 @@ class EPDeepseekMoE(nn.Module):
output_split_sizes = torch.zeros_like(input_split_sizes) output_split_sizes = torch.zeros_like(input_split_sizes)
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3] # [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=fp8_communication) dist.all_to_all_single(
output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=fp8_communication
)
with torch.no_grad(): with torch.no_grad():
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone() activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()

View File

@ -26,11 +26,15 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, RingAttention, dist_cross_entropy, cross_entropy_1d from ..layer import ColoAttention, RingAttention, dist_cross_entropy
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
@ -162,9 +166,13 @@ class LlamaPipelineForwards:
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
elif is_share_sp_tp(sp_mode): elif is_share_sp_tp(sp_mode):
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication) hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all": elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication) hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
if self.gradient_checkpointing and self.training and use_cache: if self.gradient_checkpointing and self.training and use_cache:
if use_cache: if use_cache:
@ -675,7 +683,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
past_seen_tokens = 0 past_seen_tokens = 0
seq_len = inputs_embeds.shape[1] seq_len = inputs_embeds.shape[1]
batch_size = inputs_embeds.shape[0] inputs_embeds.shape[0]
if use_cache: # kept for BC (cache positions) if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache): if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)

View File

@ -691,7 +691,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
# sp: all-to-all comminucation when introducing sequence parallel # sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication) # (1, 4, 256) attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
) # (1, 4, 256)
else: else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

View File

@ -5,7 +5,7 @@ from typing import Callable, Dict, List, Union
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D