You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/examples/inference/stable_diffusion/sd3_generation.py

82 lines
3.4 KiB

import argparse
from diffusers import DiffusionPipeline
from torch import bfloat16
from torch import distributed as dist
from torch import float16, float32
import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
# For Stable Diffusion 3, we'll use the following configuration
MODEL_CLS = DiffusionPipeline
TORCH_DTYPE_MAP = {
"fp16": float16,
"fp32": float32,
"bf16": bfloat16,
}
def infer(args):
# ==============================
# Launch colossalai, setup distributed environment
# ==============================
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ==============================
# Load model and tokenizer
# ==============================
model_path_or_name = args.model
model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
# ==============================
# Initialize InferenceEngine
# ==============================
coordinator.print_on_master(f"Initializing Inference Engine...")
inference_config = InferenceConfig(
dtype=args.dtype,
max_batch_size=args.max_batch_size,
tp_size=args.tp_size,
use_cuda_kernel=args.use_cuda_kernel,
patched_parallelism_size=dist.get_world_size(),
)
engine = InferenceEngine(model, inference_config=inference_config, verbose=True)
# ==============================
# Generation
# ==============================
coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
if dist.get_rank() == 0:
out.save(f"cat_parallel_size{dist.get_world_size()}.jpg")
coordinator.print_on_master(out)
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
if __name__ == "__main__":
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, help="Path to the model or model name")
parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
parser.add_argument("-p", "--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt")
parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size")
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
args = parser.parse_args()
infer(args)