[tensor] Refactor handle_trans_spec in DistSpecManager

pull/3699/head
YH 2023-05-06 18:55:37 +09:00 committed by GitHub
parent 2da5d81dec
commit 2629f9717d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 6 deletions

View File

@ -4,10 +4,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
# from colossalai.nn.layer.utils import divide # from colossalai.nn.layer.utils import divide
from numpy import prod from numpy import prod
from packaging import version
from colossalai.logging import get_dist_logger from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
from colossalai.tensor.distspec import _DistSpec
from colossalai.tensor.process_group import ProcessGroup from colossalai.tensor.process_group import ProcessGroup
@ -171,11 +169,21 @@ class DistSpecManager:
pg: ProcessGroup) -> torch.Tensor: pg: ProcessGroup) -> torch.Tensor:
assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec" assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec"
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec" assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec"
forward_trans_handle = getattr(DistSpecManager, f'_{old_dist_spec.placement.value}2{dist_spec.placement.value}')
trans_func_key = (old_dist_spec.placement, dist_spec.placement)
trans_funcs = {
(DistPlacementPattern.REPLICATE, DistPlacementPattern.REPLICATE): DistSpecManager._r2r,
(DistPlacementPattern.REPLICATE, DistPlacementPattern.SHARD): DistSpecManager._r2s,
(DistPlacementPattern.SHARD, DistPlacementPattern.REPLICATE): DistSpecManager._s2r,
(DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s
}
forward_trans_handle = trans_funcs[trans_func_key]
if not DistSpecManager._use_autograd_function: if not DistSpecManager._use_autograd_function:
return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg) return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg)
backward_trans_handle = getattr(DistSpecManager,
f'_{dist_spec.placement.value}2{old_dist_spec.placement.value}') backward_trans_handle = trans_funcs[(dist_spec.placement, old_dist_spec.placement)]
return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle, return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle,
backward_trans_handle) backward_trans_handle)