[test] fixed rerun_on_exception and adapted test cases (#487)

pull/528/head^2
Frank Lee 2022-03-25 17:25:12 +08:00 committed by GitHub
parent 4d322b79da
commit 3601b2bad0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 143 additions and 135 deletions

View File

@ -1,6 +1,7 @@
import re
from typing import Callable, List, Any
from functools import partial
from inspect import signature
def parameterize(argument: str, values: List[Any]) -> Callable:
@ -105,6 +106,12 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
If max_try is None, it will rerun foreven if exception keeps occurings
"""
def _match_lines(lines, pattern):
for line in lines:
if re.match(pattern, line):
return True
return False
def _wrapper(func):
def _run_until_success(*args, **kwargs):
@ -115,15 +122,25 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
while max_try is None or try_count < max_try:
try:
try_count += 1
func(*args, **kwargs)
ret = func(*args, **kwargs)
return ret
except exception_type as e:
if pattern is None or re.match(pattern, str(e)):
error_lines = str(e).split('\n')
if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)):
print('Exception is caught, retrying...')
# when pattern is not specified, we always skip the exception
# when pattern is specified, we only skip when pattern is matched
continue
else:
print('Maximum number of attempts is reached or pattern is not matched, no more retrying...')
raise e
# Override signature
# otherwise pytest.mark.parameterize will raise the following error:
# function does not use argumetn xxx
sig = signature(func)
_run_until_success.__signature__ = sig
return _run_until_success
return _wrapper

View File

@ -1,8 +1,9 @@
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.testing import assert_close_loose
import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import assert_close_loose, rerun_on_exception
from colossalai.utils import free_port
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
@ -83,6 +84,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_naive_amp():
world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port())

View File

@ -9,6 +9,7 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_on_exception
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
@ -63,6 +64,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_comm():
world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port())

View File

@ -13,6 +13,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.context import reset_seeds
from colossalai.global_variables import tensor_parallel_env as tp_env
from colossalai.testing import rerun_on_exception
CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py'))
@ -140,6 +141,7 @@ def run_dist(rank, world_size, backend, port_list, host):
@pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_context():
"""
As no computation or communication is done, we can run this test on CPU.

View File

@ -12,29 +12,26 @@ import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import colossalai
from colossalai.builder import build_dataset, build_data_sampler, build_transform
from colossalai.builder import build_dataset, build_transform
from torchvision import transforms
from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader
from colossalai.utils import get_dataloader, free_port
from colossalai.testing import rerun_on_exception
CONFIG = Config(
dict(
train_data=dict(
dataset=dict(
type='CIFAR10',
root=Path(os.environ['DATA']),
train=True,
download=True,
),
dataloader=dict(
batch_size=8,
),
transform_pipeline=[
dict(type='ToTensor'),
dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
train_data=dict(dataset=dict(
type='CIFAR10',
root=Path(os.environ['DATA']),
train=True,
download=True,
),
dataloader=dict(batch_size=8,),
transform_pipeline=[
dict(type='ToTensor'),
dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]),
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None),
@ -43,15 +40,8 @@ CONFIG = Config(
))
def run_data_sampler(rank, world_size):
dist_args = dict(
config=CONFIG,
rank=rank,
world_size=world_size,
backend='gloo',
port='29903',
host='localhost'
)
def run_data_sampler(rank, world_size, port):
dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost')
colossalai.launch(**dist_args)
print('finished initialization')
@ -71,15 +61,16 @@ def run_data_sampler(rank, world_size):
dist.broadcast(img_to_compare, src=0, group=gpc.get_group(ParallelMode.DATA))
if gpc.get_local_rank(ParallelMode.DATA) != 0:
assert not torch.equal(img,
img_to_compare), 'Same image was distributed across ranks but expected it to be different'
assert not torch.equal(
img, img_to_compare), 'Same image was distributed across ranks but expected it to be different'
torch.cuda.empty_cache()
@pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_data_sampler():
world_size = 4
test_func = partial(run_data_sampler, world_size=world_size)
test_func = partial(run_data_sampler, world_size=world_size, port=free_port())
mp.spawn(test_func, nprocs=world_size)

View File

@ -16,45 +16,33 @@ import colossalai
from colossalai.builder import build_dataset, build_transform
from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
CONFIG = Config(
dict(
train_data=dict(
dataset=dict(
type='CIFAR10',
root=Path(os.environ['DATA']),
train=True,
download=True,
),
dataloader=dict(
num_workers=2,
batch_size=2,
shuffle=True
),
transform_pipeline=[
dict(type='ToTensor'),
dict(type='RandomCrop', size=32),
dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
train_data=dict(dataset=dict(
type='CIFAR10',
root=Path(os.environ['DATA']),
train=True,
download=True,
),
dataloader=dict(num_workers=2, batch_size=2, shuffle=True),
transform_pipeline=[
dict(type='ToTensor'),
dict(type='RandomCrop', size=32),
dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]),
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None),
),
seed=1024,
)
)
))
def run_data_sampler(rank, world_size):
dist_args = dict(
config=CONFIG,
rank=rank,
world_size=world_size,
backend='gloo',
port='29904',
host='localhost'
)
def run_data_sampler(rank, world_size, port):
dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost')
colossalai.launch(**dist_args)
dataset_cfg = gpc.config.train_data.dataset
@ -91,9 +79,10 @@ def run_data_sampler(rank, world_size):
@pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_data_sampler():
world_size = 4
test_func = partial(run_data_sampler, world_size=world_size)
test_func = partial(run_data_sampler, world_size=world_size, port=free_port())
mp.spawn(test_func, nprocs=world_size)

View File

@ -13,8 +13,9 @@ from colossalai.logging import get_dist_logger
from colossalai.nn import LinearWarmupLR
from colossalai.nn.loss import CrossEntropyLoss
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, free_port, get_dataloader
from colossalai.utils import free_port, get_dataloader
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
from colossalai.testing import rerun_on_exception
from model_zoo.vit import vit_tiny_patch4_32
from torchvision import transforms
from torchvision.datasets import CIFAR10
@ -79,6 +80,7 @@ def run_trainer(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_hybrid_parallel():
world_size = 8
run_func = partial(run_trainer, world_size=world_size, port=free_port())

View File

@ -4,11 +4,10 @@ import colossalai
import pytest
import torch.multiprocessing as mp
from colossalai.amp import AMP_TYPE
from colossalai.context import Config
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize
from colossalai.testing import parameterize, rerun_on_exception
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
fp16=dict(mode=None),
@ -57,6 +56,7 @@ def run_engine(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_engine():
world_size = 2
run_func = partial(run_engine, world_size=world_size, port=free_port())

View File

@ -10,28 +10,15 @@ from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from checks_1d.check_layer_1d import *
CONFIG = dict(
parallel=dict(
pipeline=dict(size=1),
tensor=dict(
size=4,
mode='1d'
)
),
)
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),)
def check_layer(rank, world_size, port):
disable_existing_loggers()
launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_col()
check_linear_row()
@ -48,6 +35,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_1d():
world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port())

