[tensor] redistribute among different process groups (#1247)

* make it faster

* [tensor] rename convert_to_dist -> redistribute

* [tensor] ShardSpec and ReplicaSpec

* [tensor] redistribute among diff pgs

* polish code
pull/1250/head
Jiarui Fang 2022-07-12 10:24:05 +08:00 committed by GitHub
parent 9bcd2fd4af
commit 1aad903c15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 48 additions and 17 deletions

View File

@ -13,7 +13,6 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
# Output:P
partial_output = torch.mm(mat1, mat2)
# Reduce(Output)

View File

@ -14,7 +14,6 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor = input_tensor.redistribute(ReplicaSpec())
output_parallel = F.embedding(input_tensor,
@ -47,7 +46,6 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Find index in this shard and mask those not here
# Reduce all
pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(ReplicaSpec())
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)

View File

@ -32,9 +32,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Gather(Output)
# Input:B
compute_spec = weight.compute_spec
input_tensor = input_tensor.redistribute(ReplicaSpec())
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
output_parallel = F.linear(input_parallel, weight, bias)

View File

@ -186,7 +186,28 @@ class ColoTensor(torch.Tensor):
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
self.dist_spec = dist_spec
def redistribute(self, dist_spec: _DistSpec) -> 'ColoTensor':
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
"""redistribute
Redistribute the tensor among processes. The rule is like this:
1. If the pg is None, then redistributed tensor payload among TP process group. Keep the
DP process group still as replicated.
2. If the pg is not not None and not equal to the cureent process group.
First, convert the tensor as replicated among TP process group.
Second, reset the process group.
Third, conver the tensor (new replicated both among tp and dp process group) to the new dist_spec.
Args:
dist_spec (_DistSpec): the new dist spec.
pg (Optional[ProcessGroup], optional): the new process group . Defaults to None.
Returns:
ColoTensor: a redistributed colotensor
"""
if pg is not None and pg != self.get_process_group():
print('here _redistribute')
# if the pg is not equal, convert the current tensor to replicated
self._redistribute(ReplicaSpec())
self.process_group = pg
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
@ -202,7 +223,6 @@ class ColoTensor(torch.Tensor):
"""
return self.redistribute(ReplicaSpec())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
tensor = tensor.as_subclass(ColoTensor)

View File

@ -117,13 +117,13 @@ class ProcessGroup:
if not isinstance(obj, ProcessGroup):
return False
if self._rank != obj._rank:
assert False
return False
if self._rank_list != obj._rank_list:
assert False
return False
if self._tp_rank_list != obj._tp_rank_list:
assert False
return False
if self._dp_rank_list != obj._dp_rank_list:
assert False
return False
if self._tp_degree != obj._tp_degree:
return False
if self._dp_degree != obj._dp_degree:

View File

@ -164,7 +164,7 @@ def run_check_shared_param():
# TODO(jiaruifang) optimize this line
if not model.cls.predictions.bias.has_initialized:
model.cls.predictions.bias.pg = pg
model.cls.predictions.bias.dist_spec = distspec.replicate()
model.cls.predictions.bias.dist_spec = ReplicaSpec()
model.cls.predictions.bias.has_initialized = True
model.cls.predictions.bias.set_tensor_spec(*col_spec)
try:

View File

@ -9,7 +9,6 @@ from colossalai.utils import get_current_device
from torch.nn import Parameter
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import distspec
def _run_layer_norm():

View File

@ -5,7 +5,7 @@ from numpy import allclose
import colossalai
from colossalai.utils import free_port
from colossalai.tensor import distspec, ColoTensorSpec
from colossalai.tensor import ColoTensorSpec
from colossalai.core import global_context as gpc
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
@ -85,7 +85,7 @@ def _run_tensor_shard_init(world_size):
shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_dist_spec(distspec.replicate())
t.set_dist_spec(ReplicaSpec())
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
@ -102,10 +102,26 @@ def _run_tensor_replicated_init(world_size):
def _run_process_group(world_size):
pg1 = ProcessGroup()
pg2 = ProcessGroup()
assert pg1 == pg2
def _run_redistributed(world_size):
if world_size != 4:
return
pg1 = ProcessGroup(tp_degree=2, dp_degree=2)
pg2 = ProcessGroup(tp_degree=4, dp_degree=1)
spec1 = ColoTensorSpec(pg1)
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
t1 = t1.redistribute(ShardSpec([0], [pg1.tp_world_size()]))
assert t1.is_sharded()
t1 = t1.redistribute(ShardSpec([-1], [pg2.tp_world_size()]), pg2)
assert t1.is_sharded()
pg3 = ProcessGroup(tp_degree=1, dp_degree=4)
t1 = t1.redistribute(ReplicaSpec(), pg3)
assert t1.is_replicate()
def run_dist_tests(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_tensor_shard_init(world_size)
@ -115,6 +131,7 @@ def run_dist_tests(rank, world_size, port):
_run_tensor_indexing()
_run_operand(world_size)
_run_wrapped_tensor_func()
_run_redistributed(world_size)
@pytest.mark.dist
@ -126,4 +143,4 @@ def test_dist_cases(world_size):
if __name__ == '__main__':
test_dist_cases(1)
test_dist_cases(4)