mirror of https://github.com/hpcaitech/ColossalAI
[Shardformer] Merge flash attention branch to pipeline branch (#4362)
* [shardformer] supported flash attention test dependency (#4158) * [shardformer] fix flash attention utils test (#4180) * [shardformer] opt support flash attention (#4163) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] add performance benchmark of shardformer (#4175) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] benchmark fix * [shardformer] benchmark fix * [shardformer] llama support flash attention (#4185) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] llama support flash attention * [shardformer] llama support flash attention * [shardformer] Move the import statement for xformer outside the forward function. * [shardformer] gpt2 support flash attention. (#4191) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] gpt2 support flash attention * [shardformer] gpt2 support flash attention * [shardformer] gpt2 support flash attention * [shardformer] bloom support flash attention (#4188) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] bloom suport flash attention * [shardformer] add assert to sequence length * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] bert support flash attention. (#4206) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] bert support flash attention * [shardformer] t5 support flash attention. (#4216) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] t5 support flash attention * [shardformer] t5 support flash attention * fix typo * fix typo * fix typo * fix typo * fix typo * fix typo * [shardformer] support 'paddedcausal' type of attention mask in Coloattention. (#4215) * added padded causal attn mask type for ColoAttention * [shardformer]t5 flash attention fix (#4239) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] t5 flash attention fix * [shardformer] update gpt2 to use coloattention. (#4234) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 * [shardformer] update opt and llama to use coloattention. (#4226) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt * [shardformer] shardformer support jit fused operator. (#4236) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] bloom support jit fused operator * [shardformer] bloom support jit fused operator * [shardformer] bloom support jit fused operator * [shardformer] t5 support jit fused operator * [shardformer] t5 support jit fused operator * [shardformer] t5 support jit fused operator * [shardformer] add roadmap of flash attention * [shardformer] add roadmap of flash attention * [shardformer] add roadmap of flash attention * [shardformer] add type hint to 'self' param of forward * [shardformer] merge feature/shardformer-models branch to feature/flash-attention-shardformer branch. (#4290) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> * [shardformer] whisper support flash attention (#4301) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] whisper support flash attention * [shardformer] whisper support flash attention * [shardformer]whisper support jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> * [shardformer] sam support flash attention (#4316) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] sam support flash attention --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> * [shardformer] merge blip2/chatglm (#4321) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [shardformer] blip2 support flash attention and jit operator (#4325) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin * [shardformer] blip2 support flash attention and jit operator * [shardformer] blip2 support flash attention and jit operator * [shardformer] blip2 support flash attention and jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [shardformer] chatglm support flash attention and jit operator (#4330) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin * [shardformer] chatglm support flash attention and jit operator * [shardformer] chatglm support flash attention and jit operator * [shardformer] chatglm support flash attention and jit operator * [shardformer] chatglm support flash attention and jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [shardformer] vit support flash attention and jit operator (#4334) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin * [shardformer] vit support flash attention and jit operator * [shardformer] vit support flash attention and jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [pipeline] merge flash attention branch * [pipeline] merge flash attention branch * [pipeline] merge flash attention branch * [pipeline] fix conflict * [pipeline] fix conflict * Merge branch 'feature/pipeline' into feature/pipeline * Merge branch 'feature/pipeline' into feature/pipeline * Merge branch 'feature/pipeline' into feature/pipeline * activate checks * activate checks * activate checks * activate checks * activate checks * activate checks * activate checks * activate checks * fix flash attention tests * gemini ignore whisper * fix vit * fix xformers import handle --------- Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com>pull/4445/head
parent
a88e92251d
commit
906426cb44
|
@ -30,7 +30,7 @@
|
|||
|
||||
### Quick Start
|
||||
|
||||
The sample API usage is given below:
|
||||
The sample API usage is given below(If you enable the use of flash attention, please install xformers.):
|
||||
|
||||
```python
|
||||
from colossalai.shardformer import ShardConfig, Shard
|
||||
|
@ -106,6 +106,20 @@ We will follow this roadmap to develop Shardformer:
|
|||
- [ ] Multi-modal
|
||||
- [x] SAM
|
||||
- [x] BLIP-2
|
||||
- [ ] Flash Attention Support
|
||||
- [ ] NLP
|
||||
- [x] BERT
|
||||
- [x] T5
|
||||
- [x] LlaMa
|
||||
- [x] GPT2
|
||||
- [x] OPT
|
||||
- [x] BLOOM
|
||||
- [ ] GLM
|
||||
- [ ] RoBERTa
|
||||
- [ ] ALBERT
|
||||
- [ ] ERNIE
|
||||
- [ ] GPT Neo
|
||||
- [ ] GPT-J
|
||||
|
||||
## 💡 API Design
|
||||
|
||||
|
@ -373,11 +387,49 @@ pytest tests/test_shardformer
|
|||
|
||||
### System Performance
|
||||
|
||||
To be added.
|
||||
We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate the performance improvement of Shardformer. We compared the training time between the original model and the shard model.
|
||||
|
||||
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
|
||||
|
||||
In the case of using 2 GPUs, the training times are as follows.
|
||||
| N_CTX | org_model | shard_model |
|
||||
| :------: | :-----: | :-----: |
|
||||
| 256 | 11.2ms | 17.2ms |
|
||||
| 512 | 9.8ms | 19.5ms |
|
||||
| 1024 | 19.6ms | 18.9ms |
|
||||
| 2048 | 46.6ms | 30.8ms |
|
||||
| 4096 | 160.5ms | 90.4ms |
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/performance_benchmark_gpus2.png" width="600" />
|
||||
<br/>
|
||||
</p>
|
||||
|
||||
In the case of using 4 GPUs, the training times are as follows.
|
||||
|
||||
| N_CTX | org_model | shard_model |
|
||||
| :------: | :-----: | :-----: |
|
||||
| 256 | 10.0ms | 21.1ms |
|
||||
| 512 | 11.5ms | 20.2ms |
|
||||
| 1024 | 22.1ms | 20.6ms |
|
||||
| 2048 | 46.9ms | 24.8ms |
|
||||
| 4096 | 160.4ms | 68.0ms |
|
||||
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/performance_benchmark_gpus4.png" width="600" />
|
||||
<br/>
|
||||
</p>
|
||||
|
||||
|
||||
As shown in the figures above, when the sequence length is around 1000 or greater, the parallel optimization of Shardformer for long sequences starts to become evident.
|
||||
|
||||
### Convergence
|
||||
|
||||
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
|
||||
|
||||
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
|
||||
|
||||
| accuracy | f1 | loss | GPU number | model shard |
|
||||
| :------: | :-----: | :-----: | :--------: | :---------: |
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
torchrun --standalone --nproc_per_node=4 shardformer_benchmark.py \
|
||||
torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \
|
||||
--model "bert" \
|
||||
--pretrain "bert-base-uncased" \
|
||||
--max_epochs 1 \
|
|
@ -0,0 +1,86 @@
|
|||
"""
|
||||
Shardformer Benchmark
|
||||
"""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
import triton
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
||||
|
||||
def data_gen(batch_size, seq_length):
|
||||
input_ids = torch.randint(0, seq_length, (batch_size, seq_length), dtype=torch.long)
|
||||
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
def data_gen_for_sequence_classification(batch_size, seq_length):
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen(batch_size, seq_length)
|
||||
data['labels'] = torch.ones((batch_size), dtype=torch.long)
|
||||
return data
|
||||
|
||||
|
||||
MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128,
|
||||
num_labels=16)
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
|
||||
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)
|
||||
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [
|
||||
triton.testing.Benchmark(x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(8, 13)],
|
||||
line_arg='provider',
|
||||
line_vals=['org_model', 'shard_model'],
|
||||
line_names=['org_model', 'shard_model'],
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'lama_for_sequence_classification-batch-{BATCH}',
|
||||
args={
|
||||
'BATCH': BATCH,
|
||||
'dtype': torch.float16,
|
||||
'model_func': model_func
|
||||
})
|
||||
]
|
||||
|
||||
|
||||
def train(model, data):
|
||||
output = model(**data)
|
||||
loss = output.logits.mean()
|
||||
loss.backward()
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, device="cuda"):
|
||||
warmup = 10
|
||||
rep = 100
|
||||
# prepare data
|
||||
data = data_gen_for_sequence_classification(BATCH, N_CTX)
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
model = model_func().to(device)
|
||||
model.train()
|
||||
if provider == "org_model":
|
||||
fn = lambda: train(model, data)
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "shard_model":
|
||||
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model = shard_former.optimize(model).cuda()
|
||||
fn = lambda: train(sharded_model, data)
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
|
||||
# start benchmark, command:
|
||||
# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
|
||||
if __name__ == "__main__":
|
||||
colossalai.launch_from_torch({})
|
||||
bench_shardformer.run(save_path='.', print_data=dist.get_rank() == 0)
|
|
@ -1,5 +1,6 @@
|
|||
import math
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
@ -962,3 +963,138 @@ class BertPipelineForwards:
|
|||
else:
|
||||
hidden_states = outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
|
||||
def get_bert_flash_attention_forward():
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
from transformers.models.bert.modeling_bert import BertAttention
|
||||
|
||||
def forward(
|
||||
self: BertAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
final_attention_mask = None
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
final_attention_mask = relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
final_attention_mask = relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
scale = 1 / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
if final_attention_mask != None:
|
||||
final_attention_mask = final_attention_mask * scale + attention_mask
|
||||
else:
|
||||
final_attention_mask = attention_mask
|
||||
batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
|
||||
tgt_len = key_layer.size()[2]
|
||||
final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, tgt_len)
|
||||
|
||||
query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
|
||||
key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
|
||||
value_layer = value_layer.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
context_layer = me_attention(query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_bias=final_attention_mask,
|
||||
p=self.dropout.p,
|
||||
scale=scale)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, None)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bert_self_output_forward():
|
||||
|
||||
from transformers.models.bert.modeling_bert import BertSelfOutput
|
||||
|
||||
def forward(self: BertSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bert_output_forward():
|
||||
|
||||
from transformers.models.bert.modeling_bert import BertOutput
|
||||
|
||||
def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -58,3 +59,62 @@ def forward_fn():
|
|||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_blip2_flash_attention_forward():
|
||||
|
||||
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
|
||||
def forward(
|
||||
self: Blip2Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
mixed_qkv = self.qkv(hidden_states)
|
||||
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4)
|
||||
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
|
||||
|
||||
attention = ColoAttention(embed_dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout.p,
|
||||
scale=self.scale)
|
||||
context_layer = attention(query_states, key_states, value_states)
|
||||
|
||||
output = self.projection(context_layer)
|
||||
outputs = (output, None)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_blip2_QFormer_self_output_forward():
|
||||
|
||||
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput
|
||||
|
||||
def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_blip2_QFormer_output_forward():
|
||||
|
||||
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput
|
||||
|
||||
def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import functional as F
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
|
@ -675,3 +676,223 @@ class BloomPipelineForwards:
|
|||
else:
|
||||
hidden_states = outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
|
||||
def get_bloom_flash_attention_forward(enabel_jit_fused=False):
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
def forward(
|
||||
self: BloomAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
|
||||
fused_qkv = self.query_key_value(hidden_states)
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
batch_size, tgt_len, _ = hidden_states.size()
|
||||
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
||||
|
||||
_, kv_length, _, _ = key_layer.size()
|
||||
|
||||
proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim)
|
||||
query_layer = query_layer.contiguous().view(*proj_shape)
|
||||
key_layer = key_layer.contiguous().view(*proj_shape)
|
||||
value_layer = value_layer.contiguous().view(*proj_shape)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
tgt_len = key_layer.size()[1]
|
||||
|
||||
attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length),
|
||||
dtype=torch.float32,
|
||||
device=query_layer.device,
|
||||
requires_grad=True)
|
||||
attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1,
|
||||
kv_length) * self.beta
|
||||
attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask,
|
||||
torch.finfo(torch.float32).min)
|
||||
|
||||
context_layer = me_attention(query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_bias=attention_numerical_mask,
|
||||
scale=self.inv_norm_factor,
|
||||
p=self.attention_dropout.p)
|
||||
context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
slices = self.hidden_size / self.pretraining_tp
|
||||
output_tensor = torch.zeros_like(context_layer)
|
||||
for i in range(self.pretraining_tp):
|
||||
output_tensor = output_tensor + F.linear(
|
||||
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
# TODO to replace with the bias_dropout_add function in jit
|
||||
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||
outputs = (output_tensor, present, None)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bloom_attention_forward():
|
||||
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
def forward(
|
||||
self: BloomAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
|
||||
batch_size, q_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=2)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
_, _, kv_length = key_layer.shape
|
||||
|
||||
if use_cache is True:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
# [batch_size * num_heads, q_length, kv_length]
|
||||
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
||||
matmul_result = alibi.baddbmm(
|
||||
batch1=query_layer,
|
||||
batch2=key_layer,
|
||||
beta=self.beta,
|
||||
alpha=self.inv_norm_factor,
|
||||
)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
||||
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||
if input_dtype == torch.float16:
|
||||
attention_scores = attention_scores.to(torch.float)
|
||||
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
||||
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
||||
|
||||
# [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
# change view [batch_size x num_heads, q_length, kv_length]
|
||||
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
||||
|
||||
# change view [batch_size, num_heads, q_length, head_dim]
|
||||
context_layer = self._merge_heads(context_layer)
|
||||
|
||||
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
slices = self.hidden_size / self.pretraining_tp
|
||||
output_tensor = torch.zeros_like(context_layer)
|
||||
for i in range(self.pretraining_tp):
|
||||
output_tensor = output_tensor + F.linear(
|
||||
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||
|
||||
outputs = (output_tensor, present)
|
||||
if output_attentions:
|
||||
outputs += (attention_probs,)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bloom_mlp_forward():
|
||||
|
||||
from transformers.models.bloom.modeling_bloom import BloomMLP
|
||||
|
||||
def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
|
||||
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
intermediate_output = torch.zeros_like(residual)
|
||||
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
|
||||
for i in range(self.pretraining_tp):
|
||||
intermediate_output = intermediate_output + F.linear(
|
||||
hidden_states[:, :, int(i * slices):int((i + 1) * slices)],
|
||||
self.dense_4h_to_h.weight[:, int(i * slices):int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
intermediate_output = self.dense_4h_to_h(hidden_states)
|
||||
output = self.dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
|
||||
return output
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bloom_gelu_forward():
|
||||
|
||||
from transformers.models.bloom.modeling_bloom import BloomGelu
|
||||
|
||||
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||
|
||||
def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor:
|
||||
bias = torch.zeros_like(x)
|
||||
if self.training:
|
||||
return JitGeLUFunction.apply(x, bias)
|
||||
else:
|
||||
return self.bloom_gelu_forward(x, bias)
|
||||
|
||||
return forward
|
||||
|
|
|
@ -17,6 +17,116 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
|||
)
|
||||
|
||||
|
||||
def get_flash_core_attention_forward():
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
|
||||
from .chatglm2_6b.modeling_chatglm import CoreAttention
|
||||
|
||||
def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
|
||||
pytorch_major_version = int(torch.__version__.split(".")[0])
|
||||
if pytorch_major_version >= 2:
|
||||
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
||||
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
is_causal=True)
|
||||
else:
|
||||
if attention_mask is not None:
|
||||
attention_mask = ~attention_mask
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
||||
attention_mask)
|
||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.reshape(*new_context_layer_shape)
|
||||
else:
|
||||
# Raw attention scores
|
||||
query_layer = query_layer.permute(1, 0, 2, 3).contiguous()
|
||||
key_layer = key_layer.permute(1, 0, 2, 3).contiguous()
|
||||
value_layer = value_layer.permute(1, 0, 2, 3).contiguous()
|
||||
|
||||
scale = 1.0 / self.norm_factor
|
||||
if self.coeff is not None:
|
||||
scale = scale * self.coeff
|
||||
|
||||
flash_attention_mask = None
|
||||
attn_mask_type = None
|
||||
if attention_mask is None:
|
||||
attn_mask_type = AttnMaskType.causal
|
||||
else:
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
|
||||
attention = ColoAttention(embed_dim=self.hidden_size_per_partition,
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
dropout=self.attention_dropout.p,
|
||||
scale=scale)
|
||||
context_layer = attention(query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=flash_attention_mask,
|
||||
attn_mask_type=attn_mask_type)
|
||||
|
||||
context_layer = context_layer.permute(1, 0, -1).contiguous()
|
||||
|
||||
return context_layer
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_glm_block_forward():
|
||||
|
||||
from .chatglm2_6b.modeling_chatglm import GLMBlock
|
||||
|
||||
def forward(
|
||||
self: GLMBlock,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
rotary_pos_emb,
|
||||
kv_cache=None,
|
||||
use_cache=True,
|
||||
):
|
||||
# hidden_states: [s, b, h]
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
# Self attention.
|
||||
attention_output, kv_cache = self.self_attention(
|
||||
layernorm_output,
|
||||
attention_mask,
|
||||
rotary_pos_emb,
|
||||
kv_cache=kv_cache,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
layernorm_input = self.dropout_add(attention_output, residual, self.hidden_dropout, self.training)
|
||||
|
||||
# Layer norm post the self attention.
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||
|
||||
# MLP.
|
||||
mlp_output = self.mlp(layernorm_output)
|
||||
|
||||
# Second residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = layernorm_input
|
||||
|
||||
output = self.dropout_add(mlp_output, residual, self.hidden_dropout, self.training)
|
||||
|
||||
return output, kv_cache
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
|
||||
class ChatGLMPipelineForwards:
|
||||
'''
|
||||
This class serves as a micro library for ChatGLM model forwards under pipeline parallelism.
|
||||
|
|
|
@ -668,3 +668,88 @@ class GPT2PipelineForwards:
|
|||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def get_gpt2_flash_attention_forward():
|
||||
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
|
||||
def split_heads(tensor, num_heads, attn_head_size):
|
||||
"""
|
||||
Splits hidden_size dim into attn_head_size and num_heads
|
||||
"""
|
||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||
tensor = tensor.view(new_shape)
|
||||
return tensor
|
||||
|
||||
def forward(
|
||||
self: GPT2Attention,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
||||
_, tgt_len, _ = hidden_states.size()
|
||||
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
if not hasattr(self, "q_attn"):
|
||||
raise ValueError(
|
||||
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.")
|
||||
|
||||
query = self.q_attn(hidden_states)
|
||||
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
||||
attention_mask = encoder_attention_mask
|
||||
else:
|
||||
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||
|
||||
query = split_heads(query, self.num_heads, self.head_dim)
|
||||
key = split_heads(key, self.num_heads, self.head_dim)
|
||||
value = split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key = torch.cat((past_key, key), dim=1)
|
||||
value = torch.cat((past_value, value), dim=1)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key, value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
if not self.is_cross_attention:
|
||||
attn_mask_type = AttnMaskType.causal
|
||||
flash_attention_mask = None
|
||||
if attention_mask != None:
|
||||
if attn_mask_type == AttnMaskType.causal:
|
||||
attn_mask_type == AttnMaskType.paddedcausal
|
||||
else:
|
||||
attn_mask_type = AttnMaskType.padding
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
|
||||
scale = value.size(-1)**-0.5
|
||||
if self.scale_attn_by_inverse_layer_idx:
|
||||
scale = scale * (1 / float(self.layer_idx + 1))
|
||||
|
||||
# use coloattention
|
||||
attention = ColoAttention(embed_dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.attn_dropout.p,
|
||||
scale=scale)
|
||||
|
||||
attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
|
||||
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
outputs = (attn_output, present, None)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
import torch
|
||||
|
||||
|
||||
def get_dropout_add_func():
|
||||
|
||||
from transformers.models.bloom.modeling_bloom import dropout_add
|
||||
|
||||
def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
|
||||
return dropout_add(x, residual, prob, training)
|
||||
|
||||
return self_dropout_add
|
||||
|
||||
|
||||
def get_jit_fused_dropout_add_func():
|
||||
|
||||
from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train
|
||||
|
||||
def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
|
||||
bias = torch.zeros_like(x)
|
||||
if training:
|
||||
return bias_dropout_add_fused_train(x, bias, residual, prob)
|
||||
return bias_dropout_add_fused_inference(x, bias, residual, prob)
|
||||
|
||||
return self_dropout_add
|
||||
|
||||
|
||||
def get_jit_fused_gelu_forward_func():
|
||||
|
||||
from colossalai.kernel.jit.bias_gelu import bias_gelu
|
||||
|
||||
def bloom_gelu_forward(x: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
|
||||
return bias_gelu(bias, x)
|
||||
|
||||
return bloom_gelu_forward
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
@ -386,3 +386,67 @@ class LlamaPipelineForwards:
|
|||
else:
|
||||
hidden_states = transformer_outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
|
||||
def get_llama_flash_attention_forward():
|
||||
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
|
||||
def forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
|
||||
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||
value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
|
||||
|
||||
flash_attention_mask = None
|
||||
attn_mask_type = AttnMaskType.causal
|
||||
if attention_mask != None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
|
||||
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||
attn_output = attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=flash_attention_mask,
|
||||
attn_mask_type=attn_mask_type)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
return forward
|
||||
|
|
|
@ -0,0 +1,174 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def get_opt_flash_attention_forward():
|
||||
|
||||
from transformers.models.opt.modeling_opt import OPTAttention
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
|
||||
def forward(
|
||||
self: OPTAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
||||
|
||||
attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*attention_input_shape)
|
||||
# get key, value proj
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k, v, cross_attentions
|
||||
key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape)
|
||||
value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape)
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*attention_input_shape)
|
||||
value_states = self.v_proj(key_value_states).view(*attention_input_shape)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
|
||||
value_states = self.v_proj(hidden_states).view(*attention_input_shape)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=1)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=1)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
|
||||
value_states = self.v_proj(hidden_states).view(*attention_input_shape)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
if layer_head_mask != None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}")
|
||||
|
||||
flash_attention_mask = None
|
||||
attn_mask_type = AttnMaskType.causal
|
||||
if attention_mask != None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}")
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
attn_mask_type = AttnMaskType.paddedcausal
|
||||
|
||||
attention = ColoAttention(embed_dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout,
|
||||
scale=self.scaling)
|
||||
attn_output = attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=flash_attention_mask,
|
||||
attn_mask_type=attn_mask_type)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_opt_decoder_layer_forward():
|
||||
|
||||
from transformers.models.opt.modeling_opt import OPTDecoderLayer
|
||||
|
||||
def forward(
|
||||
self: OPTDecoderLayer,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
||||
if self.do_layer_norm_before:
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
||||
|
||||
# 350m applies layer norm AFTER attention
|
||||
if not self.do_layer_norm_before:
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
|
||||
residual = hidden_states
|
||||
|
||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
||||
if self.do_layer_norm_before:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
|
||||
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training).view(hidden_states_shape)
|
||||
|
||||
# 350m applies layer norm AFTER attention
|
||||
if not self.do_layer_norm_before:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
|
@ -1,4 +1,9 @@
|
|||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def forward_fn():
|
||||
|
@ -37,3 +42,162 @@ def forward_fn():
|
|||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_sam_flash_attention_forward():
|
||||
|
||||
from transformers.models.sam.modeling_sam import SamAttention
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
|
||||
def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor:
|
||||
batch, point_batch_size, n_tokens, channel = hidden_states.shape
|
||||
c_per_head = channel // num_attention_heads
|
||||
hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
|
||||
return hidden_states
|
||||
|
||||
def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor:
|
||||
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
|
||||
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
|
||||
|
||||
def forward(self: SamAttention,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
attention_similarity: Tensor = None) -> Tensor:
|
||||
# Input projections
|
||||
query = self.q_proj(query)
|
||||
key = self.k_proj(key)
|
||||
value = self.v_proj(value)
|
||||
|
||||
point_batch_size = query.shape[1]
|
||||
# Separate into heads
|
||||
query = _separate_heads(query, self.num_attention_heads)
|
||||
key = _separate_heads(key, self.num_attention_heads)
|
||||
value = _separate_heads(value, self.num_attention_heads)
|
||||
|
||||
# SamAttention
|
||||
_, _, _, c_per_head = query.shape
|
||||
bias = None
|
||||
if attention_similarity is not None:
|
||||
bias = attention_similarity
|
||||
|
||||
scale = 1.0 / math.sqrt(c_per_head)
|
||||
out = me_attention(query, key, value, attn_bias=bias, scale=scale)
|
||||
|
||||
out = _recombine_heads(out, point_batch_size)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_sam_vision_flash_attention_forward():
|
||||
|
||||
from transformers.models.sam.modeling_sam import SamVisionAttention
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
query: torch.Tensor,
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
||||
|
||||
Args:
|
||||
attn (`torch.Tensor`):
|
||||
attention map.
|
||||
query (`torch.Tensor`):
|
||||
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
|
||||
rel_pos_h (`torch.Tensor`):
|
||||
relative position embeddings (Lh, channel) for height axis.
|
||||
rel_pos_w (`torch.Tensor`):
|
||||
relative position embeddings (Lw, channel) for width axis.
|
||||
q_size (tuple):
|
||||
spatial sequence size of query q with (query_height, query_width).
|
||||
k_size (tuple):
|
||||
spatial sequence size of key k with (key_height, key_width).
|
||||
|
||||
Returns:
|
||||
attn (`torch.Tensor`):
|
||||
attention map with added relative positional embeddings.
|
||||
"""
|
||||
|
||||
query_height, query_width = q_size
|
||||
key_height, key_width = k_size
|
||||
relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h)
|
||||
relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w)
|
||||
|
||||
batch_size, _, nHead, dim = query.shape
|
||||
reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
|
||||
rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
||||
rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width)
|
||||
return rel_pos
|
||||
|
||||
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get relative positional embeddings according to the relative positions of
|
||||
query and key sizes.
|
||||
|
||||
Args:
|
||||
q_size (int):
|
||||
size of the query.
|
||||
k_size (int):
|
||||
size of key k.
|
||||
rel_pos (`torch.Tensor`):
|
||||
relative position embeddings (L, channel).
|
||||
|
||||
Returns:
|
||||
Extracted positional embeddings according to relative positions.
|
||||
"""
|
||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||
# Interpolate rel pos.
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode="linear",
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
||||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
||||
def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
|
||||
batch_size, height, width, _ = hidden_states.shape
|
||||
# qkv with shape (3, batch_size, nHead, height * width, channel)
|
||||
qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads,
|
||||
-1).permute(2, 0, 1, 3, 4))
|
||||
|
||||
query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0)
|
||||
|
||||
rel_pos = None
|
||||
if self.use_rel_pos:
|
||||
rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width))
|
||||
|
||||
attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, height, width, -1)
|
||||
|
||||
attn_output = self.proj(attn_output)
|
||||
|
||||
outputs = (attn_output, None)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
|
|
@ -587,3 +587,209 @@ class T5PipelineForwards:
|
|||
decoder_starting_stage=decoder_starting_stage)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def get_t5_flash_attention_forward():
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
from transformers.models.t5.modeling_t5 import T5Attention
|
||||
|
||||
def forward(
|
||||
self: T5Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
position_bias: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
query_length: Optional[int] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
||||
"""
|
||||
# Input is (batch_size, seq_length, dim)
|
||||
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
|
||||
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
real_seq_length = seq_length
|
||||
|
||||
if past_key_value is not None:
|
||||
if len(past_key_value) != 2:
|
||||
raise ValueError(
|
||||
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
|
||||
)
|
||||
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
|
||||
|
||||
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
|
||||
|
||||
def shape(states):
|
||||
"""projection"""
|
||||
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
|
||||
|
||||
def unshape(states):
|
||||
"""reshape"""
|
||||
return states.view(batch_size, -1, self.inner_dim)
|
||||
|
||||
def project(hidden_states, proj_layer, key_value_states, past_key_value):
|
||||
"""projects hidden states correctly to key/query states"""
|
||||
if key_value_states is None:
|
||||
# self-attn
|
||||
# (batch_size, n_heads, seq_length, dim_per_head)
|
||||
hidden_states = shape(proj_layer(hidden_states))
|
||||
elif past_key_value is None:
|
||||
# cross-attn
|
||||
# (batch_size, n_heads, seq_length, dim_per_head)
|
||||
hidden_states = shape(proj_layer(key_value_states))
|
||||
|
||||
if past_key_value is not None:
|
||||
if key_value_states is None:
|
||||
# self-attn
|
||||
# (batch_size, n_heads, key_length, dim_per_head)
|
||||
hidden_states = torch.cat([past_key_value, hidden_states], dim=1)
|
||||
elif past_key_value.shape[1] != key_value_states.shape[1]:
|
||||
# checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
# cross-attn
|
||||
# (batch_size, n_heads, seq_length, dim_per_head)
|
||||
hidden_states = shape(proj_layer(key_value_states))
|
||||
else:
|
||||
# cross-attn
|
||||
hidden_states = past_key_value
|
||||
return hidden_states
|
||||
|
||||
# get query states
|
||||
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
|
||||
|
||||
# get key/value states
|
||||
key_states = project(hidden_states, self.k, key_value_states,
|
||||
past_key_value[0] if past_key_value is not None else None)
|
||||
value_states = project(hidden_states, self.v, key_value_states,
|
||||
past_key_value[1] if past_key_value is not None else None)
|
||||
|
||||
if position_bias is None:
|
||||
if not self.has_relative_attention_bias:
|
||||
position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length),
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
position_bias.requires_grad = True
|
||||
else:
|
||||
position_bias = self.compute_bias(real_seq_length, key_length, device=query_states.device)
|
||||
|
||||
# if key and values are already calculated
|
||||
# we want only the last query position bias
|
||||
if past_key_value is not None:
|
||||
position_bias = position_bias[:, :, -hidden_states.size(1):, :]
|
||||
|
||||
if mask is not None:
|
||||
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
||||
|
||||
if self.pruned_heads:
|
||||
mask = torch.ones(position_bias.shape[1])
|
||||
mask[list(self.pruned_heads)] = 0
|
||||
position_bias_masked = position_bias[:, mask.bool()]
|
||||
else:
|
||||
position_bias_masked = position_bias
|
||||
|
||||
position_bias_masked = position_bias_masked.contiguous()
|
||||
attn_output = me_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_bias=position_bias_masked,
|
||||
p=self.dropout,
|
||||
scale=1.0)
|
||||
attn_output = unshape(attn_output)
|
||||
attn_output = self.o(attn_output)
|
||||
|
||||
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
|
||||
|
||||
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_T5_layer_ff_forward():
|
||||
|
||||
from transformers.models.t5.modeling_t5 import T5LayerFF
|
||||
|
||||
def forward(self: T5LayerFF, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
forwarded_states = self.layer_norm(hidden_states)
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
hidden_states = self.dropout_add(forwarded_states, hidden_states, self.dropout.p, self.dropout.training)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_T5_layer_self_attention_forward():
|
||||
|
||||
from transformers.models.t5.modeling_t5 import T5LayerSelfAttention
|
||||
|
||||
def forward(
|
||||
self: T5LayerSelfAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_bias: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output = self.SelfAttention(
|
||||
normed_hidden_states,
|
||||
mask=attention_mask,
|
||||
position_bias=position_bias,
|
||||
layer_head_mask=layer_head_mask,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
|
||||
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_T5_layer_cross_attention_forward():
|
||||
|
||||
from transformers.models.t5.modeling_t5 import T5LayerCrossAttention
|
||||
|
||||
def forward(
|
||||
self: T5LayerCrossAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_bias: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
query_length: Optional[int] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output = self.EncDecAttention(
|
||||
normed_hidden_states,
|
||||
mask=attention_mask,
|
||||
key_value_states=key_value_states,
|
||||
position_bias=position_bias,
|
||||
layer_head_mask=layer_head_mask,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
query_length=query_length,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
layer_output = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
|
||||
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import math
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -335,3 +336,51 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
|
|||
)
|
||||
|
||||
return pp_forward
|
||||
|
||||
|
||||
def get_vit_flash_self_attention_forward():
|
||||
|
||||
from transformers.models.vit.modeling_vit import ViTSelfAttention
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import ColoAttention
|
||||
|
||||
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x
|
||||
|
||||
def forward(self: ViTSelfAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size)
|
||||
value_layer = transpose_for_scores(self.value(hidden_states), self.num_attention_heads,
|
||||
self.attention_head_size)
|
||||
query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size)
|
||||
|
||||
scale = 1.0 / math.sqrt(self.attention_head_size)
|
||||
attention = ColoAttention(embed_dim=self.all_head_size,
|
||||
num_heads=self.num_attention_heads,
|
||||
dropout=self.dropout.p,
|
||||
scale=scale)
|
||||
context_layer = attention(query_layer, key_layer, value_layer)
|
||||
|
||||
outputs = (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_vit_output_forward():
|
||||
|
||||
from transformers.models.vit.modeling_vit import ViTOutput
|
||||
|
||||
def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
|
|
@ -0,0 +1,249 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def get_whisper_flash_attention_forward():
|
||||
|
||||
from transformers.models.whisper.modeling_whisper import WhisperAttention
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
|
||||
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
|
||||
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
|
||||
|
||||
def forward(
|
||||
self: WhisperAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (is_cross_attention and past_key_value is not None
|
||||
and past_key_value[0].shape[1] == key_value_states.shape[1]):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=1)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=1)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
# get query proj
|
||||
query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
if layer_head_mask is not None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}")
|
||||
|
||||
attn_type = None
|
||||
flash_attention_mask = None
|
||||
|
||||
if self.is_decoder:
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
|
||||
attn_type = AttnMaskType.paddedcausal
|
||||
|
||||
attention = ColoAttention(embed_dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout,
|
||||
scale=self.scaling)
|
||||
attn_output = attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=flash_attention_mask,
|
||||
attn_mask_type=attn_type)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_whisper_encoder_layer_forward():
|
||||
|
||||
from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
|
||||
|
||||
def forward(
|
||||
self: WhisperEncoderLayer,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_head_mask: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states, attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
||||
|
||||
if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any()
|
||||
or torch.isnan(hidden_states).any()):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_whisper_decoder_layer_forward():
|
||||
|
||||
from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer
|
||||
|
||||
def forward(
|
||||
self: WhisperDecoderLayer,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
encoder_hidden_states (`torch.FloatTensor`):
|
||||
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`.
|
||||
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
||||
size `(decoder_attention_heads,)`.
|
||||
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
||||
|
||||
# Cross-Attention Block
|
||||
cross_attn_present_key_value = None
|
||||
cross_attn_weights = None
|
||||
if encoder_hidden_states is not None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
||||
|
||||
# add cross-attn to positions 3,4 of present_key_value tuple
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights, cross_attn_weights)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
|
@ -7,7 +7,14 @@ from torch.nn import Module
|
|||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from ..modeling.bert import BertPipelineForwards
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.bert import (
|
||||
BertPipelineForwards,
|
||||
get_bert_flash_attention_forward,
|
||||
get_jit_fused_bert_output_forward,
|
||||
get_jit_fused_bert_self_output_forward,
|
||||
)
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
|
@ -37,7 +44,13 @@ class BertPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertEmbeddings,
|
||||
BertLayer,
|
||||
BertOutput,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput,
|
||||
)
|
||||
|
||||
policy = {}
|
||||
|
||||
|
@ -126,6 +139,23 @@ class BertPolicy(Policy):
|
|||
policy=policy,
|
||||
target_key=BertEmbeddings)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_bert_flash_attention_forward(),
|
||||
})
|
||||
|
||||
# use jit operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_jit_fused_bert_self_output_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
policy[BertOutput] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_jit_fused_bert_output_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
|
||||
return policy
|
||||
|
||||
def add_lm_head_policy(self, base_policy):
|
||||
|
|
|
@ -3,7 +3,13 @@ import torch.nn as nn
|
|||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.blip2 import forward_fn
|
||||
from ..modeling.blip2 import (
|
||||
forward_fn,
|
||||
get_blip2_flash_attention_forward,
|
||||
get_jit_fused_blip2_QFormer_output_forward,
|
||||
get_jit_fused_blip2_QFormer_self_output_forward,
|
||||
)
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['BlipPolicy', 'BlipModelPolicy']
|
||||
|
@ -33,6 +39,8 @@ class BlipPolicy(Policy):
|
|||
Blip2EncoderLayer,
|
||||
Blip2QFormerLayer,
|
||||
Blip2QFormerModel,
|
||||
Blip2QFormerOutput,
|
||||
Blip2QFormerSelfOutput,
|
||||
Blip2VisionModel,
|
||||
)
|
||||
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM
|
||||
|
@ -275,6 +283,24 @@ class BlipPolicy(Policy):
|
|||
policy=policy,
|
||||
target_key=OPTDecoderLayer)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_blip2_flash_attention_forward(),
|
||||
})
|
||||
|
||||
# use jit operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
policy[Blip2QFormerSelfOutput] = ModulePolicyDescription(
|
||||
method_replacement={
|
||||
'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_jit_fused_blip2_QFormer_output_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
|
|
@ -7,7 +7,16 @@ from torch.nn import Module
|
|||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from ..modeling.bloom import BloomPipelineForwards, build_bloom_alibi_tensor_fn
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.bloom import (
|
||||
BloomPipelineForwards,
|
||||
build_bloom_alibi_tensor_fn,
|
||||
get_bloom_flash_attention_forward,
|
||||
get_jit_fused_bloom_attention_forward,
|
||||
get_jit_fused_bloom_gelu_forward,
|
||||
get_jit_fused_bloom_mlp_forward,
|
||||
)
|
||||
from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
|
@ -30,7 +39,7 @@ class BloomPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomGelu, BloomMLP, BloomModel
|
||||
|
||||
policy = {}
|
||||
|
||||
|
@ -107,6 +116,27 @@ class BloomPolicy(Policy):
|
|||
policy=policy,
|
||||
target_key=BloomBlock)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[BloomAttention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_bloom_flash_attention_forward(),
|
||||
'dropout_add': get_dropout_add_func()
|
||||
})
|
||||
|
||||
# enable jit fused operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
policy[BloomAttention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_jit_fused_bloom_attention_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
policy[BloomMLP] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_jit_fused_bloom_mlp_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
policy[BloomGelu] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_jit_fused_bloom_gelu_forward(),
|
||||
'bloom_gelu_forward': get_jit_fused_gelu_forward_func(),
|
||||
})
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
|
|
@ -15,6 +15,8 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
|||
GLMBlock,
|
||||
)
|
||||
|
||||
from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['ChatGLMPolicy', 'ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
|
||||
|
@ -35,12 +37,11 @@ class ChatGLMPolicy(Policy):
|
|||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock
|
||||
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock
|
||||
|
||||
policy = {}
|
||||
|
||||
|
@ -121,6 +122,19 @@ class ChatGLMPolicy(Policy):
|
|||
policy=policy,
|
||||
target_key=ChatGLMModel)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[CoreAttention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_flash_core_attention_forward(),
|
||||
})
|
||||
|
||||
# use jit fused operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
policy[GLMBlock] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_jit_fused_glm_block_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
@ -192,7 +206,6 @@ class ChatGLMModelPolicy(ChatGLMPolicy):
|
|||
return []
|
||||
|
||||
|
||||
|
||||
class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
|
@ -213,4 +226,3 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
|
|||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in ChatGLMForConditionalGenerationModel."""
|
||||
return []
|
||||
|
||||
|
|
|
@ -5,7 +5,8 @@ from torch import Tensor, nn
|
|||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from ..modeling.gpt2 import GPT2PipelineForwards
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
|
@ -33,7 +34,7 @@ class GPT2Policy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
|
||||
|
||||
policy = {}
|
||||
|
||||
|
@ -114,6 +115,11 @@ class GPT2Policy(Policy):
|
|||
],
|
||||
policy=policy,
|
||||
target_key=GPT2Block)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[GPT2Attention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_gpt2_flash_attention_forward(),
|
||||
})
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch.nn import Module
|
|||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from ..modeling.llama import LlamaPipelineForwards
|
||||
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
|
||||
|
@ -31,7 +31,7 @@ class LlamaPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
|
||||
|
||||
policy = {}
|
||||
|
||||
|
@ -104,6 +104,11 @@ class LlamaPolicy(Policy):
|
|||
policy=policy,
|
||||
target_key=LlamaModel)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_llama_flash_attention_forward(),
|
||||
})
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
|
|
@ -25,6 +25,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
from ..modeling.opt import get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
|
@ -114,6 +116,19 @@ class OPTPolicy(Policy):
|
|||
policy=policy,
|
||||
target_key=OPTDecoderLayer)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[OPTAttention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_opt_flash_attention_forward(),
|
||||
})
|
||||
|
||||
# use jit fused operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_jit_fused_opt_decoder_layer_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
@ -189,13 +204,11 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
|||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
|
||||
policy=policy,
|
||||
target_key=OPTForCausalLM)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(model_cls=OPTForCausalLM,
|
||||
new_forward=OPTPipelineForwards.opt_for_causal_lm_forward,
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.sam import forward_fn
|
||||
from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['SamPolicy', 'SamModelPolicy']
|
||||
|
@ -19,6 +19,7 @@ class SamPolicy(Policy):
|
|||
|
||||
def module_policy(self):
|
||||
from transformers.models.sam.modeling_sam import (
|
||||
SamAttention,
|
||||
SamFeedForward,
|
||||
SamTwoWayAttentionBlock,
|
||||
SamTwoWayTransformer,
|
||||
|
@ -196,6 +197,15 @@ class SamPolicy(Policy):
|
|||
policy=policy,
|
||||
target_key=SamTwoWayTransformer)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[SamAttention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_sam_flash_attention_forward(),
|
||||
})
|
||||
policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_sam_vision_flash_attention_forward(),
|
||||
})
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
|
|
@ -14,7 +14,14 @@ from colossalai.shardformer.layer import (
|
|||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.t5 import T5PipelineForwards
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
from ..modeling.t5 import (
|
||||
T5PipelineForwards,
|
||||
get_jit_fused_T5_layer_ff_forward,
|
||||
get_t5_flash_attention_forward,
|
||||
get_T5_layer_cross_attention_forward,
|
||||
get_T5_layer_self_attention_forward,
|
||||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
|
||||
|
@ -168,6 +175,27 @@ class T5BasePolicy(Policy):
|
|||
suffix="final_layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5Stack)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[T5Attention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_t5_flash_attention_forward(),
|
||||
})
|
||||
|
||||
# use jit operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
policy[T5LayerFF] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_jit_fused_T5_layer_ff_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_T5_layer_self_attention_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_T5_layer_cross_attention_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
|
|
@ -3,11 +3,15 @@ from typing import Callable, Dict, List, Union
|
|||
import torch.nn as nn
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col
|
||||
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
from ..modeling.vit import (
|
||||
ViTForImageClassification_pipeline_forward,
|
||||
ViTForMaskedImageModeling_pipeline_forward,
|
||||
ViTModel_pipeline_forward,
|
||||
get_jit_fused_vit_output_forward,
|
||||
get_vit_flash_self_attention_forward,
|
||||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
@ -23,7 +27,8 @@ class ViTPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel
|
||||
|
||||
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel, ViTOutput, ViTSelfAttention
|
||||
|
||||
policy = {}
|
||||
|
||||
|
@ -33,7 +38,7 @@ class ViTPolicy(Policy):
|
|||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
target_module=DropoutForReplicatedInput,
|
||||
)
|
||||
])
|
||||
|
||||
|
@ -83,8 +88,18 @@ class ViTPolicy(Policy):
|
|||
),
|
||||
])
|
||||
|
||||
return policy
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_vit_flash_self_attention_forward(),
|
||||
})
|
||||
|
||||
# use jit fused operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
policy[ViTOutput] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_jit_fused_vit_output_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
return policy
|
||||
|
||||
def new_model_class(self):
|
||||
|
@ -167,7 +182,7 @@ class ViTForImageClassificationPolicy(ViTPolicy):
|
|||
ViTForImageClassification:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
|
|
@ -3,6 +3,12 @@ import torch.nn as nn
|
|||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
from ..modeling.whisper import (
|
||||
get_jit_fused_whisper_decoder_layer_forward,
|
||||
get_jit_fused_whisper_encoder_layer_forward,
|
||||
get_whisper_flash_attention_forward,
|
||||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
|
@ -30,6 +36,7 @@ class WhisperPolicy(Policy):
|
|||
|
||||
def module_policy(self):
|
||||
from transformers.models.whisper.modeling_whisper import (
|
||||
WhisperAttention,
|
||||
WhisperDecoder,
|
||||
WhisperDecoderLayer,
|
||||
WhisperEncoder,
|
||||
|
@ -181,6 +188,24 @@ class WhisperPolicy(Policy):
|
|||
],
|
||||
policy=policy,
|
||||
target_key=WhisperDecoder)
|
||||
|
||||
# enable flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[WhisperAttention] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_whisper_flash_attention_forward(),
|
||||
})
|
||||
|
||||
# use jit fused operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_jit_fused_whisper_encoder_layer_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={
|
||||
'forward': get_jit_fused_whisper_decoder_layer_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
|
||||
return policy
|
||||
|
||||
def add_lm_head_policy(self, base_policy):
|
||||
|
|
|
@ -26,6 +26,8 @@ class ShardConfig:
|
|||
enable_tensor_parallelism: bool = True
|
||||
enable_fused_normalization: bool = False
|
||||
enable_all_optimization: bool = False
|
||||
enable_flash_attention: bool = False
|
||||
enable_jit_fused: bool = False
|
||||
|
||||
# TODO: add support for tensor parallel
|
||||
# pipeline_parallel_size: int
|
||||
|
@ -44,7 +46,6 @@ class ShardConfig:
|
|||
else:
|
||||
# get the parallel size
|
||||
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
|
||||
|
||||
# turn on all optimization if all_optimization is set to True
|
||||
if self.enable_all_optimization:
|
||||
self._turn_on_all_optimization()
|
||||
|
@ -55,3 +56,5 @@ class ShardConfig:
|
|||
"""
|
||||
# you can add all the optimization flag here
|
||||
self.enable_fused_normalization = True
|
||||
self.enable_flash_attention = True
|
||||
self.enable_jit_fused = True
|
||||
|
|
|
@ -18,3 +18,5 @@ SentencePiece
|
|||
ninja
|
||||
flash_attn>=2.0
|
||||
datasets
|
||||
ninja
|
||||
flash-attn
|
||||
|
|
|
@ -20,7 +20,7 @@ def data_gen():
|
|||
# token_type_ids = tokenized_input['token_type_ids']
|
||||
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64)
|
||||
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
|
@ -69,19 +69,21 @@ def data_gen_for_mcq():
|
|||
# data['labels'] = torch.tensor([0], dtype=torch.int64)
|
||||
input_ids = torch.tensor([[[
|
||||
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
|
||||
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102
|
||||
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102
|
||||
],
|
||||
[
|
||||
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
|
||||
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
|
||||
2218, 1999, 1996, 2192, 1012, 102, 0
|
||||
2218, 1999, 1996, 2192, 1012, 102, 0, 0
|
||||
]]])
|
||||
token_type_ids = torch.tensor(
|
||||
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
|
||||
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||
0]]])
|
||||
attention_mask = torch.tensor(
|
||||
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
|
||||
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
|
||||
0]]])
|
||||
labels = torch.tensor([0], dtype=torch.int64)
|
||||
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
|
||||
|
|
|
@ -38,6 +38,7 @@ output_transform_fn = lambda x: x
|
|||
loss_fn_blip2_model = lambda x: x.loss
|
||||
|
||||
config = transformers.Blip2Config()
|
||||
config.vision_config.patch_size = 14
|
||||
config.text_config.num_hidden_layers = 1
|
||||
config.qformer_config.num_hidden_layers = 1
|
||||
config.vision_config.num_hidden_layers = 1
|
||||
|
|
|
@ -16,8 +16,8 @@ def data_gen():
|
|||
# tokenized_input = tokenizer(input, return_tensors='pt')
|
||||
# input_ids = tokenized_input['input_ids']
|
||||
# attention_mask = tokenized_input['attention_mask']
|
||||
input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595, 632, 207595]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
|
@ -33,7 +33,7 @@ def data_gen_for_token_classification():
|
|||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
|
@ -53,8 +53,8 @@ def data_gen_for_question_answering():
|
|||
# inputs = tokenizer(question, text, return_tensors="pt")
|
||||
|
||||
input_ids = torch.tensor(
|
||||
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
start_positions = torch.tensor([1], dtype=torch.int64)
|
||||
end_positions = torch.tensor([10], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids,
|
||||
|
|
|
@ -6,7 +6,6 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
|
|||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
|
||||
# ================================
|
||||
# Register single-sentence ChatGLM
|
||||
# ================================
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class ChatGLMConfig(PretrainedConfig):
|
||||
model_type = "chatglm"
|
||||
|
||||
def __init__(self,
|
||||
num_layers=28,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=4096,
|
||||
ffn_hidden_size=13696,
|
||||
kv_channels=128,
|
||||
num_attention_heads=32,
|
||||
seq_length=2048,
|
||||
hidden_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
layernorm_epsilon=1e-5,
|
||||
rmsnorm=True,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
post_layer_norm=True,
|
||||
add_bias_linear=False,
|
||||
add_qkv_bias=False,
|
||||
bias_dropout_fusion=True,
|
||||
multi_query_attention=False,
|
||||
multi_query_group_num=1,
|
||||
apply_query_key_layer_scaling=True,
|
||||
attention_softmax_in_fp32=True,
|
||||
fp32_residual_connection=False,
|
||||
quantization_bit=0,
|
||||
pre_seq_len=None,
|
||||
prefix_projection=False,
|
||||
**kwargs):
|
||||
self.num_layers = num_layers
|
||||
self.vocab_size = padded_vocab_size
|
||||
self.padded_vocab_size = padded_vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.kv_channels = kv_channels
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.seq_length = seq_length
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layernorm_epsilon = layernorm_epsilon
|
||||
self.rmsnorm = rmsnorm
|
||||
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
||||
self.post_layer_norm = post_layer_norm
|
||||
self.add_bias_linear = add_bias_linear
|
||||
self.add_qkv_bias = add_qkv_bias
|
||||
self.bias_dropout_fusion = bias_dropout_fusion
|
||||
self.multi_query_attention = multi_query_attention
|
||||
self.multi_query_group_num = multi_query_group_num
|
||||
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
||||
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
||||
self.fp32_residual_connection = fp32_residual_connection
|
||||
self.quantization_bit = quantization_bit
|
||||
self.pre_seq_len = pre_seq_len
|
||||
self.prefix_projection = prefix_projection
|
||||
super().__init__(**kwargs)
|
File diff suppressed because it is too large
Load Diff
|
@ -18,8 +18,8 @@ def data_gen():
|
|||
# tokenized_input = tokenizer(input, return_tensors='pt')
|
||||
# input_ids = tokenized_input['input_ids']
|
||||
# attention_mask = tokenized_input['attention_mask']
|
||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
|
@ -46,7 +46,7 @@ def data_gen_for_token_classification():
|
|||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 1]], dtype=torch.int64)
|
||||
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
|
|
|
@ -16,8 +16,9 @@ def data_gen_for_encoder_only():
|
|||
# config = T5Config(decoder_start_token_id=0)
|
||||
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
|
||||
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long()
|
||||
return dict(input_ids=input_ids)
|
||||
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long()
|
||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long()
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
def data_gen_for_conditional_generation():
|
||||
|
@ -25,17 +26,16 @@ def data_gen_for_conditional_generation():
|
|||
#
|
||||
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
|
||||
data = data_gen_for_encoder_only()
|
||||
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long()
|
||||
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1]]).long()
|
||||
data['labels'] = labels
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_t5_model():
|
||||
# decoder_inputs_ids is obtained with the following code
|
||||
#
|
||||
# decoder_input_ids = model._shift_right(input_ids)
|
||||
data = data_gen_for_encoder_only()
|
||||
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long()
|
||||
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long()
|
||||
data['decoder_input_ids'] = decoder_input_ids
|
||||
return data
|
||||
|
||||
|
|
|
@ -76,14 +76,14 @@ model_zoo.register(name='transformers_whisper',
|
|||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name='transformers_whisperForConditionalGeneration',
|
||||
model_zoo.register(name='transformers_whisper_for_conditional_generation',
|
||||
model_fn=lambda: transformers.WhisperForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_attr,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name='transformers_whisperWhisperForAudioClassification',
|
||||
model_zoo.register(name='transformers_whisper_for_audio_classification',
|
||||
model_fn=lambda: transformers.WhisperForAudioClassification(config),
|
||||
data_gen_fn=data_gen_for_audio_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
|
|
|
@ -93,7 +93,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
|
|||
'transformers_vit_for_image_classification', 'transformers_chatglm',
|
||||
'transformers_chatglm_for_conditional_generation', 'transformers_blip2',
|
||||
'transformers_blip2_conditional_gerneration', 'transformers_sam', 'transformers_whisper',
|
||||
'transformers_whisperForConditionalGeneration', 'transformers_whisperWhisperForAudioClassification'
|
||||
'transformers_whisper_for_conditional_generation', 'transformers_whisper_for_audio_classification'
|
||||
]:
|
||||
continue
|
||||
|
||||
|
|
|
@ -21,7 +21,13 @@ from colossalai.shardformer._utils import getattr_
|
|||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
|
||||
|
||||
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
|
||||
def build_model(model_fn,
|
||||
enable_fused_normalization=True,
|
||||
enable_tensor_parallelism=True,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
use_lazy_init: bool = False):
|
||||
# create new model
|
||||
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
||||
with ctx:
|
||||
# create new model
|
||||
|
@ -31,7 +37,10 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
|
|||
ctx.materialize(org_model)
|
||||
# shard model
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_jit_fused=enable_jit_fused)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||
return org_model.cuda(), sharded_model.cuda()
|
||||
|
|
|
@ -46,14 +46,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
check_grad(bert, sharded_bert, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False)
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False, True])
|
||||
@parameterize('enable_tensor_parallelism', [False, True])
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
|
||||
use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
enable_flash_attention, enable_jit_fused, use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
|
|
|
@ -47,10 +47,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_blip2')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
enable_flash_attention, enable_jit_fused)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -44,13 +44,15 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
|
||||
use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
enable_flash_attention, enable_jit_fused, use_lazy_init)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
|
@ -72,7 +72,9 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
# create new model
|
||||
|
@ -80,7 +82,9 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
|
|||
|
||||
# shard model
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_jit_fused=enable_jit_fused)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
if name == "transformers_chatglm":
|
||||
|
|
|
@ -68,7 +68,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
|
|
|
@ -49,12 +49,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, use_lazy_init):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
enable_flash_attention, use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -42,18 +42,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
# check grad
|
||||
col_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens']
|
||||
row_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
|
||||
check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False)
|
||||
check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False)
|
||||
check_grad(opt_model, shard_opt_model, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False)
|
||||
check_grad(opt_model, shard_opt_model, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False)
|
||||
|
||||
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_opt_test(use_lazy_init, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention,
|
||||
enable_jit_fused):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
enable_flash_attention, enable_jit_fused, use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -62,7 +65,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
|
|||
def check_OPTModel(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_t5_test()
|
||||
run_opt_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -41,10 +41,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_sam_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
def run_sam_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_sam')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
enable_flash_attention)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -33,8 +33,8 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
# check grad
|
||||
col_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.q', 'shared']
|
||||
row_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.relative_attention_bias']
|
||||
check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-7, rtol=1e-5, dim=0, verbose=False)
|
||||
check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-7, rtol=1e-5, dim=1, verbose=False)
|
||||
check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
|
||||
check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
|
||||
|
||||
# check weights are tied
|
||||
if hasattr(org_model, 'lm_head'):
|
||||
|
@ -45,11 +45,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init, enable_flash_attention,
|
||||
enable_jit_fused):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
enable_flash_attention, enable_jit_fused, use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -20,7 +20,9 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
|
||||
assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3)
|
||||
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
|
@ -45,10 +47,13 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
enable_flash_attention, enable_jit_fused)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
|
@ -48,12 +48,16 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn,
|
||||
enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
enable_flash_attention=enable_flash_attention,
|
||||
enable_jit_fused=enable_jit_fused)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
|
Loading…
Reference in New Issue