View File

@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
@ -18,7 +18,7 @@ from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check
check_vocab_parallel_loss)
from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')), )
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')),)
def check_operations():
@ -55,6 +55,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_2d():
world_size = 4
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())

View File

@ -7,16 +7,14 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from checks_2p5d.check_layer_2p5d import *
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
CONFIG = dict(
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2.5d', depth=1),
),
)
CONFIG = dict(parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2.5d', depth=1),
),)
def check_operations():
@ -41,12 +39,7 @@ def check_layer():
def check_layer_and_operation(rank, world_size, port):
disable_existing_loggers()
launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
@ -58,6 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_2p5d():
world_size = 4
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())

View File

@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
@ -51,6 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_3d():
world_size = 8
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())

View File

@ -7,14 +7,10 @@ import pytest
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.testing import rerun_on_exception
from functools import partial
CONFIG = dict(
parallel=dict(
tensor=dict(size=4, mode='sequence')
)
)
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
def check_ring_qk(rank, world_size):
@ -26,14 +22,14 @@ def check_ring_qk(rank, world_size):
sub_seq_length = seq_length // world_size
# create master tensors
q = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
k = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
q = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
k = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
# create distributed tensors
sub_q = q.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
sub_k = k.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
sub_q = q.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
sub_k = k.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
# set autograd attributes
q.requires_grad = True
@ -53,7 +49,7 @@ def check_ring_qk(rank, world_size):
sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length)
# check master and distributed attetion scores
sub_master_a = a[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
sub_master_a = a[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2)
# run master backward
@ -61,11 +57,11 @@ def check_ring_qk(rank, world_size):
a.mean().backward()
# run distributed backward
partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
torch.autograd.backward(sub_a, partial_master_a_grad)
# check master and distributed grads
partial_master_q_grad = q.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
partial_master_q_grad = q.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \
'attention score cannot match'
@ -79,14 +75,14 @@ def check_ring_av(rank, world_size):
sub_seq_length = seq_length // world_size
# create master tensors
a = torch.rand(batch_size*num_heads, seq_length, seq_length).cuda()
v = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda()
a = torch.rand(batch_size * num_heads, seq_length, seq_length).cuda()
v = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
# create distributed tensors
sub_a = a.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
sub_v = v.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous()
sub_a = a.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
sub_v = v.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
# set autograd attributes
a.requires_grad = True
@ -108,7 +104,7 @@ def check_ring_av(rank, world_size):
# print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}')
# check master and distributed output
sub_master_out = out[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
sub_master_out = out[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2)
# # run master backward
@ -116,23 +112,17 @@ def check_ring_av(rank, world_size):
out.mean().backward()
# # run distributed backward
partial_master_out_grad = out.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
partial_master_out_grad = out.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
torch.autograd.backward(sub_out, partial_master_out_grad)
# # check master and distributed grads
partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length]
partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \
'attention output cannot match'
def run_test(rank, world_size):
colossalai.launch(
rank=rank,
world_size=world_size,
config=CONFIG,
host='localhost',
port=29500
)
colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=29500)
# check_ring_qk(rank, world_size)
check_ring_av(rank, world_size)
@ -142,6 +132,7 @@ def run_test(rank, world_size):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_sequence():
world_size = 4
run_func = partial(run_test, world_size=world_size)

