mirror of https://github.com/hpcaitech/ColossalAI
[chore] refactor
parent
162251ab78
commit
ad6558e91c
|
@ -6,6 +6,10 @@ from typing import Dict, List, Tuple
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import _TYPES
|
from safetensors.torch import _TYPES
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tensornvme.async_file_io import AsyncFileWriter
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||||
_TYPES_INV = {v: k for k, v in _TYPES.items()}
|
_TYPES_INV = {v: k for k, v in _TYPES.items()}
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,3 +51,14 @@ def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Ten
|
||||||
n = len(metadata_buf)
|
n = len(metadata_buf)
|
||||||
|
|
||||||
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors
|
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors
|
||||||
|
|
||||||
|
|
||||||
|
def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||||
|
prepared_data, tensors = prepare(state_dict)
|
||||||
|
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||||
|
|
||||||
|
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||||
|
f_writer.write(header_bytes)
|
||||||
|
|
||||||
|
for tensor in tensors:
|
||||||
|
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
|
||||||
|
|
Loading…
Reference in New Issue