mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] update shardformer readme (#4689)
* [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readmepull/4697/head
parent
1d454733c4
commit
8844691f4b
|
@ -30,27 +30,48 @@
|
|||
|
||||
### Quick Start
|
||||
|
||||
The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization, It requires that the sequence length be a multiple of 8.):
|
||||
The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization):
|
||||
|
||||
```python
|
||||
from colossalai.shardformer import ShardConfig, Shard
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from transformers import BertForMaskedLM
|
||||
import colossalai
|
||||
|
||||
# launch colossalai
|
||||
colossalai.launch_from_torch()
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
# create model
|
||||
config = BertConfig.from_pretrained('bert-base-uncased')
|
||||
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
|
||||
|
||||
# create huggingface model as normal
|
||||
shard_config = ShardConfig()
|
||||
shard_config = ShardConfig(tensor_parallel_process_group=tp_group,
|
||||
pipeline_stage_manager=stage_manager,
|
||||
enable_tensor_parallelism=True,
|
||||
enable_fused_normalization=True,
|
||||
enable_flash_attention=True,
|
||||
enable_jit_fused=True,
|
||||
enable_sequence_parallelism=True,
|
||||
enable_sequence_overlap=True)
|
||||
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model = shard_former.optimize(model).to('cuda')
|
||||
sharded_model, shared_params = shard_former.optimize(model).to('cuda')
|
||||
|
||||
# do everything like normal
|
||||
...
|
||||
```
|
||||
shardformer configuration
|
||||
|
||||
`tensor_parallel_process_group`: the process group of tensor parallelism, it's necessary when using tensor parallel.
|
||||
`pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism.
|
||||
{{ autodoc:colossalai.pipeline.stage_manager.PipelineStageManager }}
|
||||
`enable_tensor_parallelism`: using tensor parallel, partition the model along the columns or along the rows
|
||||
`enable_fused_normalization`: using apex fused layernorm
|
||||
`enable_flash_attention`: using flash attention
|
||||
`enable_jit_fused`: using jit fused operators
|
||||
`enable_sequence_parallelism`: using sequence parallelism, partition these non-tensor parallel regions along the sequence dimension.
|
||||
`enable_sequence_overlap`: overlap the computation and communication in the sequence parallelism, it's used with `enable_sequence_parallelism`.
|
||||
|
||||
|
||||
### Write your own policy
|
||||
|
||||
|
@ -82,44 +103,30 @@ We will follow this roadmap to develop Shardformer:
|
|||
- [x] API Implementation
|
||||
- [x] Unit Testing
|
||||
- [ ] Policy Implementation
|
||||
- [ ] Hugging Face
|
||||
- [ ] NLP
|
||||
- [x] BERT
|
||||
- [x] T5
|
||||
- [x] LlaMa
|
||||
- [x] GPT2
|
||||
- [x] OPT
|
||||
- [x] BLOOM
|
||||
- [ ] GLM
|
||||
- [ ] RoBERTa
|
||||
- [ ] ALBERT
|
||||
- [ ] ERNIE
|
||||
- [ ] GPT Neo
|
||||
- [ ] GPT-J
|
||||
- [ ] CV
|
||||
- [x] ViT
|
||||
- [ ] BEiT
|
||||
- [ ] SwinTransformer
|
||||
- [ ] SwinTransformer V2
|
||||
- [ ] Audio
|
||||
- [x] Whisper
|
||||
- [ ] Multi-modal
|
||||
- [x] SAM
|
||||
- [x] BLIP-2
|
||||
- [ ] Flash Attention Support
|
||||
- [ ] NLP
|
||||
- [x] BERT
|
||||
- [x] T5
|
||||
- [x] LlaMa
|
||||
- [x] GPT2
|
||||
- [x] OPT
|
||||
- [x] BLOOM
|
||||
- [ ] GLM
|
||||
- [ ] RoBERTa
|
||||
- [ ] ALBERT
|
||||
- [ ] ERNIE
|
||||
- [ ] GPT Neo
|
||||
- [ ] GPT-J
|
||||
|
||||
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
|
||||
| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
|
||||
| bert | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
|
||||
| t5 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||
| llama V1/V2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||
| gpt2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
|
||||
| opt | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||
| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
|
||||
| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
|
||||
| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||
| whisper | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||
| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||
| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
|
||||
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
|
||||
|
||||
## 💡 API Design
|
||||
|
||||
|
@ -286,41 +293,36 @@ class ShardFormer:
|
|||
|
||||
Example:
|
||||
|
||||
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||
shard_config = ShardConfig()
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
model = shard_former.optimize(model, policy=policy)
|
||||
dataloader = shard_former.shard_dataset(dataset)
|
||||
model, shared_params = shard_former.optimize(org_model)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, shard_config: ShardConfig):
|
||||
"""
|
||||
Do two things:
|
||||
1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
|
||||
1. Create a distribute coordinator
|
||||
2. serve as a store for shard config
|
||||
"""
|
||||
self.shard_config = shard_config
|
||||
self.pg_manager = None
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
def init_distributed(self) -> colossalai.cluster.ProcessGroupManager:
|
||||
"""
|
||||
Initialize the distributed process group according to the
|
||||
"""
|
||||
pg_manager = ...
|
||||
self.pg_manager = pg_manager
|
||||
return pg_manager
|
||||
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
|
||||
r"""
|
||||
This method will optimize the model based on the given policy.
|
||||
|
||||
def shard_model(self, model: torch.nn.Module,policy: Policy) -> torch.nn.Module:
|
||||
"""
|
||||
Shard model for TP and PP
|
||||
"""
|
||||
...
|
||||
Args:
|
||||
model (`torch.nn.Model`): the origin huggingface model
|
||||
shard_config (`ShardConfig`): the config for distribute information
|
||||
policy (`Policy`): the custom policy for sharding
|
||||
|
||||
def shard_dataset(self, dataset: Dataset) -> Dataloader:
|
||||
Returns: the sharded model and the shared parameters
|
||||
"""
|
||||
Shard dataset for DP
|
||||
"""
|
||||
...
|
||||
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
|
||||
shared_params = sharder.shard()
|
||||
return model, shared_params
|
||||
```
|
||||
|
||||
## ⌨️ Development Notes
|
||||
|
@ -429,13 +431,24 @@ As shown in the figures above, when the sequence length is around 1000 or greate
|
|||
### Convergence
|
||||
|
||||
|
||||
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](../../examples/language/bert/finetune.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.
|
||||
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.
|
||||
|
||||
the configurations are as follows:
|
||||
```python
|
||||
batch_size = 2
|
||||
epoch = 3
|
||||
lr = 2.4e-5
|
||||
accumulation_steps = 8
|
||||
warmup_fraction = 0.03
|
||||
```
|
||||
|
||||
|
||||
|
||||
| accuracy | f1 | loss | GPU number | model sharded |
|
||||
| :------: | :-----: | :-----: | :--------: | :---------: |
|
||||
| 0.84589 | 0.88613 | 0.43414 | 4 | True |
|
||||
| 0.83594 | 0.88064 | 0.43298 | 1 | False |
|
||||
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
|
||||
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
|
||||
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
|
||||
|
||||
|
||||
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
|
||||
|
|
|
@ -49,9 +49,12 @@ def train(args):
|
|||
|
||||
# if multiple GPUs, shard the model
|
||||
if dist.get_world_size() > 1:
|
||||
shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm)
|
||||
tp_group = dist.new_group(backend='nccl')
|
||||
shard_config = ShardConfig(tensor_parallel_process_group=tp_group,
|
||||
enable_tensor_parallelism=True,
|
||||
enable_all_optimization=True)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
model = shard_former.optimize(model)
|
||||
model, _ = shard_former.optimize(model)
|
||||
|
||||
optim = Adam(model.parameters(), lr=args.lr)
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \
|
||||
--model "bert" \
|
||||
--pretrain "bert-base-uncased" \
|
||||
--max_epochs 1 \
|
||||
--max_epochs 3 \
|
||||
--batch_size 2 \
|
||||
--lr 2.4e-5 \
|
||||
--fused_layernorm False \
|
||||
|
|
|
@ -29,7 +29,8 @@ MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4,
|
|||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128,
|
||||
num_labels=16)
|
||||
num_labels=16,
|
||||
pad_token_id=2)
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
|
||||
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)
|
||||
|
||||
|
@ -73,7 +74,8 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d
|
|||
if provider == "shard_model":
|
||||
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model = shard_former.optimize(model).cuda()
|
||||
sharded_model, _ = shard_former.optimize(model)
|
||||
sharded_model = sharded_model.cuda()
|
||||
fn = lambda: train(sharded_model, data)
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
|
Loading…
Reference in New Issue