[test] fixed rerun_on_exception and adapted test cases (#487)

pull/528/head^2
Frank Lee 2022-03-25 17:25:12 +08:00 committed by GitHub
parent 4d322b79da
commit 3601b2bad0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 143 additions and 135 deletions

View File

@ -1,6 +1,7 @@
import re import re
from typing import Callable, List, Any from typing import Callable, List, Any
from functools import partial from functools import partial
from inspect import signature
def parameterize(argument: str, values: List[Any]) -> Callable: def parameterize(argument: str, values: List[Any]) -> Callable:
@ -105,6 +106,12 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
If max_try is None, it will rerun foreven if exception keeps occurings If max_try is None, it will rerun foreven if exception keeps occurings
""" """
def _match_lines(lines, pattern):
for line in lines:
if re.match(pattern, line):
return True
return False
def _wrapper(func): def _wrapper(func):
def _run_until_success(*args, **kwargs): def _run_until_success(*args, **kwargs):
@ -115,15 +122,25 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
while max_try is None or try_count < max_try: while max_try is None or try_count < max_try:
try: try:
try_count += 1 try_count += 1
func(*args, **kwargs) ret = func(*args, **kwargs)
return ret
except exception_type as e: except exception_type as e:
if pattern is None or re.match(pattern, str(e)): error_lines = str(e).split('\n')
if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)):
print('Exception is caught, retrying...')
# when pattern is not specified, we always skip the exception # when pattern is not specified, we always skip the exception
# when pattern is specified, we only skip when pattern is matched # when pattern is specified, we only skip when pattern is matched
continue continue
else: else:
print('Maximum number of attempts is reached or pattern is not matched, no more retrying...')
raise e raise e
# Override signature
# otherwise pytest.mark.parameterize will raise the following error:
# function does not use argumetn xxx
sig = signature(func)
_run_until_success.__signature__ = sig
return _run_until_success return _run_until_success
return _wrapper return _wrapper

View File

@ -1,8 +1,9 @@
import torch import torch
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.testing import assert_close_loose import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import assert_close_loose, rerun_on_exception
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
@ -83,6 +84,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_naive_amp(): def test_naive_amp():
world_size = 1 world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())

View File

@ -9,6 +9,7 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_on_exception
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
@ -63,6 +64,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_comm(): def test_comm():
world_size = 4 world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port()) run_func = partial(check_layer, world_size=world_size, port=free_port())

View File

@ -13,6 +13,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.context import reset_seeds from colossalai.context import reset_seeds
from colossalai.global_variables import tensor_parallel_env as tp_env from colossalai.global_variables import tensor_parallel_env as tp_env
from colossalai.testing import rerun_on_exception
CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py')) CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py'))
@ -140,6 +141,7 @@ def run_dist(rank, world_size, backend, port_list, host):
@pytest.mark.cpu @pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_context(): def test_context():
""" """
As no computation or communication is done, we can run this test on CPU. As no computation or communication is done, we can run this test on CPU.

View File

@ -12,29 +12,26 @@ import torch.multiprocessing as mp
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import colossalai import colossalai
from colossalai.builder import build_dataset, build_data_sampler, build_transform from colossalai.builder import build_dataset, build_transform
from torchvision import transforms from torchvision import transforms
from colossalai.context import ParallelMode, Config from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader from colossalai.utils import get_dataloader, free_port
from colossalai.testing import rerun_on_exception
CONFIG = Config( CONFIG = Config(
dict( dict(
train_data=dict( train_data=dict(dataset=dict(
dataset=dict( type='CIFAR10',
type='CIFAR10', root=Path(os.environ['DATA']),
root=Path(os.environ['DATA']), train=True,
train=True, download=True,
download=True,
),
dataloader=dict(
batch_size=8,
),
transform_pipeline=[
dict(type='ToTensor'),
dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
), ),
dataloader=dict(batch_size=8,),
transform_pipeline=[
dict(type='ToTensor'),
dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]),
parallel=dict( parallel=dict(
pipeline=dict(size=1), pipeline=dict(size=1),
tensor=dict(size=1, mode=None), tensor=dict(size=1, mode=None),
@ -43,15 +40,8 @@ CONFIG = Config(
)) ))
def run_data_sampler(rank, world_size): def run_data_sampler(rank, world_size, port):
dist_args = dict( dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost')
config=CONFIG,
rank=rank,
world_size=world_size,
backend='gloo',
port='29903',
host='localhost'
)
colossalai.launch(**dist_args) colossalai.launch(**dist_args)
print('finished initialization') print('finished initialization')
@ -71,15 +61,16 @@ def run_data_sampler(rank, world_size):
dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA)) dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA))
if gpc.get_local_rank(ParallelMode.DATA) != 0: if gpc.get_local_rank(ParallelMode.DATA) != 0:
assert not torch.equal(img, assert not torch.equal(
img_to_compare), 'Same image was distributed across ranks but expected it to be different' img, img_to_compare), 'Same image was distributed across ranks but expected it to be different'
torch.cuda.empty_cache() torch.cuda.empty_cache()
@pytest.mark.cpu @pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_data_sampler(): def test_data_sampler():
world_size = 4 world_size = 4
test_func = partial(run_data_sampler, world_size=world_size) test_func = partial(run_data_sampler, world_size=world_size, port=free_port())
mp.spawn(test_func, nprocs=world_size) mp.spawn(test_func, nprocs=world_size)

