diff --git a/tests/test_hf_model.py b/tests/test_hf_model.py index ddee7b1..798d3a6 100644 --- a/tests/test_hf_model.py +++ b/tests/test_hf_model.py @@ -177,7 +177,7 @@ class TestMMModel: model_name, trust_remote_code=True, device='cuda:0').eval() else: model = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype=torch.float16, + model_name, torch_dtype=torch.float32, trust_remote_code=True).cuda() tokenizer = AutoTokenizer.from_pretrained(model_name,