diff --git a/examples/language/grok-1/inference.py b/examples/language/grok-1/inference.py index a73820af9..faef7ae9d 100644 --- a/examples/language/grok-1/inference.py +++ b/examples/language/grok-1/inference.py @@ -1,7 +1,7 @@ import time import torch -from transformers import AutoModelForCausalLM, LlamaTokenizerFast +from transformers import AutoModelForCausalLM, AutoTokenizer from utils import get_defualt_parser, inference, print_output if __name__ == "__main__": @@ -9,6 +9,9 @@ if __name__ == "__main__": args = parser.parse_args() start = time.time() torch.set_default_dtype(torch.bfloat16) + + tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( args.pretrained, trust_remote_code=True, @@ -18,10 +21,6 @@ if __name__ == "__main__": model.eval() init_time = time.time() - start - # A transformers-compatible version of the grok-1 tokenizer by Xenova - # https://huggingface.co/Xenova/grok-1-tokenizer - tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer") - for text in args.text: output = inference( model, diff --git a/examples/language/grok-1/inference_tp.py b/examples/language/grok-1/inference_tp.py index 604de1487..cf05880dc 100644 --- a/examples/language/grok-1/inference_tp.py +++ b/examples/language/grok-1/inference_tp.py @@ -2,7 +2,7 @@ import time import torch from grok1_policy import Grok1ForCausalLMPolicy -from transformers import AutoModelForCausalLM, LlamaTokenizerFast +from transformers import AutoModelForCausalLM, AutoTokenizer from utils import get_defualt_parser, inference, print_output import colossalai @@ -27,6 +27,9 @@ if __name__ == "__main__": ) booster = Booster(plugin=plugin) torch.set_default_dtype(torch.bfloat16) + + tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True) + with LazyInitContext(default_device=get_current_device()): model = AutoModelForCausalLM.from_pretrained( args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16 @@ -35,10 +38,6 @@ if __name__ == "__main__": model.eval() init_time = time.time() - start - # A transformers-compatible version of the grok-1 tokenizer by Xenova - # https://huggingface.co/Xenova/grok-1-tokenizer - tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer") - for text in args.text: output = inference( model.unwrap(),