[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
pull/6016/head
pre-commit-ci[bot] 2024-08-17 09:37:37 +00:00
parent 4cf79fa275
commit 81272e9d00
17 changed files with 39 additions and 26 deletions

View File

@ -64,7 +64,12 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True, use_fp8: bool = False
self,
module: nn.Module,
precision: str,
overlap_allgather: bool = False,
cast_inputs: bool = True,
use_fp8: bool = False,
) -> None:
super().__init__(module)
self.dtype = None

View File

@ -3,6 +3,7 @@ import torch.distributed as dist
from torch.autograd import Function
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
from colossalai.shardformer.layer._operation import reduce_forward
from colossalai.shardformer.shard import ShardConfig

View File

@ -16,11 +16,6 @@ from colossalai.shardformer.layer._operation import (
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
def get_flash_core_attention_forward():

View File

@ -24,7 +24,7 @@ from colossalai.shardformer.layer._operation import (
)
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, dist_cross_entropy, cross_entropy_1d
from ..layer import ColoAttention, dist_cross_entropy
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]

View File

@ -145,7 +145,9 @@ class EPDeepseekMoE(nn.Module):
output_split_sizes = torch.zeros_like(input_split_sizes)
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=fp8_communication)
dist.all_to_all_single(
output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=fp8_communication
)
with torch.no_grad():
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()

View File

@ -26,11 +26,15 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, RingAttention, dist_cross_entropy, cross_entropy_1d
from ..layer import ColoAttention, RingAttention, dist_cross_entropy
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
@ -162,9 +166,13 @@ class LlamaPipelineForwards:
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
elif is_share_sp_tp(sp_mode):
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication)
hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication)
hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
if self.gradient_checkpointing and self.training and use_cache:
if use_cache:
@ -675,7 +683,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
past_seen_tokens = 0
seq_len = inputs_embeds.shape[1]
batch_size = inputs_embeds.shape[0]
inputs_embeds.shape[0]
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

View File

@ -691,7 +691,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication) # (1, 4, 256)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
) # (1, 4, 256)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

View File

@ -5,7 +5,7 @@ from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D