You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/utils/checkpoint_io/backend.py

75 lines
1.9 KiB

import shutil
import tempfile
from abc import ABC, abstractmethod
from typing import Dict, List, Type
from .reader import CheckpointReader, DiskCheckpointReader
from .writer import CheckpointWriter, DiskCheckpointWriter
_backends: Dict[str, Type['CheckpointIOBackend']] = {}
def register(name: str):
assert name not in _backends, f'"{name}" is registered'
def wrapper(cls):
_backends[name] = cls
return cls
return wrapper
def get_backend(name: str) -> 'CheckpointIOBackend':
assert name in _backends, f'Unsupported backend "{name}"'
return _backends[name]()
class CheckpointIOBackend(ABC):
def __init__(self) -> None:
super().__init__()
self.temps: List[str] = []
@abstractmethod
def get_writer(self,
base_name: str,
overwrite: bool = False,
rank: int = 0,
world_size: int = 1) -> CheckpointWriter:
pass
@abstractmethod
def get_reader(self, base_name: str) -> CheckpointReader:
pass
@abstractmethod
def get_temp(self, base_name: str) -> str:
pass
@abstractmethod
def clean_temp(self) -> None:
pass
@register('disk')
class CheckpointDiskIO(CheckpointIOBackend):
def get_writer(self,
base_name: str,
overwrite: bool = False,
rank: int = 0,
world_size: int = 1) -> CheckpointWriter:
return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size)
def get_reader(self, base_name: str) -> CheckpointReader:
return DiskCheckpointReader(base_name)
def get_temp(self, base_name: str) -> str:
temp_dir_name = tempfile.mkdtemp(dir=base_name)
self.temps.append(temp_dir_name)
return temp_dir_name
def clean_temp(self) -> None:
for temp_dir_name in self.temps:
shutil.rmtree(temp_dir_name)