mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
82 lines
3.4 KiB
82 lines
3.4 KiB
import argparse
|
|
|
|
from diffusers import DiffusionPipeline
|
|
from torch import bfloat16
|
|
from torch import distributed as dist
|
|
from torch import float16, float32
|
|
|
|
import colossalai
|
|
from colossalai.cluster import DistCoordinator
|
|
from colossalai.inference.config import DiffusionGenerationConfig, InferenceConfig
|
|
from colossalai.inference.core.engine import InferenceEngine
|
|
|
|
# For Stable Diffusion 3, we'll use the following configuration
|
|
MODEL_CLS = DiffusionPipeline
|
|
|
|
TORCH_DTYPE_MAP = {
|
|
"fp16": float16,
|
|
"fp32": float32,
|
|
"bf16": bfloat16,
|
|
}
|
|
|
|
|
|
def infer(args):
|
|
# ==============================
|
|
# Launch colossalai, setup distributed environment
|
|
# ==============================
|
|
colossalai.launch_from_torch()
|
|
coordinator = DistCoordinator()
|
|
|
|
# ==============================
|
|
# Load model and tokenizer
|
|
# ==============================
|
|
model_path_or_name = args.model
|
|
model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
|
|
|
|
# ==============================
|
|
# Initialize InferenceEngine
|
|
# ==============================
|
|
coordinator.print_on_master(f"Initializing Inference Engine...")
|
|
inference_config = InferenceConfig(
|
|
dtype=args.dtype,
|
|
max_batch_size=args.max_batch_size,
|
|
tp_size=args.tp_size,
|
|
use_cuda_kernel=args.use_cuda_kernel,
|
|
patched_parallelism_size=dist.get_world_size(),
|
|
)
|
|
engine = InferenceEngine(model, inference_config=inference_config, verbose=True)
|
|
|
|
# ==============================
|
|
# Generation
|
|
# ==============================
|
|
coordinator.print_on_master(f"Generating...")
|
|
out = engine.generate(prompts=[args.prompt], generation_config=DiffusionGenerationConfig())[0]
|
|
if dist.get_rank() == 0:
|
|
out.save(f"cat_parallel_size{dist.get_world_size()}.jpg")
|
|
coordinator.print_on_master(out)
|
|
|
|
|
|
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m MODEL_PATH
|
|
|
|
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
|
|
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "stabilityai/stable-diffusion-3-medium-diffusers" --tp_size 1
|
|
|
|
# colossalai run --nproc_per_node 1 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
|
|
# colossalai run --nproc_per_node 2 examples/inference/stable_diffusion/sd3_generation.py -m "PixArt-alpha/PixArt-XL-2-1024-MS" --tp_size 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# ==============================
|
|
# Parse Arguments
|
|
# ==============================
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-m", "--model", type=str, help="Path to the model or model name")
|
|
parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
|
|
parser.add_argument("-p", "--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt")
|
|
parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size")
|
|
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
|
|
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
|
|
args = parser.parse_args()
|
|
|
|
infer(args)
|