[exmaple] diffuser, support quant inference for stable diffusion (#2186)

pull/2189/head^2
BlueRum 2022-12-23 16:06:29 +08:00 committed by GitHub
parent bc0e271e71
commit 1cf6d92d7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 116 additions and 4 deletions

View File

@ -22,6 +22,7 @@ from imwatermark import WatermarkEncoder
from scripts.txt2img import put_watermark
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from utils import replace_module, getModelSize
def chunk(it, size):
@ -44,7 +45,6 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model
@ -183,6 +183,12 @@ def main():
choices=["full", "autocast"],
default="autocast"
)
parser.add_argument(
"--use_int8",
type=bool,
default=False,
help="use int8 for inference",
)
opt = parser.parse_args()
seed_everything(opt.seed)
@ -193,6 +199,12 @@ def main():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
# quantize model
if opt.use_int8:
model = replace_module(model)
# # to compute the model size
# getModelSize(model)
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
@ -280,3 +292,5 @@ def main():
if __name__ == "__main__":
main()
# # to compute the mem allocated
# print(torch.cuda.max_memory_allocated() / 1024 / 1024)

View File

@ -20,6 +20,7 @@ from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
from utils import replace_module, getModelSize
torch.set_grad_enabled(False)
@ -43,7 +44,6 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model
@ -174,6 +174,12 @@ def parse_args():
default=1,
help="repeat each prompt in file this often",
)
parser.add_argument(
"--use_int8",
type=bool,
default=False,
help="use int8 for inference",
)
opt = parser.parse_args()
return opt
@ -191,10 +197,17 @@ def main(opt):
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
model = model.to(device)
# quantize model
if opt.use_int8:
model = replace_module(model)
# # to compute the model size
# getModelSize(model)
if opt.plms:
sampler = PLMSSampler(model)
elif opt.dpm:
@ -290,3 +303,5 @@ def main(opt):
if __name__ == "__main__":
opt = parse_args()
main(opt)
# # to compute the mem allocated
# print(torch.cuda.max_memory_allocated() / 1024 / 1024)

View File

@ -0,0 +1,83 @@
import bitsandbytes as bnb
import torch.nn as nn
import torch
class Linear8bit(nn.Linear):
def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=False,
memory_efficient_backward=False,
threshold=6.0,
weight_data=None,
bias_data=None
):
super(Linear8bit, self).__init__(
input_features, output_features, bias
)
self.state = bnb.MatmulLtState()
self.bias = bias_data
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.register_parameter("SCB", nn.Parameter(torch.empty(0), requires_grad=False))
self.weight = weight_data
self.quant()
def quant(self):
weight = self.weight.data.contiguous().half().cuda()
CB, _, SCB, _, _ = bnb.functional.double_quant(weight)
delattr(self, "weight")
setattr(self, "weight", nn.Parameter(CB, requires_grad=False))
delattr(self, "SCB")
setattr(self, "SCB", nn.Parameter(SCB, requires_grad=False))
del weight
def forward(self, x):
self.state.is_training = self.training
if self.bias is not None and self.bias.dtype != torch.float16:
self.bias.data = self.bias.data.half()
self.state.CB = self.weight.data
self.state.SCB = self.SCB.data
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
del self.state.CxB
return out
def replace_module(model):
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_module(module)
if isinstance(module, nn.Linear) and "out_proj" not in name:
model._modules[name] = Linear8bit(
input_features=module.in_features,
output_features=module.out_features,
threshold=6.0,
weight_data=module.weight,
bias_data=module.bias,
)
return model
def getModelSize(model):
param_size = 0
param_sum = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
param_sum += param.nelement()
buffer_size = 0
buffer_sum = 0
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
buffer_sum += buffer.nelement()
all_size = (param_size + buffer_size) / 1024 / 1024
print('Model Size: {:.3f}MB'.format(all_size))
return (param_size, param_sum, buffer_size, buffer_sum, all_size)