mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] update shardformer readme (#4689)
* [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readmepull/4697/head
parent
1d454733c4
commit
8844691f4b
|
@ -30,27 +30,48 @@
|
||||||
|
|
||||||
### Quick Start
|
### Quick Start
|
||||||
|
|
||||||
The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization, It requires that the sequence length be a multiple of 8.):
|
The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from colossalai.shardformer import ShardConfig, Shard
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
from transformers import BertForMaskedLM
|
from transformers import BertForMaskedLM
|
||||||
|
import colossalai
|
||||||
|
|
||||||
# launch colossalai
|
# launch colossalai
|
||||||
colossalai.launch_from_torch()
|
colossalai.launch_from_torch(config={})
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
config = BertConfig.from_pretrained('bert-base-uncased')
|
config = BertConfig.from_pretrained('bert-base-uncased')
|
||||||
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
|
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
|
||||||
|
|
||||||
# create huggingface model as normal
|
# create huggingface model as normal
|
||||||
shard_config = ShardConfig()
|
shard_config = ShardConfig(tensor_parallel_process_group=tp_group,
|
||||||
|
pipeline_stage_manager=stage_manager,
|
||||||
|
enable_tensor_parallelism=True,
|
||||||
|
enable_fused_normalization=True,
|
||||||
|
enable_flash_attention=True,
|
||||||
|
enable_jit_fused=True,
|
||||||
|
enable_sequence_parallelism=True,
|
||||||
|
enable_sequence_overlap=True)
|
||||||
|
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
sharded_model = shard_former.optimize(model).to('cuda')
|
sharded_model, shared_params = shard_former.optimize(model).to('cuda')
|
||||||
|
|
||||||
# do everything like normal
|
# do everything like normal
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
shardformer configuration
|
||||||
|
|
||||||
|
`tensor_parallel_process_group`: the process group of tensor parallelism, it's necessary when using tensor parallel.
|
||||||
|
`pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism.
|
||||||
|
{{ autodoc:colossalai.pipeline.stage_manager.PipelineStageManager }}
|
||||||
|
`enable_tensor_parallelism`: using tensor parallel, partition the model along the columns or along the rows
|
||||||
|
`enable_fused_normalization`: using apex fused layernorm
|
||||||
|
`enable_flash_attention`: using flash attention
|
||||||
|
`enable_jit_fused`: using jit fused operators
|
||||||
|
`enable_sequence_parallelism`: using sequence parallelism, partition these non-tensor parallel regions along the sequence dimension.
|
||||||
|
`enable_sequence_overlap`: overlap the computation and communication in the sequence parallelism, it's used with `enable_sequence_parallelism`.
|
||||||
|
|
||||||
|
|
||||||
### Write your own policy
|
### Write your own policy
|
||||||
|
|
||||||
|
@ -82,44 +103,30 @@ We will follow this roadmap to develop Shardformer:
|
||||||
- [x] API Implementation
|
- [x] API Implementation
|
||||||
- [x] Unit Testing
|
- [x] Unit Testing
|
||||||
- [ ] Policy Implementation
|
- [ ] Policy Implementation
|
||||||
- [ ] Hugging Face
|
|
||||||
- [ ] NLP
|
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
|
||||||
- [x] BERT
|
| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
|
||||||
- [x] T5
|
| bert | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
|
||||||
- [x] LlaMa
|
| t5 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||||
- [x] GPT2
|
| llama V1/V2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||||
- [x] OPT
|
| gpt2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
|
||||||
- [x] BLOOM
|
| opt | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||||
- [ ] GLM
|
| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
|
||||||
- [ ] RoBERTa
|
| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
|
||||||
- [ ] ALBERT
|
| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||||
- [ ] ERNIE
|
| whisper | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||||
- [ ] GPT Neo
|
| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||||
- [ ] GPT-J
|
| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||||
- [ ] CV
|
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||||
- [x] ViT
|
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||||
- [ ] BEiT
|
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||||
- [ ] SwinTransformer
|
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||||
- [ ] SwinTransformer V2
|
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||||
- [ ] Audio
|
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||||
- [x] Whisper
|
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||||
- [ ] Multi-modal
|
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||||
- [x] SAM
|
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||||
- [x] BLIP-2
|
|
||||||
- [ ] Flash Attention Support
|
|
||||||
- [ ] NLP
|
|
||||||
- [x] BERT
|
|
||||||
- [x] T5
|
|
||||||
- [x] LlaMa
|
|
||||||
- [x] GPT2
|
|
||||||
- [x] OPT
|
|
||||||
- [x] BLOOM
|
|
||||||
- [ ] GLM
|
|
||||||
- [ ] RoBERTa
|
|
||||||
- [ ] ALBERT
|
|
||||||
- [ ] ERNIE
|
|
||||||
- [ ] GPT Neo
|
|
||||||
- [ ] GPT-J
|
|
||||||
|
|
||||||
## 💡 API Design
|
## 💡 API Design
|
||||||
|
|
||||||
|
@ -286,41 +293,36 @@ class ShardFormer:
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
|
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||||
|
shard_config = ShardConfig()
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
shard_former.init_distributed()
|
model, shared_params = shard_former.optimize(org_model)
|
||||||
model = shard_former.optimize(model, policy=policy)
|
|
||||||
dataloader = shard_former.shard_dataset(dataset)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, shard_config: ShardConfig):
|
def __init__(self, shard_config: ShardConfig):
|
||||||
"""
|
"""
|
||||||
Do two things:
|
Do two things:
|
||||||
1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
|
1. Create a distribute coordinator
|
||||||
2. serve as a store for shard config
|
2. serve as a store for shard config
|
||||||
"""
|
"""
|
||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
self.pg_manager = None
|
self.coordinator = DistCoordinator()
|
||||||
|
|
||||||
def init_distributed(self) -> colossalai.cluster.ProcessGroupManager:
|
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
|
||||||
"""
|
r"""
|
||||||
Initialize the distributed process group according to the
|
This method will optimize the model based on the given policy.
|
||||||
"""
|
|
||||||
pg_manager = ...
|
|
||||||
self.pg_manager = pg_manager
|
|
||||||
return pg_manager
|
|
||||||
|
|
||||||
def shard_model(self, model: torch.nn.Module,policy: Policy) -> torch.nn.Module:
|
Args:
|
||||||
"""
|
model (`torch.nn.Model`): the origin huggingface model
|
||||||
Shard model for TP and PP
|
shard_config (`ShardConfig`): the config for distribute information
|
||||||
"""
|
policy (`Policy`): the custom policy for sharding
|
||||||
...
|
|
||||||
|
|
||||||
def shard_dataset(self, dataset: Dataset) -> Dataloader:
|
Returns: the sharded model and the shared parameters
|
||||||
"""
|
"""
|
||||||
Shard dataset for DP
|
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
|
||||||
"""
|
shared_params = sharder.shard()
|
||||||
...
|
return model, shared_params
|
||||||
```
|
```
|
||||||
|
|
||||||
## ⌨️ Development Notes
|
## ⌨️ Development Notes
|
||||||
|
@ -429,13 +431,24 @@ As shown in the figures above, when the sequence length is around 1000 or greate
|
||||||
### Convergence
|
### Convergence
|
||||||
|
|
||||||
|
|
||||||
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](../../examples/language/bert/finetune.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.
|
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.
|
||||||
|
|
||||||
|
the configurations are as follows:
|
||||||
|
```python
|
||||||
|
batch_size = 2
|
||||||
|
epoch = 3
|
||||||
|
lr = 2.4e-5
|
||||||
|
accumulation_steps = 8
|
||||||
|
warmup_fraction = 0.03
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
| accuracy | f1 | loss | GPU number | model sharded |
|
| accuracy | f1 | loss | GPU number | model sharded |
|
||||||
| :------: | :-----: | :-----: | :--------: | :---------: |
|
| :------: | :-----: | :-----: | :--------: | :---------: |
|
||||||
| 0.84589 | 0.88613 | 0.43414 | 4 | True |
|
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
|
||||||
| 0.83594 | 0.88064 | 0.43298 | 1 | False |
|
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
|
||||||
|
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
|
||||||
|
|
||||||
|
|
||||||
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
|
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
|
||||||
|
|
|
@ -49,9 +49,12 @@ def train(args):
|
||||||
|
|
||||||
# if multiple GPUs, shard the model
|
# if multiple GPUs, shard the model
|
||||||
if dist.get_world_size() > 1:
|
if dist.get_world_size() > 1:
|
||||||
shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm)
|
tp_group = dist.new_group(backend='nccl')
|
||||||
|
shard_config = ShardConfig(tensor_parallel_process_group=tp_group,
|
||||||
|
enable_tensor_parallelism=True,
|
||||||
|
enable_all_optimization=True)
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
model = shard_former.optimize(model)
|
model, _ = shard_former.optimize(model)
|
||||||
|
|
||||||
optim = Adam(model.parameters(), lr=args.lr)
|
optim = Adam(model.parameters(), lr=args.lr)
|
||||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \
|
torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \
|
||||||
--model "bert" \
|
--model "bert" \
|
||||||
--pretrain "bert-base-uncased" \
|
--pretrain "bert-base-uncased" \
|
||||||
--max_epochs 1 \
|
--max_epochs 3 \
|
||||||
--batch_size 2 \
|
--batch_size 2 \
|
||||||
--lr 2.4e-5 \
|
--lr 2.4e-5 \
|
||||||
--fused_layernorm False \
|
--fused_layernorm False \
|
||||||
|
|
|
@ -29,7 +29,8 @@ MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4,
|
||||||
intermediate_size=256,
|
intermediate_size=256,
|
||||||
num_attention_heads=4,
|
num_attention_heads=4,
|
||||||
max_position_embeddings=128,
|
max_position_embeddings=128,
|
||||||
num_labels=16)
|
num_labels=16,
|
||||||
|
pad_token_id=2)
|
||||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
|
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
|
||||||
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)
|
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)
|
||||||
|
|
||||||
|
@ -73,7 +74,8 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d
|
||||||
if provider == "shard_model":
|
if provider == "shard_model":
|
||||||
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
|
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
sharded_model = shard_former.optimize(model).cuda()
|
sharded_model, _ = shard_former.optimize(model)
|
||||||
|
sharded_model = sharded_model.cuda()
|
||||||
fn = lambda: train(sharded_model, data)
|
fn = lambda: train(sharded_model, data)
|
||||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||||
return ms
|
return ms
|
||||||
|
|
Loading…
Reference in New Issue