reorgnize colotensor directory (#1062)

* reorgnize colotensor directory

* polish code
pull/1066/head
Jiarui Fang 2022-06-03 18:04:22 +08:00 committed by GitHub
parent 3d10be33bd
commit a00644079e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 130 additions and 66 deletions

View File

@ -4,3 +4,7 @@ from .lr_scheduler import *
from .metric import * from .metric import *
from .model import * from .model import *
from .optimizer import * from .optimizer import *
from ._ops import *
from .modules import ColoLinear, ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module

View File

@ -10,6 +10,7 @@ def register_colo_module(module_type: type, colo_module: ColoModule):
global _COLOSSAL_MODULES global _COLOSSAL_MODULES
_COLOSSAL_MODULES[module_type] = colo_module _COLOSSAL_MODULES[module_type] = colo_module
def is_colo_module(module: torch.nn.Module): def is_colo_module(module: torch.nn.Module):
global _COLOSSAL_MODULES global _COLOSSAL_MODULES
for module_type in _COLOSSAL_MODULES.keys(): for module_type in _COLOSSAL_MODULES.keys():
@ -17,6 +18,7 @@ def is_colo_module(module: torch.nn.Module):
return True return True
return False return False
def get_colo_module(module: torch.nn.Module): def get_colo_module(module: torch.nn.Module):
global _COLOSSAL_MODULES global _COLOSSAL_MODULES
if is_colo_module(module): if is_colo_module(module):
@ -26,6 +28,7 @@ def get_colo_module(module: torch.nn.Module):
else: else:
return None return None
def check_colo_module(module: torch.nn.Module, recursive=True): def check_colo_module(module: torch.nn.Module, recursive=True):
if is_colo_module(module): if is_colo_module(module):
colo_module = get_colo_module(module) colo_module = get_colo_module(module)
@ -41,14 +44,16 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
compute_pattern = cur_compute_pattern compute_pattern = cur_compute_pattern
else: else:
if cur_compute_pattern != compute_pattern: if cur_compute_pattern != compute_pattern:
raise Exception(f'Invalid ColoParameter spec: Params in {module} have different compute_pattern.') raise Exception(
f'Invalid ColoParameter spec: Params in {module} have different compute_pattern.')
else: else:
continue continue
if compute_pattern is not None: if compute_pattern is not None:
colo_module.register(compute_pattern) colo_module.register(compute_pattern)
if not colo_module.has_compute_pattern(compute_pattern): if not colo_module.has_compute_pattern(compute_pattern):
raise Exception(f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.') raise Exception(
f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.')
match_specs = False match_specs = False
allowed_specs = colo_module.get_dist_specs(compute_pattern) allowed_specs = colo_module.get_dist_specs(compute_pattern)
@ -73,6 +78,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
for submodule in module.children(): for submodule in module.children():
check_colo_module(submodule, recursive=True) check_colo_module(submodule, recursive=True)
def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, recursive=True, mode='default'): def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, recursive=True, mode='default'):
compute_pattern = parallel_action.compute_pattern compute_pattern = parallel_action.compute_pattern
if is_colo_module(module): if is_colo_module(module):
@ -99,4 +105,3 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, r
if recursive == True: if recursive == True:
for submodule in module.children(): for submodule in module.children():
init_colo_module(submodule, parallel_action, recursive=True, mode=mode) init_colo_module(submodule, parallel_action, recursive=True, mode=mode)

View File

@ -4,6 +4,7 @@ from typing import List, Dict
class ColoModule(object): class ColoModule(object):
def __init__(self): def __init__(self):
self._shard_params: List[str] = [] self._shard_params: List[str] = []
# Example: # Example:
@ -21,8 +22,12 @@ class ColoModule(object):
def _register_shard_params(self, params: List[str]): def _register_shard_params(self, params: List[str]):
self._shard_params = params self._shard_params = params
def _register_allowed_patterns(self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], mode='default'): def _register_allowed_patterns(self,
assert list(dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.' compute_pattern: ComputePattern,
dist_specs: Dict[str, _DistSpec],
mode='default'):
assert list(
dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.'
if not compute_pattern in self._allowed_patterns: if not compute_pattern in self._allowed_patterns:
self._allowed_patterns[compute_pattern] = {} self._allowed_patterns[compute_pattern] = {}
self._allowed_patterns[compute_pattern][mode] = dist_specs self._allowed_patterns[compute_pattern][mode] = dist_specs

View File

@ -3,7 +3,9 @@ from colossalai.tensor import ComputePattern, distspec
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
class ColoEmbedding(ColoModule): class ColoEmbedding(ColoModule):
def __init__(self): def __init__(self):
super(ColoEmbedding, self).__init__() super(ColoEmbedding, self).__init__()
self._register_shard_params(['weight']) self._register_shard_params(['weight'])
@ -19,7 +21,9 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), 'weight':
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
}, },
mode='row', mode='row',
) )
@ -28,7 +32,9 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), 'weight':
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1],
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
}, },
mode='col', mode='col',
) )

View File

