mirror of https://github.com/hpcaitech/ColossalAI
[tensor] redistribute among different process groups (#1247)
* make it faster * [tensor] rename convert_to_dist -> redistribute * [tensor] ShardSpec and ReplicaSpec * [tensor] redistribute among diff pgs * polish codepull/1250/head
parent
9bcd2fd4af
commit
1aad903c15
|
@ -13,7 +13,6 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
||||||
|
|
||||||
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
|
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
|
||||||
|
|
||||||
|
|
||||||
# Output:P
|
# Output:P
|
||||||
partial_output = torch.mm(mat1, mat2)
|
partial_output = torch.mm(mat1, mat2)
|
||||||
# Reduce(Output)
|
# Reduce(Output)
|
||||||
|
|
|
@ -14,7 +14,6 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||||
sparse: bool = False) -> ColoTensor:
|
sparse: bool = False) -> ColoTensor:
|
||||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||||
# Gather splitted lookup table
|
# Gather splitted lookup table
|
||||||
|
|
||||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||||
|
|
||||||
output_parallel = F.embedding(input_tensor,
|
output_parallel = F.embedding(input_tensor,
|
||||||
|
@ -47,7 +46,6 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
||||||
# Find index in this shard and mask those not here
|
# Find index in this shard and mask those not here
|
||||||
# Reduce all
|
# Reduce all
|
||||||
pg = weight.get_process_group()
|
pg = weight.get_process_group()
|
||||||
|
|
||||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||||
|
|
||||||
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||||
|
|
|
@ -32,9 +32,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
||||||
# All-Gather(Output)
|
# All-Gather(Output)
|
||||||
# Input:B
|
# Input:B
|
||||||
compute_spec = weight.compute_spec
|
compute_spec = weight.compute_spec
|
||||||
|
|
||||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||||
|
|
||||||
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
|
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
|
||||||
|
|
||||||
output_parallel = F.linear(input_parallel, weight, bias)
|
output_parallel = F.linear(input_parallel, weight, bias)
|
||||||
|
|
|
@ -186,7 +186,28 @@ class ColoTensor(torch.Tensor):
|
||||||
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
|
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
|
||||||
self.dist_spec = dist_spec
|
self.dist_spec = dist_spec
|
||||||
|
|
||||||
def redistribute(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
|
||||||
|
"""redistribute
|
||||||
|
Redistribute the tensor among processes. The rule is like this:
|
||||||
|
1. If the pg is None, then redistributed tensor payload among TP process group. Keep the
|
||||||
|
DP process group still as replicated.
|
||||||
|
2. If the pg is not not None and not equal to the cureent process group.
|
||||||
|
First, convert the tensor as replicated among TP process group.
|
||||||
|
Second, reset the process group.
|
||||||
|
Third, conver the tensor (new replicated both among tp and dp process group) to the new dist_spec.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dist_spec (_DistSpec): the new dist spec.
|
||||||
|
pg (Optional[ProcessGroup], optional): the new process group . Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ColoTensor: a redistributed colotensor
|
||||||
|
"""
|
||||||
|
if pg is not None and pg != self.get_process_group():
|
||||||
|
print('here _redistribute')
|
||||||
|
# if the pg is not equal, convert the current tensor to replicated
|
||||||
|
self._redistribute(ReplicaSpec())
|
||||||
|
self.process_group = pg
|
||||||
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
||||||
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
|
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
|
||||||
|
|
||||||
|
@ -202,7 +223,6 @@ class ColoTensor(torch.Tensor):
|
||||||
"""
|
"""
|
||||||
return self.redistribute(ReplicaSpec())
|
return self.redistribute(ReplicaSpec())
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
||||||
tensor = tensor.as_subclass(ColoTensor)
|
tensor = tensor.as_subclass(ColoTensor)
|
||||||
|
|
|
@ -117,13 +117,13 @@ class ProcessGroup:
|
||||||
if not isinstance(obj, ProcessGroup):
|
if not isinstance(obj, ProcessGroup):
|
||||||
return False
|
return False
|
||||||
if self._rank != obj._rank:
|
if self._rank != obj._rank:
|
||||||
assert False
|
return False
|
||||||
if self._rank_list != obj._rank_list:
|
if self._rank_list != obj._rank_list:
|
||||||
assert False
|
return False
|
||||||
if self._tp_rank_list != obj._tp_rank_list:
|
if self._tp_rank_list != obj._tp_rank_list:
|
||||||
assert False
|
return False
|
||||||
if self._dp_rank_list != obj._dp_rank_list:
|
if self._dp_rank_list != obj._dp_rank_list:
|
||||||
assert False
|
return False
|
||||||
if self._tp_degree != obj._tp_degree:
|
if self._tp_degree != obj._tp_degree:
|
||||||
return False
|
return False
|
||||||
if self._dp_degree != obj._dp_degree:
|
if self._dp_degree != obj._dp_degree:
|
||||||
|
|
|
@ -164,7 +164,7 @@ def run_check_shared_param():
|
||||||
# TODO(jiaruifang) optimize this line
|
# TODO(jiaruifang) optimize this line
|
||||||
if not model.cls.predictions.bias.has_initialized:
|
if not model.cls.predictions.bias.has_initialized:
|
||||||
model.cls.predictions.bias.pg = pg
|
model.cls.predictions.bias.pg = pg
|
||||||
model.cls.predictions.bias.dist_spec = distspec.replicate()
|
model.cls.predictions.bias.dist_spec = ReplicaSpec()
|
||||||
model.cls.predictions.bias.has_initialized = True
|
model.cls.predictions.bias.has_initialized = True
|
||||||
model.cls.predictions.bias.set_tensor_spec(*col_spec)
|
model.cls.predictions.bias.set_tensor_spec(*col_spec)
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -9,7 +9,6 @@ from colossalai.utils import get_current_device
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.tensor import distspec
|
|
||||||
|
|
||||||
|
|
||||||
def _run_layer_norm():
|
def _run_layer_norm():
|
||||||
|
|
|
@ -5,7 +5,7 @@ from numpy import allclose
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.tensor import distspec, ColoTensorSpec
|
from colossalai.tensor import ColoTensorSpec
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
@ -85,7 +85,7 @@ def _run_tensor_shard_init(world_size):
|
||||||
shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
|
shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
|
||||||
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
|
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
|
||||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||||
t.set_dist_spec(distspec.replicate())
|
t.set_dist_spec(ReplicaSpec())
|
||||||
|
|
||||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
|
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
|
||||||
|
|
||||||
|
@ -102,10 +102,26 @@ def _run_tensor_replicated_init(world_size):
|
||||||
def _run_process_group(world_size):
|
def _run_process_group(world_size):
|
||||||
pg1 = ProcessGroup()
|
pg1 = ProcessGroup()
|
||||||
pg2 = ProcessGroup()
|
pg2 = ProcessGroup()
|
||||||
|
|
||||||
assert pg1 == pg2
|
assert pg1 == pg2
|
||||||
|
|
||||||
|
|
||||||
|
def _run_redistributed(world_size):
|
||||||
|
if world_size != 4:
|
||||||
|
return
|
||||||
|
pg1 = ProcessGroup(tp_degree=2, dp_degree=2)
|
||||||
|
pg2 = ProcessGroup(tp_degree=4, dp_degree=1)
|
||||||
|
|
||||||
|
spec1 = ColoTensorSpec(pg1)
|
||||||
|
t1 = ColoTensor.from_torch_tensor(torch.randn(2, 3, 4), spec1)
|
||||||
|
t1 = t1.redistribute(ShardSpec([0], [pg1.tp_world_size()]))
|
||||||
|
assert t1.is_sharded()
|
||||||
|
t1 = t1.redistribute(ShardSpec([-1], [pg2.tp_world_size()]), pg2)
|
||||||
|
assert t1.is_sharded()
|
||||||
|
pg3 = ProcessGroup(tp_degree=1, dp_degree=4)
|
||||||
|
t1 = t1.redistribute(ReplicaSpec(), pg3)
|
||||||
|
assert t1.is_replicate()
|
||||||
|
|
||||||
|
|
||||||
def run_dist_tests(rank, world_size, port):
|
def run_dist_tests(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
_run_tensor_shard_init(world_size)
|
_run_tensor_shard_init(world_size)
|
||||||
|
@ -115,6 +131,7 @@ def run_dist_tests(rank, world_size, port):
|
||||||
_run_tensor_indexing()
|
_run_tensor_indexing()
|
||||||
_run_operand(world_size)
|
_run_operand(world_size)
|
||||||
_run_wrapped_tensor_func()
|
_run_wrapped_tensor_func()
|
||||||
|
_run_redistributed(world_size)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -126,4 +143,4 @@ def test_dist_cases(world_size):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_dist_cases(1)
|
test_dist_cases(4)
|
||||||
|
|
Loading…
Reference in New Issue