View File

@ -16,45 +16,33 @@ import colossalai
from colossalai.builder import build_dataset, build_transform from colossalai.builder import build_dataset, build_transform
from colossalai.context import ParallelMode, Config from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
CONFIG = Config( CONFIG = Config(
dict( dict(
train_data=dict( train_data=dict(dataset=dict(
dataset=dict( type='CIFAR10',
type='CIFAR10', root=Path(os.environ['DATA']),
root=Path(os.environ['DATA']), train=True,
train=True, download=True,
download=True,
),
dataloader=dict(
num_workers=2,
batch_size=2,
shuffle=True
),
transform_pipeline=[
dict(type='ToTensor'),
dict(type='RandomCrop', size=32),
dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
), ),
dataloader=dict(num_workers=2, batch_size=2, shuffle=True),
transform_pipeline=[
dict(type='ToTensor'),
dict(type='RandomCrop', size=32),
dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]),
parallel=dict( parallel=dict(
pipeline=dict(size=1), pipeline=dict(size=1),
tensor=dict(size=1, mode=None), tensor=dict(size=1, mode=None),
), ),
seed=1024, seed=1024,
) ))
)
def run_data_sampler(rank, world_size): def run_data_sampler(rank, world_size, port):
dist_args = dict( dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost')
config=CONFIG,
rank=rank,
world_size=world_size,
backend='gloo',
port='29904',
host='localhost'
)
colossalai.launch(**dist_args) colossalai.launch(**dist_args)
dataset_cfg = gpc.config.train_data.dataset dataset_cfg = gpc.config.train_data.dataset
@ -91,9 +79,10 @@ def run_data_sampler(rank, world_size):
@pytest.mark.cpu @pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_data_sampler(): def test_data_sampler():
world_size = 4 world_size = 4
test_func = partial(run_data_sampler, world_size=world_size) test_func = partial(run_data_sampler, world_size=world_size, port=free_port())
mp.spawn(test_func, nprocs=world_size) mp.spawn(test_func, nprocs=world_size)

View File

@ -13,8 +13,9 @@ from colossalai.logging import get_dist_logger
from colossalai.nn import LinearWarmupLR from colossalai.nn import LinearWarmupLR
from colossalai.nn.loss import CrossEntropyLoss from colossalai.nn.loss import CrossEntropyLoss
from colossalai.trainer import Trainer, hooks from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, free_port, get_dataloader from colossalai.utils import free_port, get_dataloader
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
from colossalai.testing import rerun_on_exception
from model_zoo.vit import vit_tiny_patch4_32 from model_zoo.vit import vit_tiny_patch4_32
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
@ -79,6 +80,7 @@ def run_trainer(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_hybrid_parallel(): def test_hybrid_parallel():
world_size = 8 world_size = 8
run_func = partial(run_trainer, world_size=world_size, port=free_port()) run_func = partial(run_trainer, world_size=world_size, port=free_port())

View File

@ -4,11 +4,10 @@ import colossalai
import pytest import pytest
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.amp import AMP_TYPE from colossalai.amp import AMP_TYPE
from colossalai.context import Config
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import free_port from colossalai.utils import free_port
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize from colossalai.testing import parameterize, rerun_on_exception
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
fp16=dict(mode=None), fp16=dict(mode=None),
@ -57,6 +56,7 @@ def run_engine(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_engine(): def test_engine():
world_size = 2 world_size = 2
run_func = partial(run_engine, world_size=world_size, port=free_port()) run_func = partial(run_engine, world_size=world_size, port=free_port())

View File

@ -10,28 +10,15 @@ from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from checks_1d.check_layer_1d import * from checks_1d.check_layer_1d import *
CONFIG = dict( CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),)
parallel=dict(
pipeline=dict(size=1),
tensor=dict(
size=4,
mode='1d'
)
),
)
def check_layer(rank, world_size, port): def check_layer(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
launch(config=CONFIG, launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
check_linear_col() check_linear_col()
check_linear_row() check_linear_row()
@ -48,6 +35,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_1d(): def test_1d():
world_size = 4 world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port()) run_func = partial(check_layer, world_size=world_size, port=free_port())

View File

@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight, from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed, check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight, check_vocab_parallel_classifier_given_embed_weight,
@ -18,7 +18,7 @@ from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check
check_vocab_parallel_loss) check_vocab_parallel_loss)
from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')), ) CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')),)
def check_operations(): def check_operations():
@ -55,6 +55,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_2d(): def test_2d():
world_size = 4 world_size = 4
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())

