2023-09-11 08:24:28 +00:00
|
|
|
from typing import Dict, List
|
|
|
|
|
2023-09-18 08:31:06 +00:00
|
|
|
from colossalai.legacy.tensor import ComputePattern
|
|
|
|
from colossalai.legacy.tensor.distspec import _DistSpec
|
2022-05-26 03:50:44 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ColoModule(object):
|
|
|
|
def __init__(self):
|
|
|
|
self._shard_params: List[str] = []
|
|
|
|
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
|
|
|
|
|
|
|
|
def _register_shard_params(self, params: List[str]):
|
|
|
|
self._shard_params = params
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def _register_allowed_patterns(
|
|
|
|
self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], mode="default"
|
|
|
|
):
|
|
|
|
assert (
|
|
|
|
list(dist_specs.keys()).sort() == self._shard_params.sort()
|
|
|
|
), "Every registered param should have dist_spec."
|
2022-05-26 03:50:44 +00:00
|
|
|
if not compute_pattern in self._allowed_patterns:
|
|
|
|
self._allowed_patterns[compute_pattern] = {}
|
2022-05-26 10:15:42 +00:00
|
|
|
self._allowed_patterns[compute_pattern][mode] = dist_specs
|
2022-05-26 03:50:44 +00:00
|
|
|
|
2022-05-26 10:15:42 +00:00
|
|
|
def _set_default(self, compute_pattern: ComputePattern, target_mode):
|
2023-09-19 06:20:26 +00:00
|
|
|
self._allowed_patterns[compute_pattern]["default"] = self._allowed_patterns[compute_pattern][target_mode]
|
2022-06-03 10:04:22 +00:00
|
|
|
|
2022-05-26 03:50:44 +00:00
|
|
|
def has_compute_pattern(self, compute_pattern: ComputePattern):
|
|
|
|
return compute_pattern in self._allowed_patterns
|
2022-06-03 10:04:22 +00:00
|
|
|
|
2022-05-26 03:50:44 +00:00
|
|
|
def get_dist_specs(self, compute_pattern: ComputePattern):
|
|
|
|
assert self.has_compute_pattern(compute_pattern)
|
|
|
|
return self._allowed_patterns[compute_pattern]
|
2022-06-03 10:04:22 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode="default"):
|
2022-05-26 10:15:42 +00:00
|
|
|
return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern]
|
2022-06-03 10:04:22 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode="default"):
|
2022-05-26 10:15:42 +00:00
|
|
|
assert self.has_compute_pattern_with_mode(compute_pattern, mode)
|
|
|
|
return self._allowed_patterns[compute_pattern][mode]
|
2022-05-26 03:50:44 +00:00
|
|
|
|
|
|
|
def get_param_names(self):
|
|
|
|
return self._shard_params
|
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
def register(self, compute_pattern, pg):
|
2022-06-03 10:04:22 +00:00
|
|
|
raise NotImplementedError
|