From 7487215b9569bf9a7b36a49fd3fef47e22b07e89 Mon Sep 17 00:00:00 2001
From: Jiarui Fang <fangjiarui123@gmail.com>
Date: Wed, 29 Jun 2022 10:03:09 +0800
Subject: [PATCH] [ColoTensor] add independent process group (#1179)

---
 colossalai/tensor/__init__.py      |  4 +-
 colossalai/tensor/process_group.py | 69 ++++++++++++++++++++++++++++++
 tests/test_tensor/test_model.py    | 66 ++++++++++++++--------------
 tests/test_tensor/test_tensor.py   | 22 ++++------
 4 files changed, 116 insertions(+), 45 deletions(-)
 create mode 100644 colossalai/tensor/process_group.py

diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py
index f830731a0..b71db453d 100644
--- a/colossalai/tensor/__init__.py
+++ b/colossalai/tensor/__init__.py
@@ -7,8 +7,10 @@ from .dist_spec_mgr import DistSpecManager
 from .param_op_hook import ParamOpHook, ParamOpHookManager
 from .chunk import ChunkManager, TensorState
 from . import distspec
+from .process_group import ProcessGroup
 
 __all__ = [
     'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor',
-    'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState'
+    'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState',
+    'ProcessGroup'
 ]
diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py
new file mode 100644
index 000000000..c60e8127e
--- /dev/null
+++ b/colossalai/tensor/process_group.py
@@ -0,0 +1,69 @@
+import torch
+from typing import List, Optional
+
+
+class ProcessGroup:
+    """
+    Process Group contains group partition for Tensor Parallel and Data Parallel.
+    WARNING, the ProcessGroup must be used after torch.distributed.initialize()
+    args:
+        rank: the global rank of the current process.
+        ranks: List[int], a list of rank id belongings to this process group.
+        backend: str, the backend of the process group.
+        tp_degree: Optional[int], tensor parallelism degree, default None means 1
+        dp_degree: Optional[int], data parallelism degree, default None means len(ranks)
+    """
+
+    def __init__(self,
+                 rank: int,
+                 ranks: List[int],
+                 backend: str = 'nccl',
+                 tp_degree: Optional[int] = None,
+                 dp_degree: Optional[int] = None) -> None:
+        self._rank = rank
+        self._rank_list = ranks
+        self._backend = backend
+        self._world_size = len(self._rank_list)
+        assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
+
+        if dp_degree is None and tp_degree is None:
+            self._dp_degree = self._world_size
+            self._tp_degree = 1
+
+        if dp_degree and not tp_degree:
+            self._dp_degree = dp_degree
+            assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
+            self._tp_degree = self._world_size / dp_degree
+
+        if not dp_degree and tp_degree:
+            self._tp_degree = tp_degree
+            assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
+            self._dp_degree = self._world_size / tp_degree
+
+        self._tp_rank_list = []
+        self._dp_rank_list = []
+
+        for rank_id in range(self._world_size):
+            # rank_id and self._rank in the same tp group
+            if rank_id % self._tp_degree == self._rank % self._tp_degree:
+                self._dp_rank_list.append(rank_id)
+            if rank_id // self._tp_degree == self._rank // self._tp_degree:
+                self._tp_rank_list.append(rank_id)
+
+        self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend=backend)
+        self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend=backend)
+
+    def world_size(self):
+        return self._world_size
+
+    def dp_world_size(self):
+        return len(self._dp_rank_list)
+
+    def tp_world_size(self):
+        return len(self._tp_rank_list)
+
+    def dp_process_group(self):
+        return self._dp_process_group
+
+    def tp_process_group(self):
+        return self._tp_process_group
diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py
index 2880af885..b5d1bc752 100644
--- a/tests/test_tensor/test_model.py
+++ b/tests/test_tensor/test_model.py
@@ -10,7 +10,7 @@ from colossalai.utils.cuda import get_current_device
 from colossalai.utils import free_port
 from colossalai.utils.model.colo_init_context import ColoInitContext
 from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
