ColossalAI/colossalai/pipeline/policy/__init__.py

23 lines
730 B
Python
Raw Normal View History

from typing import Any, Dict, List, Optional, Tuple, Type
from torch import Tensor
from torch.nn import Module, Parameter
from colossalai.pipeline.stage_manager import PipelineStageManager
from .base import Policy
from .bert import BertModel, BertModelPolicy
POLICY_MAP: Dict[Type[Module], Type[Policy]] = {
BertModel: BertModelPolicy,
}
def pipeline_parallelize(
model: Module,
stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
if type(model) not in POLICY_MAP:
raise NotImplementedError(f"Policy for {type(model)} not implemented")
policy = POLICY_MAP[type(model)](stage_manager)
return policy.parallelize_model(model)