6.2 KiB
Stable Diffusion with Colossal-AI
Colosssal-AI provides a faster and lower cost solution for pretraining and fine-tuning for AIGC (AI-Generated Content) applications such as the model stable-diffusion from Stability AI.
We take advantage of Colosssal-AI to exploit multiple optimization strategies , e.g. data parallelism, tensor parallelism, mixed precision & ZeRO, to scale the training to multiple GPUs.
🚀Quick Start
- Create a new environment for diffusion
conda env create -f environment.yaml
conda activate ldm
- Install Colossal-AI from our official page
pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org
- Install PyTorch Lightning compatible commit
git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa
pip install -r requirements.txt && pip install .
cd ..
- Comment out the
from_pretrained
field in thetrain_colossalai_cifar10.yaml
. - Run training with CIFAR10.
python main.py -logdir /tmp -t true -postfix test -b configs/train_colossalai_cifar10.yaml
Stable Diffusion
Stable Diffusion is a latent text-to-image diffusion model. Thanks to a generous compute donation from Stability AI and support from LAION, we were able to train a Latent Diffusion Model on 512x512 images from a subset of the LAION-5B database. Similar to Google's Imagen, this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts.
Stable Diffusion with Colossal-AI provides 6.5x faster training and pretraining cost saving, the hardware cost of fine-tuning can be almost 7X cheaper (from RTX3090/4090 24GB to RTX3050/2070 8GB).
Requirements
A suitable conda environment named ldm
can be created
and activated with:
conda env create -f environment.yaml
conda activate ldm
You can also update an existing latent diffusion environment by running
conda install pytorch torchvision -c pytorch
pip install transformers==4.19.2 diffusers invisible-watermark
pip install -e .
Install Colossal-AI v0.1.10 From Our Official Website
pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org
Install Lightning
We use the Sep. 2022 version with commit id as b04a7aa
.
git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa
pip install -r requirements.txt && pip install .
The specified version is due to the interface incompatibility caused by the latest update of Lightning, which will be fixed in the near future.
Dataset
The dataSet is from LAION-5B, the subset of LAION,
you should the change the data.file_path
in the config/train_colossalai.yaml
Training
We provide the script train.sh
to run the training task , and two Stategy in configs
:train_colossalai.yaml
For example, you can run the training from colossalai by
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml
- you can change the
--logdir
the save the log information and the last checkpoint
Training config
You can change the trainging config in the yaml file
- accelerator: acceleratortype, default 'gpu'
- devices: device number used for training, default 4
- max_epochs: max training epochs
- precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai
Example
Training on cifar10
We provide the finetuning example on CIFAR10 dataset
You can run by config train_colossalai_cifar10.yaml
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai_cifar10.yaml
Comments
-
Our codebase for the diffusion models builds heavily on OpenAI's ADM codebase , lucidrains, Stable Diffusion, Lightning and Hugging Face. Thanks for open-sourcing!
-
The implementation of the transformer encoder is from x-transformers by lucidrains.
-
The implementation of flash attention is from HazyResearch.
BibTeX
@article{bian2021colossal,
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
journal={arXiv preprint arXiv:2110.14883},
year={2021}
}
@misc{rombach2021highresolution,
title={High-Resolution Image Synthesis with Latent Diffusion Models},
author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
year={2021},
eprint={2112.10752},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
@article{dao2022flashattention,
title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
journal={arXiv preprint arXiv:2205.14135},
year={2022}
}