mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
90 lines
2.4 KiB
90 lines
2.4 KiB
from abc import ABC, abstractmethod |
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple |
|
|
|
import torch.nn as nn |
|
from torch.optim import Optimizer |
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler |
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
from colossalai.checkpoint_io import CheckpointIO |
|
from colossalai.interface import OptimizerWrapper |
|
|
|
__all__ = ["Plugin"] |
|
|
|
|
|
class Plugin(ABC): |
|
@abstractmethod |
|
def supported_devices(self) -> List[str]: |
|
pass |
|
|
|
@abstractmethod |
|
def supported_precisions(self) -> List[str]: |
|
pass |
|
|
|
@abstractmethod |
|
def control_precision(self) -> bool: |
|
pass |
|
|
|
@abstractmethod |
|
def control_device(self) -> bool: |
|
pass |
|
|
|
@abstractmethod |
|
def support_no_sync(self) -> bool: |
|
pass |
|
|
|
@abstractmethod |
|
def support_lora(self) -> bool: |
|
pass |
|
|
|
@abstractmethod |
|
def configure( |
|
self, |
|
model: nn.Module, |
|
optimizer: Optional[Optimizer] = None, |
|
criterion: Optional[Callable] = None, |
|
dataloader: Optional[DataLoader] = None, |
|
lr_scheduler: Optional[LRScheduler] = None, |
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: |
|
# implement this method |
|
pass |
|
|
|
@abstractmethod |
|
def control_checkpoint_io(self) -> bool: |
|
""" |
|
Whether the plugin controls the checkpoint io |
|
""" |
|
|
|
@abstractmethod |
|
def get_checkpoint_io(self) -> CheckpointIO: |
|
""" |
|
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True. |
|
""" |
|
|
|
@abstractmethod |
|
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: |
|
""" |
|
Context manager to disable gradient synchronization. |
|
""" |
|
|
|
@abstractmethod |
|
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module: |
|
""" |
|
Add LoRA modules to the model passed in. Should only be called in booster.enable_lora(). |
|
""" |
|
|
|
@abstractmethod |
|
def prepare_dataloader( |
|
self, |
|
dataset: Dataset, |
|
batch_size: int, |
|
shuffle: bool = False, |
|
seed: int = 1024, |
|
drop_last: bool = False, |
|
pin_memory: bool = False, |
|
num_workers: int = 0, |
|
**kwargs, |
|
): |
|
"""Prepare a dataloader for distributed training. The dataloader will be wrapped by |
|
`torch.utils.data.DataLoader` |
|
"""
|
|
|