diff --git a/utils.py b/utils.py index bfb20b3..45015f9 100644 --- a/utils.py +++ b/utils.py @@ -40,16 +40,14 @@ def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = if num_gpus < 2 and device_map is None: model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda() else: - from accelerate import load_checkpoint_and_dispatch + from accelerate import dispatch_model - model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs) - model = model.eval() + model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half() if device_map is None: device_map = auto_configure_device_map(num_gpus) - model = load_checkpoint_and_dispatch( - model, checkpoint_path, device_map=device_map, offload_folder="offload", offload_state_dict=True).half() + model = dispatch_model(model, device_map=device_map) return model