[Shardformer] Merge flash attention branch to pipeline branch (#4362)

* [shardformer] supported flash attention test dependency (#4158)

* [shardformer] fix flash attention utils test (#4180)

* [shardformer] opt support flash attention (#4163)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] add performance benchmark of shardformer (#4175)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] benchmark fix

* [shardformer] benchmark fix

* [shardformer] llama support flash attention (#4185)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] llama support flash attention

* [shardformer] llama support flash attention

* [shardformer] Move the import statement for xformer outside the forward function.

* [shardformer] gpt2 support flash attention. (#4191)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] gpt2 support flash attention

* [shardformer] gpt2 support flash attention

* [shardformer] gpt2 support flash attention

* [shardformer] bloom support flash attention (#4188)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bloom suport flash attention

* [shardformer] add assert to sequence length

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

* [shardformer] bert support flash attention. (#4206)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bert support flash attention

* [shardformer] t5 support flash attention. (#4216)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] t5 support flash attention

* [shardformer] t5 support flash attention

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* [shardformer] support 'paddedcausal'  type of attention mask in Coloattention. (#4215)

* added padded causal attn mask type for ColoAttention

* [shardformer]t5 flash attention fix (#4239)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] t5 flash attention fix

* [shardformer] update gpt2 to use coloattention. (#4234)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2

* [shardformer] update opt and llama to use coloattention. (#4226)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt

* [shardformer] shardformer support jit fused operator. (#4236)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bloom support jit fused operator

* [shardformer] bloom support jit fused operator

* [shardformer] bloom support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] add roadmap of flash attention

* [shardformer] add roadmap of flash attention

* [shardformer] add roadmap of flash attention

* [shardformer] add type hint to 'self' param of forward

* [shardformer] merge feature/shardformer-models branch to feature/flash-attention-shardformer branch. (#4290)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

---------

Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>

* [shardformer] whisper support flash attention (#4301)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] whisper support flash attention

* [shardformer] whisper support flash attention

* [shardformer]whisper support jit operator

---------

Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>

* [shardformer] sam support flash attention (#4316)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] sam support flash attention

---------

Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>

* [shardformer] merge blip2/chatglm  (#4321)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

---------

Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: klhhhhh <1412841649@qq.com>

* [shardformer] blip2 support flash attention and jit operator (#4325)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] blip2 support flash attention and jit operator

* [shardformer] blip2 support flash attention and jit operator

* [shardformer] blip2 support flash attention and jit operator

---------

Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: klhhhhh <1412841649@qq.com>

* [shardformer] chatglm support flash attention and jit operator (#4330)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

---------

Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: klhhhhh <1412841649@qq.com>

* [shardformer] vit support flash attention and jit operator (#4334)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] vit support flash attention and jit operator

* [shardformer] vit support flash attention and jit operator

---------

Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: klhhhhh <1412841649@qq.com>

* [pipeline] merge flash attention branch

* [pipeline] merge flash attention branch

* [pipeline] merge flash attention branch

* [pipeline] fix conflict

* [pipeline] fix conflict

* Merge branch 'feature/pipeline' into feature/pipeline

* Merge branch 'feature/pipeline' into feature/pipeline

* Merge branch 'feature/pipeline' into feature/pipeline

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* fix flash attention tests

* gemini ignore whisper

* fix vit

* fix xformers import handle

---------

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: klhhhhh <1412841649@qq.com>
pull/4445/head
flybird1111 2023-08-07 16:41:07 +08:00 committed by Hongxin Liu
parent a88e92251d
commit 906426cb44
52 changed files with 2060 additions and 1555 deletions

View File

@ -30,7 +30,7 @@
### Quick Start
The sample API usage is given below:
The sample API usage is given below(If you enable the use of flash attention, please install xformers.):
```python
from colossalai.shardformer import ShardConfig, Shard
@ -106,6 +106,20 @@ We will follow this roadmap to develop Shardformer:
- [ ] Multi-modal
- [x] SAM
- [x] BLIP-2
- [ ] Flash Attention Support
- [ ] NLP
- [x] BERT
- [x] T5
- [x] LlaMa
- [x] GPT2
- [x] OPT
- [x] BLOOM
- [ ] GLM
- [ ] RoBERTa
- [ ] ALBERT
- [ ] ERNIE
- [ ] GPT Neo
- [ ] GPT-J
## 💡 API Design
@ -373,11 +387,49 @@ pytest tests/test_shardformer
### System Performance
To be added.
We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate the performance improvement of Shardformer. We compared the training time between the original model and the shard model.
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
In the case of using 2 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 11.2ms | 17.2ms |
| 512 | 9.8ms | 19.5ms |
| 1024 | 19.6ms | 18.9ms |
| 2048 | 46.6ms | 30.8ms |
| 4096 | 160.5ms | 90.4ms |
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/performance_benchmark_gpus2.png" width="600" />
<br/>
</p>
In the case of using 4 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 10.0ms | 21.1ms |
| 512 | 11.5ms | 20.2ms |
| 1024 | 22.1ms | 20.6ms |
| 2048 | 46.9ms | 24.8ms |
| 4096 | 160.4ms | 68.0ms |
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/performance_benchmark_gpus4.png" width="600" />
<br/>
</p>
As shown in the figures above, when the sequence length is around 1000 or greater, the parallel optimization of Shardformer for long sequences starts to become evident.
### Convergence
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
| accuracy | f1 | loss | GPU number | model shard |
| :------: | :-----: | :-----: | :--------: | :---------: |

View File

@ -1,4 +1,4 @@
torchrun --standalone --nproc_per_node=4 shardformer_benchmark.py \
torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \
--model "bert" \
--pretrain "bert-base-uncased" \
--max_epochs 1 \

View File

@ -0,0 +1,86 @@
"""
Shardformer Benchmark
"""
import torch
import torch.distributed as dist
import transformers
import triton
import colossalai
from colossalai.shardformer import ShardConfig, ShardFormer
def data_gen(batch_size, seq_length):
input_ids = torch.randint(0, seq_length, (batch_size, seq_length), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_sequence_classification(batch_size, seq_length):
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen(batch_size, seq_length)
data['labels'] = torch.ones((batch_size), dtype=torch.long)
return data
MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)
# vary seq length for fixed head and batch=4
configs = [
triton.testing.Benchmark(x_names=['N_CTX'],
x_vals=[2**i for i in range(8, 13)],
line_arg='provider',
line_vals=['org_model', 'shard_model'],
line_names=['org_model', 'shard_model'],
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'lama_for_sequence_classification-batch-{BATCH}',
args={
'BATCH': BATCH,
'dtype': torch.float16,
'model_func': model_func
})
]
def train(model, data):
output = model(**data)
loss = output.logits.mean()
loss.backward()
@triton.testing.perf_report(configs)
def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, device="cuda"):
warmup = 10
rep = 100
# prepare data
data = data_gen_for_sequence_classification(BATCH, N_CTX)
data = {k: v.cuda() for k, v in data.items()}
model = model_func().to(device)
model.train()
if provider == "org_model":
fn = lambda: train(model, data)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "shard_model":
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model = shard_former.optimize(model).cuda()
fn = lambda: train(sharded_model, data)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
# start benchmark, command:
# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
if __name__ == "__main__":
colossalai.launch_from_torch({})
bench_shardformer.run(save_path='.', print_data=dist.get_rank() == 0)

View File

@ -1,5 +1,6 @@
import math
import warnings
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -962,3 +963,138 @@ class BertPipelineForwards:
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
def get_bert_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.bert.modeling_bert import BertAttention
def forward(
self: BertAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
final_attention_mask = None
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
final_attention_mask = relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
final_attention_mask = relative_position_scores_query + relative_position_scores_key
scale = 1 / math.sqrt(self.attention_head_size)
if attention_mask is not None:
if final_attention_mask != None:
final_attention_mask = final_attention_mask * scale + attention_mask
else:
final_attention_mask = attention_mask
batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
tgt_len = key_layer.size()[2]
final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, tgt_len)
query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
value_layer = value_layer.permute(0, 2, 1, 3).contiguous()
context_layer = me_attention(query_layer,
key_layer,
value_layer,
attn_bias=final_attention_mask,
p=self.dropout.p,
scale=scale)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, None)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
return forward
def get_jit_fused_bert_self_output_forward():
from transformers.models.bert.modeling_bert import BertSelfOutput
def forward(self: BertSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
return forward
def get_jit_fused_bert_output_forward():
from transformers.models.bert.modeling_bert import BertOutput
def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
return forward

View File

@ -1,3 +1,4 @@
import math
from typing import Optional, Tuple, Union
import torch
@ -58,3 +59,62 @@ def forward_fn():
return outputs
return forward
def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
def forward(
self: Blip2Attention,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
mixed_qkv = self.qkv(hidden_states)
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4)
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
attention = ColoAttention(embed_dim=self.embed_dim,
num_heads=self.num_heads,
dropout=self.dropout.p,
scale=self.scale)
context_layer = attention(query_states, key_states, value_states)
output = self.projection(context_layer)
outputs = (output, None)
return outputs
return forward
def get_jit_fused_blip2_QFormer_self_output_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput
def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
return forward
def get_jit_fused_blip2_QFormer_output_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput
def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
return forward

View File

@ -5,6 +5,7 @@ import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@ -675,3 +676,223 @@ class BloomPipelineForwards:
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
def get_bloom_flash_attention_forward(enabel_jit_fused=False):
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.bloom.modeling_bloom import BloomAttention
def forward(
self: BloomAttention,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states)
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, tgt_len, _ = hidden_states.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
_, kv_length, _, _ = key_layer.size()
proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim)
query_layer = query_layer.contiguous().view(*proj_shape)
key_layer = key_layer.contiguous().view(*proj_shape)
value_layer = value_layer.contiguous().view(*proj_shape)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
tgt_len = key_layer.size()[1]
attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length),
dtype=torch.float32,
device=query_layer.device,
requires_grad=True)
attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1,
kv_length) * self.beta
attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask,
torch.finfo(torch.float32).min)
context_layer = me_attention(query_layer,
key_layer,
value_layer,
attn_bias=attention_numerical_mask,
scale=self.inv_norm_factor,
p=self.attention_dropout.p)
context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
# TODO to replace with the bias_dropout_add function in jit
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
outputs = (output_tensor, present, None)
return outputs
return forward
def get_jit_fused_bloom_attention_forward():
from transformers.models.bloom.modeling_bloom import BloomAttention
def forward(
self: BloomAttention,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=2)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, _, kv_length = key_layer.shape
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
# [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
matmul_result = alibi.baddbmm(
batch1=query_layer,
batch2=key_layer,
beta=self.beta,
alpha=self.inv_norm_factor,
)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16:
attention_scores = attention_scores.to(torch.float)
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# change view [batch_size x num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = self._merge_heads(context_layer)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
outputs = (output_tensor, present)
if output_attentions:
outputs += (attention_probs,)
return outputs
return forward
def get_jit_fused_bloom_mlp_forward():
from transformers.models.bloom.modeling_bloom import BloomMLP
def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
if self.pretraining_tp > 1 and self.slow_but_exact:
intermediate_output = torch.zeros_like(residual)
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
for i in range(self.pretraining_tp):
intermediate_output = intermediate_output + F.linear(
hidden_states[:, :, int(i * slices):int((i + 1) * slices)],
self.dense_4h_to_h.weight[:, int(i * slices):int((i + 1) * slices)],
)
else:
intermediate_output = self.dense_4h_to_h(hidden_states)
output = self.dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
return output
return forward
def get_jit_fused_bloom_gelu_forward():
from transformers.models.bloom.modeling_bloom import BloomGelu
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor:
bias = torch.zeros_like(x)
if self.training:
return JitGeLUFunction.apply(x, bias)
else:
return self.bloom_gelu_forward(x, bias)
return forward

View File

@ -17,6 +17,116 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
)
def get_flash_core_attention_forward():
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
from .chatglm2_6b.modeling_chatglm import CoreAttention
def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
pytorch_major_version = int(torch.__version__.split(".")[0])
if pytorch_major_version >= 2:
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
is_causal=True)
else:
if attention_mask is not None:
attention_mask = ~attention_mask
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
attention_mask)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
else:
# Raw attention scores
query_layer = query_layer.permute(1, 0, 2, 3).contiguous()
key_layer = key_layer.permute(1, 0, 2, 3).contiguous()
value_layer = value_layer.permute(1, 0, 2, 3).contiguous()
scale = 1.0 / self.norm_factor
if self.coeff is not None:
scale = scale * self.coeff
flash_attention_mask = None
attn_mask_type = None
if attention_mask is None:
attn_mask_type = AttnMaskType.causal
else:
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.hidden_size_per_partition,
num_heads=self.num_attention_heads_per_partition,
dropout=self.attention_dropout.p,
scale=scale)
context_layer = attention(query_layer,
key_layer,
value_layer,
attn_mask=flash_attention_mask,
attn_mask_type=attn_mask_type)
context_layer = context_layer.permute(1, 0, -1).contiguous()
return context_layer
return forward
def get_jit_fused_glm_block_forward():
from .chatglm2_6b.modeling_chatglm import GLMBlock
def forward(
self: GLMBlock,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=None,
use_cache=True,
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, kv_cache = self.self_attention(
layernorm_output,
attention_mask,
rotary_pos_emb,
kv_cache=kv_cache,
use_cache=use_cache,
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
layernorm_input = self.dropout_add(attention_output, residual, self.hidden_dropout, self.training)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = self.dropout_add(mlp_output, residual, self.hidden_dropout, self.training)
return output, kv_cache
return forward
class ChatGLMPipelineForwards:
'''
This class serves as a micro library for ChatGLM model forwards under pipeline parallelism.

View File

@ -668,3 +668,88 @@ class GPT2PipelineForwards:
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def get_gpt2_flash_attention_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
def split_heads(tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor
def forward(
self: GPT2Attention,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
_, tgt_len, _ = hidden_states.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.")
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = split_heads(query, self.num_heads, self.head_dim)
key = split_heads(key, self.num_heads, self.head_dim)
value = split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)
if use_cache is True:
present = (key, value)
else:
present = None
if not self.is_cross_attention:
attn_mask_type = AttnMaskType.causal
flash_attention_mask = None
if attention_mask != None:
if attn_mask_type == AttnMaskType.causal:
attn_mask_type == AttnMaskType.paddedcausal
else:
attn_mask_type = AttnMaskType.padding
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
scale = value.size(-1)**-0.5
if self.scale_attn_by_inverse_layer_idx:
scale = scale * (1 / float(self.layer_idx + 1))
# use coloattention
attention = ColoAttention(embed_dim=self.embed_dim,
num_heads=self.num_heads,
dropout=self.attn_dropout.p,
scale=scale)
attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present, None)
return outputs
return forward

View File

@ -0,0 +1,34 @@
import torch
def get_dropout_add_func():
from transformers.models.bloom.modeling_bloom import dropout_add
def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
return dropout_add(x, residual, prob, training)
return self_dropout_add
def get_jit_fused_dropout_add_func():
from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train
def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
bias = torch.zeros_like(x)
if training:
return bias_dropout_add_fused_train(x, bias, residual, prob)
return bias_dropout_add_fused_inference(x, bias, residual, prob)
return self_dropout_add
def get_jit_fused_gelu_forward_func():
from colossalai.kernel.jit.bias_gelu import bias_gelu
def bloom_gelu_forward(x: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
return bias_gelu(bias, x)
return bloom_gelu_forward

View File

@ -1,4 +1,4 @@
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Tuple
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -386,3 +386,67 @@ class LlamaPipelineForwards:
else:
hidden_states = transformer_outputs.get('hidden_states')
return {'hidden_states': hidden_states}
def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
def forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
flash_attention_mask = None
attn_mask_type = AttnMaskType.causal
if attention_mask != None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(query_states,
key_states,
value_states,
attn_mask=flash_attention_mask,
attn_mask_type=attn_mask_type)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return forward

View File

@ -0,0 +1,174 @@
from typing import Optional, Tuple
import torch
from torch import nn
def get_opt_flash_attention_forward():
from transformers.models.opt.modeling_opt import OPTAttention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
def forward(
self: OPTAttention,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
# get query proj
query_states = self.q_proj(hidden_states).view(*attention_input_shape)
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k, v, cross_attentions
key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape)
value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape)
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*attention_input_shape)
value_states = self.v_proj(key_value_states).view(*attention_input_shape)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
value_states = self.v_proj(hidden_states).view(*attention_input_shape)
key_states = torch.cat([past_key_value[0], key_states], dim=1)
value_states = torch.cat([past_key_value[1], value_states], dim=1)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
value_states = self.v_proj(hidden_states).view(*attention_input_shape)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
src_len = key_states.size(1)
if layer_head_mask != None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}")
flash_attention_mask = None
attn_mask_type = AttnMaskType.causal
if attention_mask != None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}")
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.embed_dim,
num_heads=self.num_heads,
dropout=self.dropout,
scale=self.scaling)
attn_output = attention(query_states,
key_states,
value_states,
attn_mask=flash_attention_mask,
attn_mask_type=attn_mask_type)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
return forward
def get_jit_fused_opt_decoder_layer_forward():
from transformers.models.opt.modeling_opt import OPTDecoderLayer
def forward(
self: OPTDecoderLayer,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training).view(hidden_states_shape)
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
return forward

View File

@ -1,4 +1,9 @@
import math
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
def forward_fn():
@ -37,3 +42,162 @@ def forward_fn():
return outputs
return forward
def get_sam_flash_attention_forward():
from transformers.models.sam.modeling_sam import SamAttention
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads
hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
return hidden_states
def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor:
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
def forward(self: SamAttention,
query: Tensor,
key: Tensor,
value: Tensor,
attention_similarity: Tensor = None) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
point_batch_size = query.shape[1]
# Separate into heads
query = _separate_heads(query, self.num_attention_heads)
key = _separate_heads(key, self.num_attention_heads)
value = _separate_heads(value, self.num_attention_heads)
# SamAttention
_, _, _, c_per_head = query.shape
bias = None
if attention_similarity is not None:
bias = attention_similarity
scale = 1.0 / math.sqrt(c_per_head)
out = me_attention(query, key, value, attn_bias=bias, scale=scale)
out = _recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out
return forward
def get_sam_vision_flash_attention_forward():
from transformers.models.sam.modeling_sam import SamVisionAttention
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
def add_decomposed_rel_pos(
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (`torch.Tensor`):
attention map.
query (`torch.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`torch.Tensor`):
relative position embeddings (Lh, channel) for height axis.
rel_pos_w (`torch.Tensor`):
relative position embeddings (Lw, channel) for width axis.
q_size (tuple):
spatial sequence size of query q with (query_height, query_width).
k_size (tuple):
spatial sequence size of key k with (key_height, key_width).
Returns:
attn (`torch.Tensor`):
attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w)
batch_size, _, nHead, dim = query.shape
reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width)
return rel_pos
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int):
size of the query.
k_size (int):
size of key k.
rel_pos (`torch.Tensor`):
relative position embeddings (L, channel).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
# qkv with shape (3, batch_size, nHead, height * width, channel)
qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads,
-1).permute(2, 0, 1, 3, 4))
query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0)
rel_pos = None
if self.use_rel_pos:
rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width))
attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale)
attn_output = attn_output.reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
outputs = (attn_output, None)
return outputs
return forward

View File

@ -587,3 +587,209 @@ class T5PipelineForwards:
decoder_starting_stage=decoder_starting_stage)
return outputs
def get_t5_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.t5.modeling_t5 import T5Attention
def forward(
self: T5Attention,
hidden_states: torch.Tensor,
mask: Optional[torch.Tensor] = None,
key_value_states: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
layer_head_mask: Optional[torch.Tensor] = None,
query_length: Optional[int] = None,
use_cache: bool = False,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length
if past_key_value is not None:
if len(past_key_value) != 2:
raise ValueError(
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
)
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
def unshape(states):
"""reshape"""
return states.view(batch_size, -1, self.inner_dim)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=1)
elif past_key_value.shape[1] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else:
# cross-attn
hidden_states = past_key_value
return hidden_states
# get query states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
# get key/value states
key_states = project(hidden_states, self.k, key_value_states,
past_key_value[0] if past_key_value is not None else None)
value_states = project(hidden_states, self.v, key_value_states,
past_key_value[1] if past_key_value is not None else None)
if position_bias is None:
if not self.has_relative_attention_bias:
position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length),
device=query_states.device,
dtype=query_states.dtype)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(real_seq_length, key_length, device=query_states.device)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1):, :]
if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias
position_bias_masked = position_bias_masked.contiguous()
attn_output = me_attention(query_states,
key_states,
value_states,
attn_bias=position_bias_masked,
p=self.dropout,
scale=1.0)
attn_output = unshape(attn_output)
attn_output = self.o(attn_output)
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
return outputs
return forward
def get_jit_fused_T5_layer_ff_forward():
from transformers.models.t5.modeling_t5 import T5LayerFF
def forward(self: T5LayerFF, hidden_states: torch.Tensor) -> torch.Tensor:
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = self.dropout_add(forwarded_states, hidden_states, self.dropout.p, self.dropout.training)
return hidden_states
return forward
def get_T5_layer_self_attention_forward():
from transformers.models.t5.modeling_t5 import T5LayerSelfAttention
def forward(
self: T5LayerSelfAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs
return forward
def get_T5_layer_cross_attention_forward():
from transformers.models.t5.modeling_t5 import T5LayerCrossAttention
def forward(
self: T5LayerCrossAttention,
hidden_states: torch.Tensor,
key_value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
query_length: Optional[int] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
normed_hidden_states,
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
query_length=query_length,
output_attentions=output_attentions,
)
layer_output = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs
return forward

View File

@ -1,4 +1,5 @@
import logging
import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
@ -335,3 +336,51 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
)
return pp_forward
def get_vit_flash_self_attention_forward():
from transformers.models.vit.modeling_vit import ViTSelfAttention
from colossalai.kernel.cuda_native.flash_attention import ColoAttention
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
x = x.view(new_x_shape)
return x
def forward(self: ViTSelfAttention,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size)
value_layer = transpose_for_scores(self.value(hidden_states), self.num_attention_heads,
self.attention_head_size)
query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size)
scale = 1.0 / math.sqrt(self.attention_head_size)
attention = ColoAttention(embed_dim=self.all_head_size,
num_heads=self.num_attention_heads,
dropout=self.dropout.p,
scale=scale)
context_layer = attention(query_layer, key_layer, value_layer)
outputs = (context_layer,)
return outputs
return forward
def get_jit_fused_vit_output_forward():
from transformers.models.vit.modeling_vit import ViTOutput
def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
return hidden_states
return forward

View File

@ -0,0 +1,249 @@
from typing import Optional, Tuple
import torch
from torch import nn
def get_whisper_flash_attention_forward():
from transformers.models.whisper.modeling_whisper import WhisperAttention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
def forward(
self: WhisperAttention,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (is_cross_attention and past_key_value is not None
and past_key_value[0].shape[1] == key_value_states.shape[1]):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
key_states = torch.cat([past_key_value[0], key_states], dim=1)
value_states = torch.cat([past_key_value[1], value_states], dim=1)
else:
# self_attention
key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
# get query proj
query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim)
src_len = key_states.size(1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}")
attn_type = None
flash_attention_mask = None
if self.is_decoder:
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
attn_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.embed_dim,
num_heads=self.num_heads,
dropout=self.dropout,
scale=self.scaling)
attn_output = attention(query_states,
key_states,
value_states,
attn_mask=flash_attention_mask,
attn_mask_type=attn_type)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
return forward
def get_jit_fused_whisper_encoder_layer_forward():
from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
def forward(
self: WhisperEncoderLayer,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
output_attentions: bool = False,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any()
or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
return forward
def get_jit_fused_whisper_decoder_layer_forward():
from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer
def forward(
self: WhisperDecoderLayer,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(decoder_attention_heads,)`.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# Fully Connected
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (present_key_value,)
return outputs
return forward

View File

@ -7,7 +7,14 @@ from torch.nn import Module
import colossalai.shardformer.layer as col_nn
from ..modeling.bert import BertPipelineForwards
from .._utils import getattr_, setattr_
from ..modeling.bert import (
BertPipelineForwards,
get_bert_flash_attention_forward,
get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward,
)
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@ -37,7 +44,13 @@ class BertPolicy(Policy):
return self.model
def module_policy(self):
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertLayer,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
policy = {}
@ -126,6 +139,23 @@ class BertPolicy(Policy):
policy=policy,
target_key=BertEmbeddings)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={
'forward': get_bert_flash_attention_forward(),
})
# use jit operator
if self.shard_config.enable_jit_fused:
policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_bert_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[BertOutput] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_bert_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy
def add_lm_head_policy(self, base_policy):

View File

@ -3,7 +3,13 @@ import torch.nn as nn
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.blip2 import forward_fn
from ..modeling.blip2 import (
forward_fn,
get_blip2_flash_attention_forward,
get_jit_fused_blip2_QFormer_output_forward,
get_jit_fused_blip2_QFormer_self_output_forward,
)
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['BlipPolicy', 'BlipModelPolicy']
@ -33,6 +39,8 @@ class BlipPolicy(Policy):
Blip2EncoderLayer,
Blip2QFormerLayer,
Blip2QFormerModel,
Blip2QFormerOutput,
Blip2QFormerSelfOutput,
Blip2VisionModel,
)
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM
@ -275,6 +283,24 @@ class BlipPolicy(Policy):
policy=policy,
target_key=OPTDecoderLayer)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={
'forward': get_blip2_flash_attention_forward(),
})
# use jit operator
if self.shard_config.enable_jit_fused:
policy[Blip2QFormerSelfOutput] = ModulePolicyDescription(
method_replacement={
'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_blip2_QFormer_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy
def postprocess(self):

View File

@ -7,7 +7,16 @@ from torch.nn import Module
import colossalai.shardformer.layer as col_nn
from ..modeling.bloom import BloomPipelineForwards, build_bloom_alibi_tensor_fn
from .._utils import getattr_, setattr_
from ..modeling.bloom import (
BloomPipelineForwards,
build_bloom_alibi_tensor_fn,
get_bloom_flash_attention_forward,
get_jit_fused_bloom_attention_forward,
get_jit_fused_bloom_gelu_forward,
get_jit_fused_bloom_mlp_forward,
)
from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -30,7 +39,7 @@ class BloomPolicy(Policy):
return self.model
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomGelu, BloomMLP, BloomModel
policy = {}
@ -107,6 +116,27 @@ class BloomPolicy(Policy):
policy=policy,
target_key=BloomBlock)
if self.shard_config.enable_flash_attention:
policy[BloomAttention] = ModulePolicyDescription(method_replacement={
'forward': get_bloom_flash_attention_forward(),
'dropout_add': get_dropout_add_func()
})
# enable jit fused operator
if self.shard_config.enable_jit_fused:
policy[BloomAttention] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_bloom_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[BloomMLP] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_bloom_mlp_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[BloomGelu] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_bloom_gelu_forward(),
'bloom_gelu_forward': get_jit_fused_gelu_forward_func(),
})
return policy
def postprocess(self):

View File

@ -15,6 +15,8 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
GLMBlock,
)
from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['ChatGLMPolicy', 'ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
@ -35,12 +37,11 @@ class ChatGLMPolicy(Policy):
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock
policy = {}
@ -121,6 +122,19 @@ class ChatGLMPolicy(Policy):
policy=policy,
target_key=ChatGLMModel)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[CoreAttention] = ModulePolicyDescription(method_replacement={
'forward': get_flash_core_attention_forward(),
})
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[GLMBlock] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_glm_block_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy
def postprocess(self):
@ -192,7 +206,6 @@ class ChatGLMModelPolicy(ChatGLMPolicy):
return []
class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
def module_policy(self):
@ -213,4 +226,3 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in ChatGLMForConditionalGenerationModel."""
return []

View File

@ -5,7 +5,8 @@ from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn
from ..modeling.gpt2 import GPT2PipelineForwards
from .._utils import getattr_, setattr_
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@ -33,7 +34,7 @@ class GPT2Policy(Policy):
return self.model
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
policy = {}
@ -114,6 +115,11 @@ class GPT2Policy(Policy):
],
policy=policy,
target_key=GPT2Block)
if self.shard_config.enable_flash_attention:
policy[GPT2Attention] = ModulePolicyDescription(method_replacement={
'forward': get_gpt2_flash_attention_forward(),
})
return policy
def postprocess(self):

View File

@ -7,7 +7,7 @@ from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from ..modeling.llama import LlamaPipelineForwards
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
@ -31,7 +31,7 @@ class LlamaPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
policy = {}
@ -104,6 +104,11 @@ class LlamaPolicy(Policy):
policy=policy,
target_key=LlamaModel)
if self.shard_config.enable_flash_attention:
policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
'forward': get_llama_flash_attention_forward(),
})
return policy
def postprocess(self):

View File

@ -25,6 +25,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_, setattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.opt import get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@ -114,6 +116,19 @@ class OPTPolicy(Policy):
policy=policy,
target_key=OPTDecoderLayer)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[OPTAttention] = ModulePolicyDescription(method_replacement={
'forward': get_opt_flash_attention_forward(),
})
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_opt_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy
def postprocess(self):
@ -189,13 +204,11 @@ class OPTForCausalLMPolicy(OPTPolicy):
from transformers.models.opt.modeling_opt import OPTForCausalLM
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=OPTForCausalLM)
if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=OPTForCausalLM,
new_forward=OPTPipelineForwards.opt_for_causal_lm_forward,

View File

@ -3,7 +3,7 @@ import torch.nn as nn
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.sam import forward_fn
from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['SamPolicy', 'SamModelPolicy']
@ -19,6 +19,7 @@ class SamPolicy(Policy):
def module_policy(self):
from transformers.models.sam.modeling_sam import (
SamAttention,
SamFeedForward,
SamTwoWayAttentionBlock,
SamTwoWayTransformer,
@ -196,6 +197,15 @@ class SamPolicy(Policy):
policy=policy,
target_key=SamTwoWayTransformer)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[SamAttention] = ModulePolicyDescription(method_replacement={
'forward': get_sam_flash_attention_forward(),
})
policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={
'forward': get_sam_vision_flash_attention_forward(),
})
return policy
def postprocess(self):

View File

@ -14,7 +14,14 @@ from colossalai.shardformer.layer import (
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
from .._utils import getattr_, setattr_
from ..modeling.t5 import T5PipelineForwards
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.t5 import (
T5PipelineForwards,
get_jit_fused_T5_layer_ff_forward,
get_t5_flash_attention_forward,
get_T5_layer_cross_attention_forward,
get_T5_layer_self_attention_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
@ -168,6 +175,27 @@ class T5BasePolicy(Policy):
suffix="final_layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5Stack)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[T5Attention] = ModulePolicyDescription(method_replacement={
'forward': get_t5_flash_attention_forward(),
})
# use jit operator
if self.shard_config.enable_jit_fused:
policy[T5LayerFF] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_T5_layer_ff_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={
'forward': get_T5_layer_self_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={
'forward': get_T5_layer_cross_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy
def postprocess(self):

View File

@ -3,11 +3,15 @@ from typing import Callable, Dict, List, Union
import torch.nn as nn
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.vit import (
ViTForImageClassification_pipeline_forward,
ViTForMaskedImageModeling_pipeline_forward,
ViTModel_pipeline_forward,
get_jit_fused_vit_output_forward,
get_vit_flash_self_attention_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -23,7 +27,8 @@ class ViTPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel, ViTOutput, ViTSelfAttention
policy = {}
@ -33,7 +38,7 @@ class ViTPolicy(Policy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
target_module=DropoutForReplicatedInput,
)
])
@ -83,8 +88,18 @@ class ViTPolicy(Policy):
),
])
return policy
# use flash attention
if self.shard_config.enable_flash_attention:
policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={
'forward': get_vit_flash_self_attention_forward(),
})
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[ViTOutput] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_vit_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy
def new_model_class(self):
@ -167,7 +182,7 @@ class ViTForImageClassificationPolicy(ViTPolicy):
ViTForImageClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)

View File

@ -3,6 +3,12 @@ import torch.nn as nn
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.whisper import (
get_jit_fused_whisper_decoder_layer_forward,
get_jit_fused_whisper_encoder_layer_forward,
get_whisper_flash_attention_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@ -30,6 +36,7 @@ class WhisperPolicy(Policy):
def module_policy(self):
from transformers.models.whisper.modeling_whisper import (
WhisperAttention,
WhisperDecoder,
WhisperDecoderLayer,
WhisperEncoder,
@ -181,6 +188,24 @@ class WhisperPolicy(Policy):
],
policy=policy,
target_key=WhisperDecoder)
# enable flash attention
if self.shard_config.enable_flash_attention:
policy[WhisperAttention] = ModulePolicyDescription(method_replacement={
'forward': get_whisper_flash_attention_forward(),
})
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_whisper_encoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_whisper_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy
def add_lm_head_policy(self, base_policy):

View File

@ -26,6 +26,8 @@ class ShardConfig:
enable_tensor_parallelism: bool = True
enable_fused_normalization: bool = False
enable_all_optimization: bool = False
enable_flash_attention: bool = False
enable_jit_fused: bool = False
# TODO: add support for tensor parallel
# pipeline_parallel_size: int
@ -44,7 +46,6 @@ class ShardConfig:
else:
# get the parallel size
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
# turn on all optimization if all_optimization is set to True
if self.enable_all_optimization:
self._turn_on_all_optimization()
@ -55,3 +56,5 @@ class ShardConfig:
"""
# you can add all the optimization flag here
self.enable_fused_normalization = True
self.enable_flash_attention = True
self.enable_jit_fused = True

View File

@ -18,3 +18,5 @@ SentencePiece
ninja
flash_attn>=2.0
datasets
ninja
flash-attn

View File

@ -20,7 +20,7 @@ def data_gen():
# token_type_ids = tokenized_input['token_type_ids']
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64)
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
@ -69,19 +69,21 @@ def data_gen_for_mcq():
# data['labels'] = torch.tensor([0], dtype=torch.int64)
input_ids = torch.tensor([[[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102
],
[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
2218, 1999, 1996, 2192, 1012, 102, 0
2218, 1999, 1996, 2192, 1012, 102, 0, 0
]]])
token_type_ids = torch.tensor(
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0]]])
attention_mask = torch.tensor(
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0]]])
labels = torch.tensor([0], dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)

View File

@ -38,6 +38,7 @@ output_transform_fn = lambda x: x
loss_fn_blip2_model = lambda x: x.loss
config = transformers.Blip2Config()
config.vision_config.patch_size = 14
config.text_config.num_hidden_layers = 1
config.qformer_config.num_hidden_layers = 1
config.vision_config.num_hidden_layers = 1

View File

@ -16,8 +16,8 @@ def data_gen():
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595, 632, 207595]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
@ -33,7 +33,7 @@ def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64)
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data
@ -53,8 +53,8 @@ def data_gen_for_question_answering():
# inputs = tokenizer(question, text, return_tensors="pt")
input_ids = torch.tensor(
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
start_positions = torch.tensor([1], dtype=torch.int64)
end_positions = torch.tensor([10], dtype=torch.int64)
return dict(input_ids=input_ids,

View File

@ -6,7 +6,6 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
from ..registry import ModelAttribute, model_zoo
# ================================
# Register single-sentence ChatGLM
# ================================

View File

@ -1,58 +0,0 @@
from transformers import PretrainedConfig
class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"
def __init__(self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(**kwargs)

View File

@ -18,8 +18,8 @@ def data_gen():
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
@ -46,7 +46,7 @@ def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 1]], dtype=torch.int64)
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
return data

View File

@ -16,8 +16,9 @@ def data_gen_for_encoder_only():
# config = T5Config(decoder_start_token_id=0)
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long()
return dict(input_ids=input_ids)
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long()
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_conditional_generation():
@ -25,17 +26,16 @@ def data_gen_for_conditional_generation():
#
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
data = data_gen_for_encoder_only()
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long()
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1]]).long()
data['labels'] = labels
return data
def data_gen_for_t5_model():
# decoder_inputs_ids is obtained with the following code
#
# decoder_input_ids = model._shift_right(input_ids)
data = data_gen_for_encoder_only()
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long()
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long()
data['decoder_input_ids'] = decoder_input_ids
return data

View File

@ -76,14 +76,14 @@ model_zoo.register(name='transformers_whisper',
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_whisperForConditionalGeneration',
model_zoo.register(name='transformers_whisper_for_conditional_generation',
model_fn=lambda: transformers.WhisperForConditionalGeneration(config),
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_attr,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_whisperWhisperForAudioClassification',
model_zoo.register(name='transformers_whisper_for_audio_classification',
model_fn=lambda: transformers.WhisperForAudioClassification(config),
data_gen_fn=data_gen_for_audio_classification,
output_transform_fn=output_transform_fn,

View File

@ -93,7 +93,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
'transformers_vit_for_image_classification', 'transformers_chatglm',
'transformers_chatglm_for_conditional_generation', 'transformers_blip2',
'transformers_blip2_conditional_gerneration', 'transformers_sam', 'transformers_whisper',
'transformers_whisperForConditionalGeneration', 'transformers_whisperWhisperForAudioClassification'
'transformers_whisper_for_conditional_generation', 'transformers_whisper_for_audio_classification'
]:
continue

View File

@ -21,7 +21,13 @@ from colossalai.shardformer._utils import getattr_
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
def build_model(model_fn,
enable_fused_normalization=True,
enable_tensor_parallelism=True,
enable_flash_attention=False,
enable_jit_fused=False,
use_lazy_init: bool = False):
# create new model
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
# create new model
@ -31,7 +37,10 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
ctx.materialize(org_model)
# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism)
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda(), sharded_model.cuda()

View File

@ -46,14 +46,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
check_grad(bert, sharded_bert, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False)
@parameterize('enable_fused_normalization', [False, True])
@parameterize('enable_tensor_parallelism', [False, True])
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
@parameterize('use_lazy_init', [False, True])
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
enable_flash_attention, enable_jit_fused, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

View File

@ -47,10 +47,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_blip2')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -44,13 +44,15 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
@parameterize('use_lazy_init', [False, True])
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
enable_flash_attention, enable_jit_fused, use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -72,7 +72,9 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create new model
@ -80,7 +82,9 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism)
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
if name == "transformers_chatglm":

View File

@ -68,7 +68,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
torch.cuda.empty_cache()
@parameterize('test_config', [{
'tp_size': 1,
'pp_size': 2,

View File

@ -49,12 +49,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('use_lazy_init', [False, True])
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
enable_flash_attention, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -42,18 +42,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check grad
col_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens']
row_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False)
check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False)
check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False)
check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False)
@parameterize('use_lazy_init', [False, True])
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('use_lazy_init', [False, True])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_opt_test(use_lazy_init, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention,
enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
enable_flash_attention, enable_jit_fused, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
@ -62,7 +65,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
def check_OPTModel(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_t5_test()
run_opt_test()
@pytest.mark.dist

View File

@ -41,10 +41,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_sam_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('enable_flash_attention', [True, False])
def run_sam_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention):
sub_model_zoo = model_zoo.get_sub_registry('transformers_sam')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -33,8 +33,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check grad
col_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.q', 'shared']
row_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.relative_attention_bias']
check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-7, rtol=1e-5, dim=0, verbose=False)
check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-7, rtol=1e-5, dim=1, verbose=False)
check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
# check weights are tied
if hasattr(org_model, 'lm_head'):
@ -45,11 +45,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('use_lazy_init', [False, True])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init, enable_flash_attention,
enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
enable_flash_attention, enable_jit_fused, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -20,7 +20,9 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3)
# do backward
org_loss.backward()
shard_loss.backward()
@ -45,10 +47,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -48,12 +48,16 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn,
enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism)
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()