mirror of https://github.com/hpcaitech/ColossalAI
[ColoTensor] add independent process group (#1179)
parent
26ba87272d
commit
7487215b95
|
@ -7,8 +7,10 @@ from .dist_spec_mgr import DistSpecManager
|
|||
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
||||
from .chunk import ChunkManager, TensorState
|
||||
from . import distspec
|
||||
from .process_group import ProcessGroup
|
||||
|
||||
__all__ = [
|
||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor',
|
||||
'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState'
|
||||
'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState',
|
||||
'ProcessGroup'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
import torch
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class ProcessGroup:
|
||||
"""
|
||||
Process Group contains group partition for Tensor Parallel and Data Parallel.
|
||||
WARNING, the ProcessGroup must be used after torch.distributed.initialize()
|
||||
args:
|
||||
rank: the global rank of the current process.
|
||||
ranks: List[int], a list of rank id belongings to this process group.
|
||||
backend: str, the backend of the process group.
|
||||
tp_degree: Optional[int], tensor parallelism degree, default None means 1
|
||||
dp_degree: Optional[int], data parallelism degree, default None means len(ranks)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
rank: int,
|
||||
ranks: List[int],
|
||||
backend: str = 'nccl',
|
||||
tp_degree: Optional[int] = None,
|
||||
dp_degree: Optional[int] = None) -> None:
|
||||
self._rank = rank
|
||||
self._rank_list = ranks
|
||||
self._backend = backend
|
||||
self._world_size = len(self._rank_list)
|
||||
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
|
||||
|
||||
if dp_degree is None and tp_degree is None:
|
||||
self._dp_degree = self._world_size
|
||||
self._tp_degree = 1
|
||||
|
||||
if dp_degree and not tp_degree:
|
||||
self._dp_degree = dp_degree
|
||||
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
|
||||
self._tp_degree = self._world_size / dp_degree
|
||||
|
||||
if not dp_degree and tp_degree:
|
||||
self._tp_degree = tp_degree
|
||||
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
|
||||
self._dp_degree = self._world_size / tp_degree
|
||||
|
||||
self._tp_rank_list = []
|
||||
self._dp_rank_list = []
|
||||
|
||||
for rank_id in range(self._world_size):
|
||||
# rank_id and self._rank in the same tp group
|
||||
if rank_id % self._tp_degree == self._rank % self._tp_degree:
|
||||
self._dp_rank_list.append(rank_id)
|
||||
if rank_id // self._tp_degree == self._rank // self._tp_degree:
|
||||
self._tp_rank_list.append(rank_id)
|
||||
|
||||
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend=backend)
|
||||
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend=backend)
|
||||
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
def dp_world_size(self):
|
||||
return len(self._dp_rank_list)
|
||||
|
||||
def tp_world_size(self):
|
||||
return len(self._tp_rank_list)
|
||||
|
||||
def dp_process_group(self):
|
||||
return self._dp_process_group
|
||||
|
||||
def tp_process_group(self):
|
||||
return self._tp_process_group
|
|
@ -10,7 +10,7 @@ from colossalai.utils.cuda import get_current_device
|
|||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
|
||||
ComputeSpec, ColoTensor, DistSpecManager
|
||||
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
|
@ -18,34 +18,30 @@ from functools import partial
|
|||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
|
||||
|
||||
def init_1d_row_linear(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_row_linear(weight, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_linear(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col_linear(weight, pg):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_row_embedding(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_row_embedding(weight, pg):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_embedding(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col_embedding(weight, pg):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
@ -69,6 +65,9 @@ def run_1d_hybrid_tp(model_name):
|
|||
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
|
||||
p2.data.copy_(p1.data)
|
||||
|
||||
rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
if 'bert' == model_name:
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
|
@ -76,29 +75,29 @@ def run_1d_hybrid_tp(model_name):
|
|||
# print(name)
|
||||
# num_class = type_vocab_size = 2 | (8, 2)
|
||||
if 'classifier' in name and 'weight' in name:
|
||||
init_1d_row_linear(p)
|
||||
init_1d_row_linear(p, pg)
|
||||
# num_class = vocab_size = 30524 | (30524, 8)
|
||||
if 'word_embeddings' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p)
|
||||
init_1d_row_embedding(p, pg)
|
||||
# num_class = seq_len = 512 | (512, 8)
|
||||
if 'position_embeddings' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p)
|
||||
init_1d_row_embedding(p, pg)
|
||||
# num_class = type_vocab_size = 2 | (2, 8)
|
||||
if 'token_type_embeddings' in name and 'weight' in name:
|
||||
init_1d_col_embedding(p)
|
||||
init_1d_col_embedding(p, pg)
|
||||
elif "simple_net" == model_name:
|
||||
# A naive way to set spec for all weights in Linear
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
if 'embed' in name and 'weight' in name:
|
||||
init_1d_col_embedding(p)
|
||||
init_1d_col_embedding(p, pg)
|
||||
if 'proj1' in name and ('weight' in name or 'bias' in name):
|
||||
init_1d_col_linear(p)
|
||||
init_1d_col_linear(p, pg)
|
||||
if 'proj2' in name and 'weight' in name:
|
||||
init_1d_row_linear(p)
|
||||
init_1d_row_linear(p, pg)
|
||||
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
||||
init_1d_col_linear(p)
|
||||
init_1d_col_linear(p, pg)
|
||||
|
||||
model = model.cuda()
|
||||
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||
|
@ -112,8 +111,8 @@ def run_1d_hybrid_tp(model_name):
|
|||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
||||
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
|
@ -221,6 +220,10 @@ def run_1d_row_tp(model_name: str):
|
|||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
|
||||
set_seed(1)
|
||||
if rank == 0:
|
||||
model_torch = model_builder(checkpoint=True)
|
||||
|
@ -230,9 +233,9 @@ def run_1d_row_tp(model_name: str):
|
|||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
|
||||
init_1d_row_linear(p)
|
||||
init_1d_row_linear(p, pg)
|
||||
if 'embed' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p)
|
||||
init_1d_row_embedding(p, pg)
|
||||
|
||||
model = model.cuda()
|
||||
|
||||
|
@ -330,10 +333,11 @@ def run_pretrain_load_dist(rank, world_size, port):
|
|||
|
||||
# The test case has to download huggingface pretrained models from the internet
|
||||
# So we manually trigger the test.
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def _test_pretrain_load(world_size):
|
||||
def test_pretrain_load(world_size):
|
||||
run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
@ -342,4 +346,4 @@ if __name__ == '__main__':
|
|||
# test_model_parameters()
|
||||
# test_colo_optimizer()
|
||||
# test_model(4)
|
||||
_test_pretrain_load(4)
|
||||
test_pretrain_load(4)
|
||||
|
|
|
@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
|
|||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, TensorSpec, ColoTensor
|
||||
from colossalai.tensor import distspec, TensorSpec, ColoTensor, ProcessGroup
|
||||
from colossalai.context import ParallelMode
|
||||
from functools import partial
|
||||
|
||||
|
@ -21,14 +21,6 @@ def test_tensor_indexing():
|
|||
assert allclose(torch_t[:, 1], colo_t[:, 1])
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
# FIXME(ver217): support lazy init
|
||||
def test_lazy_init_tensor():
|
||||
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
|
||||
assert lazy_t._torch_tensor.numel() == 0
|
||||
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
|
||||
|
||||
|
||||
def test_wrapped_tensor_func():
|
||||
t_ref = torch.randn(4, 5)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone())
|
||||
|
@ -62,10 +54,12 @@ def test_operand():
|
|||
|
||||
def _run_view(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
rank = gpc.get_global_rank()
|
||||
pg = ProcessGroup(rank, list(range(world_size)))
|
||||
assert pg.dp_world_size() == world_size, f"{pg.dp_world_size()} vs {world_size}"
|
||||
t = ColoTensor.from_torch_tensor(
|
||||
t_ref,
|
||||
TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0],
|
||||
num_partitions=[world_size])))
|
||||
TensorSpec(distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])))
|
||||
|
||||
assert t.size_global()[0] == 4 * world_size
|
||||
assert t.size_global(1) == 5
|
||||
|
@ -81,8 +75,10 @@ def _run_view(world_size):
|
|||
|
||||
def _run_tensor_shard_init(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
print(gpc.get_group(ParallelMode.DATA).size())
|
||||
shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size])
|
||||
|
||||
rank = gpc.get_global_rank()
|
||||
pg = ProcessGroup(rank, list(range(world_size)))
|
||||
shard_spec = distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])
|
||||
tensor_spec = TensorSpec(shard_spec)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
|
||||
|
|
Loading…
Reference in New Issue