[example] Change some training settings for diffusion (#2195)

pull/2197/head
BlueRum 2 years ago committed by GitHub
parent 2458659919
commit 6642cebdbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -87,14 +87,15 @@ you should the change the `data.file_path` in the `config/train_colossalai.yaml`
## Training
We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml` and `train_ddp.yaml`
We provide the script `train_colossalai.sh` to run the training task with colossalai,
and can also use `train_ddp.sh` to run the training task with ddp to compare.
For example, you can run the training from colossalai by
In `train_colossalai.sh` the main command is:
```
python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml
```
- you can change the `--logdir` the save the log information and the last checkpoint
- you can change the `--logdir` to decide where to save the log information and the last checkpoint.
### Training config
@ -155,6 +156,7 @@ optional arguments:
--config CONFIG path to config which constructs model
--ckpt CKPT path to checkpoint of model
--seed SEED the seed (for reproducible sampling)
--use_int8 whether to use quantization method
--precision {full,autocast}
evaluate at this precision
```

@ -80,19 +80,22 @@ model:
data:
target: main.DataModuleFromConfig
params:
batch_size: 64
batch_size: 128
wrap: False
# num_workwers should be 2 * batch_size, and total num less than 1024
# e.g. if use 8 devices, no more than 128
num_workers: 128
train:
target: ldm.data.base.Txt2ImgIterableBaseDataset
params:
file_path: "/data/scratch/diffuser/laion_part0/"
file_path: # YOUR DATASET_PATH
world_size: 1
rank: 0
lightning:
trainer:
accelerator: 'gpu'
devices: 4
devices: 8
log_gpu_memory: all
max_epochs: 2
precision: 16

@ -80,25 +80,21 @@ model:
data:
target: main.DataModuleFromConfig
params:
batch_size: 16
num_workers: 4
batch_size: 128
# num_workwers should be 2 * batch_size, and the total num less than 1024
# e.g. if use 8 devices, no more than 128
num_workers: 128
train:
target: ldm.data.teyvat.hf_dataset
target: ldm.data.base.Txt2ImgIterableBaseDataset
params:
path: Fazzie/Teyvat
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
- target: torchvision.transforms.RandomCrop
params:
size: 512
- target: torchvision.transforms.RandomHorizontalFlip
file_path: # YOUR DATAPATH
world_size: 1
rank: 0
lightning:
trainer:
accelerator: 'gpu'
devices: 2
devices: 8
log_gpu_memory: all
max_epochs: 2
precision: 16

@ -1,5 +0,0 @@
# HF_DATASETS_OFFLINE=1
# TRANSFORMERS_OFFLINE=1
# DIFFUSERS_OFFLINE=1
python main.py --logdir /tmp/ -t -b configs/Teyvat/train_colossalai_teyvat.yaml

@ -0,0 +1,5 @@
HF_DATASETS_OFFLINE=1
TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1
python main.py --logdir /tmp -t -b /configs/train_colossalai.yaml

@ -0,0 +1,5 @@
HF_DATASETS_OFFLINE=1
TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1
python main.py --logdir /tmp -t -b /configs/train_ddp.yaml
Loading…
Cancel
Save