|
|
|
@ -4,10 +4,8 @@ import torch
|
|
|
|
|
import torch.distributed as dist |
|
|
|
|
# from colossalai.nn.layer.utils import divide |
|
|
|
|
from numpy import prod |
|
|
|
|
from packaging import version |
|
|
|
|
|
|
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
|
from colossalai.tensor.distspec import _DistSpec |
|
|
|
|
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec |
|
|
|
|
from colossalai.tensor.process_group import ProcessGroup |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -171,11 +169,21 @@ class DistSpecManager:
|
|
|
|
|
pg: ProcessGroup) -> torch.Tensor: |
|
|
|
|
assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec" |
|
|
|
|
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec" |
|
|
|
|
forward_trans_handle = getattr(DistSpecManager, f'_{old_dist_spec.placement.value}2{dist_spec.placement.value}') |
|
|
|
|
|
|
|
|
|
trans_func_key = (old_dist_spec.placement, dist_spec.placement) |
|
|
|
|
trans_funcs = { |
|
|
|
|
(DistPlacementPattern.REPLICATE, DistPlacementPattern.REPLICATE): DistSpecManager._r2r, |
|
|
|
|
(DistPlacementPattern.REPLICATE, DistPlacementPattern.SHARD): DistSpecManager._r2s, |
|
|
|
|
(DistPlacementPattern.SHARD, DistPlacementPattern.REPLICATE): DistSpecManager._s2r, |
|
|
|
|
(DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
forward_trans_handle = trans_funcs[trans_func_key] |
|
|
|
|
if not DistSpecManager._use_autograd_function: |
|
|
|
|
return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg) |
|
|
|
|
backward_trans_handle = getattr(DistSpecManager, |
|
|
|
|
f'_{dist_spec.placement.value}2{old_dist_spec.placement.value}') |
|
|
|
|
|
|
|
|
|
backward_trans_handle = trans_funcs[(dist_spec.placement, old_dist_spec.placement)] |
|
|
|
|
|
|
|
|
|
return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle, |
|
|
|
|
backward_trans_handle) |
|
|
|
|
|
|
|
|
|