mirror of https://github.com/hpcaitech/ColossalAI
[exmaple] diffuser, support quant inference for stable diffusion (#2186)
parent
bc0e271e71
commit
1cf6d92d7c
|
@ -22,6 +22,7 @@ from imwatermark import WatermarkEncoder
|
||||||
from scripts.txt2img import put_watermark
|
from scripts.txt2img import put_watermark
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from utils import replace_module, getModelSize
|
||||||
|
|
||||||
|
|
||||||
def chunk(it, size):
|
def chunk(it, size):
|
||||||
|
@ -44,7 +45,6 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||||
print("unexpected keys:")
|
print("unexpected keys:")
|
||||||
print(u)
|
print(u)
|
||||||
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -183,6 +183,12 @@ def main():
|
||||||
choices=["full", "autocast"],
|
choices=["full", "autocast"],
|
||||||
default="autocast"
|
default="autocast"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_int8",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="use int8 for inference",
|
||||||
|
)
|
||||||
|
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
seed_everything(opt.seed)
|
seed_everything(opt.seed)
|
||||||
|
@ -193,6 +199,12 @@ def main():
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
|
# quantize model
|
||||||
|
if opt.use_int8:
|
||||||
|
model = replace_module(model)
|
||||||
|
# # to compute the model size
|
||||||
|
# getModelSize(model)
|
||||||
|
|
||||||
sampler = DDIMSampler(model)
|
sampler = DDIMSampler(model)
|
||||||
|
|
||||||
os.makedirs(opt.outdir, exist_ok=True)
|
os.makedirs(opt.outdir, exist_ok=True)
|
||||||
|
@ -280,3 +292,5 @@ def main():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
# # to compute the mem allocated
|
||||||
|
# print(torch.cuda.max_memory_allocated() / 1024 / 1024)
|
||||||
|
|
|
@ -20,6 +20,7 @@ from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
||||||
|
from utils import replace_module, getModelSize
|
||||||
|
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
@ -43,7 +44,6 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||||
print("unexpected keys:")
|
print("unexpected keys:")
|
||||||
print(u)
|
print(u)
|
||||||
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -174,6 +174,12 @@ def parse_args():
|
||||||
default=1,
|
default=1,
|
||||||
help="repeat each prompt in file this often",
|
help="repeat each prompt in file this often",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_int8",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="use int8 for inference",
|
||||||
|
)
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
return opt
|
return opt
|
||||||
|
|
||||||
|
@ -191,10 +197,17 @@ def main(opt):
|
||||||
|
|
||||||
config = OmegaConf.load(f"{opt.config}")
|
config = OmegaConf.load(f"{opt.config}")
|
||||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||||
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
# quantize model
|
||||||
|
if opt.use_int8:
|
||||||
|
model = replace_module(model)
|
||||||
|
# # to compute the model size
|
||||||
|
# getModelSize(model)
|
||||||
|
|
||||||
if opt.plms:
|
if opt.plms:
|
||||||
sampler = PLMSSampler(model)
|
sampler = PLMSSampler(model)
|
||||||
elif opt.dpm:
|
elif opt.dpm:
|
||||||
|
@ -290,3 +303,5 @@ def main(opt):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
opt = parse_args()
|
opt = parse_args()
|
||||||
main(opt)
|
main(opt)
|
||||||
|
# # to compute the mem allocated
|
||||||
|
# print(torch.cuda.max_memory_allocated() / 1024 / 1024)
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class Linear8bit(nn.Linear):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_features,
|
||||||
|
output_features,
|
||||||
|
bias=True,
|
||||||
|
has_fp16_weights=False,
|
||||||
|
memory_efficient_backward=False,
|
||||||
|
threshold=6.0,
|
||||||
|
weight_data=None,
|
||||||
|
bias_data=None
|
||||||
|
):
|
||||||
|
super(Linear8bit, self).__init__(
|
||||||
|
input_features, output_features, bias
|
||||||
|
)
|
||||||
|
self.state = bnb.MatmulLtState()
|
||||||
|
self.bias = bias_data
|
||||||
|
self.state.threshold = threshold
|
||||||
|
self.state.has_fp16_weights = has_fp16_weights
|
||||||
|
self.state.memory_efficient_backward = memory_efficient_backward
|
||||||
|
if threshold > 0.0 and not has_fp16_weights:
|
||||||
|
self.state.use_pool = True
|
||||||
|
|
||||||
|
self.register_parameter("SCB", nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
self.weight = weight_data
|
||||||
|
self.quant()
|
||||||
|
|
||||||
|
|
||||||
|
def quant(self):
|
||||||
|
weight = self.weight.data.contiguous().half().cuda()
|
||||||
|
CB, _, SCB, _, _ = bnb.functional.double_quant(weight)
|
||||||
|
delattr(self, "weight")
|
||||||
|
setattr(self, "weight", nn.Parameter(CB, requires_grad=False))
|
||||||
|
delattr(self, "SCB")
|
||||||
|
setattr(self, "SCB", nn.Parameter(SCB, requires_grad=False))
|
||||||
|
del weight
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
self.state.is_training = self.training
|
||||||
|
|
||||||
|
if self.bias is not None and self.bias.dtype != torch.float16:
|
||||||
|
self.bias.data = self.bias.data.half()
|
||||||
|
|
||||||
|
self.state.CB = self.weight.data
|
||||||
|
self.state.SCB = self.SCB.data
|
||||||
|
|
||||||
|
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||||
|
del self.state.CxB
|
||||||
|
return out
|
||||||
|
|
||||||
|
def replace_module(model):
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if len(list(module.children())) > 0:
|
||||||
|
replace_module(module)
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear) and "out_proj" not in name:
|
||||||
|
model._modules[name] = Linear8bit(
|
||||||
|
input_features=module.in_features,
|
||||||
|
output_features=module.out_features,
|
||||||
|
threshold=6.0,
|
||||||
|
weight_data=module.weight,
|
||||||
|
bias_data=module.bias,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def getModelSize(model):
|
||||||
|
param_size = 0
|
||||||
|
param_sum = 0
|
||||||
|
for param in model.parameters():
|
||||||
|
param_size += param.nelement() * param.element_size()
|
||||||
|
param_sum += param.nelement()
|
||||||
|
buffer_size = 0
|
||||||
|
buffer_sum = 0
|
||||||
|
for buffer in model.buffers():
|
||||||
|
buffer_size += buffer.nelement() * buffer.element_size()
|
||||||
|
buffer_sum += buffer.nelement()
|
||||||
|
all_size = (param_size + buffer_size) / 1024 / 1024
|
||||||
|
print('Model Size: {:.3f}MB'.format(all_size))
|
||||||
|
return (param_size, param_sum, buffer_size, buffer_sum, all_size)
|
Loading…
Reference in New Issue