2023-09-19 06:20:26 +00:00
|
|
|
from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec
|
2023-09-11 08:24:28 +00:00
|
|
|
|
2022-05-26 10:15:42 +00:00
|
|
|
from .colo_module import ColoModule
|
|
|
|
|
2022-06-03 10:04:22 +00:00
|
|
|
|
2022-05-26 10:15:42 +00:00
|
|
|
class ColoEmbedding(ColoModule):
|
|
|
|
def __init__(self):
|
|
|
|
super(ColoEmbedding, self).__init__()
|
2023-09-19 06:20:26 +00:00
|
|
|
self._register_shard_params(["weight"])
|
2022-06-03 10:04:22 +00:00
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
def register(self, compute_pattern, pg: ProcessGroup):
|
2022-05-26 10:15:42 +00:00
|
|
|
if not compute_pattern in self._allowed_patterns:
|
|
|
|
if ComputePattern.TP1D == compute_pattern:
|
2022-07-04 10:54:37 +00:00
|
|
|
self._set_TP1D(pg)
|
2022-06-03 10:04:22 +00:00
|
|
|
|
2022-07-04 10:54:37 +00:00
|
|
|
def _set_TP1D(self, pg: ProcessGroup):
|
2022-05-26 10:15:42 +00:00
|
|
|
# TP1D Row Linear
|
|
|
|
_compute_pattern = ComputePattern.TP1D
|
|
|
|
self._register_allowed_patterns(
|
|
|
|
compute_pattern=_compute_pattern,
|
|
|
|
dist_specs={
|
2023-09-19 06:20:26 +00:00
|
|
|
"weight": ShardSpec([0], [pg.tp_world_size()]),
|
2022-05-26 10:15:42 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
mode="row",
|
2022-05-26 10:15:42 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# TP1D Col Linear
|
|
|
|
self._register_allowed_patterns(
|
|
|
|
compute_pattern=_compute_pattern,
|
|
|
|
dist_specs={
|
2023-09-19 06:20:26 +00:00
|
|
|
"weight": ShardSpec([-1], [pg.tp_world_size()]),
|
2022-05-26 10:15:42 +00:00
|
|
|
},
|
2023-09-19 06:20:26 +00:00
|
|
|
mode="col",
|
2022-05-26 10:15:42 +00:00
|
|
|
)
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
self._set_default(compute_pattern=_compute_pattern, target_mode="row")
|