mirror of https://github.com/hpcaitech/ColossalAI
[exmaple] diffuser, support quant inference for stable diffusion (#2186)
parent
bc0e271e71
commit
1cf6d92d7c
|
@ -22,6 +22,7 @@ from imwatermark import WatermarkEncoder
|
|||
from scripts.txt2img import put_watermark
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from utils import replace_module, getModelSize
|
||||
|
||||
|
||||
def chunk(it, size):
|
||||
|
@ -44,7 +45,6 @@ def load_model_from_config(config, ckpt, verbose=False):
|
|||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
@ -183,6 +183,12 @@ def main():
|
|||
choices=["full", "autocast"],
|
||||
default="autocast"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_int8",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use int8 for inference",
|
||||
)
|
||||
|
||||
opt = parser.parse_args()
|
||||
seed_everything(opt.seed)
|
||||
|
@ -193,6 +199,12 @@ def main():
|
|||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.to(device)
|
||||
|
||||
# quantize model
|
||||
if opt.use_int8:
|
||||
model = replace_module(model)
|
||||
# # to compute the model size
|
||||
# getModelSize(model)
|
||||
|
||||
sampler = DDIMSampler(model)
|
||||
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
|
@ -280,3 +292,5 @@ def main():
|
|||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# # to compute the mem allocated
|
||||
# print(torch.cuda.max_memory_allocated() / 1024 / 1024)
|
||||
|
|
|
@ -20,6 +20,7 @@ from ldm.util import instantiate_from_config
|
|||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
||||
from utils import replace_module, getModelSize
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
@ -43,7 +44,6 @@ def load_model_from_config(config, ckpt, verbose=False):
|
|||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
@ -174,6 +174,12 @@ def parse_args():
|
|||
default=1,
|
||||
help="repeat each prompt in file this often",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_int8",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use int8 for inference",
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
return opt
|
||||
|
||||
|
@ -191,10 +197,17 @@ def main(opt):
|
|||
|
||||
config = OmegaConf.load(f"{opt.config}")
|
||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.to(device)
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
# quantize model
|
||||
if opt.use_int8:
|
||||
model = replace_module(model)
|
||||
# # to compute the model size
|
||||
# getModelSize(model)
|
||||
|
||||
if opt.plms:
|
||||
sampler = PLMSSampler(model)
|
||||
elif opt.dpm:
|
||||
|
@ -290,3 +303,5 @@ def main(opt):
|
|||
if __name__ == "__main__":
|
||||
opt = parse_args()
|
||||
main(opt)
|
||||
# # to compute the mem allocated
|
||||
# print(torch.cuda.max_memory_allocated() / 1024 / 1024)
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
import bitsandbytes as bnb
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
class Linear8bit(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
input_features,
|
||||
output_features,
|
||||
bias=True,
|
||||
has_fp16_weights=False,
|
||||
memory_efficient_backward=False,
|
||||
threshold=6.0,
|
||||
weight_data=None,
|
||||
bias_data=None
|
||||
):
|
||||
super(Linear8bit, self).__init__(
|
||||
input_features, output_features, bias
|
||||
)
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.bias = bias_data
|
||||
self.state.threshold = threshold
|
||||
self.state.has_fp16_weights = has_fp16_weights
|
||||
self.state.memory_efficient_backward = memory_efficient_backward
|
||||
if threshold > 0.0 and not has_fp16_weights:
|
||||
self.state.use_pool = True
|
||||
|
||||
self.register_parameter("SCB", nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
self.weight = weight_data
|
||||
self.quant()
|
||||
|
||||
|
||||
def quant(self):
|
||||
weight = self.weight.data.contiguous().half().cuda()
|
||||
CB, _, SCB, _, _ = bnb.functional.double_quant(weight)
|
||||
delattr(self, "weight")
|
||||
setattr(self, "weight", nn.Parameter(CB, requires_grad=False))
|
||||
delattr(self, "SCB")
|
||||
setattr(self, "SCB", nn.Parameter(SCB, requires_grad=False))
|
||||
del weight
|
||||
|
||||
def forward(self, x):
|
||||
self.state.is_training = self.training
|
||||
|
||||
if self.bias is not None and self.bias.dtype != torch.float16:
|
||||
self.bias.data = self.bias.data.half()
|
||||
|
||||
self.state.CB = self.weight.data
|
||||
self.state.SCB = self.SCB.data
|
||||
|
||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||
del self.state.CxB
|
||||
return out
|
||||
|
||||
def replace_module(model):
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
replace_module(module)
|
||||
|
||||
if isinstance(module, nn.Linear) and "out_proj" not in name:
|
||||
model._modules[name] = Linear8bit(
|
||||
input_features=module.in_features,
|
||||
output_features=module.out_features,
|
||||
threshold=6.0,
|
||||
weight_data=module.weight,
|
||||
bias_data=module.bias,
|
||||
)
|
||||
return model
|
||||
|
||||
def getModelSize(model):
|
||||
param_size = 0
|
||||
param_sum = 0
|
||||
for param in model.parameters():
|
||||
param_size += param.nelement() * param.element_size()
|
||||
param_sum += param.nelement()
|
||||
buffer_size = 0
|
||||
buffer_sum = 0
|
||||
for buffer in model.buffers():
|
||||
buffer_size += buffer.nelement() * buffer.element_size()
|
||||
buffer_sum += buffer.nelement()
|
||||
all_size = (param_size + buffer_size) / 1024 / 1024
|
||||
print('Model Size: {:.3f}MB'.format(all_size))
|
||||
return (param_size, param_sum, buffer_size, buffer_sum, all_size)
|
Loading…
Reference in New Issue