[example] Change some training settings for diffusion (#2195)

pull/2197/head
BlueRum 2022-12-26 15:22:20 +08:00 committed by GitHub
parent 2458659919
commit 6642cebdbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 30 additions and 24 deletions

View File

@ -87,14 +87,15 @@ you should the change the `data.file_path` in the `config/train_colossalai.yaml`
## Training ## Training
We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml` and `train_ddp.yaml` We provide the script `train_colossalai.sh` to run the training task with colossalai,
and can also use `train_ddp.sh` to run the training task with ddp to compare.
For example, you can run the training from colossalai by In `train_colossalai.sh` the main command is:
``` ```
python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml
``` ```
- you can change the `--logdir` the save the log information and the last checkpoint - you can change the `--logdir` to decide where to save the log information and the last checkpoint.
### Training config ### Training config
@ -155,6 +156,7 @@ optional arguments:
--config CONFIG path to config which constructs model --config CONFIG path to config which constructs model
--ckpt CKPT path to checkpoint of model --ckpt CKPT path to checkpoint of model
--seed SEED the seed (for reproducible sampling) --seed SEED the seed (for reproducible sampling)
--use_int8 whether to use quantization method
--precision {full,autocast} --precision {full,autocast}
evaluate at this precision evaluate at this precision
``` ```

View File

@ -80,19 +80,22 @@ model:
data: data:
target: main.DataModuleFromConfig target: main.DataModuleFromConfig
params: params:
batch_size: 64 batch_size: 128
wrap: False wrap: False
# num_workwers should be 2 * batch_size, and total num less than 1024
# e.g. if use 8 devices, no more than 128
num_workers: 128
train: train:
target: ldm.data.base.Txt2ImgIterableBaseDataset target: ldm.data.base.Txt2ImgIterableBaseDataset
params: params:
file_path: "/data/scratch/diffuser/laion_part0/" file_path: # YOUR DATASET_PATH
world_size: 1 world_size: 1
rank: 0 rank: 0
lightning: lightning:
trainer: trainer:
accelerator: 'gpu' accelerator: 'gpu'
devices: 4 devices: 8
log_gpu_memory: all log_gpu_memory: all
max_epochs: 2 max_epochs: 2
precision: 16 precision: 16

View File

@ -80,25 +80,21 @@ model:
data: data:
target: main.DataModuleFromConfig target: main.DataModuleFromConfig
params: params:
batch_size: 16 batch_size: 128
num_workers: 4 # num_workwers should be 2 * batch_size, and the total num less than 1024
# e.g. if use 8 devices, no more than 128
num_workers: 128
train: train:
target: ldm.data.teyvat.hf_dataset target: ldm.data.base.Txt2ImgIterableBaseDataset
params: params:
path: Fazzie/Teyvat file_path: # YOUR DATAPATH
image_transforms: world_size: 1
- target: torchvision.transforms.Resize rank: 0
params:
size: 512
- target: torchvision.transforms.RandomCrop
params:
size: 512
- target: torchvision.transforms.RandomHorizontalFlip
lightning: lightning:
trainer: trainer:
accelerator: 'gpu' accelerator: 'gpu'
devices: 2 devices: 8
log_gpu_memory: all log_gpu_memory: all
max_epochs: 2 max_epochs: 2
precision: 16 precision: 16

View File

@ -1,5 +0,0 @@
# HF_DATASETS_OFFLINE=1
# TRANSFORMERS_OFFLINE=1
# DIFFUSERS_OFFLINE=1
python main.py --logdir /tmp/ -t -b configs/Teyvat/train_colossalai_teyvat.yaml

View File

@ -0,0 +1,5 @@
HF_DATASETS_OFFLINE=1
TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1
python main.py --logdir /tmp -t -b /configs/train_colossalai.yaml

View File

@ -0,0 +1,5 @@
HF_DATASETS_OFFLINE=1
TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1
python main.py --logdir /tmp -t -b /configs/train_ddp.yaml