-    ComputeSpec, ColoTensor, DistSpecManager
+    ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
 from colossalai.context import ParallelMode
 from colossalai.core import global_context as gpc
 from colossalai.nn.optimizer import ColoOptimizer
@@ -18,34 +18,30 @@ from functools import partial
 from _utils import tensor_equal, tensor_shard_equal, set_seed
 
 
-def init_1d_row_linear(weight):
-    spec = TensorSpec(
-        distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
-        ComputeSpec(ComputePattern.TP1D))
+def init_1d_row_linear(weight, pg: ProcessGroup):
+    spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
+                      ComputeSpec(ComputePattern.TP1D))
     with DistSpecManager.no_grad():
         weight.set_tensor_spec(spec)
 
 
-def init_1d_col_linear(weight):
-    spec = TensorSpec(
-        distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
-        ComputeSpec(ComputePattern.TP1D))
+def init_1d_col_linear(weight, pg):
+    spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
+                      ComputeSpec(ComputePattern.TP1D))
     with DistSpecManager.no_grad():
         weight.set_tensor_spec(spec)
 
 
-def init_1d_row_embedding(weight):
-    spec = TensorSpec(
-        distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
-        ComputeSpec(ComputePattern.TP1D))
+def init_1d_row_embedding(weight, pg):
+    spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
+                      ComputeSpec(ComputePattern.TP1D))
     with DistSpecManager.no_grad():
         weight.set_tensor_spec(spec)
 
 
-def init_1d_col_embedding(weight):
-    spec = TensorSpec(
-        distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
-        ComputeSpec(ComputePattern.TP1D))
+def init_1d_col_embedding(weight, pg):
+    spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
+                      ComputeSpec(ComputePattern.TP1D))
     with DistSpecManager.no_grad():
         weight.set_tensor_spec(spec)
 
@@ -69,6 +65,9 @@ def run_1d_hybrid_tp(model_name):
         for p1, p2 in zip(model.parameters(), model_torch.parameters()):
             p2.data.copy_(p1.data)
 
+    rank = gpc.get_local_rank(ParallelMode.GLOBAL)
+    world_size = gpc.get_world_size(ParallelMode.GLOBAL)
+    pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
     if 'bert' == model_name:
         for name, p in model.named_parameters():
             if not isinstance(p, ColoTensor):
@@ -76,29 +75,29 @@ def run_1d_hybrid_tp(model_name):
             # print(name)
             # num_class = type_vocab_size = 2 | (8, 2)
             if 'classifier' in name and 'weight' in name:
-                init_1d_row_linear(p)
+                init_1d_row_linear(p, pg)
             # num_class = vocab_size = 30524 | (30524, 8)
             if 'word_embeddings' in name and 'weight' in name:
-                init_1d_row_embedding(p)
+                init_1d_row_embedding(p, pg)
             # num_class = seq_len = 512 | (512, 8)
             if 'position_embeddings' in name and 'weight' in name:
-                init_1d_row_embedding(p)
+                init_1d_row_embedding(p, pg)
             # num_class = type_vocab_size = 2 | (2, 8)
             if 'token_type_embeddings' in name and 'weight' in name:
-                init_1d_col_embedding(p)
+                init_1d_col_embedding(p, pg)
     elif "simple_net" == model_name:
         # A naive way to set spec for all weights in Linear
         for name, p in model.named_parameters():
             if not isinstance(p, ColoTensor):
                 continue
             if 'embed' in name and 'weight' in name:
-                init_1d_col_embedding(p)
+                init_1d_col_embedding(p, pg)
             if 'proj1' in name and ('weight' in name or 'bias' in name):
-                init_1d_col_linear(p)
+                init_1d_col_linear(p, pg)
             if 'proj2' in name and 'weight' in name:
-                init_1d_row_linear(p)
+                init_1d_row_linear(p, pg)
             if 'classifier' in name and ('weight' in name or 'bias' in name):
-                init_1d_col_linear(p)
+                init_1d_col_linear(p, pg)
 
     model = model.cuda()
     colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
