mirror of https://github.com/hpcaitech/ColossalAI
optimized context test time consumption (#446)
parent
496cbb0760
commit
b72b8445c6
|
@ -449,6 +449,7 @@ class ParallelContext:
|
||||||
dist.destroy_process_group(group)
|
dist.destroy_process_group(group)
|
||||||
# destroy global process group
|
# destroy global process group
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
self._groups.clear()
|
||||||
|
|
||||||
def set_device(self, device_ordinal: int = None):
|
def set_device(self, device_ordinal: int = None):
|
||||||
"""Sets distributed processes to be bound to devices.
|
"""Sets distributed processes to be bound to devices.
|
||||||
|
|
|
@ -13,7 +13,7 @@ def assert_not_equal(a: Tensor, b: Tensor):
|
||||||
def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8):
|
def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8):
|
||||||
assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}'
|
assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}'
|
||||||
|
|
||||||
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-2, atol: float = 1e-3):
|
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
|
||||||
assert_close(a, b, rtol, atol)
|
assert_close(a, b, rtol, atol)
|
||||||
|
|
||||||
def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
|
def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
|
||||||
|
|
|
@ -46,6 +46,7 @@ def free_port():
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
sock = socket.socket()
|
sock = socket.socket()
|
||||||
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
port = random.randint(20000, 65000)
|
port = random.randint(20000, 65000)
|
||||||
sock.bind(('localhost', port))
|
sock.bind(('localhost', port))
|
||||||
sock.close()
|
sock.close()
|
||||||
|
|
|
@ -5,6 +5,7 @@ import pytest
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.amp import convert_to_naive_amp
|
from colossalai.amp import convert_to_naive_amp
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
from colossalai.testing import assert_close_loose
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
@ -48,7 +49,7 @@ def run_naive_amp():
|
||||||
# forward pass
|
# forward pass
|
||||||
amp_output = amp_model(data)
|
amp_output = amp_model(data)
|
||||||
torch_output = torch_model(data)
|
torch_output = torch_model(data)
|
||||||
assert torch.allclose(amp_output, torch_output, rtol=1e-3, atol=1e-3), f'{amp_output} vs {torch_output}'
|
assert_close_loose(amp_output, torch_output)
|
||||||
|
|
||||||
# backward
|
# backward
|
||||||
amp_optimizer.backward(amp_output.mean())
|
amp_optimizer.backward(amp_output.mean())
|
||||||
|
@ -56,7 +57,7 @@ def run_naive_amp():
|
||||||
|
|
||||||
# check grad
|
# check grad
|
||||||
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
|
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
|
||||||
torch.allclose(amp_param.grad, torch_param.grad.half(), rtol=1e-3, atol=1e-3)
|
assert_close_loose(amp_param.grad, torch_param.grad.half())
|
||||||
|
|
||||||
# step
|
# step
|
||||||
amp_optimizer.step()
|
amp_optimizer.step()
|
||||||
|
@ -64,7 +65,7 @@ def run_naive_amp():
|
||||||
|
|
||||||
# check updated param
|
# check updated param
|
||||||
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
|
for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()):
|
||||||
torch.allclose(amp_param, torch_param.half(), rtol=1e-3, atol=1e-3)
|
assert_close_loose(amp_param, torch_param.half())
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
|
|
|
@ -1,105 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from colossalai import launch
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2d_init.py').absolute()
|
|
||||||
|
|
||||||
|
|
||||||
def check_data_parallel_rank(rank):
|
|
||||||
if rank in [0, 1, 2, 3, 4, 5, 6, 7]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.DATA) == 0
|
|
||||||
elif rank in [8, 9, 10, 11, 12, 13, 14, 15]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.DATA) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def check_pipeline_parallel_rank(rank):
|
|
||||||
if rank in [0, 1, 2, 3]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
|
||||||
elif rank in [4, 5, 6, 7]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1
|
|
||||||
elif rank in [8, 9, 10, 11]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
|
||||||
elif rank in [12, 13, 14, 15]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.PIPELINE) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def check_model_parallel_rank(rank):
|
|
||||||
for i in range(8):
|
|
||||||
if rank in [i, i+8]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.MODEL) == i
|
|
||||||
|
|
||||||
|
|
||||||
def check_tensor_parallel_rank(rank):
|
|
||||||
if rank in [0, 4, 8, 12]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.TENSOR) == 0
|
|
||||||
elif rank in [1, 5, 9, 13]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.TENSOR) == 1
|
|
||||||
elif rank in [2, 6, 10, 14]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.TENSOR) == 2
|
|
||||||
elif rank in [3, 7, 11, 15]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.TENSOR) == 3
|
|
||||||
|
|
||||||
|
|
||||||
def check_2d_parallel_rank(rank):
|
|
||||||
if rank in [0, 4, 8, 12]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0
|
|
||||||
assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 0
|
|
||||||
elif rank in [1, 5, 9, 13]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 0
|
|
||||||
assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 1
|
|
||||||
elif rank in [2, 6, 10, 14]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 1
|
|
||||||
assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 0
|
|
||||||
elif rank in [3, 7, 11, 15]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) == 1
|
|
||||||
assert gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def init_2d(rank, world_size, backend, port, host):
|
|
||||||
dist_args = dict(
|
|
||||||
config=CONFIG_PATH,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
backend=backend,
|
|
||||||
port=port,
|
|
||||||
host=host,
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
launch(**dist_args)
|
|
||||||
|
|
||||||
check_tensor_parallel_rank(rank)
|
|
||||||
check_data_parallel_rank(rank)
|
|
||||||
check_2d_parallel_rank(rank)
|
|
||||||
check_pipeline_parallel_rank(rank)
|
|
||||||
check_model_parallel_rank(rank)
|
|
||||||
gpc.destroy()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.cpu
|
|
||||||
def test_2d_init():
|
|
||||||
"""
|
|
||||||
As no computation or communication is done, we can run this test on CPU.
|
|
||||||
"""
|
|
||||||
world_size = 16
|
|
||||||
test_fn = partial(init_2d,
|
|
||||||
world_size=world_size,
|
|
||||||
backend='gloo',
|
|
||||||
port=free_port(),
|
|
||||||
host='localhost'
|
|
||||||
)
|
|
||||||
mp.spawn(test_fn, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_2d_init()
|
|
|
@ -1,128 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.initialize import launch
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2p5d_init.py').absolute()
|
|
||||||
|
|
||||||
|
|
||||||
def check_data_parallel_rank(rank):
|
|
||||||
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
|
|
||||||
|
|
||||||
if rank in list(range(16)):
|
|
||||||
assert dp_rank == 0
|
|
||||||
elif rank in list(range(16, 32)):
|
|
||||||
assert dp_rank == 1
|
|
||||||
|
|
||||||
|
|
||||||
def check_pipeline_parallel_rank(rank):
|
|
||||||
ppr = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
||||||
|
|
||||||
if rank in list(range(8)):
|
|
||||||
assert ppr == 0
|
|
||||||
elif rank in list(range(8, 16)):
|
|
||||||
assert ppr == 1
|
|
||||||
elif rank in list(range(16, 24)):
|
|
||||||
assert ppr == 0
|
|
||||||
elif rank in list(range(24, 32)):
|
|
||||||
assert ppr == 1
|
|
||||||
|
|
||||||
|
|
||||||
def check_model_parallel_rank(rank):
|
|
||||||
for i in range(16):
|
|
||||||
if rank in [i, i+16]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.MODEL) == i
|
|
||||||
|
|
||||||
|
|
||||||
def check_tensor_parallel_rank(rank):
|
|
||||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
|
||||||
|
|
||||||
for i in range(8):
|
|
||||||
ranks = list(range(i, 32, 8))
|
|
||||||
if rank in ranks:
|
|
||||||
assert tp_rank == i, f'{rank}:{tp_rank}'
|
|
||||||
|
|
||||||
|
|
||||||
def check_2p5d_parallel_rank(rank):
|
|
||||||
rp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
|
||||||
cp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
|
||||||
dp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
|
||||||
xp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_XZ)
|
|
||||||
|
|
||||||
# check for row parallel group
|
|
||||||
for i in range(2):
|
|
||||||
ranks = list(range(i, 32, 2))
|
|
||||||
if rank in ranks:
|
|
||||||
assert rp_rank == i
|
|
||||||
|
|
||||||
# check for col parallel group
|
|
||||||
for i in range(2):
|
|
||||||
ranks = list(range(i * 2, 32, 4))
|
|
||||||
ranks_plus_ones = [val + 1 for val in ranks]
|
|
||||||
ranks.extend(ranks_plus_ones)
|
|
||||||
if rank in ranks:
|
|
||||||
assert cp_rank == i
|
|
||||||
|
|
||||||
# check for depth parallel group
|
|
||||||
for i in range(2):
|
|
||||||
ranks = []
|
|
||||||
for j in range(i * 4, 32, 8):
|
|
||||||
ranks.extend([j + k for k in range(4)])
|
|
||||||
if rank in ranks:
|
|
||||||
assert dp_rank == i
|
|
||||||
|
|
||||||
# check for xz parallel group
|
|
||||||
for i in range(2):
|
|
||||||
ranks = list(range(i * 2, 32, 8))
|
|
||||||
ranks_plus_one = [val + 1 for val in ranks]
|
|
||||||
ranks.extend(ranks_plus_one)
|
|
||||||
if rank in ranks:
|
|
||||||
assert xp_rank == i
|
|
||||||
|
|
||||||
|
|
||||||
def init_2halfd(rank, world_size, backend, port, host):
|
|
||||||
dist_args = dict(
|
|
||||||
config=CONFIG_PATH,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
backend=backend,
|
|
||||||
port=port,
|
|
||||||
host=host,
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
launch(**dist_args)
|
|
||||||
check_data_parallel_rank(rank)
|
|
||||||
check_pipeline_parallel_rank(rank)
|
|
||||||
check_tensor_parallel_rank(rank)
|
|
||||||
check_2p5d_parallel_rank(rank)
|
|
||||||
check_model_parallel_rank(rank)
|
|
||||||
gpc.destroy()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.cpu
|
|
||||||
def test_2halfd_init():
|
|
||||||
"""
|
|
||||||
As no computation or communication is done, we can run this test on CPU.
|
|
||||||
"""
|
|
||||||
world_size = 32
|
|
||||||
test_fn = partial(init_2halfd,
|
|
||||||
world_size=world_size,
|
|
||||||
backend='gloo',
|
|
||||||
port=free_port(),
|
|
||||||
host='localhost'
|
|
||||||
)
|
|
||||||
mp.spawn(test_fn, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_2halfd_init()
|
|
|
@ -1,120 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.initialize import launch
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
|
|
||||||
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_3d_init.py').absolute()
|
|
||||||
|
|
||||||
|
|
||||||
def check_data_parallel_rank(rank):
|
|
||||||
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
|
|
||||||
|
|
||||||
if rank in list(range(16)):
|
|
||||||
assert dp_rank == 0
|
|
||||||
elif rank in list(range(16, 32)):
|
|
||||||
assert dp_rank == 1
|
|
||||||
|
|
||||||
|
|
||||||
def check_pipeline_parallel_rank(rank):
|
|
||||||
ppr = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
||||||
|
|
||||||
if rank in list(range(8)):
|
|
||||||
assert ppr == 0
|
|
||||||
elif rank in list(range(8, 16)):
|
|
||||||
assert ppr == 1
|
|
||||||
elif rank in list(range(16, 24)):
|
|
||||||
assert ppr == 0
|
|
||||||
elif rank in list(range(24, 32)):
|
|
||||||
assert ppr == 1
|
|
||||||
|
|
||||||
|
|
||||||
def check_model_parallel_rank(rank):
|
|
||||||
for i in range(16):
|
|
||||||
if rank in [i, i+16]:
|
|
||||||
assert gpc.get_local_rank(ParallelMode.MODEL) == i
|
|
||||||
|
|
||||||
|
|
||||||
def check_tensor_parallel_rank(rank):
|
|
||||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
|
||||||
|
|
||||||
for i in range(8):
|
|
||||||
ranks = list(range(i, 32, 8))
|
|
||||||
if rank in ranks:
|
|
||||||
assert tp_rank == i
|
|
||||||
|
|
||||||
|
|
||||||
def check_3d_parallel_rank(rank):
|
|
||||||
ip_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
|
||||||
wp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
|
||||||
op_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
|
||||||
|
|
||||||
# check for input parallel group
|
|
||||||
for i in range(2):
|
|
||||||
_ranks = list(range(i * 2, 32, 4))
|
|
||||||
_ranks_plus_one = [val + 1 for val in _ranks]
|
|
||||||
input_ranks = _ranks + _ranks_plus_one
|
|
||||||
if rank in input_ranks:
|
|
||||||
assert ip_rank == i
|
|
||||||
|
|
||||||
# check for weight parallel group
|
|
||||||
for i in range(2):
|
|
||||||
ranks = list(range(i, 32, 2))
|
|
||||||
|
|
||||||
if rank in ranks:
|
|
||||||
assert wp_rank == i
|
|
||||||
|
|
||||||
# check for output parallel group
|
|
||||||
for i in range(2):
|
|
||||||
ranks = []
|
|
||||||
for j in range(i * 4, 32, 8):
|
|
||||||
ranks.extend([j + k for k in range(4)])
|
|
||||||
if rank in ranks:
|
|
||||||
assert op_rank == i
|
|
||||||
|
|
||||||
|
|
||||||
def init_3d(rank, world_size, backend, port, host):
|
|
||||||
dist_args = dict(
|
|
||||||
config=CONFIG_PATH,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
backend=backend,
|
|
||||||
port=port,
|
|
||||||
host=host,
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
launch(**dist_args)
|
|
||||||
check_tensor_parallel_rank(rank)
|
|
||||||
check_3d_parallel_rank(rank)
|
|
||||||
check_data_parallel_rank(rank)
|
|
||||||
check_pipeline_parallel_rank(rank)
|
|
||||||
check_model_parallel_rank(rank)
|
|
||||||
gpc.destroy()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.cpu
|
|
||||||
def test_3d_init():
|
|
||||||
"""
|
|
||||||
As no computation or communication is done, we can run this test on CPU.
|
|
||||||
"""
|
|
||||||
world_size = 32
|
|
||||||
test_fn = partial(init_3d,
|
|
||||||
world_size=world_size,
|
|
||||||
backend='gloo',
|
|
||||||
port=free_port(),
|
|
||||||
host='localhost'
|
|
||||||
)
|
|
||||||
mp.spawn(test_fn, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_3d_init()
|
|
|
@ -0,0 +1,162 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
from colossalai import launch
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.context import reset_seeds
|
||||||
|
from colossalai.global_variables import tensor_parallel_env as tp_env
|
||||||
|
|
||||||
|
CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py'))
|
||||||
|
|
||||||
|
|
||||||
|
def check_data_parallel_rank(rank):
|
||||||
|
global_world_size = gpc.get_world_size(ParallelMode.GLOBAL)
|
||||||
|
mp_size = gpc.get_world_size(ParallelMode.MODEL)
|
||||||
|
num_dp_groups = global_world_size // mp_size
|
||||||
|
dp_local_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||||
|
|
||||||
|
assert gpc.get_world_size(ParallelMode.DATA) == num_dp_groups
|
||||||
|
|
||||||
|
for group_idx in range(num_dp_groups):
|
||||||
|
ranks_in_dp_group = range(group_idx * mp_size, (group_idx + 1) * mp_size)
|
||||||
|
if rank in ranks_in_dp_group:
|
||||||
|
assert dp_local_rank == group_idx
|
||||||
|
|
||||||
|
|
||||||
|
def check_pipeline_parallel_rank(rank):
|
||||||
|
mp_world_size = gpc.get_world_size(ParallelMode.MODEL)
|
||||||
|
tp_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
|
num_pipeline_stage = mp_world_size // tp_world_size
|
||||||
|
pipeline_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
|
for stage_idx in range(num_pipeline_stage):
|
||||||
|
ranks_in_current_stage = range(stage_idx * tp_world_size, (stage_idx + 1) * tp_world_size)
|
||||||
|
if rank in ranks_in_current_stage:
|
||||||
|
assert stage_idx == pipeline_local_rank
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_parallel_rank(rank):
|
||||||
|
mp_size = gpc.get_world_size(ParallelMode.MODEL)
|
||||||
|
rank_within_mp_group = rank % mp_size
|
||||||
|
mp_local_rank = gpc.get_local_rank(ParallelMode.MODEL)
|
||||||
|
assert rank_within_mp_group == mp_local_rank
|
||||||
|
|
||||||
|
|
||||||
|
def check_tensor_parallel_rank(rank):
|
||||||
|
if tp_env.mode == '2d':
|
||||||
|
check_2d_tensor_parallel_rank(rank)
|
||||||
|
elif tp_env == '2.5d':
|
||||||
|
check_2p5d_tensor_parallel_rank(rank)
|
||||||
|
elif tp_env == '3d':
|
||||||
|
check_3d_tensor_parallel_rank(rank)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tp_info():
|
||||||
|
global_world_size = gpc.get_world_size(ParallelMode.GLOBAL)
|
||||||
|
tp_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
|
num_tp_groups = global_world_size // tp_world_size
|
||||||
|
tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||||
|
return tp_local_rank, tp_world_size, num_tp_groups
|
||||||
|
|
||||||
|
|
||||||
|
def check_2d_tensor_parallel_rank(rank):
|
||||||
|
tp_local_rank, tp_world_size, num_tp_groups = get_tp_info()
|
||||||
|
|
||||||
|
for group_id in range(num_tp_groups):
|
||||||
|
ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size)
|
||||||
|
|
||||||
|
if rank in ranks_in_current_tp_group:
|
||||||
|
col_local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||||
|
row_local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||||
|
|
||||||
|
assert col_local_rank == tp_local_rank // tp_env.summa_dim
|
||||||
|
assert row_local_rank == tp_local_rank % tp_env.summa_dim
|
||||||
|
|
||||||
|
|
||||||
|
def check_2p5d_tensor_parallel_rank(rank):
|
||||||
|
tp_local_rank, tp_world_size, num_tp_groups = get_tp_info()
|
||||||
|
|
||||||
|
for group_id in range(num_tp_groups):
|
||||||
|
ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size)
|
||||||
|
|
||||||
|
if rank in ranks_in_current_tp_group:
|
||||||
|
rp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||||
|
cp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||||
|
dp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||||
|
xp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_XZ)
|
||||||
|
|
||||||
|
assert rp_rank == tp_local_rank % tp_env.summa_dim
|
||||||
|
assert cp_rank == tp_local_rank // tp_env.tesseract_dim
|
||||||
|
assert dp_rank == tp_local_rank // (tp_env.summa_dim**2)
|
||||||
|
assert xp_rank == tp_local_rank // tp_env.summa_dim
|
||||||
|
|
||||||
|
|
||||||
|
def check_3d_tensor_parallel_rank(rank):
|
||||||
|
tp_local_rank, tp_world_size, num_tp_groups = get_tp_info()
|
||||||
|
|
||||||
|
for group_id in range(num_tp_groups):
|
||||||
|
ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size)
|
||||||
|
|
||||||
|
if rank in ranks_in_current_tp_group:
|
||||||
|
ip_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||||
|
wp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||||
|
op_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||||
|
|
||||||
|
assert ip_rank == tp_local_rank % tp_env.depth_3d
|
||||||
|
assert wp_rank == tp_local_rank // tp_env.depth_3d
|
||||||
|
assert op_rank == tp_local_rank // (tp_env.depth_3d**2)
|
||||||
|
|
||||||
|
|
||||||
|
def init_context(config_path, rank, world_size, backend, port, host):
|
||||||
|
dist_args = dict(config=config_path,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
backend=backend,
|
||||||
|
port=port,
|
||||||
|
host=host,
|
||||||
|
verbose=True)
|
||||||
|
launch(**dist_args)
|
||||||
|
|
||||||
|
check_tensor_parallel_rank(rank)
|
||||||
|
check_data_parallel_rank(rank)
|
||||||
|
check_pipeline_parallel_rank(rank)
|
||||||
|
check_model_parallel_rank(rank)
|
||||||
|
gpc.destroy()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, backend, port_list, host):
|
||||||
|
for config_path, port in zip(CONFIG_PATH_LIST, port_list):
|
||||||
|
init_context(config_path=config_path, rank=rank, world_size=world_size, backend=backend, port=port, host=host)
|
||||||
|
reset_seeds()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cpu
|
||||||
|
def test_context():
|
||||||
|
"""
|
||||||
|
As no computation or communication is done, we can run this test on CPU.
|
||||||
|
"""
|
||||||
|
world_size = 32
|
||||||
|
port_list = []
|
||||||
|
|
||||||
|
for _ in range(len(CONFIG_PATH_LIST)):
|
||||||
|
while True:
|
||||||
|
port = free_port()
|
||||||
|
if port not in port_list:
|
||||||
|
port_list.append(port)
|
||||||
|
break
|
||||||
|
|
||||||
|
test_fn = partial(run_dist, world_size=world_size, backend='gloo', port_list=port_list, host='localhost')
|
||||||
|
mp.spawn(test_fn, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_context()
|
Loading…
Reference in New Issue