import torch import torch.nn as nn import transformers from transformers import LlamaConfig, LlamaForCausalLM from .model_utils import find_layers from .quant import make_quant def load_quant(pretrained: str, checkpoint: str, wbits: int, groupsize: int): config = LlamaConfig.from_pretrained(pretrained) def noop(*args, **kwargs): pass torch.nn.init.kaiming_uniform_ = noop torch.nn.init.uniform_ = noop torch.nn.init.normal_ = noop torch.set_default_dtype(torch.half) transformers.modeling_utils._init_weights = False torch.set_default_dtype(torch.half) model = LlamaForCausalLM(config) torch.set_default_dtype(torch.float) model = model.eval() layers = find_layers(model) for name in ['lm_head']: if name in layers: del layers[name] make_quant(model, layers, wbits, groupsize) print(f'Loading model with {wbits} bits...') if checkpoint.endswith('.safetensors'): from safetensors.torch import load_file as safe_load model.load_state_dict(safe_load(checkpoint)) else: model.load_state_dict(torch.load(checkpoint)) model.seqlen = 2048 print('Done.') return model