mirror of https://github.com/hpcaitech/ColossalAI
[tensor] redistribute among different process groups (#1247)
* make it faster * [tensor] rename convert_to_dist -> redistribute * [tensor] ShardSpec and ReplicaSpec * [tensor] redistribute among diff pgs * polish codepull/1250/head
parent
9bcd2fd4af
commit
1aad903c15
|
@ -13,7 +13,6 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
|||
|
||||
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
|
||||
|
||||
|
||||
# Output:P
|
||||
partial_output = torch.mm(mat1, mat2)
|
||||
# Reduce(Output)
|
||||
|
|
|
@ -14,7 +14,6 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
|||
sparse: bool = False) -> ColoTensor:
|
||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
output_parallel = F.embedding(input_tensor,
|
||||
|
@ -47,7 +46,6 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|||
# Find index in this shard and mask those not here
|
||||
# Reduce all
|
||||
pg = weight.get_process_group()
|
||||
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
|
|
@ -32,9 +32,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||
# All-Gather(Output)
|
||||
# Input:B
|
||||
compute_spec = weight.compute_spec
|
||||
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
|
||||
|
||||
output_parallel = F.linear(input_parallel, weight, bias)
|
||||
|
|
|
@ -186,7 +186,28 @@ class ColoTensor(torch.Tensor):
|
|||
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
def redistribute(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
|
||||
"""redistribute
|
||||
Redistribute the tensor among processes. The rule is like this:
|
||||
1. If the pg is None, then redistributed tensor payload among TP process group. Keep the
|
||||
DP process group still as replicated.
|
||||
2. If the pg is not not None and not equal to the cureent process group.
|
||||
First, convert the tensor as replicated among TP process group.
|
||||
Second, reset the process group.
|
||||
Third, conver the tensor (new replicated both among tp and dp process group) to the new dist_spec.
|
||||
|
||||
Args:
|
||||
dist_spec (_DistSpec): the new dist spec.
|
||||
pg (Optional[ProcessGroup], optional): the new process group . Defaults to None.
|
||||
|
||||
Returns:
|
||||
ColoTensor: a redistributed colotensor
|
||||
"""
|
||||
if pg is not None and pg != self.get_process_group():
|
||||
print('here _redistribute')
|
||||
# if the pg is not equal, convert the current tensor to replicated
|
||||
self._redistribute(ReplicaSpec())
|
||||
self.process_group = pg
|
||||
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
||||
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
|
||||
|
||||
|
@ -202,7 +223,6 @@ class ColoTensor(torch.Tensor):
|
|||
"""
|
||||
return self.redistribute(ReplicaSpec())
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
||||
tensor = tensor.as_subclass(ColoTensor)
|
||||
|
|
|
@ -117,13 +117,13 @@ class ProcessGroup:
|
|||
if not isinstance(obj, ProcessGroup):
|
||||
return False
|
||||
if self._rank != obj._rank:
|
||||
assert False
|
||||
return False
|
||||
if self._rank_list != obj._rank_list:
|
||||
assert False
|
||||
return False
|
||||
if self._tp_rank_list != obj._tp_rank_list:
|
||||
assert False
|
||||
return False
|
||||
if self._dp_rank_list != obj._dp_rank_list:
|
||||
assert False
|
||||
return False
|
||||
if self._tp_degree != obj._tp_degree:
|
||||
return False
|
||||
if self._dp_degree != obj._dp_degree:
|
||||
|
|
|
@ -164,7 +164,7 @@ def run_check_shared_param():
|
|||
# TODO(jiaruifang) optimize this line
|
||||
if not model.cls.predictions.bias.has_initialized:
|
||||
model.cls.predictions.bias.pg = pg
|
||||
model.cls.predictions.bias.dist_spec = distspec.replicate()
|
||||
model.cls.predictions.bias.dist_spec = ReplicaSpec()
|
||||
model.cls.predictions.bias.has_initialized = True
|
||||
model.cls.predictions.bias.set_tensor_spec(*col_spec)
|
||||
try:
|
||||
|
|
|
@ -9,7 +9,6 @@ from colossalai.utils import get_current_device
|
|||
from torch.nn import Parameter
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec
|
||||
|
||||
|
||||
def _run_layer_norm():
|
||||
|
|
|
@ -5,7 +5,7 @@ from numpy import allclose
|
|||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, ColoTensorSpec
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.core import global_context as gpc
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
|
@ -85,7 +85,7 @@ def _run_tensor_shard_init(world_size):
|
|||
shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
|
||||
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
t.set_dist_spec(distspec.replicate())
|
||||
t.set_dist_spec(ReplicaSpec())
|
||||
|
||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
|
||||
|
||||
|
@ -102,10 +102,26 @@ def _run_tensor_replicated_init(world_size):
|
|||
def _run_process_group(world_size):
|
||||
pg1 = ProcessGroup()
|
||||
pg2 = ProcessGroup()
|
||||
|
||||
assert pg1 == pg2
|
||||
|
||||
|
||||
def _run_redistributed(world_size):
|
||||
if world_size != 4:
|
||||
return
|
||||
pg1 = ProcessGroup(tp_degree=2, dp_degree=2)
|
||||
pg2 = ProcessGroup(tp_degree=4, dp_degree=1)
|
||||
|
||||
spec1 = ColoTensorSpec(pg1)
|
||||
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
|
||||
t1 = t1.redistribute(ShardSpec([0], [pg1.tp_world_size()]))
|
||||
assert t1.is_sharded()
|
||||
t1 = t1.redistribute(ShardSpec([-1], [pg2.tp_world_size()]), pg2)
|
||||
assert t1.is_sharded()
|
||||
pg3 = ProcessGroup(tp_degree=1, dp_degree=4)
|
||||
t1 = t1.redistribute(ReplicaSpec(), pg3)
|
||||
assert t1.is_replicate()
|
||||
|
||||
|
||||
def run_dist_tests(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_tensor_shard_init(world_size)
|
||||
|
@ -115,6 +131,7 @@ def run_dist_tests(rank, world_size, port):
|
|||
_run_tensor_indexing()
|
||||
_run_operand(world_size)
|
||||
_run_wrapped_tensor_func()
|
||||
_run_redistributed(world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@ -126,4 +143,4 @@ def test_dist_cases(world_size):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dist_cases(1)
|
||||
test_dist_cases(4)
|
||||
|
|
Loading…
Reference in New Issue