@ -3,7 +3,9 @@ from colossalai.tensor import ComputePattern, distspec
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
class ColoLinear(ColoModule): class ColoLinear(ColoModule):
def __init__(self): def __init__(self):
super(ColoLinear, self).__init__() super(ColoLinear, self).__init__()
self._register_shard_params(['weight', 'bias']) self._register_shard_params(['weight', 'bias'])
@ -19,8 +21,11 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), 'weight':
'bias': None distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1],
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
'bias':
None
}, },
mode='row', mode='row',
) )
@ -29,8 +34,12 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), 'weight':
'bias': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]) distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
'bias':
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
[gpc.get_world_size(ParallelMode.PARALLEL_1D)])
}, },
mode='col', mode='col',
) )

View File

@ -7,7 +7,9 @@ from .lamb import Lamb
from .lars import Lars from .lars import Lars
from .cpu_adam import CPUAdam from .cpu_adam import CPUAdam
from .hybrid_adam import HybridAdam from .hybrid_adam import HybridAdam
from .colo_optimizer import ColoOptimizer
__all__ = [ __all__ = [
'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam', 'CPU_ADAM_CNT' 'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam',
'CPU_ADAM_CNT', 'ColoOptimizer'
] ]

View File

@ -1,21 +1,14 @@
from .spec import ComputePattern, ParallelAction, TensorSpec from .spec import ComputePattern, ParallelAction, TensorSpec
from .op_wrapper import (
colo_op_impl,)
from .colo_tensor import ColoTensor from .colo_tensor import ColoTensor
from .colo_parameter import ColoParameter from .colo_parameter import ColoParameter
from .utils import convert_parameter, named_params_with_colotensor from .utils import convert_parameter, named_params_with_colotensor
from ._ops import *
from .optim.colo_optimizer import ColoOptimizer
from . import distspec from . import distspec
from .dist_spec_mgr import DistSpecManager from .dist_spec_mgr import DistSpecManager
from .param_op_hook import ParamOpHook, use_param_op_hooks from .param_op_hook import ParamOpHook, use_param_op_hooks
from .chunk import ChunkManager, TensorState from .chunk import ChunkManager, TensorState
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
from .modules import ColoLinear, ColoEmbedding
__all__ = [ __all__ = [
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction', 'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ParallelAction', 'named_params_with_colotensor',
'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager', 'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'use_param_op_hooks', 'ChunkManager', 'TensorState'
'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', 'ColoLinear',
'ColoEmbedding', 'ParamOpHook', 'use_param_op_hooks', 'ChunkManager', 'TensorState'
] ]

View File

@ -1,9 +1,9 @@
from .colo_tensor import ColoTensor from colossalai.tensor.colo_tensor import ColoTensor
from .const import TensorType from colossalai.tensor.const import TensorType
import torch import torch
from colossalai.tensor import TensorSpec, distspec from colossalai.tensor import TensorSpec, distspec
from copy import copy from copy import copy
from .param_op_hook import _ParamOpHookWrapper, PreFwdPostBwd, PostFwdPreBwd from colossalai.tensor.param_op_hook import _ParamOpHookWrapper, PreFwdPostBwd, PostFwdPreBwd
from typing import Optional from typing import Optional

View File

@ -1,11 +1,29 @@
from colossalai.tensor.distspec import _DistSpec from colossalai.tensor.distspec import _DistSpec
from colossalai.nn.layer.utils import divide # from colossalai.nn.layer.utils import divide
from numpy import prod from numpy import prod
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
import torch.distributed as dist import torch.distributed as dist
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
# colossalai.tensor shall not import any submodule from colossal.nn
def divide(numerator, denominator):
"""Only allow exact division.
Args:
numerator (int): Numerator of the division.
denominator (int): Denominator of the division.
Returns:
int: the result of exact division.
"""
assert denominator != 0, 'denominator can not be zero'
assert numerator % denominator == 0, \
'{} is not divisible by {}'.format(numerator, denominator)
return numerator // denominator
class TransformDistSpec(torch.autograd.Function): class TransformDistSpec(torch.autograd.Function):
@staticmethod @staticmethod

View File

@ -1,10 +1,8 @@
import torch import torch
from colossalai.tensor.colo_tensor import ColoTensor
from typing import Iterator, Tuple, Union from typing import Iterator, Tuple, Union
import torch.nn as nn import torch.nn as nn
from colossalai.tensor import ColoTensor from colossalai.tensor.colo_tensor import ColoTensor
# The function is credited to PyTorch Team # The function is credited to PyTorch Team

View File

