From 1cf6d92d7c93a26e29cadeb71bb34ee96b149a28 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Fri, 23 Dec 2022 16:06:29 +0800 Subject: [PATCH] [exmaple] diffuser, support quant inference for stable diffusion (#2186) --- examples/images/diffusion/scripts/img2img.py | 16 +++- examples/images/diffusion/scripts/txt2img.py | 21 ++++- examples/images/diffusion/scripts/utils.py | 83 ++++++++++++++++++++ 3 files changed, 116 insertions(+), 4 deletions(-) create mode 100644 examples/images/diffusion/scripts/utils.py diff --git a/examples/images/diffusion/scripts/img2img.py b/examples/images/diffusion/scripts/img2img.py index e8ccfa259..877538d47 100644 --- a/examples/images/diffusion/scripts/img2img.py +++ b/examples/images/diffusion/scripts/img2img.py @@ -22,6 +22,7 @@ from imwatermark import WatermarkEncoder from scripts.txt2img import put_watermark from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler +from utils import replace_module, getModelSize def chunk(it, size): @@ -44,7 +45,6 @@ def load_model_from_config(config, ckpt, verbose=False): print("unexpected keys:") print(u) - model.cuda() model.eval() return model @@ -183,6 +183,12 @@ def main(): choices=["full", "autocast"], default="autocast" ) + parser.add_argument( + "--use_int8", + type=bool, + default=False, + help="use int8 for inference", + ) opt = parser.parse_args() seed_everything(opt.seed) @@ -193,6 +199,12 @@ def main(): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) + # quantize model + if opt.use_int8: + model = replace_module(model) + # # to compute the model size + # getModelSize(model) + sampler = DDIMSampler(model) os.makedirs(opt.outdir, exist_ok=True) @@ -280,3 +292,5 @@ def main(): if __name__ == "__main__": main() + # # to compute the mem allocated + # print(torch.cuda.max_memory_allocated() / 1024 / 1024) diff --git a/examples/images/diffusion/scripts/txt2img.py b/examples/images/diffusion/scripts/txt2img.py index 15993008f..364ebac6c 100644 --- a/examples/images/diffusion/scripts/txt2img.py +++ b/examples/images/diffusion/scripts/txt2img.py @@ -20,6 +20,7 @@ from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler +from utils import replace_module, getModelSize torch.set_grad_enabled(False) @@ -43,7 +44,6 @@ def load_model_from_config(config, ckpt, verbose=False): print("unexpected keys:") print(u) - model.cuda() model.eval() return model @@ -174,6 +174,12 @@ def parse_args(): default=1, help="repeat each prompt in file this often", ) + parser.add_argument( + "--use_int8", + type=bool, + default=False, + help="use int8 for inference", + ) opt = parser.parse_args() return opt @@ -191,10 +197,17 @@ def main(opt): config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") - + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - model = model.to(device) + model = model.to(device) + + # quantize model + if opt.use_int8: + model = replace_module(model) + # # to compute the model size + # getModelSize(model) + if opt.plms: sampler = PLMSSampler(model) elif opt.dpm: @@ -290,3 +303,5 @@ def main(opt): if __name__ == "__main__": opt = parse_args() main(opt) + # # to compute the mem allocated + # print(torch.cuda.max_memory_allocated() / 1024 / 1024) diff --git a/examples/images/diffusion/scripts/utils.py b/examples/images/diffusion/scripts/utils.py new file mode 100644 index 000000000..c954b22ca --- /dev/null +++ b/examples/images/diffusion/scripts/utils.py @@ -0,0 +1,83 @@ +import bitsandbytes as bnb +import torch.nn as nn +import torch + +class Linear8bit(nn.Linear): + def __init__( + self, + input_features, + output_features, + bias=True, + has_fp16_weights=False, + memory_efficient_backward=False, + threshold=6.0, + weight_data=None, + bias_data=None + ): + super(Linear8bit, self).__init__( + input_features, output_features, bias + ) + self.state = bnb.MatmulLtState() + self.bias = bias_data + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.register_parameter("SCB", nn.Parameter(torch.empty(0), requires_grad=False)) + self.weight = weight_data + self.quant() + + + def quant(self): + weight = self.weight.data.contiguous().half().cuda() + CB, _, SCB, _, _ = bnb.functional.double_quant(weight) + delattr(self, "weight") + setattr(self, "weight", nn.Parameter(CB, requires_grad=False)) + delattr(self, "SCB") + setattr(self, "SCB", nn.Parameter(SCB, requires_grad=False)) + del weight + + def forward(self, x): + self.state.is_training = self.training + + if self.bias is not None and self.bias.dtype != torch.float16: + self.bias.data = self.bias.data.half() + + self.state.CB = self.weight.data + self.state.SCB = self.SCB.data + + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + del self.state.CxB + return out + +def replace_module(model): + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_module(module) + + if isinstance(module, nn.Linear) and "out_proj" not in name: + model._modules[name] = Linear8bit( + input_features=module.in_features, + output_features=module.out_features, + threshold=6.0, + weight_data=module.weight, + bias_data=module.bias, + ) + return model + +def getModelSize(model): + param_size = 0 + param_sum = 0 + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + param_sum += param.nelement() + buffer_size = 0 + buffer_sum = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + buffer_sum += buffer.nelement() + all_size = (param_size + buffer_size) / 1024 / 1024 + print('Model Size: {:.3f}MB'.format(all_size)) + return (param_size, param_sum, buffer_size, buffer_sum, all_size)