[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
pull/6016/head
pre-commit-ci[bot] 2024-08-17 09:37:37 +00:00
parent 4cf79fa275
commit 81272e9d00
17 changed files with 39 additions and 26 deletions

View File

@ -392,4 +392,4 @@ def tokenize_kto(
"label": data_point["label"], "label": data_point["label"],
"input_id_decode": decoded_full_prompt, "input_id_decode": decoded_full_prompt,
"completion_decode": decoded_completion, "completion_decode": decoded_completion,
} }

View File

@ -356,4 +356,4 @@ class DPOTrainer(SLTrainer):
os.makedirs(self.save_dir, exist_ok=True) os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg) f.write(msg)
step_bar.close() step_bar.close()

View File

@ -346,4 +346,4 @@ class KTOTrainer(SLTrainer):
os.makedirs(self.save_dir, exist_ok=True) os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg) f.write(msg)
step_bar.close() step_bar.close()

View File

@ -323,4 +323,4 @@ class ORPOTrainer(SLTrainer):
os.makedirs(self.save_dir, exist_ok=True) os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg) f.write(msg)
step_bar.close() step_bar.close()

View File

@ -903,4 +903,4 @@ For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/mai
## Attention ## Attention
The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance. The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.

View File

@ -375,4 +375,4 @@ if __name__ == "__main__":
os.makedirs(os.path.dirname(args.config_file), exist_ok=True) os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f: with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4) json.dump(args.__dict__, f, indent=4)
train(args) train(args)

View File

@ -340,4 +340,4 @@ if __name__ == "__main__":
os.makedirs(os.path.dirname(args.config_file), exist_ok=True) os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f: with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4) json.dump(args.__dict__, f, indent=4)
train(args) train(args)

View File

@ -20,4 +20,4 @@ datasets
ninja==1.11.1 ninja==1.11.1
sentencepiece==0.1.99 sentencepiece==0.1.99
flash-attn flash-attn
tiktoken tiktoken

View File

@ -640,4 +640,4 @@ for lora_rank in ${LORA_RANK[@]}; do
fi fi
done done
done done
done done

View File

@ -64,7 +64,12 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin): class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__( def __init__(
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True, use_fp8: bool = False self,
module: nn.Module,
precision: str,
overlap_allgather: bool = False,
cast_inputs: bool = True,
use_fp8: bool = False,
) -> None: ) -> None:
super().__init__(module) super().__init__(module)
self.dtype = None self.dtype = None

View File

@ -3,6 +3,7 @@ import torch.distributed as dist
from torch.autograd import Function from torch.autograd import Function
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from colossalai.shardformer.layer._operation import reduce_forward from colossalai.shardformer.layer._operation import reduce_forward
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig

View File

@ -16,11 +16,6 @@ from colossalai.shardformer.layer._operation import (
gather_forward_split_backward, gather_forward_split_backward,
split_forward_gather_backward, split_forward_gather_backward,
) )
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
def get_flash_core_attention_forward(): def get_flash_core_attention_forward():

View File

@ -24,7 +24,7 @@ from colossalai.shardformer.layer._operation import (
) )
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, dist_cross_entropy, cross_entropy_1d from ..layer import ColoAttention, dist_cross_entropy
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]

View File

@ -145,7 +145,9 @@ class EPDeepseekMoE(nn.Module):
output_split_sizes = torch.zeros_like(input_split_sizes) output_split_sizes = torch.zeros_like(input_split_sizes)
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3] # [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=fp8_communication) dist.all_to_all_single(
output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=fp8_communication
)
with torch.no_grad(): with torch.no_grad():
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone() activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
@ -694,7 +696,7 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
# TODO: upgrade transformers to 4.44.0 to fix the bug, remove the hard code. # TODO: upgrade transformers to 4.44.0 to fix the bug, remove the hard code.
self._use_flash_attention_2 = shard_config.enable_flash_attention self._use_flash_attention_2 = shard_config.enable_flash_attention
self._use_sdpa = False if shard_config.enable_flash_attention else self._use_sdpa self._use_sdpa = False if shard_config.enable_flash_attention else self._use_sdpa

View File

@ -26,11 +26,15 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, RingAttention, dist_cross_entropy, cross_entropy_1d from ..layer import ColoAttention, RingAttention, dist_cross_entropy
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"] _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
@ -162,9 +166,13 @@ class LlamaPipelineForwards:
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group) hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
elif is_share_sp_tp(sp_mode): elif is_share_sp_tp(sp_mode):
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication) hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all": elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication) hidden_states = split_forward_gather_backward(
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
if self.gradient_checkpointing and self.training and use_cache: if self.gradient_checkpointing and self.training and use_cache:
if use_cache: if use_cache:
@ -355,7 +363,7 @@ class LlamaPipelineForwards:
loss = dist_cross_entropy( loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
) )
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output return (loss,) + output if loss is not None else output
@ -675,7 +683,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
past_seen_tokens = 0 past_seen_tokens = 0
seq_len = inputs_embeds.shape[1] seq_len = inputs_embeds.shape[1]
batch_size = inputs_embeds.shape[0] inputs_embeds.shape[0]
if use_cache: # kept for BC (cache positions) if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache): if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)

View File

@ -691,7 +691,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
# sp: all-to-all comminucation when introducing sequence parallel # sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication) # (1, 4, 256) attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
) # (1, 4, 256)
else: else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

View File

@ -5,7 +5,7 @@ from typing import Callable, Dict, List, Union
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D