diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index 9aa3558d9..390e12511 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -61,4 +61,4 @@ def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None f_writer.write(header_bytes) for tensor in tensors: - f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset) + f_writer.write_tensor(tensor)