Browse Source

[ckpt] add safetensors util

pull/6088/head
botbw 1 month ago
parent
commit
162251ab78
  1. 49
      colossalai/utils/safetensors.py

49
colossalai/utils/safetensors.py

@ -0,0 +1,49 @@
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json
from dataclasses import asdict, dataclass
from typing import Dict, List, Tuple
import torch
from safetensors.torch import _TYPES
_TYPES_INV = {v: k for k, v in _TYPES.items()}
@dataclass
class TensorInfo:
dtype: str
shape: List[int]
data_offsets: Tuple[int, int]
@dataclass
class PreparedData:
n: int
header_bytes: bytes
offset: int
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor]]:
sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0]))
tensors = []
metadata = {}
offset = 0
for name, tensor in sorted_data:
n = tensor.numel() * tensor.element_size()
tensor_info = TensorInfo(
dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n)
)
offset += n
metadata[name] = asdict(tensor_info)
tensors.append(tensor)
metadata_buf = json.dumps(metadata).encode("utf-8")
extra = (8 - len(metadata_buf) % 8) % 8
metadata_buf += b" " * extra
n = len(metadata_buf)
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors
Loading…
Cancel
Save