[Fix] Grok-1 use tokenizer from the same pretrained path (#5532)

* [fix] use tokenizer from the same pretrained path

* trust remote code
pull/5535/head
Yuanheng Zhao 2024-03-28 16:30:04 +08:00 committed by GitHub
parent 00525f7772
commit 36c4bb2893
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 10 deletions

View File

@ -1,7 +1,7 @@
import time
import torch
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import get_defualt_parser, inference, print_output
if __name__ == "__main__":
@ -9,6 +9,9 @@ if __name__ == "__main__":
args = parser.parse_args()
start = time.time()
torch.set_default_dtype(torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
trust_remote_code=True,
@ -18,10 +21,6 @@ if __name__ == "__main__":
model.eval()
init_time = time.time() - start
# A transformers-compatible version of the grok-1 tokenizer by Xenova
# https://huggingface.co/Xenova/grok-1-tokenizer
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
for text in args.text:
output = inference(
model,

View File

@ -2,7 +2,7 @@ import time
import torch
from grok1_policy import Grok1ForCausalLMPolicy
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import get_defualt_parser, inference, print_output
import colossalai
@ -27,6 +27,9 @@ if __name__ == "__main__":
)
booster = Booster(plugin=plugin)
torch.set_default_dtype(torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
with LazyInitContext(default_device=get_current_device()):
model = AutoModelForCausalLM.from_pretrained(
args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16
@ -35,10 +38,6 @@ if __name__ == "__main__":
model.eval()
init_time = time.time() - start
# A transformers-compatible version of the grok-1 tokenizer by Xenova
# https://huggingface.co/Xenova/grok-1-tokenizer
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
for text in args.text:
output = inference(
model.unwrap(),