mirror of https://github.com/hpcaitech/ColossalAI
[example] fix save_load bug for dreambooth (#2280)
parent
f027ef7913
commit
1405b4381e
|
@ -1,20 +1,22 @@
|
|||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="input"
|
||||
export OUTPUT_DIR="output"
|
||||
INSTANCE_PROMPT="a photo of sks dog"
|
||||
HF_DATASETS_OFFLINE=1
|
||||
TRANSFORMERS_OFFLINE=1
|
||||
export MODEL_NAME= <Your Pretrained Model Path>
|
||||
export INSTANCE_DIR= <Your Input Pics Path>
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
HF_DATASETS_OFFLINE=1
|
||||
TRANSFORMERS_OFFLINE=1
|
||||
DIFFUSERS_OFFLINE=1
|
||||
|
||||
torchrun --nproc_per_node 2 --master_port=25641 train_dreambooth_colossalai.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--instance_prompt="a photo of a dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--learning_rate=5e-6 \
|
||||
--instance_prompt=INSTANCE_PROMPT \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=400 \
|
||||
--placement="cpu"
|
||||
--num_class_images=200 \
|
||||
--placement="cuda" \
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
python train_dreambooth.py \
|
||||
--pretrained_model_name_or_path= ## Your Model Path \
|
||||
--instance_data_dir= ## Your Training Input Pics Path \
|
||||
--output_dir="path-to-save-model" \
|
||||
--instance_prompt="a photo of a dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--learning_rate=5e-6 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
|
@ -0,0 +1,12 @@
|
|||
from diffusers import StableDiffusionPipeline, DiffusionPipeline
|
||||
import torch
|
||||
|
||||
model_id = <Your Model Path>
|
||||
print(f"Loading model... from{model_id}")
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
prompt = "A photo of an apple."
|
||||
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
|
||||
|
||||
image.save("output.png")
|
|
@ -1,19 +0,0 @@
|
|||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="input"
|
||||
export OUTPUT_DIR="output"
|
||||
HF_DATASETS_OFFLINE=1
|
||||
TRANSFORMERS_OFFLINE=1
|
||||
DIFFUSERS_OFFLINE=1
|
||||
|
||||
accelerate launch train_dreambooth.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--learning_rate=5e-6 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=400
|
|
@ -11,6 +11,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from copy import deepcopy
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
|
@ -359,6 +360,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
|
|||
placement_policy=placememt_policy,
|
||||
pin_memory=True,
|
||||
search_range_mb=32)
|
||||
|
||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
||||
|
@ -381,6 +383,7 @@ def main(args):
|
|||
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
||||
"clip_grad_norm": args.max_grad_norm,
|
||||
}
|
||||
|
||||
colossalai.launch_from_torch(config=config)
|
||||
pg = ProcessGroup()
|
||||
|
||||
|
@ -465,21 +468,21 @@ def main(args):
|
|||
|
||||
text_encoder = text_encoder_cls.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
revision=args.revision,)
|
||||
|
||||
logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
revision=args.revision,)
|
||||
|
||||
with ColoInitContext(device='cpu'):
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
|
||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||
with ColoInitContext():
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
low_cpu_mem_usage=False)
|
||||
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
@ -597,7 +600,7 @@ def main(args):
|
|||
for epoch in range(args.num_train_epochs):
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
# Move batch to gpu
|
||||
for key, value in batch.items():
|
||||
batch[key] = value.to(get_current_device(), non_blocking=True)
|
||||
|
@ -653,7 +656,7 @@ def main(args):
|
|||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
@ -678,13 +681,15 @@ def main(args):
|
|||
break
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
unet=convert_to_torch_module(unet)
|
||||
|
||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=convert_to_torch_module(unet),
|
||||
unet=unet,
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
|
||||
|
||||
|
|
Loading…
Reference in New Issue