View File

@ -11,6 +11,7 @@ from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.testing import assert_equal_in_group
from colossalai.testing import rerun_on_exception
BATCH_SIZE = 4
DIM = 16
@ -62,6 +63,7 @@ def run_test(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_grad_handler():
world_size = 4
run_func = partial(run_test, world_size=world_size, port=free_port())

View File

@ -9,6 +9,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.testing import rerun_on_exception
BATCH_SIZE = 16
NUM_EXPERTS = 4
@ -86,6 +87,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
@pytest.mark.parametrize("hidden_size", [32, 144])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
@pytest.mark.parametrize("router", [Top1Router, Top2Router])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_moe_kernel(rs, hidden_size, data_type, router):
world_size = 4
run_func = partial(run_routing,

View File

@ -8,7 +8,7 @@ from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Experts
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.testing import assert_equal_in_group
from colossalai.testing import assert_equal_in_group, rerun_on_exception
D_MODEL = 4
D_FF = 8
@ -60,6 +60,7 @@ def run_test(rank, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_moe_initialization():
world_size = 4
run_func = partial(run_test, port=free_port())

View File

@ -15,6 +15,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import get_dist_logger
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_on_exception
BATCH_SIZE = 4
SEQ_LENGTH = 2
@ -92,6 +93,7 @@ def run_check(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_p2p():
world_size = 4
run_func = partial(run_check, world_size=world_size, port=free_port())

View File

@ -10,19 +10,14 @@ from colossalai.initialize import launch
from colossalai.logging import get_dist_logger
from functools import partial
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, 'resnet_config.py')
def run_partition(rank, world_size, port):
launch(config=CONFIG_PATH,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl'
)
launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
logger = get_dist_logger()
logger.info('finished initialization')
@ -37,6 +32,7 @@ def run_partition(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_partition():
world_size = 4
run_func = partial(run_partition, world_size=world_size, port=free_port())

View File

@ -14,6 +14,7 @@ from colossalai.core import global_context as gpc
from colossalai.engine.schedule import PipelineSchedule
from colossalai.initialize import launch
from colossalai.utils import free_port, get_dataloader, print_rank_0
from colossalai.testing import rerun_on_exception
from torchvision import transforms
from torchvision.datasets import CIFAR10
@ -67,6 +68,7 @@ def run_schedule(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_pipeline_schedule():
world_size = 4
run_func = partial(run_schedule, world_size=world_size, port=free_port())

View File

@ -9,7 +9,7 @@ from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer
from colossalai.utils import MultiTimer, free_port
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize
from colossalai.testing import parameterize, rerun_on_exception
BATCH_SIZE = 4
IMG_SIZE = 32
@ -51,6 +51,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_trainer_no_pipeline():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())

View File

@ -17,6 +17,7 @@ from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from colossalai.testing import rerun_on_exception
BATCH_SIZE = 4
IMG_SIZE = 32
@ -85,6 +86,7 @@ def run_trainer_with_pipeline(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_trainer_with_pipeline():
world_size = 4
run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port())

View File

@ -1,7 +1,7 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.zero.sharded_param import ShardedTensor
import colossalai
@ -47,6 +47,7 @@ def run_tensor_move(rank):
GLOBAL_MODEL_DATA_TRACER.close()
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_tensor_move():
mp.spawn(run_tensor_move, nprocs=1)

View File

@ -10,6 +10,7 @@ import torch.nn as nn
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import free_port, get_dataloader
from colossalai.testing import rerun_on_exception
from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10
@ -86,6 +87,7 @@ def run_no_pipeline(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_engine():
world_size = 4
func = partial(run_no_pipeline, world_size=world_size, port=free_port())

View File

@ -14,9 +14,9 @@ from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from colossalai.testing import parameterize
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
from functools import partial
from colossalai.testing import parameterize, rerun_on_exception
def checkpoint_wrapper(module, enable=True):
@ -102,6 +102,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_clip_grad():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())

View File

@ -14,6 +14,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
@ -57,6 +58,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_init_context(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -14,6 +14,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
@ -63,6 +64,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_shard_model_v2(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -10,6 +10,7 @@ from colossalai.utils import free_port
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.testing import rerun_on_exception
from tests.test_zero_data_parallel.common import CONFIG, allclose
@ -35,6 +36,7 @@ def _run_shard_tensor(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_shard_tensor(world_size):
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
@ -55,6 +57,7 @@ def _run_shard_param_v2(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_shard_param_v2(world_size):
run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -15,6 +15,7 @@ from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
@ -106,6 +107,7 @@ def _run_dist(rank, world_size, port):
# use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_sharded_optim_v2(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -14,6 +14,7 @@ from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy
from torchvision.models import resnet50
from colossalai.testing import rerun_on_exception
def run_dist(rank, world_size, port):
@ -71,6 +72,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_sharded_optim_with_sync_bn():
"""
This test is to make sure that buffers are synchronized between ranks

View File

@ -14,6 +14,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
@ -51,6 +52,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_state_dict(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -13,6 +13,7 @@ from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
@ -96,6 +97,7 @@ def run_dist(rank, world_size, port, parallel_config):
@pytest.mark.skip
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_mp_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size)
@ -103,6 +105,7 @@ def test_mp_engine(world_size):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size)