[ColoTensor] add independent process group (#1179)

pull/1182/head
Jiarui Fang 2022-06-29 10:03:09 +08:00 committed by GitHub
parent 26ba87272d
commit 7487215b95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 116 additions and 45 deletions

View File

@ -7,8 +7,10 @@ from .dist_spec_mgr import DistSpecManager
from .param_op_hook import ParamOpHook, ParamOpHookManager
from .chunk import ChunkManager, TensorState
from . import distspec
from .process_group import ProcessGroup
__all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor',
'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState'
'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState',
'ProcessGroup'
]

View File

@ -0,0 +1,69 @@
import torch
from typing import List, Optional
class ProcessGroup:
"""
Process Group contains group partition for Tensor Parallel and Data Parallel.
WARNING, the ProcessGroup must be used after torch.distributed.initialize()
args:
rank: the global rank of the current process.
ranks: List[int], a list of rank id belongings to this process group.
backend: str, the backend of the process group.
tp_degree: Optional[int], tensor parallelism degree, default None means 1
dp_degree: Optional[int], data parallelism degree, default None means len(ranks)
"""
def __init__(self,
rank: int,
ranks: List[int],
backend: str = 'nccl',
tp_degree: Optional[int] = None,
dp_degree: Optional[int] = None) -> None:
self._rank = rank
self._rank_list = ranks
self._backend = backend
self._world_size = len(self._rank_list)
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
if dp_degree is None and tp_degree is None:
self._dp_degree = self._world_size
self._tp_degree = 1
if dp_degree and not tp_degree:
self._dp_degree = dp_degree
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
self._tp_degree = self._world_size / dp_degree
if not dp_degree and tp_degree:
self._tp_degree = tp_degree
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
self._dp_degree = self._world_size / tp_degree
self._tp_rank_list = []
self._dp_rank_list = []
for rank_id in range(self._world_size):
# rank_id and self._rank in the same tp group
if rank_id % self._tp_degree == self._rank % self._tp_degree:
self._dp_rank_list.append(rank_id)
if rank_id // self._tp_degree == self._rank // self._tp_degree:
self._tp_rank_list.append(rank_id)
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend=backend)
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend=backend)
def world_size(self):
return self._world_size
def dp_world_size(self):
return len(self._dp_rank_list)
def tp_world_size(self):
return len(self._tp_rank_list)
def dp_process_group(self):
return self._dp_process_group
def tp_process_group(self):
return self._tp_process_group

View File

@ -10,7 +10,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
ComputeSpec, ColoTensor, DistSpecManager
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColoOptimizer
@ -18,34 +18,30 @@ from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
def init_1d_row_linear(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
def init_1d_row_linear(weight, pg: ProcessGroup):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(spec)
def init_1d_col_linear(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
def init_1d_col_linear(weight, pg):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(spec)
def init_1d_row_embedding(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
def init_1d_row_embedding(weight, pg):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(spec)
def init_1d_col_embedding(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
def init_1d_col_embedding(weight, pg):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(spec)
@ -69,6 +65,9 @@ def run_1d_hybrid_tp(model_name):
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
p2.data.copy_(p1.data)
rank = gpc.get_local_rank(ParallelMode.GLOBAL)
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
if 'bert' == model_name:
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
@ -76,29 +75,29 @@ def run_1d_hybrid_tp(model_name):
# print(name)
# num_class = type_vocab_size = 2 | (8, 2)
if 'classifier' in name and 'weight' in name:
init_1d_row_linear(p)
init_1d_row_linear(p, pg)
# num_class = vocab_size = 30524 | (30524, 8)
if 'word_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p)
init_1d_row_embedding(p, pg)
# num_class = seq_len = 512 | (512, 8)
if 'position_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p)
init_1d_row_embedding(p, pg)
# num_class = type_vocab_size = 2 | (2, 8)
if 'token_type_embeddings' in name and 'weight' in name:
init_1d_col_embedding(p)
init_1d_col_embedding(p, pg)
elif "simple_net" == model_name:
# A naive way to set spec for all weights in Linear
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
continue
if 'embed' in name and 'weight' in name:
init_1d_col_embedding(p)
init_1d_col_embedding(p, pg)
if 'proj1' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p)
init_1d_col_linear(p, pg)
if 'proj2' in name and 'weight' in name:
init_1d_row_linear(p)
init_1d_row_linear(p, pg)
if 'classifier' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p)
init_1d_col_linear(p, pg)
model = model.cuda()
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
@ -112,8 +111,8 @@ def run_1d_hybrid_tp(model_name):
data = data.to(get_current_device())
label = label.to(get_current_device())
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
# Bcast rank0 data to all processes
if criterion:
output = model(data)
@ -221,6 +220,10 @@ def run_1d_row_tp(model_name: str):
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
rank = gpc.get_local_rank(ParallelMode.GLOBAL)
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
set_seed(1)
if rank == 0:
model_torch = model_builder(checkpoint=True)
@ -230,9 +233,9 @@ def run_1d_row_tp(model_name: str):
if not isinstance(p, ColoTensor):
continue
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
init_1d_row_linear(p)
init_1d_row_linear(p, pg)
if 'embed' in name and 'weight' in name:
init_1d_row_embedding(p)
init_1d_row_embedding(p, pg)
model = model.cuda()
@ -330,10 +333,11 @@ def run_pretrain_load_dist(rank, world_size, port):
# The test case has to download huggingface pretrained models from the internet
# So we manually trigger the test.
@pytest.mark.skip
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def _test_pretrain_load(world_size):
def test_pretrain_load(world_size):
run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
@ -342,4 +346,4 @@ if __name__ == '__main__':
# test_model_parameters()
# test_colo_optimizer()
# test_model(4)
_test_pretrain_load(4)
test_pretrain_load(4)

View File

@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import distspec, TensorSpec, ColoTensor
from colossalai.tensor import distspec, TensorSpec, ColoTensor, ProcessGroup
from colossalai.context import ParallelMode
from functools import partial
@ -21,14 +21,6 @@ def test_tensor_indexing():
assert allclose(torch_t[:, 1], colo_t[:, 1])
@pytest.mark.skip
# FIXME(ver217): support lazy init
def test_lazy_init_tensor():
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor.numel() == 0
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
def test_wrapped_tensor_func():
t_ref = torch.randn(4, 5)
t = ColoTensor.from_torch_tensor(t_ref.clone())
@ -62,10 +54,12 @@ def test_operand():
def _run_view(world_size):
t_ref = torch.randn(4, 5)
rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size)))
assert pg.dp_world_size() == world_size, f"{pg.dp_world_size()} vs {world_size}"
t = ColoTensor.from_torch_tensor(
t_ref,
TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0],
num_partitions=[world_size])))
TensorSpec(distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])))
assert t.size_global()[0] == 4 * world_size
assert t.size_global(1) == 5
@ -81,8 +75,10 @@ def _run_view(world_size):
def _run_tensor_shard_init(world_size):
t_ref = torch.randn(4, 5)
print(gpc.get_group(ParallelMode.DATA).size())
shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size])
rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size)))
shard_spec = distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])
tensor_spec = TensorSpec(shard_spec)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))