diff --git a/tests/test_hf_model.py b/tests/test_hf_model.py index dabfdf1..a3f11b3 100644 --- a/tests/test_hf_model.py +++ b/tests/test_hf_model.py @@ -125,7 +125,7 @@ class TestMMModel: # it will be loaded as float32 and might cause OOM Error. model = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype=torch.float16, + model_name, torch_dtype=torch.float32, trust_remote_code=True).cuda() tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)