mirror of https://github.com/hpcaitech/ColossalAI
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.cipull/6023/head
parent
4cf79fa275
commit
81272e9d00
|
@ -392,4 +392,4 @@ def tokenize_kto(
|
|||
"label": data_point["label"],
|
||||
"input_id_decode": decoded_full_prompt,
|
||||
"completion_decode": decoded_completion,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -356,4 +356,4 @@ class DPOTrainer(SLTrainer):
|
|||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
||||
step_bar.close()
|
||||
|
|
|
@ -346,4 +346,4 @@ class KTOTrainer(SLTrainer):
|
|||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
||||
step_bar.close()
|
||||
|
|
|
@ -323,4 +323,4 @@ class ORPOTrainer(SLTrainer):
|
|||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
||||
step_bar.close()
|
||||
|
|
|
@ -903,4 +903,4 @@ For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/mai
|
|||
## Attention
|
||||
|
||||
|
||||
The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.
|
||||
The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.
|
||||
|
|
|
@ -375,4 +375,4 @@ if __name__ == "__main__":
|
|||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
||||
train(args)
|
||||
|
|
|
@ -340,4 +340,4 @@ if __name__ == "__main__":
|
|||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
||||
train(args)
|
||||
|
|
|
@ -20,4 +20,4 @@ datasets
|
|||
ninja==1.11.1
|
||||
sentencepiece==0.1.99
|
||||
flash-attn
|
||||
tiktoken
|
||||
tiktoken
|
||||
|
|
|
@ -640,4 +640,4 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
|
|
@ -64,7 +64,12 @@ class OptimizerParamCheckState(enum.Enum):
|
|||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
def __init__(
|
||||
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True, use_fp8: bool = False
|
||||
self,
|
||||
module: nn.Module,
|
||||
precision: str,
|
||||
overlap_allgather: bool = False,
|
||||
cast_inputs: bool = True,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
|
|
|
@ -3,6 +3,7 @@ import torch.distributed as dist
|
|||
from torch.autograd import Function
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from colossalai.shardformer.layer._operation import reduce_forward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
|
|
|
@ -16,11 +16,6 @@ from colossalai.shardformer.layer._operation import (
|
|||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
|
||||
|
||||
def get_flash_core_attention_forward():
|
||||
|
|
|
@ -24,7 +24,7 @@ from colossalai.shardformer.layer._operation import (
|
|||
)
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, dist_cross_entropy, cross_entropy_1d
|
||||
from ..layer import ColoAttention, dist_cross_entropy
|
||||
|
||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||
|
||||
|
|
|
@ -145,7 +145,9 @@ class EPDeepseekMoE(nn.Module):
|
|||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
|
||||
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=fp8_communication)
|
||||
dist.all_to_all_single(
|
||||
output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=fp8_communication
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
|
||||
|
@ -694,7 +696,7 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
|
|||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
|
||||
# TODO: upgrade transformers to 4.44.0 to fix the bug, remove the hard code.
|
||||
self._use_flash_attention_2 = shard_config.enable_flash_attention
|
||||
self._use_sdpa = False if shard_config.enable_flash_attention else self._use_sdpa
|
||||
|
|
|
@ -26,11 +26,15 @@ from transformers.utils import logging
|
|||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import AttnMaskType
|
||||
from colossalai.shardformer.layer._operation import all_to_all_comm, gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import ColoAttention, RingAttention, dist_cross_entropy, cross_entropy_1d
|
||||
from ..layer import ColoAttention, RingAttention, dist_cross_entropy
|
||||
|
||||
_SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
|
||||
|
||||
|
@ -162,9 +166,13 @@ class LlamaPipelineForwards:
|
|||
hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
|
||||
|
||||
elif is_share_sp_tp(sp_mode):
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
if use_cache:
|
||||
|
@ -355,7 +363,7 @@ class LlamaPipelineForwards:
|
|||
loss = dist_cross_entropy(
|
||||
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
|
||||
)
|
||||
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
@ -675,7 +683,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
|
|||
|
||||
past_seen_tokens = 0
|
||||
seq_len = inputs_embeds.shape[1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
inputs_embeds.shape[0]
|
||||
if use_cache: # kept for BC (cache positions)
|
||||
if not isinstance(past_key_values, StaticCache):
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
|
|
|
@ -691,7 +691,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
|||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication) # (1, 4, 256)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
) # (1, 4, 256)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Callable, Dict, List, Union
|
|||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
||||
|
|
Loading…
Reference in New Issue