mirror of https://github.com/hpcaitech/ColossalAI
botbw
1 month ago
1 changed files with 49 additions and 0 deletions
@ -0,0 +1,49 @@
|
||||
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 |
||||
import json |
||||
from dataclasses import asdict, dataclass |
||||
from typing import Dict, List, Tuple |
||||
|
||||
import torch |
||||
from safetensors.torch import _TYPES |
||||
|
||||
_TYPES_INV = {v: k for k, v in _TYPES.items()} |
||||
|
||||
|
||||
@dataclass |
||||
class TensorInfo: |
||||
dtype: str |
||||
shape: List[int] |
||||
data_offsets: Tuple[int, int] |
||||
|
||||
|
||||
@dataclass |
||||
class PreparedData: |
||||
n: int |
||||
header_bytes: bytes |
||||
offset: int |
||||
|
||||
|
||||
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor]]: |
||||
sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0])) |
||||
|
||||
tensors = [] |
||||
metadata = {} |
||||
offset = 0 |
||||
|
||||
for name, tensor in sorted_data: |
||||
n = tensor.numel() * tensor.element_size() |
||||
tensor_info = TensorInfo( |
||||
dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n) |
||||
) |
||||
offset += n |
||||
metadata[name] = asdict(tensor_info) |
||||
tensors.append(tensor) |
||||
|
||||
metadata_buf = json.dumps(metadata).encode("utf-8") |
||||
|
||||
extra = (8 - len(metadata_buf) % 8) % 8 |
||||
metadata_buf += b" " * extra |
||||
|
||||
n = len(metadata_buf) |
||||
|
||||
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors |
Loading…
Reference in new issue