mirror of https://github.com/hpcaitech/ColossalAI
[example] fix save_load bug for dreambooth (#2280)
parent
f027ef7913
commit
1405b4381e
|
@ -1,20 +1,22 @@
|
||||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
export MODEL_NAME= <Your Pretrained Model Path>
|
||||||
export INSTANCE_DIR="input"
|
export INSTANCE_DIR= <Your Input Pics Path>
|
||||||
export OUTPUT_DIR="output"
|
export CLASS_DIR="path-to-class-images"
|
||||||
INSTANCE_PROMPT="a photo of sks dog"
|
export OUTPUT_DIR="path-to-save-model"
|
||||||
HF_DATASETS_OFFLINE=1
|
|
||||||
TRANSFORMERS_OFFLINE=1
|
HF_DATASETS_OFFLINE=1
|
||||||
|
TRANSFORMERS_OFFLINE=1
|
||||||
|
DIFFUSERS_OFFLINE=1
|
||||||
|
|
||||||
torchrun --nproc_per_node 2 --master_port=25641 train_dreambooth_colossalai.py \
|
torchrun --nproc_per_node 2 --master_port=25641 train_dreambooth_colossalai.py \
|
||||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||||
--instance_data_dir=$INSTANCE_DIR \
|
--instance_data_dir=$INSTANCE_DIR \
|
||||||
--output_dir=$OUTPUT_DIR \
|
--output_dir=$OUTPUT_DIR \
|
||||||
|
--instance_prompt="a photo of a dog" \
|
||||||
--resolution=512 \
|
--resolution=512 \
|
||||||
--train_batch_size=1 \
|
--train_batch_size=1 \
|
||||||
--gradient_accumulation_steps=1 \
|
--gradient_accumulation_steps=1 \
|
||||||
--learning_rate=5e-6 \
|
--learning_rate=5e-6 \
|
||||||
--instance_prompt=INSTANCE_PROMPT \
|
|
||||||
--lr_scheduler="constant" \
|
--lr_scheduler="constant" \
|
||||||
--lr_warmup_steps=0 \
|
--lr_warmup_steps=0 \
|
||||||
--max_train_steps=400 \
|
--num_class_images=200 \
|
||||||
--placement="cpu"
|
--placement="cuda" \
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
python train_dreambooth.py \
|
||||||
|
--pretrained_model_name_or_path= ## Your Model Path \
|
||||||
|
--instance_data_dir= ## Your Training Input Pics Path \
|
||||||
|
--output_dir="path-to-save-model" \
|
||||||
|
--instance_prompt="a photo of a dog" \
|
||||||
|
--resolution=512 \
|
||||||
|
--train_batch_size=1 \
|
||||||
|
--gradient_accumulation_steps=1 \
|
||||||
|
--learning_rate=5e-6 \
|
||||||
|
--lr_scheduler="constant" \
|
||||||
|
--lr_warmup_steps=0 \
|
||||||
|
--num_class_images=200 \
|
|
@ -0,0 +1,12 @@
|
||||||
|
from diffusers import StableDiffusionPipeline, DiffusionPipeline
|
||||||
|
import torch
|
||||||
|
|
||||||
|
model_id = <Your Model Path>
|
||||||
|
print(f"Loading model... from{model_id}")
|
||||||
|
|
||||||
|
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
||||||
|
|
||||||
|
prompt = "A photo of an apple."
|
||||||
|
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
|
||||||
|
|
||||||
|
image.save("output.png")
|
|
@ -1,19 +0,0 @@
|
||||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
|
||||||
export INSTANCE_DIR="input"
|
|
||||||
export OUTPUT_DIR="output"
|
|
||||||
HF_DATASETS_OFFLINE=1
|
|
||||||
TRANSFORMERS_OFFLINE=1
|
|
||||||
DIFFUSERS_OFFLINE=1
|
|
||||||
|
|
||||||
accelerate launch train_dreambooth.py \
|
|
||||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
|
||||||
--instance_data_dir=$INSTANCE_DIR \
|
|
||||||
--output_dir=$OUTPUT_DIR \
|
|
||||||
--instance_prompt="a photo of sks dog" \
|
|
||||||
--resolution=512 \
|
|
||||||
--train_batch_size=1 \
|
|
||||||
--gradient_accumulation_steps=1 \
|
|
||||||
--learning_rate=5e-6 \
|
|
||||||
--lr_scheduler="constant" \
|
|
||||||
--lr_warmup_steps=0 \
|
|
||||||
--max_train_steps=400
|
|
|
@ -11,6 +11,7 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
from copy import deepcopy
|
||||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from huggingface_hub import HfFolder, Repository, whoami
|
from huggingface_hub import HfFolder, Repository, whoami
|
||||||
|
@ -359,6 +360,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
|
||||||
placement_policy=placememt_policy,
|
placement_policy=placememt_policy,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
search_range_mb=32)
|
search_range_mb=32)
|
||||||
|
|
||||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||||
from colossalai.gemini import ChunkManager, GeminiManager
|
from colossalai.gemini import ChunkManager, GeminiManager
|
||||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
||||||
|
@ -381,6 +383,7 @@ def main(args):
|
||||||
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
||||||
"clip_grad_norm": args.max_grad_norm,
|
"clip_grad_norm": args.max_grad_norm,
|
||||||
}
|
}
|
||||||
|
|
||||||
colossalai.launch_from_torch(config=config)
|
colossalai.launch_from_torch(config=config)
|
||||||
pg = ProcessGroup()
|
pg = ProcessGroup()
|
||||||
|
|
||||||
|
@ -465,21 +468,21 @@ def main(args):
|
||||||
|
|
||||||
text_encoder = text_encoder_cls.from_pretrained(args.pretrained_model_name_or_path,
|
text_encoder = text_encoder_cls.from_pretrained(args.pretrained_model_name_or_path,
|
||||||
subfolder="text_encoder",
|
subfolder="text_encoder",
|
||||||
revision=args.revision,
|
revision=args.revision,)
|
||||||
low_cpu_mem_usage=False)
|
|
||||||
|
|
||||||
logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
|
logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path,
|
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path,
|
||||||
subfolder="vae",
|
subfolder="vae",
|
||||||
revision=args.revision,
|
revision=args.revision,)
|
||||||
low_cpu_mem_usage=False)
|
|
||||||
|
|
||||||
with ColoInitContext(device='cpu'):
|
|
||||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||||
|
with ColoInitContext():
|
||||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||||
subfolder="unet",
|
subfolder="unet",
|
||||||
revision=args.revision,
|
revision=args.revision,
|
||||||
low_cpu_mem_usage=False)
|
low_cpu_mem_usage=False)
|
||||||
|
|
||||||
|
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
text_encoder.requires_grad_(False)
|
text_encoder.requires_grad_(False)
|
||||||
|
@ -597,7 +600,7 @@ def main(args):
|
||||||
for epoch in range(args.num_train_epochs):
|
for epoch in range(args.num_train_epochs):
|
||||||
unet.train()
|
unet.train()
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
# Move batch to gpu
|
# Move batch to gpu
|
||||||
for key, value in batch.items():
|
for key, value in batch.items():
|
||||||
batch[key] = value.to(get_current_device(), non_blocking=True)
|
batch[key] = value.to(get_current_device(), non_blocking=True)
|
||||||
|
@ -653,7 +656,7 @@ def main(args):
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
@ -678,13 +681,15 @@ def main(args):
|
||||||
break
|
break
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
unet=convert_to_torch_module(unet)
|
||||||
|
|
||||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
args.pretrained_model_name_or_path,
|
args.pretrained_model_name_or_path,
|
||||||
unet=convert_to_torch_module(unet),
|
unet=unet,
|
||||||
revision=args.revision,
|
revision=args.revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline.save_pretrained(args.output_dir)
|
pipeline.save_pretrained(args.output_dir)
|
||||||
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
|
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue