mirror of https://github.com/hpcaitech/ColossalAI
[Fix] Grok-1 use tokenizer from the same pretrained path (#5532)
* [fix] use tokenizer from the same pretrained path * trust remote codepull/5535/head
parent
00525f7772
commit
36c4bb2893
|
@ -1,7 +1,7 @@
|
|||
import time
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from utils import get_defualt_parser, inference, print_output
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -9,6 +9,9 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
start = time.time()
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrained,
|
||||
trust_remote_code=True,
|
||||
|
@ -18,10 +21,6 @@ if __name__ == "__main__":
|
|||
model.eval()
|
||||
init_time = time.time() - start
|
||||
|
||||
# A transformers-compatible version of the grok-1 tokenizer by Xenova
|
||||
# https://huggingface.co/Xenova/grok-1-tokenizer
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
|
||||
|
||||
for text in args.text:
|
||||
output = inference(
|
||||
model,
|
||||
|
|
|
@ -2,7 +2,7 @@ import time
|
|||
|
||||
import torch
|
||||
from grok1_policy import Grok1ForCausalLMPolicy
|
||||
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from utils import get_defualt_parser, inference, print_output
|
||||
|
||||
import colossalai
|
||||
|
@ -27,6 +27,9 @@ if __name__ == "__main__":
|
|||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
|
||||
|
||||
with LazyInitContext(default_device=get_current_device()):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16
|
||||
|
@ -35,10 +38,6 @@ if __name__ == "__main__":
|
|||
model.eval()
|
||||
init_time = time.time() - start
|
||||
|
||||
# A transformers-compatible version of the grok-1 tokenizer by Xenova
|
||||
# https://huggingface.co/Xenova/grok-1-tokenizer
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
|
||||
|
||||
for text in args.text:
|
||||
output = inference(
|
||||
model.unwrap(),
|
||||
|
|
Loading…
Reference in New Issue