ColossalAI/colossalai/booster/plugin/plugin_base.py

81 lines
2.1 KiB
Python
Raw Normal View History

from abc import ABC, abstractmethod
from typing import Callable, Iterator, List, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader, Dataset
from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import OptimizerWrapper
__all__ = ["Plugin"]
class Plugin(ABC):
@abstractmethod
def supported_devices(self) -> List[str]:
pass
@abstractmethod
def supported_precisions(self) -> List[str]:
pass
@abstractmethod
def control_precision(self) -> bool:
pass
@abstractmethod
def control_device(self) -> bool:
pass
@abstractmethod
def support_no_sync(self) -> bool:
pass
@abstractmethod
def configure(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# implement this method
pass
@abstractmethod
def control_checkpoint_io(self) -> bool:
"""
Whether the plugin controls the checkpoint io
"""
@abstractmethod
def get_checkpoint_io(self) -> CheckpointIO:
"""
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
"""
@abstractmethod
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
"""
Context manager to disable gradient synchronization.
"""
@abstractmethod
def prepare_dataloader(
self,
dataset: Dataset,
batch_size: int,
shuffle: bool = False,
seed: int = 1024,
drop_last: bool = False,
pin_memory: bool = False,
num_workers: int = 0,
**kwargs,
):
"""Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader`
"""