diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py
index 379497b48..376c582ab 100644
--- a/colossalai/context/parallel_context.py
+++ b/colossalai/context/parallel_context.py
@@ -449,6 +449,7 @@ class ParallelContext:
                 dist.destroy_process_group(group)
         # destroy global process group
         dist.destroy_process_group()
+        self._groups.clear()
 
     def set_device(self, device_ordinal: int = None):
         """Sets distributed processes to be bound to devices.
diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py
index 052e564e9..e98f3c18b 100644
--- a/colossalai/testing/comparison.py
+++ b/colossalai/testing/comparison.py
@@ -13,7 +13,7 @@ def assert_not_equal(a: Tensor, b: Tensor):
 def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8):
     assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}'
 
-def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-2, atol: float = 1e-3):
+def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
     assert_close(a, b, rtol, atol)
 
 def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py
index 6427e4c8a..c9fe7d4d5 100644
--- a/colossalai/utils/common.py
+++ b/colossalai/utils/common.py
@@ -46,6 +46,7 @@ def free_port():
     while True:
         try:
             sock = socket.socket()
+            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
             port = random.randint(20000, 65000)
             sock.bind(('localhost', port))
             sock.close()
diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_amp/test_naive_fp16.py
index d3de7c851..c6805ad51 100644
--- a/tests/test_amp/test_naive_fp16.py
+++ b/tests/test_amp/test_naive_fp16.py
@@ -5,6 +5,7 @@ import pytest
 import torch.multiprocessing as mp
 from colossalai.amp import convert_to_naive_amp
 from tests.components_to_test.registry import non_distributed_component_funcs
+from colossalai.testing import assert_close_loose
 from colossalai.utils import free_port
 from functools import partial
 
@@ -48,7 +49,7 @@ def run_naive_amp():
         # forward pass
         amp_output = amp_model(data)
         torch_output = torch_model(data)
-        assert torch.allclose(amp_output, torch_output, rtol=1e-3, atol=1e-3), f'{amp_output} vs {torch_output}'
+        assert_close_loose(amp_output, torch_output)
 
         # backward
         amp_optimizer.backward(amp_output.mean())
@@ -56,7 +57,7 @@ def run_naive_amp():
 
         # check grad
         for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
-            torch.allclose(amp_param.grad, torch_param.grad.half(), rtol=1e-3, atol=1e-3)
+            assert_close_loose(amp_param.grad, torch_param.grad.half())
 
         # step
         amp_optimizer.step()
@@ -64,7 +65,7 @@ def run_naive_amp():
 
         # check updated param
         for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
-            torch.allclose(amp_param, torch_param.half(), rtol=1e-3, atol=1e-3)
+            assert_close_loose(amp_param, torch_param.half())
 
 
 def run_dist(rank, world_size, port):
diff --git a/tests/test_context/test_2d_init.py b/tests/test_context/test_2d_init.py
deleted file mode 100644
index 117b6e0d6..000000000
--- a/tests/test_context/test_2d_init.py
+++ /dev/null
@@ -1,105 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from functools import partial
-from pathlib import Path
-
-import pytest
-import torch
-import torch.multiprocessing as mp
-from colossalai import launch
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.utils import free_port
-
-CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2d_init.py').absolute()
-
-
-def check_data_parallel_rank(rank):
-    if rank in [0, 1, 2, 3, 4, 5, 6, 7]:
-        assert gpc.get_local_rank(ParallelMode.DATA) == 0
-    elif rank in [8, 9, 10, 11, 12, 13, 14, 15]:
-        assert gpc.get_local_rank(ParallelMode.DATA) == 1
-
-
-def check_pipeline_parallel_rank(rank):
-    if rank in [0, 1, 2, 3]:
-        assert gpc.get_local_rank(ParallelMode.PIPELINE) == 0
-    elif rank in [4, 5, 6, 7]:
-        assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1
-    elif rank in [8, 9, 10, 11]:
-        assert gpc.get_local_rank(ParallelMode.PIPELINE) == 0
-    elif rank in [12, 13, 14, 15]:
-        assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1
-
-
-def check_model_parallel_rank(rank):
-    for i in range(8):
-        if rank in [i, i+8]:
-            assert gpc.get_local_rank(ParallelMode.MODEL) == i
-
-
-def check_tensor_parallel_rank(rank):
-    if rank in [0, 4, 8, 12]:
-        assert gpc.get_local_rank(ParallelMode.TENSOR) == 0
-    elif rank in [1, 5, 9, 13]:
-        assert gpc.get_local_rank(ParallelMode.TENSOR) == 1
-    elif rank in [2, 6, 10, 14]:
-        assert gpc.get_local_rank(ParallelMode.TENSOR) == 2
-    elif rank in [3, 7, 11, 15]:
-        assert gpc.get_local_rank(ParallelMode.TENSOR) == 3
-
-
-def check_2d_parallel_rank(rank):
-    if rank in [0, 4, 8, 12]:
-        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0
-        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 0
-    elif rank in [1, 5, 9, 13]:
-        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0
-        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 1
-    elif rank in [2, 6, 10, 14]:
-        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 1
-        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 0
-    elif rank in [3, 7, 11, 15]:
-        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 1
-        assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 1
-
-
-def init_2d(rank, world_size, backend, port, host):
-    dist_args = dict(
-        config=CONFIG_PATH,
-        rank=rank,
-        world_size=world_size,
-        backend=backend,
-        port=port,
-        host=host,
-        verbose=True
-    )
-    launch(**dist_args)
-
-    check_tensor_parallel_rank(rank)
-    check_data_parallel_rank(rank)
-    check_2d_parallel_rank(rank)
-    check_pipeline_parallel_rank(rank)
-    check_model_parallel_rank(rank)
-    gpc.destroy()
-    torch.cuda.empty_cache()
-
-
-@pytest.mark.cpu
-def test_2d_init():
-    """
-    As no computation or communication is done, we can run this test on CPU.
-    """
-    world_size = 16
-    test_fn = partial(init_2d,
-                      world_size=world_size,
-                      backend='gloo',
-                      port=free_port(),
-                      host='localhost'
-                      )
-    mp.spawn(test_fn, nprocs=world_size)
-
-
-if __name__ == '__main__':
-    test_2d_init()
diff --git a/tests/test_context/test_2p5d_init.py b/tests/test_context/test_2p5d_init.py
deleted file mode 100644
index ef6789710..000000000
--- a/tests/test_context/test_2p5d_init.py
+++ /dev/null
@@ -1,128 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from functools import partial
-from pathlib import Path
-
-import pytest
-import torch
-import torch.multiprocessing as mp
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
-from colossalai.utils import free_port
-
-CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2p5d_init.py').absolute()
-
-
-def check_data_parallel_rank(rank):
-    dp_rank = gpc.get_local_rank(ParallelMode.DATA)
-
-    if rank in list(range(16)):
-        assert dp_rank == 0
-    elif rank in list(range(16, 32)):
-        assert dp_rank == 1
-
-
-def check_pipeline_parallel_rank(rank):
-    ppr = gpc.get_local_rank(ParallelMode.PIPELINE)
-
-    if rank in list(range(8)):
-        assert ppr == 0
-    elif rank in list(range(8, 16)):
-        assert ppr == 1
-    elif rank in list(range(16, 24)):
-        assert ppr == 0
-    elif rank in list(range(24, 32)):
-        assert ppr == 1
-
-
-def check_model_parallel_rank(rank):
-    for i in range(16):
-        if rank in [i, i+16]:
-            assert gpc.get_local_rank(ParallelMode.MODEL) == i
-
-
-def check_tensor_parallel_rank(rank):
-    tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
-
-    for i in range(8):
-        ranks = list(range(i, 32, 8))
-        if rank in ranks:
-            assert tp_rank == i, f'{rank}:{tp_rank}'
-
-
-def check_2p5d_parallel_rank(rank):
-    rp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
-    cp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
-    dp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
-    xp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_XZ)
-
-    # check for row parallel group
-    for i in range(2):
-        ranks = list(range(i, 32, 2))
-        if rank in ranks:
-            assert rp_rank == i
-
-    # check for col parallel group
-    for i in range(2):
-        ranks = list(range(i * 2, 32, 4))
-        ranks_plus_ones = [val + 1 for val in ranks]
-        ranks.extend(ranks_plus_ones)
-        if rank in ranks:
-            assert cp_rank == i
-
-    # check for depth parallel group
-    for i in range(2):
-        ranks = []
-        for j in range(i * 4, 32, 8):
-            ranks.extend([j + k for k in range(4)])
-        if rank in ranks:
-            assert dp_rank == i
-
-    # check for xz parallel group
-    for i in range(2):
-        ranks = list(range(i * 2, 32, 8))
-        ranks_plus_one = [val + 1 for val in ranks]
-        ranks.extend(ranks_plus_one)
-        if rank in ranks:
-            assert xp_rank == i
-
-
-def init_2halfd(rank, world_size, backend, port, host):
-    dist_args = dict(
-        config=CONFIG_PATH,
-        rank=rank,
-        world_size=world_size,
-        backend=backend,
-        port=port,
-        host=host,
-        verbose=True
-    )
-    launch(**dist_args)
-    check_data_parallel_rank(rank)
-    check_pipeline_parallel_rank(rank)
-    check_tensor_parallel_rank(rank)
-    check_2p5d_parallel_rank(rank)
-    check_model_parallel_rank(rank)
-    gpc.destroy()
-    torch.cuda.empty_cache()
-
-
-@pytest.mark.cpu
-def test_2halfd_init():
-    """
-    As no computation or communication is done, we can run this test on CPU.
-    """
-    world_size = 32
-    test_fn = partial(init_2halfd,
-                      world_size=world_size,
-                      backend='gloo',
-                      port=free_port(),
-                      host='localhost'
-                      )
-    mp.spawn(test_fn, nprocs=world_size)
-
-
-if __name__ == '__main__':
-    test_2halfd_init()
diff --git a/tests/test_context/test_3d_init.py b/tests/test_context/test_3d_init.py
deleted file mode 100644
index 12f0f1ea5..000000000
--- a/tests/test_context/test_3d_init.py
+++ /dev/null
@@ -1,120 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from functools import partial
-from pathlib import Path
-
-import pytest
-import torch
-import torch.multiprocessing as mp
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.initialize import launch
-from colossalai.utils import free_port
-
-CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_3d_init.py').absolute()
-
-
-def check_data_parallel_rank(rank):
-    dp_rank = gpc.get_local_rank(ParallelMode.DATA)
-
-    if rank in list(range(16)):
-        assert dp_rank == 0
-    elif rank in list(range(16, 32)):
-        assert dp_rank == 1
-
-
-def check_pipeline_parallel_rank(rank):
-    ppr = gpc.get_local_rank(ParallelMode.PIPELINE)
-
-    if rank in list(range(8)):
-        assert ppr == 0
-    elif rank in list(range(8, 16)):
-        assert ppr == 1
-    elif rank in list(range(16, 24)):
-        assert ppr == 0
-    elif rank in list(range(24, 32)):
-        assert ppr == 1
-
-
-def check_model_parallel_rank(rank):
-    for i in range(16):
-        if rank in [i, i+16]:
-            assert gpc.get_local_rank(ParallelMode.MODEL) == i
-
-
-def check_tensor_parallel_rank(rank):
-    tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
-
-    for i in range(8):
-        ranks = list(range(i, 32, 8))
-        if rank in ranks:
-            assert tp_rank == i
-
-
-def check_3d_parallel_rank(rank):
-    ip_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
-    wp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
-    op_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
-
-    # check for input parallel group
-    for i in range(2):
-        _ranks = list(range(i * 2, 32, 4))
-        _ranks_plus_one = [val + 1 for val in _ranks]
-        input_ranks = _ranks + _ranks_plus_one
-        if rank in input_ranks:
-            assert ip_rank == i
-
-    # check for weight parallel group
-    for i in range(2):
-        ranks = list(range(i, 32, 2))
-
-        if rank in ranks:
-            assert wp_rank == i
-
-    # check for output parallel group
-    for i in range(2):
-        ranks = []
-        for j in range(i * 4, 32, 8):
-            ranks.extend([j + k for k in range(4)])
-        if rank in ranks:
-            assert op_rank == i
-
-
-def init_3d(rank, world_size, backend, port, host):
-    dist_args = dict(
-        config=CONFIG_PATH,
-        rank=rank,
-        world_size=world_size,
-        backend=backend,
-        port=port,
-        host=host,
-        verbose=True
-    )
-    launch(**dist_args)
-    check_tensor_parallel_rank(rank)
-    check_3d_parallel_rank(rank)
-    check_data_parallel_rank(rank)
-    check_pipeline_parallel_rank(rank)
-    check_model_parallel_rank(rank)
-    gpc.destroy()
-    torch.cuda.empty_cache()
-
-
-@pytest.mark.cpu
-def test_3d_init():
-    """
-    As no computation or communication is done, we can run this test on CPU.
-    """
-    world_size = 32
-    test_fn = partial(init_3d,
-                      world_size=world_size,
-                      backend='gloo',
-                      port=free_port(),
-                      host='localhost'
-                      )
-    mp.spawn(test_fn, nprocs=world_size)
-
-
-if __name__ == '__main__':
-    test_3d_init()
diff --git a/tests/test_context/test_hybrid_parallel.py b/tests/test_context/test_hybrid_parallel.py
new file mode 100644
index 000000000..d4075ef0b
--- /dev/null
+++ b/tests/test_context/test_hybrid_parallel.py
@@ -0,0 +1,162 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+from functools import partial
+from pathlib import Path
+import pytest
+import torch
+import torch.multiprocessing as mp
+
+from colossalai import launch
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.utils import free_port
+from colossalai.context import reset_seeds
+from colossalai.global_variables import tensor_parallel_env as tp_env
+
+CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py'))
+
+
+def check_data_parallel_rank(rank):
+    global_world_size = gpc.get_world_size(ParallelMode.GLOBAL)
+    mp_size = gpc.get_world_size(ParallelMode.MODEL)
+    num_dp_groups = global_world_size // mp_size
+    dp_local_rank = gpc.get_local_rank(ParallelMode.DATA)
+
+    assert gpc.get_world_size(ParallelMode.DATA) == num_dp_groups
+
+    for group_idx in range(num_dp_groups):
+        ranks_in_dp_group = range(group_idx * mp_size, (group_idx + 1) * mp_size)
+        if rank in ranks_in_dp_group:
+            assert dp_local_rank == group_idx
+
+
+def check_pipeline_parallel_rank(rank):
+    mp_world_size = gpc.get_world_size(ParallelMode.MODEL)
+    tp_world_size = gpc.get_world_size(ParallelMode.TENSOR)
+    num_pipeline_stage = mp_world_size // tp_world_size
+    pipeline_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
+
+    for stage_idx in range(num_pipeline_stage):
+        ranks_in_current_stage = range(stage_idx * tp_world_size, (stage_idx + 1) * tp_world_size)
+        if rank in ranks_in_current_stage:
+            assert stage_idx == pipeline_local_rank
+
+
+def check_model_parallel_rank(rank):
+    mp_size = gpc.get_world_size(ParallelMode.MODEL)
+    rank_within_mp_group = rank % mp_size
+    mp_local_rank = gpc.get_local_rank(ParallelMode.MODEL)
+    assert rank_within_mp_group == mp_local_rank
+
+
+def check_tensor_parallel_rank(rank):
+    if tp_env.mode == '2d':
+        check_2d_tensor_parallel_rank(rank)
+    elif tp_env == '2.5d':
+        check_2p5d_tensor_parallel_rank(rank)
+    elif tp_env == '3d':
+        check_3d_tensor_parallel_rank(rank)
+
+
+def get_tp_info():
+    global_world_size = gpc.get_world_size(ParallelMode.GLOBAL)
+    tp_world_size = gpc.get_world_size(ParallelMode.TENSOR)
+    num_tp_groups = global_world_size // tp_world_size
+    tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
+    return tp_local_rank, tp_world_size, num_tp_groups
+
+
+def check_2d_tensor_parallel_rank(rank):
+    tp_local_rank, tp_world_size, num_tp_groups = get_tp_info()
+
+    for group_id in range(num_tp_groups):
+        ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size)
+
+        if rank in ranks_in_current_tp_group:
+            col_local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
+            row_local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
+
+            assert col_local_rank == tp_local_rank // tp_env.summa_dim
+            assert row_local_rank == tp_local_rank % tp_env.summa_dim
+
+
+def check_2p5d_tensor_parallel_rank(rank):
+    tp_local_rank, tp_world_size, num_tp_groups = get_tp_info()
+
+    for group_id in range(num_tp_groups):
+        ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size)
+
+        if rank in ranks_in_current_tp_group:
+            rp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
+            cp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
+            dp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
+            xp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_XZ)
+
+            assert rp_rank == tp_local_rank % tp_env.summa_dim
+            assert cp_rank == tp_local_rank // tp_env.tesseract_dim
+            assert dp_rank == tp_local_rank // (tp_env.summa_dim**2)
+            assert xp_rank == tp_local_rank // tp_env.summa_dim
+
+
+def check_3d_tensor_parallel_rank(rank):
+    tp_local_rank, tp_world_size, num_tp_groups = get_tp_info()
+
+    for group_id in range(num_tp_groups):
+        ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size)
+
+        if rank in ranks_in_current_tp_group:
+            ip_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
+            wp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
+            op_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
+
+            assert ip_rank == tp_local_rank % tp_env.depth_3d
+            assert wp_rank == tp_local_rank // tp_env.depth_3d
+            assert op_rank == tp_local_rank // (tp_env.depth_3d**2)
+
+
+def init_context(config_path, rank, world_size, backend, port, host):
+    dist_args = dict(config=config_path,
+                     rank=rank,
+                     world_size=world_size,
+                     backend=backend,
+                     port=port,
+                     host=host,
+                     verbose=True)
+    launch(**dist_args)
+
+    check_tensor_parallel_rank(rank)
+    check_data_parallel_rank(rank)
+    check_pipeline_parallel_rank(rank)
+    check_model_parallel_rank(rank)
+    gpc.destroy()
+    torch.cuda.empty_cache()
+
+
+def run_dist(rank, world_size, backend, port_list, host):
+    for config_path, port in zip(CONFIG_PATH_LIST, port_list):
+        init_context(config_path=config_path, rank=rank, world_size=world_size, backend=backend, port=port, host=host)
+        reset_seeds()
+
+
+@pytest.mark.cpu
+def test_context():
+    """
+    As no computation or communication is done, we can run this test on CPU.
+    """
+    world_size = 32
+    port_list = []
+
+    for _ in range(len(CONFIG_PATH_LIST)):
+        while True:
+            port = free_port()
+            if port not in port_list:
+                port_list.append(port)
+                break
+
+    test_fn = partial(run_dist, world_size=world_size, backend='gloo', port_list=port_list, host='localhost')
+    mp.spawn(test_fn, nprocs=world_size)
+
+
+if __name__ == '__main__':
+    test_context()