ColossalAI/colossalai/utils/checkpoint_io/backend.py

75 lines
1.9 KiB
Python
Raw Normal View History

import shutil
import tempfile
from abc import ABC, abstractmethod
from typing import Dict, List, Type
from .reader import CheckpointReader, DiskCheckpointReader
from .writer import CheckpointWriter, DiskCheckpointWriter
_backends: Dict[str, Type['CheckpointIOBackend']] = {}
def register(name: str):
assert name not in _backends, f'"{name}" is registered'
def wrapper(cls):
_backends[name] = cls
return cls
return wrapper
def get_backend(name: str) -> 'CheckpointIOBackend':
assert name in _backends, f'Unsupported backend "{name}"'
return _backends[name]()
class CheckpointIOBackend(ABC):
def __init__(self) -> None:
super().__init__()
self.temps: List[str] = []
@abstractmethod
def get_writer(self,
base_name: str,
overwrite: bool = False,
rank: int = 0,
world_size: int = 1) -> CheckpointWriter:
pass
@abstractmethod
def get_reader(self, base_name: str) -> CheckpointReader:
pass
@abstractmethod
def get_temp(self, base_name: str) -> str:
pass
@abstractmethod
def clean_temp(self) -> None:
pass
@register('disk')
class CheckpointDiskIO(CheckpointIOBackend):
def get_writer(self,
base_name: str,
overwrite: bool = False,
rank: int = 0,
world_size: int = 1) -> CheckpointWriter:
return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size)
def get_reader(self, base_name: str) -> CheckpointReader:
return DiskCheckpointReader(base_name)
def get_temp(self, base_name: str) -> str:
temp_dir_name = tempfile.mkdtemp(dir=base_name)
self.temps.append(temp_dir_name)
return temp_dir_name
def clean_temp(self) -> None:
for temp_dir_name in self.temps:
shutil.rmtree(temp_dir_name)