View File

@ -7,16 +7,14 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from checks_2p5d.check_layer_2p5d import * from checks_2p5d.check_layer_2p5d import *
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
CONFIG = dict( CONFIG = dict(parallel=dict(
parallel=dict( pipeline=dict(size=1),
pipeline=dict(size=1), tensor=dict(size=4, mode='2.5d', depth=1),
tensor=dict(size=4, mode='2.5d', depth=1), ),)
),
)
def check_operations(): def check_operations():
@ -41,12 +39,7 @@ def check_layer():
def check_layer_and_operation(rank, world_size, port): def check_layer_and_operation(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
launch(config=CONFIG, launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False
@ -58,6 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_2p5d(): def test_2p5d():
world_size = 4 world_size = 4
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())

View File

@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight, from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed, check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight, check_vocab_parallel_classifier_given_embed_weight,
@ -51,6 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_3d(): def test_3d():
world_size = 8 world_size = 8
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())

View File

@ -7,14 +7,10 @@ import pytest
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.testing import rerun_on_exception
from functools import partial from functools import partial
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
CONFIG = dict(
parallel=dict(
tensor=dict(size=4, mode='sequence')
)
)
def check_ring_qk(rank, world_size): def check_ring_qk(rank, world_size):
@ -26,14 +22,14 @@ def check_ring_qk(rank, world_size):
sub_seq_length = seq_length // world_size sub_seq_length = seq_length // world_size
# create master tensors # create master tensors
q = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda() q = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
k = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda() k = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
# create distributed tensors # create distributed tensors
sub_q = q.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous() sub_q = q.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
sub_k = k.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous() sub_k = k.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
# set autograd attributes # set autograd attributes
q.requires_grad = True q.requires_grad = True
@ -53,7 +49,7 @@ def check_ring_qk(rank, world_size):
sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length) sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length)
# check master and distributed attetion scores # check master and distributed attetion scores
sub_master_a = a[:, rank*sub_seq_length:(rank+1)*sub_seq_length] sub_master_a = a[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2) assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2)
# run master backward # run master backward
@ -61,11 +57,11 @@ def check_ring_qk(rank, world_size):
a.mean().backward() a.mean().backward()
# run distributed backward # run distributed backward
partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length] partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
torch.autograd.backward(sub_a, partial_master_a_grad) torch.autograd.backward(sub_a, partial_master_a_grad)
# check master and distributed grads # check master and distributed grads
partial_master_q_grad = q.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length] partial_master_q_grad = q.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \ assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \
'attention score cannot match' 'attention score cannot match'
@ -79,14 +75,14 @@ def check_ring_av(rank, world_size):
sub_seq_length = seq_length // world_size sub_seq_length = seq_length // world_size
# create master tensors # create master tensors
a = torch.rand(batch_size*num_heads, seq_length, seq_length).cuda() a = torch.rand(batch_size * num_heads, seq_length, seq_length).cuda()
v = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda() v = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
# create distributed tensors # create distributed tensors
sub_a = a.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous() sub_a = a.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
sub_v = v.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous() sub_v = v.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
# set autograd attributes # set autograd attributes
a.requires_grad = True a.requires_grad = True
@ -108,7 +104,7 @@ def check_ring_av(rank, world_size):
# print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}') # print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}')
# check master and distributed output # check master and distributed output
sub_master_out = out[:, rank*sub_seq_length:(rank+1)*sub_seq_length] sub_master_out = out[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2) assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2)
# # run master backward # # run master backward
@ -116,23 +112,17 @@ def check_ring_av(rank, world_size):
out.mean().backward() out.mean().backward()
# # run distributed backward # # run distributed backward
partial_master_out_grad = out.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length] partial_master_out_grad = out.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
torch.autograd.backward(sub_out, partial_master_out_grad) torch.autograd.backward(sub_out, partial_master_out_grad)
# # check master and distributed grads # # check master and distributed grads
partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length] partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \ assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \
'attention output cannot match' 'attention output cannot match'
def run_test(rank, world_size): def run_test(rank, world_size):
colossalai.launch( colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=29500)
rank=rank,
world_size=world_size,
config=CONFIG,
host='localhost',
port=29500
)
# check_ring_qk(rank, world_size) # check_ring_qk(rank, world_size)
check_ring_av(rank, world_size) check_ring_av(rank, world_size)
@ -142,6 +132,7 @@ def run_test(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_sequence(): def test_sequence():
world_size = 4 world_size = 4
run_func = partial(run_test, world_size=world_size) run_func = partial(run_test, world_size=world_size)

View File

@ -11,6 +11,7 @@ from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param from colossalai.utils.moe import sync_moe_model_param
from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.testing import assert_equal_in_group from colossalai.testing import assert_equal_in_group
from colossalai.testing import rerun_on_exception
BATCH_SIZE = 4 BATCH_SIZE = 4
DIM = 16 DIM = 16
@ -62,6 +63,7 @@ def run_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_grad_handler(): def test_grad_handler():
world_size = 4 world_size = 4
run_func = partial(run_test, world_size=world_size, port=free_port()) run_func = partial(run_test, world_size=world_size, port=free_port())

View File

@ -9,6 +9,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.testing import rerun_on_exception
BATCH_SIZE = 16 BATCH_SIZE = 16
NUM_EXPERTS = 4 NUM_EXPERTS = 4
@ -86,6 +87,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
@pytest.mark.parametrize("hidden_size", [32, 144]) @pytest.mark.parametrize("hidden_size", [32, 144])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) @pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
@pytest.mark.parametrize("router", [Top1Router, Top2Router]) @pytest.mark.parametrize("router", [Top1Router, Top2Router])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_moe_kernel(rs, hidden_size, data_type, router): def test_moe_kernel(rs, hidden_size, data_type, router):
world_size = 4 world_size = 4
run_func = partial(run_routing, run_func = partial(run_routing,

View File

@ -8,7 +8,7 @@ from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Experts from colossalai.nn.layer.moe import Experts
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param from colossalai.utils.moe import sync_moe_model_param
from colossalai.testing import assert_equal_in_group from colossalai.testing import assert_equal_in_group, rerun_on_exception
D_MODEL = 4 D_MODEL = 4
D_FF = 8 D_FF = 8
@ -60,6 +60,7 @@ def run_test(rank, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_moe_initialization(): def test_moe_initialization():
world_size = 4 world_size = 4
run_func = partial(run_test, port=free_port()) run_func = partial(run_test, port=free_port())

View File

@ -15,6 +15,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_on_exception
BATCH_SIZE = 4 BATCH_SIZE = 4
SEQ_LENGTH = 2 SEQ_LENGTH = 2
@ -92,6 +93,7 @@ def run_check(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_p2p(): def test_p2p():
world_size = 4 world_size = 4
run_func = partial(run_check, world_size=world_size, port=free_port()) run_func = partial(run_check, world_size=world_size, port=free_port())

View File

@ -10,19 +10,14 @@ from colossalai.initialize import launch
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from functools import partial from functools import partial
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
DIR_PATH = osp.dirname(osp.realpath(__file__)) DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, 'resnet_config.py') CONFIG_PATH = osp.join(DIR_PATH, 'resnet_config.py')
def run_partition(rank, world_size, port): def run_partition(rank, world_size, port):
launch(config=CONFIG_PATH, launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl'
)
logger = get_dist_logger() logger = get_dist_logger()
logger.info('finished initialization') logger.info('finished initialization')
@ -37,6 +32,7 @@ def run_partition(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_partition(): def test_partition():
world_size = 4 world_size = 4
run_func = partial(run_partition, world_size=world_size, port=free_port()) run_func = partial(run_partition, world_size=world_size, port=free_port())

View File

@ -14,6 +14,7 @@ from colossalai.core import global_context as gpc
from colossalai.engine.schedule import PipelineSchedule from colossalai.engine.schedule import PipelineSchedule
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port, get_dataloader, print_rank_0 from colossalai.utils import free_port, get_dataloader, print_rank_0
from colossalai.testing import rerun_on_exception
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
@ -67,6 +68,7 @@ def run_schedule(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_pipeline_schedule(): def test_pipeline_schedule():
world_size = 4 world_size = 4
run_func = partial(run_schedule, world_size=world_size, port=free_port()) run_func = partial(run_schedule, world_size=world_size, port=free_port())

View File

@ -9,7 +9,7 @@ from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer from colossalai.trainer import Trainer
from colossalai.utils import MultiTimer, free_port from colossalai.utils import MultiTimer, free_port
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize from colossalai.testing import parameterize, rerun_on_exception
BATCH_SIZE = 4 BATCH_SIZE = 4
IMG_SIZE = 32 IMG_SIZE = 32
@ -51,6 +51,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_trainer_no_pipeline(): def test_trainer_no_pipeline():
world_size = 4 world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())

View File

@ -17,6 +17,7 @@ from torch.optim import Adam
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from torchvision.models import resnet18 from torchvision.models import resnet18
from colossalai.testing import rerun_on_exception
BATCH_SIZE = 4 BATCH_SIZE = 4
IMG_SIZE = 32 IMG_SIZE = 32
@ -85,6 +86,7 @@ def run_trainer_with_pipeline(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_trainer_with_pipeline(): def test_trainer_with_pipeline():
world_size = 4 world_size = 4
run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port()) run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port())

View File

@ -1,7 +1,7 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param import ShardedTensor
import colossalai import colossalai
@ -47,6 +47,7 @@ def run_tensor_move(rank):
GLOBAL_MODEL_DATA_TRACER.close() GLOBAL_MODEL_DATA_TRACER.close()
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_tensor_move(): def test_tensor_move():
mp.spawn(run_tensor_move, nprocs=1) mp.spawn(run_tensor_move, nprocs=1)

View File

@ -10,6 +10,7 @@ import torch.nn as nn
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import free_port, get_dataloader from colossalai.utils import free_port, get_dataloader
from colossalai.testing import rerun_on_exception
from torch.optim import Adam from torch.optim import Adam
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
@ -86,6 +87,7 @@ def run_no_pipeline(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_engine(): def test_engine():
world_size = 4 world_size = 4
func = partial(run_no_pipeline, world_size=world_size, port=free_port()) func = partial(run_no_pipeline, world_size=world_size, port=free_port())

View File

@ -14,9 +14,9 @@ from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from colossalai.testing import parameterize
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
from functools import partial from functools import partial
from colossalai.testing import parameterize, rerun_on_exception
def checkpoint_wrapper(module, enable=True): def checkpoint_wrapper(module, enable=True):
@ -102,6 +102,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_clip_grad(): def test_zero_clip_grad():
world_size = 4 world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())

View File

@ -14,6 +14,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
@ -57,6 +58,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4]) @pytest.mark.parametrize("world_size", [1, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_init_context(world_size): def test_zero_init_context(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -14,6 +14,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -63,6 +64,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_shard_model_v2(world_size): def test_shard_model_v2(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -10,6 +10,7 @@ from colossalai.utils import free_port
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.testing import rerun_on_exception
from tests.test_zero_data_parallel.common import CONFIG, allclose from tests.test_zero_data_parallel.common import CONFIG, allclose
@ -35,6 +36,7 @@ def _run_shard_tensor(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_shard_tensor(world_size): def test_shard_tensor(world_size):
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port()) run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@ -55,6 +57,7 @@ def _run_shard_param_v2(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_shard_param_v2(world_size): def test_shard_param_v2(world_size):
run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port()) run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -15,6 +15,7 @@ from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from colossalai.zero.sharded_optim._utils import has_inf_or_nan from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -106,6 +107,7 @@ def _run_dist(rank, world_size, port):
# use_cpuadam = True can be used with cpu_offload = False # use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_sharded_optim_v2(world_size): def test_sharded_optim_v2(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port()) run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -14,6 +14,7 @@ from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import TensorShardStrategy
from torchvision.models import resnet50 from torchvision.models import resnet50
from colossalai.testing import rerun_on_exception
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
@ -71,6 +72,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_sharded_optim_with_sync_bn(): def test_sharded_optim_with_sync_bn():
""" """
This test is to make sure that buffers are synchronized between ranks This test is to make sure that buffers are synchronized between ranks

View File

@ -14,6 +14,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
@ -51,6 +52,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_state_dict(world_size): def test_zero_state_dict(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -13,6 +13,7 @@ from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim._utils import has_inf_or_nan from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -96,6 +97,7 @@ def run_dist(rank, world_size, port, parallel_config):
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4]) @pytest.mark.parametrize("world_size", [2, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_mp_engine(world_size): def test_mp_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG) run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@ -103,6 +105,7 @@ def test_mp_engine(world_size):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_engine(world_size): def test_zero_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG) run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)