mirror of https://github.com/hpcaitech/ColossalAI
42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import transformers
|
|
from transformers import LlamaConfig, LlamaForCausalLM
|
|
|
|
from .model_utils import find_layers
|
|
from .quant import make_quant
|
|
|
|
|
|
def load_quant(pretrained: str, checkpoint: str, wbits: int, groupsize: int):
|
|
config = LlamaConfig.from_pretrained(pretrained)
|
|
|
|
def noop(*args, **kwargs):
|
|
pass
|
|
|
|
torch.nn.init.kaiming_uniform_ = noop
|
|
torch.nn.init.uniform_ = noop
|
|
torch.nn.init.normal_ = noop
|
|
|
|
torch.set_default_dtype(torch.half)
|
|
transformers.modeling_utils._init_weights = False
|
|
torch.set_default_dtype(torch.half)
|
|
model = LlamaForCausalLM(config)
|
|
torch.set_default_dtype(torch.float)
|
|
model = model.eval()
|
|
layers = find_layers(model)
|
|
for name in ['lm_head']:
|
|
if name in layers:
|
|
del layers[name]
|
|
make_quant(model, layers, wbits, groupsize)
|
|
|
|
print(f'Loading model with {wbits} bits...')
|
|
if checkpoint.endswith('.safetensors'):
|
|
from safetensors.torch import load_file as safe_load
|
|
model.load_state_dict(safe_load(checkpoint))
|
|
else:
|
|
model.load_state_dict(torch.load(checkpoint))
|
|
model.seqlen = 2048
|
|
print('Done.')
|
|
|
|
return model
|