mirror of https://github.com/hpcaitech/ColossalAI
[Fix] Grok-1 use tokenizer from the same pretrained path (#5532)
* [fix] use tokenizer from the same pretrained path * trust remote codepull/5535/head
parent
00525f7772
commit
36c4bb2893
|
@ -1,7 +1,7 @@
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from utils import get_defualt_parser, inference, print_output
|
from utils import get_defualt_parser, inference, print_output
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -9,6 +9,9 @@ if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
start = time.time()
|
start = time.time()
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
args.pretrained,
|
args.pretrained,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
@ -18,10 +21,6 @@ if __name__ == "__main__":
|
||||||
model.eval()
|
model.eval()
|
||||||
init_time = time.time() - start
|
init_time = time.time() - start
|
||||||
|
|
||||||
# A transformers-compatible version of the grok-1 tokenizer by Xenova
|
|
||||||
# https://huggingface.co/Xenova/grok-1-tokenizer
|
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
|
|
||||||
|
|
||||||
for text in args.text:
|
for text in args.text:
|
||||||
output = inference(
|
output = inference(
|
||||||
model,
|
model,
|
||||||
|
|
|
@ -2,7 +2,7 @@ import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from grok1_policy import Grok1ForCausalLMPolicy
|
from grok1_policy import Grok1ForCausalLMPolicy
|
||||||
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from utils import get_defualt_parser, inference, print_output
|
from utils import get_defualt_parser, inference, print_output
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
@ -27,6 +27,9 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
|
||||||
|
|
||||||
with LazyInitContext(default_device=get_current_device()):
|
with LazyInitContext(default_device=get_current_device()):
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16
|
args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16
|
||||||
|
@ -35,10 +38,6 @@ if __name__ == "__main__":
|
||||||
model.eval()
|
model.eval()
|
||||||
init_time = time.time() - start
|
init_time = time.time() - start
|
||||||
|
|
||||||
# A transformers-compatible version of the grok-1 tokenizer by Xenova
|
|
||||||
# https://huggingface.co/Xenova/grok-1-tokenizer
|
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
|
|
||||||
|
|
||||||
for text in args.text:
|
for text in args.text:
|
||||||
output = inference(
|
output = inference(
|
||||||
model.unwrap(),
|
model.unwrap(),
|
||||||
|
|
Loading…
Reference in New Issue