@ -1,11 +1,12 @@
from .utils import InsertPostInitMethodToModuleSubClasses from .utils import InsertPostInitMethodToModuleSubClasses
import torch import torch
from colossalai.tensor import ColoTensor, ColoParameter, register_colo_module, init_colo_module, \ from colossalai.tensor import ColoTensor, ColoParameter
from colossalai.nn import register_colo_module, init_colo_module, \
ColoLinear, ColoEmbedding ColoLinear, ColoEmbedding
import types
from torch import nn from torch import nn
from typing import Iterator, Tuple, Union, Optional from typing import Iterator, Tuple, Union
# find named_params includes replica # find named_params includes replica
@ -24,6 +25,7 @@ def _named_params_with_replica(
name = mod_prefix + ('.' if mod_prefix else '') + name name = mod_prefix + ('.' if mod_prefix else '') + name
yield name, val yield name, val
def ColoModulize(module): def ColoModulize(module):
""" """
Replacing the parameters() and named_parameters() with our customized ones Replacing the parameters() and named_parameters() with our customized ones

View File

@ -1,9 +1,12 @@
from colossalai.utils import free_port, ColoInitContext, get_current_device from colossalai.utils import free_port, ColoInitContext, get_current_device
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, init_colo_module from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from functools import partial from functools import partial
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.nn import init_colo_module
from colossalai.nn.parallel import ColoDDP from colossalai.nn.parallel import ColoDDP
import colossalai import colossalai
@ -11,7 +14,9 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import pytest import pytest
class Net(torch.nn.Module): class Net(torch.nn.Module):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.embed = torch.nn.Embedding(20, 4) self.embed = torch.nn.Embedding(20, 4)
@ -27,6 +32,7 @@ class Net(torch.nn.Module):
x = self.proj(x) x = self.proj(x)
return x return x
def run_hybrid_device(use_ddp): def run_hybrid_device(use_ddp):
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = Net() model = Net()
@ -36,7 +42,6 @@ def run_hybrid_device(use_ddp):
model = ColoDDP(model) model = ColoDDP(model)
real_model = model.module real_model = model.module
print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}') print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}')
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}') #print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
parallel_action = ParallelAction(ComputePattern.TP1D) parallel_action = ParallelAction(ComputePattern.TP1D)
@ -54,6 +59,7 @@ def run_hybrid_device(use_ddp):
out = model(data) out = model(data)
out.sum().backward() out.sum().backward()
def run_dist(rank, world_size, port, use_ddp): def run_dist(rank, world_size, port, use_ddp):
if use_ddp and world_size == 1: if use_ddp and world_size == 1:
return return
@ -62,6 +68,7 @@ def run_dist(rank, world_size, port, use_ddp):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_hybrid_device(use_ddp) run_hybrid_device(use_ddp)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('use_ddp', [False, True]) @pytest.mark.parametrize('use_ddp', [False, True])
@ -71,5 +78,6 @@ def _test_hybrid_device(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
_test_hybrid_device(1, False) _test_hybrid_device(1, False)

View File

@ -10,9 +10,10 @@ from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils import ColoInitContext from colossalai.utils import ColoInitContext
from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec, ComputePattern, \ from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec, ComputePattern, \
ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager ParallelAction, ColoTensor, DistSpecManager
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColoOptimizer
from functools import partial from functools import partial
from _utils import set_seed from _utils import set_seed

View File

@ -1,24 +1,28 @@
from copy import copy from copy import copy
from colossalai.utils.cuda import get_current_device import pytest
from colossalai.utils.model.colo_init_context import ColoInitContext
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import ColoTensor, distspec
from functools import partial from functools import partial
import colossalai
import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn.functional as F
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.nn import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed
import colossalai
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import distspec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, register_colo_module, init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
def run_model_with_spec(mode, model_name): def run_model_with_spec(mode, model_name):
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -103,6 +107,7 @@ def run_model_with_spec(mode, model_name):
if i > 3: if i > 3:
break break
def run_linear_with_spec(mode): def run_linear_with_spec(mode):
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = torch.nn.Linear(4, 8) model = torch.nn.Linear(4, 8)
@ -122,6 +127,7 @@ def run_linear_with_spec(mode):
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad) assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad)
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad) assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad)
def run_check_shared_param(): def run_check_shared_param():
from transformers import BertForMaskedLM, BertConfig from transformers import BertForMaskedLM, BertConfig
hidden_dim = 8 hidden_dim = 8
@ -157,12 +163,14 @@ def run_check_shared_param():
except Exception as e: except Exception as e:
assert 'incorrectly sharded' in str(e) assert 'incorrectly sharded' in str(e)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_linear_with_spec('col') run_linear_with_spec('col')
run_linear_with_spec('row') run_linear_with_spec('row')
def run_dist_model(rank, world_size, port): def run_dist_model(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@ -170,11 +178,13 @@ def run_dist_model(rank, world_size, port):
run_model_with_spec('col', model_name) run_model_with_spec('col', model_name)
run_model_with_spec('row', model_name) run_model_with_spec('row', model_name)
def run_dist_check(rank, world_size, port): def run_dist_check(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_check_shared_param() run_check_shared_param()
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@ -182,6 +192,7 @@ def test_module_linear_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@ -189,6 +200,7 @@ def test_module_model(world_size):
run_func = partial(run_dist_model, world_size=world_size, port=free_port()) run_func = partial(run_dist_model, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@ -196,5 +208,6 @@ def test_module_check(world_size):
run_func = partial(run_dist_check, world_size=world_size, port=free_port()) run_func = partial(run_dist_check, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_module_check(2) test_module_check(2)