@@ -112,8 +111,8 @@ def run_1d_hybrid_tp(model_name):
         data = data.to(get_current_device())
         label = label.to(get_current_device())
 
-        torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
-        torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
+        torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
+        torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
         # Bcast rank0 data to all processes
         if criterion:
             output = model(data)
@@ -221,6 +220,10 @@ def run_1d_row_tp(model_name: str):
     with ColoInitContext(device=get_current_device()):
         model = model_builder(checkpoint=True)
 
+    rank = gpc.get_local_rank(ParallelMode.GLOBAL)
+    world_size = gpc.get_world_size(ParallelMode.GLOBAL)
+    pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
+
     set_seed(1)
     if rank == 0:
         model_torch = model_builder(checkpoint=True)
@@ -230,9 +233,9 @@ def run_1d_row_tp(model_name: str):
         if not isinstance(p, ColoTensor):
             continue
         if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
-            init_1d_row_linear(p)
+            init_1d_row_linear(p, pg)
         if 'embed' in name and 'weight' in name:
-            init_1d_row_embedding(p)
+            init_1d_row_embedding(p, pg)
 
     model = model.cuda()
 
@@ -330,10 +333,11 @@ def run_pretrain_load_dist(rank, world_size, port):
 
 # The test case has to download huggingface pretrained models from the internet
 # So we manually trigger the test.
+@pytest.mark.skip
 @pytest.mark.dist
 @pytest.mark.parametrize('world_size', [1, 4])
 @rerun_if_address_is_in_use()
-def _test_pretrain_load(world_size):
+def test_pretrain_load(world_size):
     run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port())
     mp.spawn(run_func, nprocs=world_size)
 
@@ -342,4 +346,4 @@ if __name__ == '__main__':
     # test_model_parameters()
     # test_colo_optimizer()
     # test_model(4)
-    _test_pretrain_load(4)
+    test_pretrain_load(4)
diff --git a/tests/test_tensor/test_tensor.py b/tests/test_tensor/test_tensor.py
index b04f14f57..9960e2cd0 100644
--- a/tests/test_tensor/test_tensor.py
+++ b/tests/test_tensor/test_tensor.py
@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
 import torch.multiprocessing as mp
 from colossalai.testing import rerun_if_address_is_in_use
 from colossalai.utils import free_port
-from colossalai.tensor import distspec, TensorSpec, ColoTensor
+from colossalai.tensor import distspec, TensorSpec, ColoTensor, ProcessGroup
 from colossalai.context import ParallelMode
 from functools import partial
 
@@ -21,14 +21,6 @@ def test_tensor_indexing():
     assert allclose(torch_t[:, 1], colo_t[:, 1])
 
 
-@pytest.mark.skip
-# FIXME(ver217): support lazy init
-def test_lazy_init_tensor():
-    lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
-    assert lazy_t._torch_tensor.numel() == 0
-    assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
-
-
 def test_wrapped_tensor_func():
     t_ref = torch.randn(4, 5)
     t = ColoTensor.from_torch_tensor(t_ref.clone())
@@ -62,10 +54,12 @@ def test_operand():
 
 def _run_view(world_size):
     t_ref = torch.randn(4, 5)
+    rank = gpc.get_global_rank()
+    pg = ProcessGroup(rank, list(range(world_size)))
+    assert pg.dp_world_size() == world_size, f"{pg.dp_world_size()} vs {world_size}"
     t = ColoTensor.from_torch_tensor(
         t_ref,
-        TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0],
-                                  num_partitions=[world_size])))
+        TensorSpec(distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])))
 
     assert t.size_global()[0] == 4 * world_size
     assert t.size_global(1) == 5
@@ -81,8 +75,10 @@ def _run_view(world_size):
 
 def _run_tensor_shard_init(world_size):
     t_ref = torch.randn(4, 5)
-    print(gpc.get_group(ParallelMode.DATA).size())
-    shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size])
+
+    rank = gpc.get_global_rank()
+    pg = ProcessGroup(rank, list(range(world_size)))
+    shard_spec = distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])
     tensor_spec = TensorSpec(shard_spec)
     t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
     t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))