2023-07-21 02:46:39 +00:00
|
|
|
import warnings
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
|
2023-07-04 02:28:31 +00:00
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
from torch.distributed import ProcessGroup
|
2023-07-21 02:46:39 +00:00
|
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
2023-08-07 08:41:07 +00:00
|
|
|
from torch.nn import functional as F
|
2024-04-24 14:51:50 +00:00
|
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
2023-07-21 02:46:39 +00:00
|
|
|
from transformers.modeling_outputs import (
|
|
|
|
BaseModelOutputWithPastAndCrossAttentions,
|
|
|
|
CausalLMOutputWithCrossAttentions,
|
2024-05-21 03:07:13 +00:00
|
|
|
CausalLMOutputWithPast,
|
2023-07-21 02:46:39 +00:00
|
|
|
QuestionAnsweringModelOutput,
|
|
|
|
SequenceClassifierOutputWithPast,
|
|
|
|
TokenClassifierOutput,
|
|
|
|
)
|
|
|
|
from transformers.models.bloom.modeling_bloom import (
|
|
|
|
BloomForCausalLM,
|
|
|
|
BloomForQuestionAnswering,
|
|
|
|
BloomForSequenceClassification,
|
|
|
|
BloomForTokenClassification,
|
|
|
|
BloomModel,
|
|
|
|
)
|
|
|
|
from transformers.utils import logging
|
|
|
|
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
2023-08-18 07:34:18 +00:00
|
|
|
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
|
|
|
from colossalai.shardformer.shard import ShardConfig
|
|
|
|
|
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
|
|
|
from ..layer import dist_cross_entropy
|
2024-05-21 03:07:13 +00:00
|
|
|
|
2023-08-18 07:34:18 +00:00
|
|
|
logger = logging.get_logger(__name__)
|
2023-07-04 02:28:31 +00:00
|
|
|
|
|
|
|
|
|
|
|
def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
|
2023-09-19 06:20:26 +00:00
|
|
|
def build_bloom_alibi_tensor(
|
|
|
|
self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype
|
|
|
|
) -> torch.Tensor:
|
2023-07-04 02:28:31 +00:00
|
|
|
"""
|
|
|
|
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
|
|
|
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
|
|
|
`softmax(l+a) = softmax(l)`. Based on
|
|
|
|
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
|
|
|
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
|
|
|
|
attention_mask (`torch.Tensor`):
|
|
|
|
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
|
|
|
|
num_heads (`int`, *required*):
|
|
|
|
number of heads
|
|
|
|
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
|
|
|
dtype of the output tensor
|
|
|
|
"""
|
|
|
|
import math
|
|
|
|
|
|
|
|
if dist.is_initialized():
|
|
|
|
world_size = dist.get_world_size(process_group)
|
|
|
|
num_heads = num_heads * world_size
|
|
|
|
|
|
|
|
batch_size, seq_length = attention_mask.shape
|
2023-09-19 06:20:26 +00:00
|
|
|
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
|
|
|
base = torch.tensor(
|
|
|
|
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
|
|
|
|
)
|
2023-07-04 02:28:31 +00:00
|
|
|
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
|
|
|
|
slopes = torch.pow(base, powers)
|
|
|
|
|
|
|
|
if closest_power_of_2 != num_heads:
|
2023-09-19 06:20:26 +00:00
|
|
|
extra_base = torch.tensor(
|
|
|
|
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
|
|
|
|
device=attention_mask.device,
|
|
|
|
dtype=torch.float32,
|
|
|
|
)
|
2023-07-04 02:28:31 +00:00
|
|
|
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
2023-09-19 06:20:26 +00:00
|
|
|
extra_powers = torch.arange(
|
|
|
|
1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32
|
|
|
|
)
|
2023-07-04 02:28:31 +00:00
|
|
|
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
|
|
|
|
|
|
|
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
|
|
|
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
|
|
|
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
|
|
|
# => the query_length dimension will then be broadcasted correctly
|
|
|
|
# This is more or less identical to T5's relative position bias:
|
|
|
|
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
|
|
|
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
|
|
|
alibi = slopes[..., None] * arange_tensor
|
|
|
|
if dist.is_initialized():
|
|
|
|
num_heads_per_rank = int(num_heads / dist.get_world_size(process_group))
|
|
|
|
offset = dist.get_rank(process_group) * num_heads_per_rank
|
|
|
|
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
|
2023-09-19 06:20:26 +00:00
|
|
|
alibi = alibi[:, offset : num_heads_per_rank + offset, :, :]
|
2023-07-04 02:28:31 +00:00
|
|
|
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
|
|
|
|
else:
|
|
|
|
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
|
|
|
|
|
|
|
|
return build_bloom_alibi_tensor
|
2023-07-21 02:46:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
class BloomPipelineForwards:
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
2023-07-21 02:46:39 +00:00
|
|
|
This class serves as a micro library for bloom pipeline forwards.
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
2023-07-21 02:46:39 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def bloom_model_forward(
|
|
|
|
self: BloomModel,
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
head_mask: Optional[torch.LongTensor] = None,
|
|
|
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
|
|
stage_index: Optional[List[int]] = None,
|
2023-08-18 07:34:18 +00:00
|
|
|
shard_config: ShardConfig = None,
|
2023-07-21 02:46:39 +00:00
|
|
|
**deprecated_arguments,
|
2023-09-19 06:20:26 +00:00
|
|
|
) -> Union[Tuple[torch.Tensor, ...], "BaseModelOutputWithPastAndCrossAttentions"]:
|
2023-07-21 02:46:39 +00:00
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
if deprecated_arguments.pop("position_ids", False) is not False:
|
|
|
|
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
|
|
|
warnings.warn(
|
|
|
|
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
|
|
|
" passing `position_ids`.",
|
|
|
|
FutureWarning,
|
|
|
|
)
|
|
|
|
if len(deprecated_arguments) > 0:
|
|
|
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
|
|
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
2023-09-19 06:20:26 +00:00
|
|
|
output_hidden_states = (
|
|
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
)
|
2023-07-21 02:46:39 +00:00
|
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
# add warnings here
|
|
|
|
if output_attentions:
|
2023-09-19 06:20:26 +00:00
|
|
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
2023-07-21 02:46:39 +00:00
|
|
|
output_attentions = False
|
|
|
|
if output_hidden_states:
|
2023-09-19 06:20:26 +00:00
|
|
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
2023-07-21 02:46:39 +00:00
|
|
|
output_hidden_states = False
|
|
|
|
if use_cache:
|
2023-09-19 06:20:26 +00:00
|
|
|
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
2023-07-21 02:46:39 +00:00
|
|
|
use_cache = False
|
|
|
|
# Prepare head mask if needed
|
|
|
|
# 1.0 in head_mask indicate we keep the head
|
|
|
|
# attention_probs has shape batch_size x num_heads x N x N
|
|
|
|
|
|
|
|
# head_mask has shape n_layer x batch x num_heads x N x N
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
|
|
|
|
|
|
|
# case: First stage of training
|
|
|
|
if stage_manager.is_first_stage():
|
|
|
|
# check input_ids and inputs_embeds
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
|
elif input_ids is not None:
|
|
|
|
batch_size, seq_length = input_ids.shape
|
|
|
|
elif inputs_embeds is not None:
|
|
|
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
|
|
else:
|
|
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
|
|
|
|
if inputs_embeds is None:
|
|
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
|
|
|
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
|
|
# initialize in the first stage and then pass to the next stage
|
|
|
|
else:
|
|
|
|
input_shape = hidden_states.shape[:-1]
|
|
|
|
batch_size, seq_length = input_shape
|
|
|
|
|
|
|
|
# extra recording tensor should be generated in the first stage
|
|
|
|
|
|
|
|
presents = () if use_cache else None
|
|
|
|
all_self_attentions = () if output_attentions else None
|
|
|
|
all_hidden_states = () if output_hidden_states else None
|
|
|
|
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
if use_cache:
|
|
|
|
logger.warning_once(
|
2023-09-19 06:20:26 +00:00
|
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
|
|
)
|
2023-07-21 02:46:39 +00:00
|
|
|
use_cache = False
|
|
|
|
|
|
|
|
if past_key_values is None:
|
|
|
|
past_key_values = tuple([None] * len(self.h))
|
|
|
|
# Compute alibi tensor: check build_alibi_tensor documentation,build for every stage
|
|
|
|
seq_length_with_past = seq_length
|
|
|
|
past_key_values_length = 0
|
|
|
|
if past_key_values[0] is not None:
|
2023-09-19 06:20:26 +00:00
|
|
|
past_key_values_length = past_key_values[0][0].shape[2] # source_len
|
2023-07-21 02:46:39 +00:00
|
|
|
|
|
|
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
|
|
if attention_mask is None:
|
|
|
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
|
|
|
else:
|
|
|
|
attention_mask = attention_mask.to(hidden_states.device)
|
|
|
|
|
|
|
|
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
|
|
|
|
|
|
|
# causal_mask is constructed every stage and its input is passed through different stages
|
2024-04-24 14:51:50 +00:00
|
|
|
causal_mask = _prepare_4d_causal_attention_mask(
|
2023-07-21 02:46:39 +00:00
|
|
|
attention_mask,
|
|
|
|
input_shape=(batch_size, seq_length),
|
2024-04-24 14:51:50 +00:00
|
|
|
inputs_embeds=hidden_states,
|
2023-07-21 02:46:39 +00:00
|
|
|
past_key_values_length=past_key_values_length,
|
|
|
|
)
|
2024-04-24 14:51:50 +00:00
|
|
|
causal_mask = causal_mask.bool()
|
2023-08-18 07:34:18 +00:00
|
|
|
# split the input tensor along sequence dimension
|
|
|
|
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
2024-04-03 09:15:47 +00:00
|
|
|
if shard_config and shard_config.enable_sequence_parallelism:
|
|
|
|
if shard_config.sequence_parallelism_mode == "split_gather":
|
|
|
|
hidden_states = split_forward_gather_backward(
|
2024-08-12 10:17:05 +00:00
|
|
|
hidden_states,
|
|
|
|
dim=1,
|
|
|
|
process_group=shard_config.tensor_parallel_process_group,
|
|
|
|
fp8_communication=shard_config.fp8_communication,
|
2024-04-03 09:15:47 +00:00
|
|
|
)
|
2023-08-18 07:34:18 +00:00
|
|
|
|
2023-07-21 02:46:39 +00:00
|
|
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
2023-09-19 06:20:26 +00:00
|
|
|
for i, (block, layer_past) in enumerate(
|
|
|
|
zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx
|
|
|
|
):
|
2023-07-21 02:46:39 +00:00
|
|
|
if output_hidden_states:
|
|
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
|
|
|
|
if self.gradient_checkpointing and self.training:
|
2024-04-24 14:51:50 +00:00
|
|
|
outputs = self._gradient_checkpointing_func(
|
|
|
|
block.__call__,
|
2023-07-21 02:46:39 +00:00
|
|
|
hidden_states,
|
|
|
|
alibi,
|
|
|
|
causal_mask,
|
|
|
|
layer_past,
|
|
|
|
head_mask[i],
|
2024-04-24 14:51:50 +00:00
|
|
|
use_cache,
|
|
|
|
output_attentions,
|
2023-07-21 02:46:39 +00:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
outputs = block(
|
|
|
|
hidden_states,
|
|
|
|
layer_past=layer_past,
|
|
|
|
attention_mask=causal_mask,
|
|
|
|
head_mask=head_mask[i],
|
|
|
|
use_cache=use_cache,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
alibi=alibi,
|
|
|
|
)
|
|
|
|
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
|
|
|
|
if use_cache is True:
|
|
|
|
presents = presents + (outputs[1],)
|
|
|
|
if output_attentions:
|
2023-09-19 06:20:26 +00:00
|
|
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
2023-07-21 02:46:39 +00:00
|
|
|
|
2023-08-18 07:34:18 +00:00
|
|
|
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
2024-04-03 09:15:47 +00:00
|
|
|
if shard_config and shard_config.enable_sequence_parallelism:
|
|
|
|
if shard_config.sequence_parallelism_mode == "split_gather":
|
|
|
|
hidden_states = gather_forward_split_backward(
|
2024-08-12 10:17:05 +00:00
|
|
|
hidden_states,
|
|
|
|
dim=1,
|
|
|
|
process_group=shard_config.tensor_parallel_process_group,
|
|
|
|
fp8_communication=shard_config.fp8_communication,
|
2024-04-03 09:15:47 +00:00
|
|
|
)
|
2023-08-18 07:34:18 +00:00
|
|
|
|
2023-07-21 02:46:39 +00:00
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
# Add last hidden state
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
|
2023-08-14 09:43:33 +00:00
|
|
|
# TODO(jianghai): deal with all_hidden_states, all_self_attentions, presents
|
2023-07-21 02:46:39 +00:00
|
|
|
if output_hidden_states:
|
|
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
if not return_dict:
|
|
|
|
return tuple(
|
2023-09-19 06:20:26 +00:00
|
|
|
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
|
|
|
|
)
|
2023-07-21 02:46:39 +00:00
|
|
|
|
|
|
|
# attention_mask is not returned ; presents = past_key_values
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
|
|
last_hidden_state=hidden_states,
|
|
|
|
past_key_values=presents,
|
|
|
|
hidden_states=all_hidden_states,
|
|
|
|
attentions=all_self_attentions,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# always return dict for imediate stage
|
2023-09-19 06:20:26 +00:00
|
|
|
return {"hidden_states": hidden_states}
|
2023-07-21 02:46:39 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2023-09-19 06:20:26 +00:00
|
|
|
def bloom_for_causal_lm_forward(
|
|
|
|
self: BloomForCausalLM,
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
|
|
labels: Optional[torch.Tensor] = None,
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
|
|
stage_index: Optional[List[int]] = None,
|
|
|
|
shard_config: ShardConfig = None,
|
|
|
|
**deprecated_arguments,
|
|
|
|
):
|
2023-07-21 02:46:39 +00:00
|
|
|
r"""
|
|
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
|
|
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
|
|
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
|
|
|
"""
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
if deprecated_arguments.pop("position_ids", False) is not False:
|
|
|
|
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
|
|
|
warnings.warn(
|
|
|
|
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
|
|
|
" passing `position_ids`.",
|
|
|
|
FutureWarning,
|
|
|
|
)
|
|
|
|
if len(deprecated_arguments) > 0:
|
|
|
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
|
|
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
2023-08-14 09:43:33 +00:00
|
|
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
2023-07-21 02:46:39 +00:00
|
|
|
if output_attentions:
|
2023-09-19 06:20:26 +00:00
|
|
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
2023-07-21 02:46:39 +00:00
|
|
|
output_attentions = False
|
|
|
|
if output_hidden_states:
|
2023-09-19 06:20:26 +00:00
|
|
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
2023-07-21 02:46:39 +00:00
|
|
|
output_hidden_states = False
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
|
|
|
|
self.transformer,
|
|
|
|
input_ids,
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
head_mask=head_mask,
|
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
use_cache=use_cache,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
return_dict=return_dict,
|
|
|
|
stage_manager=stage_manager,
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
stage_index=stage_index,
|
|
|
|
shard_config=shard_config,
|
|
|
|
)
|
2023-07-21 02:46:39 +00:00
|
|
|
past_key_values = None
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
hidden_states = transformer_outputs[0]
|
2024-05-21 03:07:13 +00:00
|
|
|
lm_logits = self.lm_head(hidden_states).contiguous()
|
2023-07-21 02:46:39 +00:00
|
|
|
|
2024-09-10 04:06:50 +00:00
|
|
|
loss = None
|
|
|
|
if labels is not None:
|
|
|
|
loss = dist_cross_entropy(
|
|
|
|
labels,
|
|
|
|
lm_logits,
|
|
|
|
shard_config,
|
|
|
|
self.lm_head.out_features,
|
|
|
|
self.transformer.dtype,
|
|
|
|
)
|
2023-07-21 02:46:39 +00:00
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
output = (lm_logits,) + transformer_outputs[1:]
|
|
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
|
|
|
|
return CausalLMOutputWithCrossAttentions(
|
|
|
|
loss=loss,
|
|
|
|
logits=lm_logits,
|
|
|
|
past_key_values=transformer_outputs.past_key_values,
|
|
|
|
hidden_states=transformer_outputs.hidden_states,
|
|
|
|
attentions=transformer_outputs.attentions,
|
|
|
|
)
|
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
hidden_states = transformer_outputs.get("hidden_states")
|
|
|
|
return {"hidden_states": hidden_states}
|
2023-07-21 02:46:39 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def bloom_for_sequence_classification_forward(
|
|
|
|
self: BloomForSequenceClassification,
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
|
|
labels: Optional[torch.Tensor] = None,
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
|
|
stage_index: Optional[List[int]] = None,
|
2023-08-18 07:34:18 +00:00
|
|
|
shard_config: ShardConfig = None,
|
2023-07-21 02:46:39 +00:00
|
|
|
**deprecated_arguments,
|
|
|
|
):
|
|
|
|
r"""
|
|
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
|
|
"""
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
if deprecated_arguments.pop("position_ids", False) is not False:
|
|
|
|
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
|
|
|
warnings.warn(
|
|
|
|
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
|
|
|
" passing `position_ids`.",
|
|
|
|
FutureWarning,
|
|
|
|
)
|
|
|
|
if len(deprecated_arguments) > 0:
|
|
|
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
|
|
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
2023-08-14 09:43:33 +00:00
|
|
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
2023-07-21 02:46:39 +00:00
|
|
|
if output_attentions:
|
2023-09-19 06:20:26 +00:00
|
|
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
2023-07-21 02:46:39 +00:00
|
|
|
output_attentions = False
|
|
|
|
if output_hidden_states:
|
2023-09-19 06:20:26 +00:00
|
|
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
2023-07-21 02:46:39 +00:00
|
|
|
output_hidden_states = False
|
|
|
|
|
|
|
|
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
|
|
|
|
self.transformer,
|
|
|
|
input_ids,
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
head_mask=head_mask,
|
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
use_cache=use_cache,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
return_dict=return_dict,
|
|
|
|
stage_manager=stage_manager,
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
stage_index=stage_index,
|
2023-08-18 07:34:18 +00:00
|
|
|
shard_config=shard_config,
|
2023-07-21 02:46:39 +00:00
|
|
|
)
|
|
|
|
past_key_values = None
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
batch_size = hidden_states.shape[0]
|
2023-08-14 09:43:33 +00:00
|
|
|
# update batch size
|
2023-07-21 02:46:39 +00:00
|
|
|
hidden_states = transformer_outputs[0]
|
|
|
|
logits = self.score(hidden_states)
|
|
|
|
|
|
|
|
if self.config.pad_token_id is None and batch_size != 1:
|
|
|
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
|
|
|
if self.config.pad_token_id is None:
|
|
|
|
sequence_lengths = -1
|
|
|
|
else:
|
|
|
|
if input_ids is not None:
|
2024-06-11 09:43:50 +00:00
|
|
|
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
|
|
|
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
|
|
|
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
|
|
|
sequence_lengths = sequence_lengths.to(logits.device)
|
2023-07-21 02:46:39 +00:00
|
|
|
else:
|
|
|
|
sequence_lengths = -1
|
|
|
|
logger.warning(
|
|
|
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
2023-09-19 06:20:26 +00:00
|
|
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
|
|
|
)
|
2023-07-21 02:46:39 +00:00
|
|
|
|
|
|
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
if labels is not None:
|
|
|
|
if self.config.problem_type is None:
|
|
|
|
if self.num_labels == 1:
|
|
|
|
self.config.problem_type = "regression"
|
|
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
|
|
self.config.problem_type = "single_label_classification"
|
|
|
|
else:
|
|
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
|
|
|
|
if self.config.problem_type == "regression":
|
|
|
|
loss_fct = MSELoss()
|
|
|
|
if self.num_labels == 1:
|
|
|
|
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
|
|
|
else:
|
|
|
|
loss = loss_fct(pooled_logits, labels)
|
|
|
|
elif self.config.problem_type == "single_label_classification":
|
|
|
|
loss_fct = CrossEntropyLoss()
|
|
|
|
loss = loss_fct(pooled_logits, labels)
|
|
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
|
|
loss_fct = BCEWithLogitsLoss()
|
|
|
|
loss = loss_fct(pooled_logits, labels)
|
|
|
|
if not return_dict:
|
|
|
|
output = (pooled_logits,) + transformer_outputs[1:]
|
|
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
|
|
|
|
return SequenceClassifierOutputWithPast(
|
|
|
|
loss=loss,
|
|
|
|
logits=pooled_logits,
|
|
|
|
past_key_values=transformer_outputs.past_key_values,
|
|
|
|
hidden_states=transformer_outputs.hidden_states,
|
|
|
|
attentions=transformer_outputs.attentions,
|
|
|
|
)
|
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
hidden_states = transformer_outputs.get("hidden_states")
|
|
|
|
return {"hidden_states": hidden_states}
|
2023-07-21 02:46:39 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def bloom_for_token_classification_forward(
|
|
|
|
self: BloomForTokenClassification,
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
|
|
labels: Optional[torch.Tensor] = None,
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
|
|
stage_index: Optional[List[int]] = None,
|
2023-08-18 07:34:18 +00:00
|
|
|
shard_config: ShardConfig = None,
|
2023-07-21 02:46:39 +00:00
|
|
|
**deprecated_arguments,
|
|
|
|
):
|
|
|
|
r"""
|
|
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
|
|
"""
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
if deprecated_arguments.pop("position_ids", False) is not False:
|
|
|
|
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
|
|
|
warnings.warn(
|
|
|
|
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
|
|
|
" passing `position_ids`.",
|
|
|
|
FutureWarning,
|
|
|
|
)
|
|
|
|
if len(deprecated_arguments) > 0:
|
|
|
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
|
|
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
2023-08-14 09:43:33 +00:00
|
|
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
2023-07-21 02:46:39 +00:00
|
|
|
if output_attentions:
|
2023-09-19 06:20:26 +00:00
|
|
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
2023-07-21 02:46:39 +00:00
|
|
|
output_attentions = False
|
|
|
|
if output_hidden_states:
|
2023-09-19 06:20:26 +00:00
|
|
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
2023-07-21 02:46:39 +00:00
|
|
|
output_hidden_states = False
|
|
|
|
|
|
|
|
transformer_outputs = BloomPipelineForwards.bloom_model_forward(
|
|
|
|
self.transformer,
|
|
|
|
input_ids,
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
head_mask=head_mask,
|
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
use_cache=use_cache,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
return_dict=return_dict,
|
|
|
|
stage_manager=stage_manager,
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
stage_index=stage_index,
|
2023-08-18 07:34:18 +00:00
|
|
|
shard_config=shard_config,
|
2023-07-21 02:46:39 +00:00
|
|
|
)
|
|
|
|
past_key_values = None
|
|
|
|
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
hidden_states = transformer_outputs[0]
|
|
|
|
hidden_states = self.dropout(hidden_states)
|
|
|
|
logits = self.classifier(hidden_states)
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
if labels is not None:
|
|
|
|
# move labels to correct device to enable model parallelism
|
|
|
|
labels = labels.to(logits.device)
|
|
|
|
batch_size, seq_length = labels.shape
|
|
|
|
loss_fct = CrossEntropyLoss()
|
2023-09-19 06:20:26 +00:00
|
|
|
loss = loss_fct(
|
|
|
|
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
|
|
|
|
)
|
2023-07-21 02:46:39 +00:00
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
output = (logits,) + transformer_outputs[2:]
|
|
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
|
|
|
|
return TokenClassifierOutput(
|
|
|
|
loss=loss,
|
|
|
|
logits=logits,
|
|
|
|
hidden_states=transformer_outputs.hidden_states,
|
|
|
|
attentions=transformer_outputs.attentions,
|
|
|
|
)
|
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
hidden_states = transformer_outputs.get("hidden_states")
|
|
|
|
return {"hidden_states": hidden_states}
|
2023-07-21 02:46:39 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def bloom_for_question_answering_forward(
|
|
|
|
self: BloomForQuestionAnswering,
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
start_positions: Optional[torch.LongTensor] = None,
|
|
|
|
end_positions: Optional[torch.LongTensor] = None,
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
|
|
stage_index: Optional[List[int]] = None,
|
2023-08-18 07:34:18 +00:00
|
|
|
shard_config: ShardConfig = None,
|
2023-07-21 02:46:39 +00:00
|
|
|
):
|
|
|
|
r"""
|
|
|
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
|
|
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
|
|
|
are not taken into account for computing the loss.
|
|
|
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
|
|
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
|
|
|
are not taken into account for computing the loss.
|
|
|
|
"""
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
2023-08-14 09:43:33 +00:00
|
|
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
2023-07-21 02:46:39 +00:00
|
|
|
if output_attentions:
|
2023-09-19 06:20:26 +00:00
|
|
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
2023-07-21 02:46:39 +00:00
|
|
|
output_attentions = False
|
|
|
|
if output_hidden_states:
|
2023-09-19 06:20:26 +00:00
|
|
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
2023-07-21 02:46:39 +00:00
|
|
|
output_hidden_states = False
|
|
|
|
|
|
|
|
outputs = BloomPipelineForwards.bloom_model_forward(
|
|
|
|
self.transformer,
|
|
|
|
input_ids,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
position_ids=position_ids,
|
|
|
|
head_mask=head_mask,
|
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
return_dict=return_dict,
|
|
|
|
stage_manager=stage_manager,
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
stage_index=stage_index,
|
2023-08-18 07:34:18 +00:00
|
|
|
shard_config=shard_config,
|
2023-07-21 02:46:39 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
logits = self.qa_outputs(sequence_output)
|
|
|
|
start_logits, end_logits = logits.split(1, dim=-1)
|
|
|
|
start_logits = start_logits.squeeze(-1).contiguous()
|
|
|
|
end_logits = end_logits.squeeze(-1).contiguous()
|
|
|
|
|
|
|
|
total_loss = None
|
|
|
|
if start_positions is not None and end_positions is not None:
|
|
|
|
# If we are on multi-GPU, split add a dimension
|
|
|
|
if len(start_positions.size()) > 1:
|
|
|
|
start_positions = start_positions.squeeze(-1)
|
|
|
|
if len(end_positions.size()) > 1:
|
|
|
|
end_positions = end_positions.squeeze(-1)
|
|
|
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
|
|
|
ignored_index = start_logits.size(1)
|
|
|
|
start_positions = start_positions.clamp(0, ignored_index)
|
|
|
|
end_positions = end_positions.clamp(0, ignored_index)
|
|
|
|
|
|
|
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
|
|
|
start_loss = loss_fct(start_logits, start_positions)
|
|
|
|
end_loss = loss_fct(end_logits, end_positions)
|
|
|
|
total_loss = (start_loss + end_loss) / 2
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
output = (start_logits, end_logits) + outputs[2:]
|
|
|
|
return ((total_loss,) + output) if total_loss is not None else output
|
|
|
|
|
|
|
|
return QuestionAnsweringModelOutput(
|
|
|
|
loss=total_loss,
|
|
|
|
start_logits=start_logits,
|
|
|
|
end_logits=end_logits,
|
|
|
|
hidden_states=outputs.hidden_states,
|
|
|
|
attentions=outputs.attentions,
|
|
|
|
)
|
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
hidden_states = outputs.get("hidden_states")
|
|
|
|
return {"hidden_states": hidden_states}
|
2023-08-07 08:41:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
def get_jit_fused_bloom_attention_forward():
|
|
|
|
from transformers.models.bloom.modeling_bloom import BloomAttention
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self: BloomAttention,
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
residual: torch.Tensor,
|
|
|
|
alibi: torch.Tensor,
|
|
|
|
attention_mask: torch.Tensor,
|
|
|
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
|
|
use_cache: bool = False,
|
|
|
|
output_attentions: bool = False,
|
|
|
|
):
|
2023-09-19 06:20:26 +00:00
|
|
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
2023-08-07 08:41:07 +00:00
|
|
|
|
|
|
|
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
|
|
|
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
|
|
|
|
|
|
|
batch_size, q_length, _, _ = query_layer.shape
|
|
|
|
|
|
|
|
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
|
|
|
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
|
|
|
|
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
|
|
|
if layer_past is not None:
|
|
|
|
past_key, past_value = layer_past
|
|
|
|
# concatenate along seq_length dimension:
|
|
|
|
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
|
|
|
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
|
|
|
key_layer = torch.cat((past_key, key_layer), dim=2)
|
|
|
|
value_layer = torch.cat((past_value, value_layer), dim=1)
|
|
|
|
|
|
|
|
_, _, kv_length = key_layer.shape
|
|
|
|
|
|
|
|
if use_cache is True:
|
|
|
|
present = (key_layer, value_layer)
|
|
|
|
else:
|
|
|
|
present = None
|
|
|
|
|
|
|
|
# [batch_size * num_heads, q_length, kv_length]
|
|
|
|
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
|
|
|
matmul_result = alibi.baddbmm(
|
|
|
|
batch1=query_layer,
|
|
|
|
batch2=key_layer,
|
|
|
|
beta=self.beta,
|
|
|
|
alpha=self.inv_norm_factor,
|
|
|
|
)
|
|
|
|
|
|
|
|
# change view to [batch_size, num_heads, q_length, kv_length]
|
|
|
|
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
|
|
|
|
|
|
|
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
|
|
|
input_dtype = attention_scores.dtype
|
|
|
|
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
|
|
|
if input_dtype == torch.float16:
|
|
|
|
attention_scores = attention_scores.to(torch.float)
|
|
|
|
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
|
|
|
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
|
|
|
|
|
|
|
# [batch_size, num_heads, q_length, kv_length]
|
|
|
|
attention_probs = self.attention_dropout(attention_probs)
|
|
|
|
|
|
|
|
if head_mask is not None:
|
|
|
|
attention_probs = attention_probs * head_mask
|
|
|
|
|
|
|
|
# change view [batch_size x num_heads, q_length, kv_length]
|
|
|
|
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
|
|
|
|
|
|
|
|
# matmul: [batch_size * num_heads, q_length, head_dim]
|
|
|
|
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
|
|
|
|
|
|
|
# change view [batch_size, num_heads, q_length, head_dim]
|
|
|
|
context_layer = self._merge_heads(context_layer)
|
|
|
|
|
|
|
|
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
|
|
|
if self.pretraining_tp > 1 and self.slow_but_exact:
|
|
|
|
slices = self.hidden_size / self.pretraining_tp
|
|
|
|
output_tensor = torch.zeros_like(context_layer)
|
|
|
|
for i in range(self.pretraining_tp):
|
|
|
|
output_tensor = output_tensor + F.linear(
|
2023-09-19 06:20:26 +00:00
|
|
|
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
|
|
|
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
2023-08-07 08:41:07 +00:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
output_tensor = self.dense(context_layer)
|
|
|
|
|
|
|
|
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
|
|
|
|
|
|
|
outputs = (output_tensor, present)
|
|
|
|
if output_attentions:
|
|
|
|
outputs += (attention_probs,)
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
|
|
|
def get_jit_fused_bloom_mlp_forward():
|
|
|
|
from transformers.models.bloom.modeling_bloom import BloomMLP
|
|
|
|
|
|
|
|
def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
|
|
|
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
|
|
|
|
|
|
|
|
if self.pretraining_tp > 1 and self.slow_but_exact:
|
|
|
|
intermediate_output = torch.zeros_like(residual)
|
|
|
|
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
|
|
|
|
for i in range(self.pretraining_tp):
|
|
|
|
intermediate_output = intermediate_output + F.linear(
|
2023-09-19 06:20:26 +00:00
|
|
|
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
|
|
|
|
self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
|
2023-08-07 08:41:07 +00:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
intermediate_output = self.dense_4h_to_h(hidden_states)
|
|
|
|
output = self.dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
|
|
|
|
return output
|
|
|
|
|
|
|
|
return forward
|
|
|
|
|
|
|
|
|
|
|
|
def get_jit_fused_bloom_gelu_forward():
|
|
|
|
from transformers.models.bloom.modeling_bloom import BloomGelu
|
|
|
|
|
|
|
|
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
|
|
|
|
|
|
|
def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
bias = torch.zeros_like(x)
|
|
|
|
if self.training:
|
|
|
|
return JitGeLUFunction.apply(x, bias)
|
|
|
|
else:
|
|
|
|
return self.bloom_gelu_forward(x, bias)
|
|
|
|
|
|
|
|
return forward
|
2023-08-18 07:34:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|
|
|
from transformers import BloomModel
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self: BloomModel,
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
head_mask: Optional[torch.LongTensor] = None,
|
|
|
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
**deprecated_arguments,
|
|
|
|
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
|
|
|
if deprecated_arguments.pop("position_ids", False) is not False:
|
|
|
|
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
|
|
|
warnings.warn(
|
|
|
|
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
|
|
|
" passing `position_ids`.",
|
|
|
|
FutureWarning,
|
|
|
|
)
|
|
|
|
if len(deprecated_arguments) > 0:
|
|
|
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
|
|
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
2023-09-19 06:20:26 +00:00
|
|
|
output_hidden_states = (
|
|
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
)
|
2023-08-18 07:34:18 +00:00
|
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
|
elif input_ids is not None:
|
|
|
|
batch_size, seq_length = input_ids.shape
|
|
|
|
elif inputs_embeds is not None:
|
|
|
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
|
|
else:
|
|
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
|
|
|
|
if past_key_values is None:
|
|
|
|
past_key_values = tuple([None] * len(self.h))
|
|
|
|
|
|
|
|
# Prepare head mask if needed
|
|
|
|
# 1.0 in head_mask indicate we keep the head
|
|
|
|
# attention_probs has shape batch_size x num_heads x N x N
|
|
|
|
# head_mask has shape n_layer x batch x num_heads x N x N
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
|
|
|
|
|
|
|
if inputs_embeds is None:
|
|
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
|
|
|
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
|
|
|
|
|
|
presents = () if use_cache else None
|
|
|
|
all_self_attentions = () if output_attentions else None
|
|
|
|
all_hidden_states = () if output_hidden_states else None
|
|
|
|
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
if use_cache:
|
|
|
|
logger.warning_once(
|
2023-09-19 06:20:26 +00:00
|
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
|
|
)
|
2023-08-18 07:34:18 +00:00
|
|
|
use_cache = False
|
|
|
|
|
|
|
|
# Compute alibi tensor: check build_alibi_tensor documentation
|
|
|
|
seq_length_with_past = seq_length
|
|
|
|
past_key_values_length = 0
|
|
|
|
if past_key_values[0] is not None:
|
|
|
|
past_key_values_length = past_key_values[0][0].shape[2]
|
|
|
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
|
|
if attention_mask is None:
|
|
|
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
|
|
|
else:
|
|
|
|
attention_mask = attention_mask.to(hidden_states.device)
|
|
|
|
|
|
|
|
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
|
|
|
|
2024-04-24 14:51:50 +00:00
|
|
|
causal_mask = _prepare_4d_causal_attention_mask(
|
2023-08-18 07:34:18 +00:00
|
|
|
attention_mask,
|
|
|
|
input_shape=(batch_size, seq_length),
|
2024-04-24 14:51:50 +00:00
|
|
|
inputs_embeds=hidden_states,
|
2023-08-18 07:34:18 +00:00
|
|
|
past_key_values_length=past_key_values_length,
|
|
|
|
)
|
2024-04-24 14:51:50 +00:00
|
|
|
causal_mask = causal_mask.bool()
|
2023-08-18 07:34:18 +00:00
|
|
|
# split the input tensor along sequence dimension
|
|
|
|
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
2023-09-19 06:20:26 +00:00
|
|
|
hidden_states = split_forward_gather_backward(
|
2024-08-12 10:17:05 +00:00
|
|
|
hidden_states,
|
|
|
|
dim=1,
|
|
|
|
process_group=shard_config.tensor_parallel_process_group,
|
|
|
|
fp8_communication=shard_config.fp8_communication,
|
2023-09-19 06:20:26 +00:00
|
|
|
)
|
2023-08-18 07:34:18 +00:00
|
|
|
|
|
|
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
|
|
|
if output_hidden_states:
|
|
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
|
|
|
|
if self.gradient_checkpointing and self.training:
|
2024-04-24 14:51:50 +00:00
|
|
|
outputs = self._gradient_checkpointing_func(
|
|
|
|
block.__call__,
|
2023-08-18 07:34:18 +00:00
|
|
|
hidden_states,
|
|
|
|
alibi,
|
|
|
|
causal_mask,
|
|
|
|
layer_past,
|
|
|
|
head_mask[i],
|
2024-04-24 14:51:50 +00:00
|
|
|
use_cache,
|
|
|
|
output_attentions,
|
2023-08-18 07:34:18 +00:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
outputs = block(
|
|
|
|
hidden_states,
|
|
|
|
layer_past=layer_past,
|
|
|
|
attention_mask=causal_mask,
|
|
|
|
head_mask=head_mask[i],
|
|
|
|
use_cache=use_cache,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
alibi=alibi,
|
|
|
|
)
|
|
|
|
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
if use_cache is True:
|
|
|
|
presents = presents + (outputs[1],)
|
|
|
|
|
|
|
|
if output_attentions:
|
|
|
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
|
|
|
|
|
|
|
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
2023-09-19 06:20:26 +00:00
|
|
|
hidden_states = gather_forward_split_backward(
|
2024-08-12 10:17:05 +00:00
|
|
|
hidden_states,
|
|
|
|
dim=1,
|
|
|
|
process_group=shard_config.tensor_parallel_process_group,
|
|
|
|
fp8_communication=shard_config.fp8_communication,
|
2023-09-19 06:20:26 +00:00
|
|
|
)
|
2023-08-18 07:34:18 +00:00
|
|
|
# Add last hidden state
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
|
|
|
|
if output_hidden_states:
|
|
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
|
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
|
|
last_hidden_state=hidden_states,
|
|
|
|
past_key_values=presents,
|
|
|
|
hidden_states=all_hidden_states,
|
|
|
|
attentions=all_self_attentions,
|
|
|
|
)
|
|
|
|
|
|
|
|
return forward
|
2024-05-21 03:07:13 +00:00
|
|
|
|
|
|
|
|
|
|
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|
|
|
from transformers import BloomForCausalLM
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self: BloomForCausalLM,
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
|
|
labels: Optional[torch.Tensor] = None,
|
|
|
|
use_cache: Optional[bool] = None,
|
|
|
|
output_attentions: Optional[bool] = None,
|
|
|
|
output_hidden_states: Optional[bool] = None,
|
|
|
|
return_dict: Optional[bool] = None,
|
|
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
|
|
r"""
|
|
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
|
|
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
|
|
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
|
|
|
"""
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
output_hidden_states = (
|
|
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
)
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
transformer_outputs = self.transformer(
|
|
|
|
input_ids=input_ids,
|
|
|
|
past_key_values=past_key_values,
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
head_mask=head_mask,
|
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
|
use_cache=use_cache,
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
|
return_dict=return_dict,
|
|
|
|
)
|
|
|
|
past_key_values = None
|
|
|
|
hidden_states = transformer_outputs[0]
|
|
|
|
lm_logits = self.lm_head(hidden_states)
|
|
|
|
|
2024-09-10 04:06:50 +00:00
|
|
|
loss = None
|
|
|
|
if labels is not None:
|
|
|
|
loss = dist_cross_entropy(
|
|
|
|
labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
|
|
|
|
)
|
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
|
|
|
|
2024-05-21 03:07:13 +00:00
|
|
|
if not return_dict:
|
|
|
|
output = (lm_logits,) + transformer_outputs[1:]
|
|
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
|
|
|
|
return CausalLMOutputWithPast(
|
|
|
|
loss=loss,
|
|
|
|
logits=lm_logits,
|
|
|
|
past_key_values=transformer_outputs.past_key_values,
|
|
|
|
hidden_states=transformer_outputs.hidden_states,
|
|
|
|
attentions=transformer_outputs.attentions,
|
|
|
|
)
|
|
|
|
|
|
|
|
return forward
|