from typing import Dict, List from colossalai.legacy.tensor import ComputePattern from colossalai.legacy.tensor.distspec import _DistSpec class ColoModule(object): def __init__(self): self._shard_params: List[str] = [] self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {} def _register_shard_params(self, params: List[str]): self._shard_params = params def _register_allowed_patterns( self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], mode="default" ): assert ( list(dist_specs.keys()).sort() == self._shard_params.sort() ), "Every registered param should have dist_spec." if not compute_pattern in self._allowed_patterns: self._allowed_patterns[compute_pattern] = {} self._allowed_patterns[compute_pattern][mode] = dist_specs def _set_default(self, compute_pattern: ComputePattern, target_mode): self._allowed_patterns[compute_pattern]["default"] = self._allowed_patterns[compute_pattern][target_mode] def has_compute_pattern(self, compute_pattern: ComputePattern): return compute_pattern in self._allowed_patterns def get_dist_specs(self, compute_pattern: ComputePattern): assert self.has_compute_pattern(compute_pattern) return self._allowed_patterns[compute_pattern] def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode="default"): return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern] def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode="default"): assert self.has_compute_pattern_with_mode(compute_pattern, mode) return self._allowed_patterns[compute_pattern][mode] def get_param_names(self): return self._shard_params def register(self, compute_pattern, pg): raise NotImplementedError