mirror of https://github.com/hpcaitech/ColossalAI
[example] Change some training settings for diffusion (#2195)
parent
2458659919
commit
6642cebdbe
|
@ -87,14 +87,15 @@ you should the change the `data.file_path` in the `config/train_colossalai.yaml`
|
||||||
|
|
||||||
## Training
|
## Training
|
||||||
|
|
||||||
We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml` and `train_ddp.yaml`
|
We provide the script `train_colossalai.sh` to run the training task with colossalai,
|
||||||
|
and can also use `train_ddp.sh` to run the training task with ddp to compare.
|
||||||
|
|
||||||
For example, you can run the training from colossalai by
|
In `train_colossalai.sh` the main command is:
|
||||||
```
|
```
|
||||||
python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml
|
python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
- you can change the `--logdir` the save the log information and the last checkpoint
|
- you can change the `--logdir` to decide where to save the log information and the last checkpoint.
|
||||||
|
|
||||||
### Training config
|
### Training config
|
||||||
|
|
||||||
|
@ -155,6 +156,7 @@ optional arguments:
|
||||||
--config CONFIG path to config which constructs model
|
--config CONFIG path to config which constructs model
|
||||||
--ckpt CKPT path to checkpoint of model
|
--ckpt CKPT path to checkpoint of model
|
||||||
--seed SEED the seed (for reproducible sampling)
|
--seed SEED the seed (for reproducible sampling)
|
||||||
|
--use_int8 whether to use quantization method
|
||||||
--precision {full,autocast}
|
--precision {full,autocast}
|
||||||
evaluate at this precision
|
evaluate at this precision
|
||||||
```
|
```
|
||||||
|
|
|
@ -80,19 +80,22 @@ model:
|
||||||
data:
|
data:
|
||||||
target: main.DataModuleFromConfig
|
target: main.DataModuleFromConfig
|
||||||
params:
|
params:
|
||||||
batch_size: 64
|
batch_size: 128
|
||||||
wrap: False
|
wrap: False
|
||||||
|
# num_workwers should be 2 * batch_size, and total num less than 1024
|
||||||
|
# e.g. if use 8 devices, no more than 128
|
||||||
|
num_workers: 128
|
||||||
train:
|
train:
|
||||||
target: ldm.data.base.Txt2ImgIterableBaseDataset
|
target: ldm.data.base.Txt2ImgIterableBaseDataset
|
||||||
params:
|
params:
|
||||||
file_path: "/data/scratch/diffuser/laion_part0/"
|
file_path: # YOUR DATASET_PATH
|
||||||
world_size: 1
|
world_size: 1
|
||||||
rank: 0
|
rank: 0
|
||||||
|
|
||||||
lightning:
|
lightning:
|
||||||
trainer:
|
trainer:
|
||||||
accelerator: 'gpu'
|
accelerator: 'gpu'
|
||||||
devices: 4
|
devices: 8
|
||||||
log_gpu_memory: all
|
log_gpu_memory: all
|
||||||
max_epochs: 2
|
max_epochs: 2
|
||||||
precision: 16
|
precision: 16
|
||||||
|
|
|
@ -80,25 +80,21 @@ model:
|
||||||
data:
|
data:
|
||||||
target: main.DataModuleFromConfig
|
target: main.DataModuleFromConfig
|
||||||
params:
|
params:
|
||||||
batch_size: 16
|
batch_size: 128
|
||||||
num_workers: 4
|
# num_workwers should be 2 * batch_size, and the total num less than 1024
|
||||||
|
# e.g. if use 8 devices, no more than 128
|
||||||
|
num_workers: 128
|
||||||
train:
|
train:
|
||||||
target: ldm.data.teyvat.hf_dataset
|
target: ldm.data.base.Txt2ImgIterableBaseDataset
|
||||||
params:
|
params:
|
||||||
path: Fazzie/Teyvat
|
file_path: # YOUR DATAPATH
|
||||||
image_transforms:
|
world_size: 1
|
||||||
- target: torchvision.transforms.Resize
|
rank: 0
|
||||||
params:
|
|
||||||
size: 512
|
|
||||||
- target: torchvision.transforms.RandomCrop
|
|
||||||
params:
|
|
||||||
size: 512
|
|
||||||
- target: torchvision.transforms.RandomHorizontalFlip
|
|
||||||
|
|
||||||
lightning:
|
lightning:
|
||||||
trainer:
|
trainer:
|
||||||
accelerator: 'gpu'
|
accelerator: 'gpu'
|
||||||
devices: 2
|
devices: 8
|
||||||
log_gpu_memory: all
|
log_gpu_memory: all
|
||||||
max_epochs: 2
|
max_epochs: 2
|
||||||
precision: 16
|
precision: 16
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
# HF_DATASETS_OFFLINE=1
|
|
||||||
# TRANSFORMERS_OFFLINE=1
|
|
||||||
# DIFFUSERS_OFFLINE=1
|
|
||||||
|
|
||||||
python main.py --logdir /tmp/ -t -b configs/Teyvat/train_colossalai_teyvat.yaml
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
HF_DATASETS_OFFLINE=1
|
||||||
|
TRANSFORMERS_OFFLINE=1
|
||||||
|
DIFFUSERS_OFFLINE=1
|
||||||
|
|
||||||
|
python main.py --logdir /tmp -t -b /configs/train_colossalai.yaml
|
|
@ -0,0 +1,5 @@
|
||||||
|
HF_DATASETS_OFFLINE=1
|
||||||
|
TRANSFORMERS_OFFLINE=1
|
||||||
|
DIFFUSERS_OFFLINE=1
|
||||||
|
|
||||||
|
python main.py --logdir /tmp -t -b /configs/train_ddp.yaml
|
Loading…
Reference in New Issue