diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index f595e6773..e6febeeb4 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -8,10 +8,10 @@ jobs: detect: name: Detect file change if: | - github.event.pull_request.draft == false && - github.base_ref == 'main' && - github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && - contains( github.event.pull_request.labels.*.name, 'Run Build and Test') + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && + contains( github.event.pull_request.labels.*.name, 'Run Build and Test') outputs: changedExtenisonFiles: ${{ steps.find-extension-change.outputs.all_changed_files }} anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }} @@ -27,10 +27,10 @@ jobs: - name: Locate base commit id: locate-base-sha run: | - curBranch=$(git rev-parse --abbrev-ref HEAD) - commonCommit=$(git merge-base origin/main $curBranch) - echo $commonCommit - echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT + curBranch=$(git rev-parse --abbrev-ref HEAD) + commonCommit=$(git merge-base origin/main $curBranch) + echo $commonCommit + echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT - name: Find the changed extension-related files id: find-extension-change @@ -63,7 +63,6 @@ jobs: echo "$file was changed" done - build: name: Build and Test Colossal-AI needs: detect @@ -124,7 +123,7 @@ jobs: - name: Execute Unit Testing if: needs.detect.outputs.anyLibraryFileChanged == 'true' run: | - PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/ + CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/ env: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py index 8c7848525..4c05a3431 100644 --- a/applications/Chat/tests/test_checkpoint.py +++ b/applications/Chat/tests/test_checkpoint.py @@ -1,19 +1,16 @@ import os import tempfile from contextlib import nullcontext -from functools import partial import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from coati.models.gpt import GPTActor from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy from transformers.models.gpt2.configuration_gpt2 import GPT2Config from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) @@ -90,8 +87,7 @@ def run_dist(rank, world_size, port, strategy): @pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini']) @rerun_if_address_is_in_use() def test_checkpoint(world_size, strategy): - run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size, strategy=strategy) if __name__ == '__main__': diff --git a/applications/Chat/tests/test_data.py b/applications/Chat/tests/test_data.py index 577309a0f..2e4d4ceac 100644 --- a/applications/Chat/tests/test_data.py +++ b/applications/Chat/tests/test_data.py @@ -1,11 +1,9 @@ import os from copy import deepcopy -from functools import partial import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from coati.experience_maker import NaiveExperienceMaker from coati.models.base import RewardModel from coati.models.gpt import GPTActor, GPTCritic @@ -13,8 +11,7 @@ from coati.replay_buffer import NaiveReplayBuffer from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy from transformers.models.gpt2.configuration_gpt2 import GPT2Config -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) @@ -114,8 +111,7 @@ def run_dist(rank, world_size, port, strategy): @pytest.mark.parametrize('strategy', ['ddp', 'colossalai']) @rerun_if_address_is_in_use() def test_data(world_size, strategy): - run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size, strategy=strategy) if __name__ == '__main__': diff --git a/colossalai/cli/benchmark/benchmark.py b/colossalai/cli/benchmark/benchmark.py index f40f8f2f9..97a9f4572 100644 --- a/colossalai/cli/benchmark/benchmark.py +++ b/colossalai/cli/benchmark/benchmark.py @@ -10,7 +10,8 @@ from colossalai.context import Config from colossalai.context.random import reset_seeds from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.utils import MultiTimer, free_port +from colossalai.testing import free_port +from colossalai.utils import MultiTimer from .models import MLP diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py index e3dd500de..c53e0f44c 100644 --- a/colossalai/testing/__init__.py +++ b/colossalai/testing/__init__.py @@ -1,7 +1,17 @@ -from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group -from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use, skip_if_not_enough_gpus +from .comparison import assert_close, assert_close_loose, assert_equal, assert_equal_in_group, assert_not_equal +from .pytest_wrapper import run_on_environment_flag +from .utils import ( + clear_cache_before_run, + free_port, + parameterize, + rerun_if_address_is_in_use, + rerun_on_exception, + skip_if_not_enough_gpus, + spawn, +) __all__ = [ 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', - 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus' + 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn', + 'clear_cache_before_run', 'run_on_environment_flag' ] diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index 64c1d6e7b..eac83e6d7 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -1,8 +1,13 @@ +import gc +import random import re -import torch -from typing import Callable, List, Any +import socket from functools import partial from inspect import signature +from typing import Any, Callable, List + +import torch +import torch.multiprocessing as mp from packaging import version @@ -43,7 +48,7 @@ def parameterize(argument: str, values: List[Any]) -> Callable: # > davis: hello # > davis: bye # > davis: stop - + Args: argument (str): the name of the argument to parameterize values (List[Any]): a list of values to iterate for this argument @@ -85,13 +90,13 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non def test_method(): print('hey') raise RuntimeError('Address already in use') - + # rerun for infinite times if Runtime error occurs @rerun_on_exception(exception_type=RuntimeError, max_try=None) def test_method(): print('hey') raise RuntimeError('Address already in use') - + # rerun only the exception message is matched with pattern # for infinite times if Runtime error occurs @rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$") @@ -101,10 +106,10 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non Args: exception_type (Exception, Optional): The type of exception to detect for rerun - pattern (str, Optional): The pattern to match the exception message. + pattern (str, Optional): The pattern to match the exception message. If the pattern is not None and matches the exception message, the exception will be detected for rerun - max_try (int, Optional): Maximum reruns for this function. The default value is 5. + max_try (int, Optional): Maximum reruns for this function. The default value is 5. If max_try is None, it will rerun foreven if exception keeps occurings """ @@ -202,3 +207,72 @@ def skip_if_not_enough_gpus(min_gpus: int): return _execute_by_gpu_num return _wrap_func + + +def free_port() -> int: + """Get a free port on localhost. + + Returns: + int: A free port on localhost. + """ + while True: + port = random.randint(20000, 65000) + try: + with socket.socket() as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("localhost", port)) + return port + except OSError: + continue + + +def spawn(func, nprocs=1, **kwargs): + """ + This function is used to spawn processes for testing. + + Usage: + # must contians arguments rank, world_size, port + def do_something(rank, world_size, port): + ... + + spawn(do_something, nprocs=8) + + # can also pass other arguments + def do_something(rank, world_size, port, arg1, arg2): + ... + + spawn(do_something, nprocs=8, arg1=1, arg2=2) + + Args: + func (Callable): The function to be spawned. + nprocs (int, optional): The number of processes to spawn. Defaults to 1. + """ + port = free_port() + wrapped_func = partial(func, world_size=nprocs, port=port, **kwargs) + mp.spawn(wrapped_func, nprocs=nprocs) + + +def clear_cache_before_run(): + """ + This function is a wrapper to clear CUDA and python cache before executing the function. + + Usage: + @clear_cache_before_run() + def test_something(): + ... + """ + + def _wrap_func(f): + + def _clear_cache(*args, **kwargs): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_max_memory_cached() + torch.cuda.synchronize() + gc.collect() + f(*args, **kwargs) + + return _clear_cache + + return _wrap_func diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 3f16bd91e..7b2e8480c 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -7,7 +7,6 @@ from .common import ( count_zeros_fp32, disposable, ensure_path_exists, - free_port, is_ddp_ignored, is_dp_rank_0, is_model_parallel_parameter, @@ -37,7 +36,6 @@ from .timer import MultiTimer, Timer __all__ = [ 'checkpoint', - 'free_port', 'print_rank_0', 'sync_model_param', 'is_ddp_ignored', diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index e15981140..95b3b8014 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -50,23 +50,6 @@ def ensure_path_exists(filename: str): Path(dirpath).mkdir(parents=True, exist_ok=True) -def free_port() -> int: - """Get a free port on localhost. - - Returns: - int: A free port on localhost. - """ - while True: - port = random.randint(20000, 65000) - try: - with socket.socket() as sock: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("localhost", port)) - return port - except OSError: - continue - - def sync_model_param(model, parallel_mode): r"""Make sure data parameters are consistent during Data Parallel Mode. diff --git a/docs/requirements-doc-test.txt b/docs/requirements-doc-test.txt index 6a6bb3bee..79e04bd56 100644 --- a/docs/requirements-doc-test.txt +++ b/docs/requirements-doc-test.txt @@ -4,3 +4,4 @@ packaging tensornvme psutil transformers +pytest diff --git a/docs/source/en/basics/colotensor_concept.md b/docs/source/en/basics/colotensor_concept.md index 2d8acd88d..1b855c03b 100644 --- a/docs/source/en/basics/colotensor_concept.md +++ b/docs/source/en/basics/colotensor_concept.md @@ -56,12 +56,12 @@ Let's see an example. A ColoTensor is initialized and sharded on 8 GPUs using tp ```python import torch import torch.multiprocessing as mp -from colossalai.utils import free_port, print_rank_0 +from colossalai.utils import print_rank_0 from functools import partial import colossalai from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern -from colossalai.utils import free_port +from colossalai.testing import spawn import torch @@ -83,8 +83,7 @@ def run_dist_tests(rank, world_size, port): print_rank_0(f"shape {t1.shape}, {t1.data}") def test_dist_cases(world_size): - run_func = partial(run_dist_tests, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist_tests, world_size) if __name__ == '__main__': test_dist_cases(4) diff --git a/docs/source/zh-Hans/basics/colotensor_concept.md b/docs/source/zh-Hans/basics/colotensor_concept.md index cac5b9a4b..d6a332df2 100644 --- a/docs/source/zh-Hans/basics/colotensor_concept.md +++ b/docs/source/zh-Hans/basics/colotensor_concept.md @@ -57,12 +57,12 @@ ColoTensor 包含额外的属性[ColoTensorSpec](https://colossalai.readthedocs. ```python import torch import torch.multiprocessing as mp -from colossalai.utils import free_port, print_rank_0 +from colossalai.utils import print_rank_0 from functools import partial import colossalai from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern -from colossalai.utils import free_port +from colossalai.testing import spawn import torch @@ -84,8 +84,7 @@ def run_dist_tests(rank, world_size, port): print_rank_0(f"shape {t1.shape}, {t1.data}") def test_dist_cases(world_size): - run_func = partial(run_dist_tests, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist_tests, world_size) if __name__ == '__main__': test_dist_cases(4) diff --git a/examples/images/vit/test_vit.py b/examples/images/vit/test_vit.py index 6a587e1df..c0ae35bca 100644 --- a/examples/images/vit/test_vit.py +++ b/examples/images/vit/test_vit.py @@ -1,11 +1,9 @@ import os import random -from functools import partial import numpy as np import pytest import torch -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from vit import get_training_components @@ -15,8 +13,7 @@ from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext @@ -156,8 +153,7 @@ def run_dist(rank, world_size, port, use_ddp): @pytest.mark.parametrize('use_ddp', [False, True]) @rerun_if_address_is_in_use() def test_vit(world_size, use_ddp): - run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size, use_ddp=use_ddp) if __name__ == '__main__': diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py index 729d1ce44..89415c23f 100644 --- a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py +++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py @@ -1,20 +1,20 @@ -import time -import pytest import argparse -from functools import partial +import time +import pytest import torch +from model_zoo import GPTLMLoss, get_gpt2_components from torch.utils._pytree import tree_map -import torch.multiprocessing as mp import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.fx.profiler import parameter_size -from colossalai.utils import free_port, get_current_device from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML -from model_zoo import get_gpt2_components, GPTLMLoss +from colossalai.fx.profiler import parameter_size +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import spawn +from colossalai.utils import get_current_device + def parse_args(): parser = argparse.ArgumentParser() @@ -24,6 +24,7 @@ def parse_args(): parser.add_argument('--memory_budget', type=float, default=16) return parser.parse_args() + @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') def train_gpt(args): memory_budget = args.memory_budget * 1024 * 1024 * 1024 @@ -33,13 +34,16 @@ def train_gpt(args): # build model model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size) - label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device()) + label = torch.randint(low=0, high=128, size=( + 64, + 8, + ), device=get_current_device()) criterion = GPTLMLoss() start_time = time.time() model = model_builder() model.train() - param_size = parameter_size(model) / 1024 ** 2 / 2 + param_size = parameter_size(model) / 1024**2 / 2 init_time = time.time() - start_time print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") @@ -74,21 +78,20 @@ def train_gpt(args): torch.cuda.synchronize() exec_time = sum(sorted(time_list)[:5]) / 5 - runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 - runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 + runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 + runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 print(f'solver_type: {solver_type} | model_type: {model_type}') - print( - f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|' - ) + print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' + f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') print(time_list) + def run(rank, world_size, port, args): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') train_gpt(args) + if __name__ == '__main__': args = parse_args() - run_func = partial(run, world_size=1, port=free_port(), args=args) - mp.spawn(run_func, nprocs=1) + spawn(run, 1, args=args) diff --git a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py index 6ceb7fd87..e331fc8fc 100644 --- a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py +++ b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py @@ -1,18 +1,13 @@ from functools import partial from time import time -from typing import Dict, Optional, Tuple, Union import psutil import torch -import torch.multiprocessing as mp -import torch.nn as nn import transformers from gpt_modules import GPT2LMHeadModel, GPTLMLoss -from torch.fx import GraphModule -from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model +from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize from colossalai.core import global_context as gpc -from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch_from_torch from colossalai.logging import disable_existing_loggers, get_dist_logger diff --git a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py index 5decfc695..5a68aae18 100644 --- a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py +++ b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py @@ -1,19 +1,14 @@ -import time -from argparse import ArgumentParser from copy import deepcopy from functools import partial -import matplotlib.pyplot as plt -import numpy as np import torch -import torch.multiprocessing as mp import torchvision.models as tm from bench_utils import bench, data_gen_resnet import colossalai from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.fx import metainfo_trace, symbolic_trace -from colossalai.utils import free_port +from colossalai.testing import spawn def _benchmark(rank, world_size, port): @@ -50,9 +45,7 @@ def _benchmark(rank, world_size, port): def auto_activation_checkpoint_batchsize_benchmark(): - world_size = 1 - run_func_module = partial(_benchmark, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_benchmark, 1) if __name__ == "__main__": diff --git a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py index ab0f2ef66..aa5c47294 100644 --- a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py +++ b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py @@ -4,14 +4,13 @@ from functools import partial import matplotlib.pyplot as plt import torch -import torch.multiprocessing as mp import torchvision.models as tm from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium import colossalai from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.fx import metainfo_trace, symbolic_trace -from colossalai.utils import free_port +from colossalai.testing import spawn def _benchmark(rank, world_size, port, args): @@ -77,8 +76,7 @@ def _benchmark(rank, world_size, port, args): def auto_activation_checkpoint_benchmark(args): world_size = 1 - run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_benchmark, world_size, args=args) if __name__ == "__main__": diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 05c0e6ac5..82b6173b3 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -12,3 +12,4 @@ contexttimer einops triton==2.0.0.dev20221202 git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn +requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_amp/test_naive_fp16.py index c01de469b..6ce4c7f49 100644 --- a/tests/test_amp/test_naive_fp16.py +++ b/tests/test_amp/test_naive_fp16.py @@ -1,14 +1,11 @@ import copy -from functools import partial import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp -from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn from tests.components_to_test.registry import non_distributed_component_funcs @@ -87,10 +84,9 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_naive_amp(): - world_size = 1 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 1) if __name__ == '__main__': diff --git a/tests/test_amp/test_torch_fp16.py b/tests/test_amp/test_torch_fp16.py index e65dd8cde..6451aa626 100644 --- a/tests/test_amp/test_torch_fp16.py +++ b/tests/test_amp/test_torch_fp16.py @@ -1,14 +1,11 @@ import copy -from functools import partial import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp -from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn from tests.components_to_test.registry import non_distributed_component_funcs @@ -87,10 +84,9 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_torch_amp(): - world_size = 1 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 1) if __name__ == '__main__': diff --git a/tests/test_analyzer/test_fx/test_bias_addition.py b/tests/test_analyzer/test_fx/test_bias_addition.py index 61951e9a5..f7b5eb140 100644 --- a/tests/test_analyzer/test_fx/test_bias_addition.py +++ b/tests/test_analyzer/test_fx/test_bias_addition.py @@ -3,7 +3,7 @@ import torch from packaging import version from torch.utils.checkpoint import checkpoint -from colossalai.testing.utils import parameterize +from colossalai.testing.utils import clear_cache_before_run, parameterize try: from colossalai._analyzer.fx import symbolic_trace @@ -81,6 +81,7 @@ class AddmmModel(torch.nn.Module): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() @parameterize("bias", [True, False]) @parameterize("bias_addition_split", [True, False]) @parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)]) diff --git a/tests/test_analyzer/test_fx/test_mod_dir.py b/tests/test_analyzer/test_fx/test_mod_dir.py index 15e0c2ec2..f62147b29 100644 --- a/tests/test_analyzer/test_fx/test_mod_dir.py +++ b/tests/test_analyzer/test_fx/test_mod_dir.py @@ -1,6 +1,8 @@ import pytest import torch +from colossalai.testing import clear_cache_before_run, parameterize + try: from colossalai._analyzer.fx import symbolic_trace except: @@ -62,9 +64,10 @@ class AModel(torch.nn.Module): @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("bias_addition_split", [True, False]) -@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)]) +@clear_cache_before_run() +@parameterize("bias", [True, False]) +@parameterize("bias_addition_split", [True, False]) +@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)]) def test_mod_dir(bias, bias_addition_split, shape): model = AModel(bias=bias) x = torch.rand(shape) @@ -75,4 +78,4 @@ def test_mod_dir(bias, bias_addition_split, shape): if __name__ == '__main__': - test_mod_dir(True, True, (3, 3, 3)) + test_mod_dir(bias=True, bias_addition_split=True, shape=(3, 3, 3)) diff --git a/tests/test_analyzer/test_fx/test_nested_ckpt.py b/tests/test_analyzer/test_fx/test_nested_ckpt.py index c31aab675..bd16f5a4f 100644 --- a/tests/test_analyzer/test_fx/test_nested_ckpt.py +++ b/tests/test_analyzer/test_fx/test_nested_ckpt.py @@ -1,7 +1,9 @@ +import pytest import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint -import pytest + +from colossalai.testing import clear_cache_before_run try: from colossalai._analyzer.fx import symbolic_trace @@ -42,6 +44,7 @@ class MyModule(nn.Module): @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@clear_cache_before_run() def test_nested_ckpt(): model = MyModule() x = torch.rand(10, 10) diff --git a/tests/test_analyzer/test_fx/test_shape_prop.py b/tests/test_analyzer/test_fx/test_shape_prop.py index 08f4ff2cb..a849feb79 100644 --- a/tests/test_analyzer/test_fx/test_shape_prop.py +++ b/tests/test_analyzer/test_fx/test_shape_prop.py @@ -3,7 +3,7 @@ import torch import torchvision.models as tm from packaging import version -from colossalai.testing.utils import parameterize +from colossalai.testing.utils import clear_cache_before_run, parameterize from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models try: @@ -32,6 +32,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() @parameterize('m', tm_models) def test_torchvision_shape_prop(m): with MetaTensorMode(): @@ -46,6 +47,7 @@ def test_torchvision_shape_prop(m): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() @parameterize('m', tmm_models) def test_timm_shape_prop(m): with MetaTensorMode(): diff --git a/tests/test_analyzer/test_fx/test_symbolic_profile.py b/tests/test_analyzer/test_fx/test_symbolic_profile.py index be781599f..17deee7a7 100644 --- a/tests/test_analyzer/test_fx/test_symbolic_profile.py +++ b/tests/test_analyzer/test_fx/test_symbolic_profile.py @@ -3,7 +3,7 @@ import torch import torchvision.models as tm from packaging import version -from colossalai.testing.utils import parameterize +from colossalai.testing.utils import clear_cache_before_run, parameterize from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models try: @@ -19,6 +19,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() @parameterize('m', tm_models) def test_torchvision_profile(m, verbose=False, bias_addition_split=False): with MetaTensorMode(): @@ -33,6 +34,7 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() @parameterize('m', tmm_models) def test_timm_profile(m, verbose=False, bias_addition_split=False): with MetaTensorMode(): diff --git a/tests/test_analyzer/test_subclasses/test_aten.py b/tests/test_analyzer/test_subclasses/test_aten.py index 591a8d617..b7858110a 100644 --- a/tests/test_analyzer/test_subclasses/test_aten.py +++ b/tests/test_analyzer/test_subclasses/test_aten.py @@ -1,9 +1,11 @@ from typing import Any, Callable, Union -import pytest +import pytest import torch import torch.nn as nn +from colossalai.testing import clear_cache_before_run + try: from colossalai._analyzer._subclasses import MetaTensor except: @@ -72,6 +74,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@clear_cache_before_run() def test_meta_aten(): for (aten_op, requires_backward), v in registered_meta.items(): for f, x in v: diff --git a/tests/test_analyzer/test_subclasses/test_flop_tensor.py b/tests/test_analyzer/test_subclasses/test_flop_tensor.py index 752836141..da3829e40 100644 --- a/tests/test_analyzer/test_subclasses/test_flop_tensor.py +++ b/tests/test_analyzer/test_subclasses/test_flop_tensor.py @@ -4,6 +4,7 @@ import torch.nn.functional as F import torchvision.models as tm from packaging import version +from colossalai.testing import clear_cache_before_run, parameterize from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models try: @@ -39,7 +40,8 @@ odd_cases = [ @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') -@pytest.mark.parametrize('func, args, kwargs', odd_cases) +@clear_cache_before_run() +@parameterize('func, args, kwargs', odd_cases) def test_flop_count_function(func, args, kwargs): rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True) assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}' diff --git a/tests/test_analyzer/test_subclasses/test_meta_mode.py b/tests/test_analyzer/test_subclasses/test_meta_mode.py index 160d411f6..d2a0a1b9c 100644 --- a/tests/test_analyzer/test_subclasses/test_meta_mode.py +++ b/tests/test_analyzer/test_subclasses/test_meta_mode.py @@ -3,6 +3,8 @@ import torch import torchvision.models as tm from packaging import version +from colossalai.testing import clear_cache_before_run, parameterize + try: from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode except: @@ -30,7 +32,8 @@ def run_and_compare(model): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') -@pytest.mark.parametrize('m', tm_models + tmm_models) +@clear_cache_before_run() +@parameterize('m', tm_models + tmm_models) def test_meta_mode_shape(m): run_and_compare(m()) diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py index f8dd0b16b..f184f64b3 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py @@ -3,7 +3,6 @@ import copy import pytest import torch import torch.fx -import torch.multiprocessing as mp import torchvision.models as tm import colossalai @@ -13,7 +12,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta # from colossalai.fx.passes.algorithms import solver_rotor # from colossalai.fx.passes.algorithms.operation import Sequence from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn if is_compatible_with_meta(): from colossalai.fx.profiler.tensor import MetaTensor @@ -26,8 +25,8 @@ except: withcodegen = False -def _run_C_solver_consistency_test(rank=0): - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') +def _run_C_solver_consistency_test(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]: model = M() @@ -70,8 +69,9 @@ def _run_C_solver_consistency_test(rank=0): @pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0") +@rerun_if_address_is_in_use() def test_C_solver_consistency(): - mp.spawn(_run_C_solver_consistency_test, nprocs=1) + spawn(_run_C_solver_consistency_test, 1) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py index 89600ea09..db268b91d 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py @@ -4,7 +4,6 @@ from typing import Callable import pytest import torch -import torch.multiprocessing as mp import torchvision.models as tm from torch.fx import GraphModule @@ -15,7 +14,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.graph_module import ColoGraphModule # from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn if is_compatible_with_meta(): from colossalai.fx.profiler.tensor import MetaTensor @@ -68,8 +67,8 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' -def _run_ckpt_solver(rank): - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') +def _run_ckpt_solver(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') MODEL_LIST = [tm.densenet121] torch.backends.cudnn.deterministic = True @@ -98,12 +97,13 @@ def _run_ckpt_solver(rank): @pytest.mark.skip("TODO(super-dainiu): refactor all tests.") @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@rerun_if_address_is_in_use() def test_ckpt_solver(): - mp.spawn(_run_ckpt_solver, nprocs=1) + spawn(_run_ckpt_solver, 1) -def _run_ckpt_solver_torch11(rank): - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') +def _run_ckpt_solver_torch11(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') MODEL_LIST = [tm.densenet121] torch.backends.cudnn.deterministic = True @@ -131,8 +131,9 @@ def _run_ckpt_solver_torch11(rank): @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") +@rerun_if_address_is_in_use() def test_ckpt_solver_torch11(): - mp.spawn(_run_ckpt_solver_torch11, nprocs=1) + spawn(_run_ckpt_solver_torch11, 1) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py b/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py index 0f90ba0b0..59880815d 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py @@ -8,6 +8,7 @@ from colossalai.fx.graph_module import ColoGraphModule # from colossalai.fx.passes.algorithms import linearize, solver_rotor # from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss) from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.testing import clear_cache_before_run if is_compatible_with_meta(): from colossalai.fx.profiler.tensor import MetaTensor @@ -24,6 +25,7 @@ except: @pytest.mark.skip(reason='TODO: modify the logger') @pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") +@clear_cache_before_run() def test_linearize(): MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} tracer = ColoTracer() @@ -84,6 +86,7 @@ def test_linearize(): @pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skip(reason="torch11 meta tensor not implemented") @pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") +@clear_cache_before_run() def test_linearize_torch11(): MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py index c925843fb..80f134fd8 100644 --- a/tests/test_auto_parallel/test_offload/test_perf.py +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -1,9 +1,7 @@ import time -from functools import partial import pytest import torch -import torch.multiprocessing as mp from torch.utils._pytree import tree_map import colossalai @@ -12,8 +10,8 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.fx.profiler import parameter_size from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize -from colossalai.utils import free_port, get_current_device +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_tensor.common_utils import set_seed @@ -140,9 +138,9 @@ def run_dist(rank, world_size, port): @pytest.mark.skip("this test failed") @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@rerun_if_address_is_in_use() def test_perf(): - run_func = partial(run_dist, world_size=1, port=free_port()) - mp.spawn(run_func, nprocs=1) + spawn(run_dist, 1) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_offload/test_solver.py b/tests/test_auto_parallel/test_offload/test_solver.py index 2efbb750f..aa2c9a368 100644 --- a/tests/test_auto_parallel/test_offload/test_solver.py +++ b/tests/test_auto_parallel/test_offload/test_solver.py @@ -3,20 +3,20 @@ import torch.fx from torch.fx import GraphModule from torch.utils._pytree import tree_map +from colossalai.auto_parallel.offload.region_manager import RegionManager +from colossalai.auto_parallel.offload.solver import NOT_NVML, SolverFactory from colossalai.fx import ColoTracer, is_compatible_with_meta from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.auto_parallel.offload.region_manager import RegionManager -from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML -from colossalai.testing import parameterize +from colossalai.testing import clear_cache_before_run, parameterize from tests.test_auto_parallel.test_offload.model_utils import * + @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@clear_cache_before_run() @parameterize('model_name', ['gpt2_', 'bert_']) @parameterize('memory_budget', [4000]) @parameterize('solver_name', ['syn', 'asyn']) -def solver_test(model_name: str, - memory_budget: float, - solver_name: str): +def solver_test(model_name: str, memory_budget: float, solver_name: str): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen = get_components_func() @@ -52,11 +52,16 @@ def solver_test(model_name: str, for region in region_list: need_offload = region.need_offload to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None - print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}') + print( + f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}' + ) for region in region_list.__reversed__(): need_offload = region.need_offload to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None - print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}') + print( + f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}' + ) + if __name__ == '__main__': - solver_test() \ No newline at end of file + solver_test() diff --git a/tests/test_auto_parallel/test_pass/test_node_converting_pass.py b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py index d0d107610..429e89aae 100644 --- a/tests/test_auto_parallel/test_pass/test_node_converting_pass.py +++ b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py @@ -6,6 +6,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.testing import clear_cache_before_run class TestModule(torch.nn.Module): @@ -26,6 +27,7 @@ def insert_narrow(gm, x_node): return gm +@clear_cache_before_run() def test_node_args_converting_pass(): model = TestModule() physical_mesh_id = torch.arange(0, 4) diff --git a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py index 7d4fd844a..bca81201c 100644 --- a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py +++ b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py @@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.testing import clear_cache_before_run class TestModule(torch.nn.Module): @@ -36,6 +37,7 @@ def recover_narrow(gm, narrow_node): @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@clear_cache_before_run() def test_size_value_converting_pass(): model = TestModule() physical_mesh_id = torch.arange(0, 4) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py index 6d1b28912..9fbe674ef 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py @@ -2,7 +2,6 @@ from functools import partial import pytest import torch -import torch.multiprocessing as mp try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model @@ -13,9 +12,7 @@ except: from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn class LinearModel(torch.nn.Module): @@ -86,11 +83,8 @@ def check_conv_module(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_bias_addition_module(): - world_size = 4 - run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port()) - mp.spawn(run_func_linear, nprocs=world_size) - run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port()) - mp.spawn(run_func_conv, nprocs=world_size) + spawn(check_linear_module, 4) + spawn(check_conv_module, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py index 7a4c8d32e..398458306 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py @@ -1,9 +1,7 @@ -from functools import partial -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from torch.utils.checkpoint import checkpoint from transformers.pytorch_utils import Conv1D @@ -17,9 +15,7 @@ except: from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn HIDDEN_SIZE = 16 @@ -65,9 +61,7 @@ def check_act_ckpt(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_mlp_layer(): - world_size = 4 - run_func = partial(check_act_ckpt, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_act_ckpt, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py index 7c3277c69..6908a1781 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py @@ -1,9 +1,7 @@ import copy -from functools import partial import pytest import torch -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP try: @@ -15,9 +13,7 @@ except: from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn class MLP(torch.nn.Module): @@ -102,9 +98,7 @@ def check_compatibility_with_ddp(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_compatibility_with_ddp(): - world_size = 4 - run_func = partial(check_compatibility_with_ddp, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_compatibility_with_ddp, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index e4435a049..05704acbf 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -1,10 +1,7 @@ import copy -from functools import partial import pytest import torch -import torch.multiprocessing as mp -from torch.nn.parallel import DistributedDataParallel as DDP try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model @@ -17,10 +14,9 @@ from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer import HybridAdam from colossalai.tensor.process_group import ProcessGroup -from colossalai.testing import assert_close, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port, get_current_device -from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper +from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn +from colossalai.utils import get_current_device +from colossalai.zero import post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper class MLP(torch.nn.Module): @@ -110,9 +106,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_auto_parallel_with_gemini(): - world_size = 4 - run_func = partial(check_auto_parallel_with_gemini, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_auto_parallel_with_gemini, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py index e7fccad36..a0b407b24 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py @@ -10,8 +10,7 @@ from colossalai._analyzer.fx.passes import shape_prop_pass # from colossalai.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks -from colossalai.testing import parameterize -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, parameterize, run_on_environment_flag NUM_REPEAT_BLOCKS = 4 BATCH_SIZE = 1 @@ -81,6 +80,7 @@ class NonRepeatModel(nn.Module): @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() @parameterize('model_cls', [RepeatModel, NonRepeatModel]) def test_repeat_blocks(model_cls): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py index 8688890ef..48d2672c6 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py @@ -1,12 +1,10 @@ import copy import random -from functools import partial from typing import Dict import numpy as np import pytest import torch -import torch.multiprocessing as mp import transformers from torch.fx import GraphModule @@ -30,9 +28,8 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.tensor.shape_consistency import to_global -from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use +from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model BATCH_SIZE = 1 @@ -190,9 +187,7 @@ def check_attention_layer(rank, model_cls, world_size, port): @parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) @rerun_if_address_is_in_use() def test_mlp_layer(model_cls): - world_size = 4 - run_func = partial(check_attention_layer, model_cls=model_cls, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_attention_layer, 4, model_cls=model_cls) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py index 5f0688d5f..5a8c3c4bf 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py @@ -1,5 +1,4 @@ import torch -import torch.nn as nn import transformers from torch.fx import GraphModule @@ -7,10 +6,10 @@ from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.testing import parameterize +from colossalai.testing import clear_cache_before_run, parameterize from colossalai.testing.pytest_wrapper import run_on_environment_flag from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model @@ -20,6 +19,7 @@ HIDDEN_DIM = 384 @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() @parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) def test_self_attention_block(model_cls): config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py index 8d4212438..d10b222c0 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py @@ -6,6 +6,8 @@ from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes import shape_prop_pass from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import clear_cache_before_run class LinearModel(nn.Module): @@ -26,6 +28,7 @@ class LinearModel(nn.Module): @pytest.mark.skip('meta tensor has some bugs in 1.11') +@clear_cache_before_run() def test_liveness_analysis(): model = LinearModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py index e41ac4fa6..e0a2133e6 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -1,23 +1,14 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp -import torch.nn as nn from colossalai.auto_parallel.meta_profiler import meta_register from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port -from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results +from colossalai.testing.utils import clear_cache_before_run, parameterize +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() @parameterize('func', [ torch.nn.functional.softmax, torch.nn.functional.relu, diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py index 1b745d890..68ccc7835 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai.device.device_mesh import DeviceMesh @@ -10,8 +7,7 @@ from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing.utils import rerun_if_address_is_in_use, spawn from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy @@ -62,9 +58,7 @@ def _binary_elementwise_mem_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_meta_concrete_info_match(): - world_size = 4 - run_func_module = partial(_binary_elementwise_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_binary_elementwise_mem_test, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py index a973a8182..c6f7b88f4 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py @@ -1,17 +1,12 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing.utils import rerun_if_address_is_in_use, spawn from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy @@ -25,7 +20,7 @@ class ConvFunctionModule(nn.Module): return nn.functional.conv2d(input, self.conv_weight) -def _conv_module_mem_test(rank, bias, world_size, port): +def _conv_module_mem_test(rank, world_size, port, bias): """This function is for conv memory test Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL @@ -62,9 +57,7 @@ def _conv_module_mem_test(rank, bias, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_conv_meta_concrete_info_match(bias=False): - world_size = 4 - run_func_module = partial(_conv_module_mem_test, bias=bias, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_conv_module_mem_test, 4, bias=bias) def _conv_function_mem_test(rank, world_size, port): @@ -103,9 +96,7 @@ def _conv_function_mem_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_conv_function_concrete_info_match(): - world_size = 4 - run_func_module = partial(_conv_function_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_conv_function_mem_test, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py index 5f3d2df50..e3f76a95c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py @@ -1,33 +1,16 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp -import torch.nn as nn - -from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType +from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import meta_register @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() def test_embedding_meta_info(): meta_func = meta_register.get(torch.nn.Embedding) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py index ddc8e3c6a..fb3ded339 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py @@ -1,24 +1,14 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing.utils import rerun_if_address_is_in_use, spawn from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy -if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register - class MyModule(nn.Module): @@ -63,9 +53,7 @@ def _linear_module_mem_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_module_meta_concrete_info_match(): - world_size = 4 - run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_linear_module_mem_test, 4) def _linear_function_mem_test(rank, world_size, port): @@ -101,9 +89,7 @@ def _linear_function_mem_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_function_meta_concrete_info_match(): - world_size = 4 - run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_linear_function_mem_test, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py index 1242b9db0..2d2d77f0c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py @@ -1,26 +1,8 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp -import torch.nn as nn - -from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem +from colossalai.testing.utils import clear_cache_before_run, parameterize from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': @@ -28,6 +10,7 @@ if torch.__version__ >= '1.12.0': @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() @parameterize( 'tensor_shapes', [ diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py index d3342d310..808172977 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py @@ -1,29 +1,17 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use, spawn from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import meta_register def _batchnorm_module_mem_test(rank, world_size, port): @@ -62,9 +50,7 @@ def _batchnorm_module_mem_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_batchnorm_meta_concrete_info_match(): - world_size = 4 - run_func_module = partial(_batchnorm_module_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_batchnorm_module_mem_test, 4) @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations') diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py index 529686d27..4cddf4e19 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py @@ -1,17 +1,12 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing.utils import rerun_if_address_is_in_use, spawn from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy @@ -51,9 +46,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_adaptiveavgpool_meta_concrete_info_match(): - world_size = 4 - run_func_module = partial(_adaptiveavgpool_module_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_adaptiveavgpool_module_mem_test, 4) def _maxpool_module_mem_test(rank, world_size, port): @@ -92,9 +85,7 @@ def _maxpool_module_mem_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_maxpool_meta_concrete_info_match(): - world_size = 4 - run_func_module = partial(_maxpool_module_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_maxpool_module_mem_test, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py index a544e9a3c..6e8145885 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py @@ -1,26 +1,9 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType +from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': @@ -37,6 +20,7 @@ class SplitModule(nn.Module): @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() def test_tensor_meta_info(): """test tensor related meta information We will just use torch.Tensor.split for the test diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py index 2ae13ea2b..b4564312e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py @@ -1,24 +1,8 @@ import pytest import torch -import torch.multiprocessing as mp -import torch.nn as nn - -from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem +from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': @@ -26,6 +10,7 @@ if torch.__version__ >= '1.12.0': @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() def test_where_meta_info(): meta_func = meta_register.get(torch.where) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py index ffc15e403..80e6a6c14 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler @@ -11,9 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -45,7 +40,7 @@ class AddBMMTorchFunctionModule(nn.Module): return output -def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): +def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = module(using_kwargs).cuda() @@ -249,14 +244,13 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por @parameterize('using_kwargs', [True, False]) @rerun_if_address_is_in_use() def test_2d_device_mesh(module, bias_shape, using_kwargs): - world_size = 4 - run_func = partial(check_2d_device_mesh, - module=module, - bias_shape=bias_shape, - world_size=world_size, - using_kwargs=using_kwargs, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn( + check_2d_device_mesh, + 4, + module=module, + bias_shape=bias_shape, + using_kwargs=using_kwargs, + ) @pytest.mark.skip("skip due to bias cases not ready") @@ -267,14 +261,13 @@ def test_2d_device_mesh(module, bias_shape, using_kwargs): @parameterize('using_kwargs', [True, False]) @rerun_if_address_is_in_use() def test_1d_device_mesh(module, bias_shape, using_kwargs): - world_size = 4 - run_func = partial(check_1d_device_mesh, - module=module, - bias_shape=bias_shape, - using_kwargs=using_kwargs, - world_size=world_size, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn( + check_1d_device_mesh, + 4, + module=module, + bias_shape=bias_shape, + using_kwargs=using_kwargs, + ) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py index 35f12ce83..fe6554cd8 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -17,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -45,7 +40,7 @@ class AddmmModel_with_param(nn.Module): return x -def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port): +def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') if model_cls == AddmmModel: @@ -189,13 +184,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port) @parameterize('model_cls', [AddmmModel, AddmmModel_with_param]) @rerun_if_address_is_in_use() def test_addmm_handler(input_shape, model_cls): - world_size = 4 - run_func_function = partial(check_addmm_function_handler, - input_shape=input_shape, - model_cls=model_cls, - world_size=world_size, - port=free_port()) - mp.spawn(run_func_function, nprocs=world_size) + spawn(check_addmm_function_handler, 4, input_shape=input_shape, model_cls=model_cls) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py index 2069b5e8a..b47b3508a 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -114,9 +109,7 @@ def check_bn_module_handler(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_bn_module_handler(): - world_size = 4 - run_func = partial(check_bn_module_handler, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_bn_module_handler, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py index dca5f6e22..800bc11a5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py @@ -1,9 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp -import torch.nn as nn import torch.nn.functional as F from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -19,9 +15,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy WEIGHT_SHAPE = (32, 16) @@ -168,9 +162,7 @@ def check_linear_module_handler(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(): - world_size = 4 - run_func_module = partial(check_linear_module_handler, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(check_linear_module_handler) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py index 14d4a73fb..c29a065d1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py @@ -1,14 +1,10 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp -import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass from colossalai._analyzer.fx.tracer.tracer import ColoTracer -from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -18,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -35,7 +29,7 @@ class LinearModule(torch.nn.Module): return x -def check_linear_module_handler(rank, bias, world_size, port): +def check_linear_module_handler(rank, world_size, port, bias): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = LinearModule(16, 32, bias=bias).cuda() @@ -157,9 +151,7 @@ def check_linear_module_handler(rank, bias, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(bias=True): - world_size = 4 - run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(check_linear_module_handler, bias=bias) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py index 2414749f6..83f3aafe2 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy -def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size, port): +def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -149,7 +144,7 @@ class BEOpModelWithIntConst(nn.Module): return out -def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port): +def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -236,13 +231,12 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_handler_with_tensor(op, other_dim): - world_size = 4 - run_func_tensor = partial(check_binary_elementwise_handler_with_tensor, - op=op, - other_dim=other_dim, - world_size=world_size, - port=free_port()) - mp.spawn(run_func_tensor, nprocs=world_size) + spawn( + check_binary_elementwise_handler_with_tensor, + 4, + op=op, + other_dim=other_dim, + ) @run_on_environment_flag(name='AUTO_PARALLEL') @@ -252,14 +246,13 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim): @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_handler_with_int(op, model_cls, other_dim): - world_size = 4 - run_func_int = partial(check_binary_elementwise_handler_with_int, - op=op, - model_cls=model_cls, - other_dim=other_dim, - world_size=world_size, - port=free_port()) - mp.spawn(run_func_int, nprocs=world_size) + spawn( + check_binary_elementwise_handler_with_int, + 4, + op=op, + model_cls=model_cls, + other_dim=other_dim, + ) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py index 34c20c1ac..f4fdc458f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -207,11 +202,8 @@ def check_1d_device_mesh(rank, module, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_bmm_handler(module): - world_size = 4 - run_func_2d = partial(check_2d_device_mesh, module=module, world_size=world_size, port=free_port()) - mp.spawn(run_func_2d, nprocs=world_size) - run_func_1d = partial(check_1d_device_mesh, module=module, world_size=world_size, port=free_port()) - mp.spawn(run_func_1d, nprocs=world_size) + spawn(check_2d_device_mesh, 4, module=module) + spawn(check_1d_device_mesh, 4, module=module) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py index fe1a0d726..f9632b1cd 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy -def check_conv_module_handler(rank, bias, world_size, port): +def check_conv_module_handler(rank, world_size, port, bias): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda() @@ -155,7 +150,7 @@ class ConvModel(nn.Module): return x -def check_conv_function_handler(rank, bias, world_size, port): +def check_conv_function_handler(rank, world_size, port, bias): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = ConvModel().cuda() @@ -302,9 +297,7 @@ def check_conv_function_handler(rank, bias, world_size, port): # @parameterize('bias', [True, False]) @rerun_if_address_is_in_use() def test_conv_module_handler(bias=False): - world_size = 4 - run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_conv_module_handler, 4, bias=bias) @run_on_environment_flag(name='AUTO_PARALLEL') @@ -314,9 +307,7 @@ def test_conv_module_handler(bias=False): # @parameterize('bias', [True, False]) @rerun_if_address_is_in_use() def test_conv_function_handler(bias=False): - world_size = 4 - run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_conv_function_handler, 4, bias=bias) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py index 8e5b7512c..64f56ba98 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py @@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHan from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, run_on_environment_flag class ReshapeModel(nn.Module): @@ -23,6 +23,7 @@ class ReshapeModel(nn.Module): @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() def test_reshape_handler(): model = ReshapeModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py index a61d2ed5c..4fa0313b1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -16,9 +13,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy NUM_EMBEDDINGS = 16 @@ -272,18 +268,14 @@ def check_embedding_function_handler(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_embedding_module_handler(): - world_size = 4 - run_func = partial(check_embedding_module_handler, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_embedding_module_handler, 4) @run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist @rerun_if_address_is_in_use() def test_embedding_function_handler(): - world_size = 4 - run_func = partial(check_embedding_function_handler, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_embedding_function_handler, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py index fb6113309..a089df743 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py @@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import clear_cache_before_run class GetattrModel(nn.Module): @@ -22,6 +23,7 @@ class GetattrModel(nn.Module): @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@clear_cache_before_run() def test_getattr_handler(): model = GetattrModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index 9a29808eb..a2e0968b1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -2,7 +2,6 @@ from functools import partial import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -14,12 +13,10 @@ from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import Li from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -103,12 +100,7 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): # @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))]) @parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))]) def test_getitem_from_tensor_handler(getitem_index): - world_size = 4 - run_func = partial(check_getitem_from_tensor_handler, - getitem_index=getitem_index, - world_size=world_size, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_getitem_from_tensor_handler, 4) class GetItemFromTupleModel(nn.Module): @@ -123,6 +115,7 @@ class GetItemFromTupleModel(nn.Module): @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() def test_getitem_from_tuple_handler(): model = GetItemFromTupleModel() tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py index edd7bae6c..ad72c2026 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -11,12 +8,10 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -104,9 +99,7 @@ def check_ln_module_handler(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_ln_module_handler(): - world_size = 4 - run_func = partial(check_ln_module_handler, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_ln_module_handler, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index bec5c3dc5..ec695cd8f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -18,14 +15,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.utils import parameterize -from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy -def check_linear_module_handler(rank, bias, input_shape, world_size, port): +def check_linear_module_handler(rank, world_size, port, bias, input_shape): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda() @@ -172,7 +168,7 @@ class LinearModel(nn.Module): return x -def check_linear_function_handler(rank, bias, input_shape, world_size, port): +def check_linear_function_handler(rank, world_size, port, bias, input_shape): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = LinearModel().cuda() @@ -313,19 +309,18 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(input_shape, bias=False): - world_size = 4 - run_func_module = partial(check_linear_module_handler, - bias=bias, - input_shape=input_shape, - world_size=world_size, - port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) - run_func_function = partial(check_linear_function_handler, - bias=bias, - input_shape=input_shape, - world_size=world_size, - port=free_port()) - mp.spawn(run_func_function, nprocs=world_size) + spawn( + check_linear_module_handler, + 4, + bias=bias, + input_shape=input_shape, + ) + spawn( + check_linear_function_handler, + 4, + bias=bias, + input_shape=input_shape, + ) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py index 46c3ff443..938acd3d1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py @@ -18,7 +18,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( StrategiesVector, ) from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing.utils import parameterize +from colossalai.testing.utils import clear_cache_before_run, parameterize class MatMulModule(nn.Module): @@ -28,6 +28,7 @@ class MatMulModule(nn.Module): @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() @parameterize( 'tensor_shapes', [ diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py index aacc7d9ae..6bff9f964 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py @@ -1,4 +1,3 @@ -import pytest import torch import torch.nn as nn @@ -8,11 +7,11 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.meta_patch.patched_module import linear -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, run_on_environment_flag @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() def test_norm_pool_handler(): model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py index 5efbb4f5f..5259455d2 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py @@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import clear_cache_before_run, parameterize class OutputModel(nn.Module): @@ -23,7 +23,7 @@ class OutputModel(nn.Module): @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') @parameterize('output_option', ['distributed', 'replicated']) -@rerun_if_address_is_in_use() +@clear_cache_before_run() def test_output_handler(output_option): model = OutputModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py index 0a5ad3e35..f071cd120 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py @@ -2,7 +2,6 @@ from functools import partial import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -15,9 +14,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -55,7 +53,7 @@ class LinearReshapeModel(nn.Module): return permute_node -def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, port): +def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') if call_function == torch.permute: @@ -328,14 +326,13 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, @parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))]) @parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel]) def test_view_handler(call_function, reshape_dims, model_cls): - world_size = 4 - run_func = partial(check_view_handler, - call_function=call_function, - reshape_dims=reshape_dims, - model_cls=model_cls, - world_size=world_size, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn( + check_view_handler, + 4, + call_function=call_function, + reshape_dims=reshape_dims, + model_cls=model_cls, + ) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py index 5e8fb51ed..6d02b0e0b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py @@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import clear_cache_before_run, parameterize class PlaceholderModel(nn.Module): @@ -22,7 +22,7 @@ class PlaceholderModel(nn.Module): @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') @parameterize('placeholder_option', ['distributed', 'replicated']) -@rerun_if_address_is_in_use() +@clear_cache_before_run() def test_placeholder_handler(placeholder_option): model = PlaceholderModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py index e589fff99..14c364c45 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py @@ -1,5 +1,4 @@ import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -9,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHan from colossalai.auto_parallel.tensor_shard.options import ShardOption from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, run_on_environment_flag class LinearModel(nn.Module): @@ -108,6 +107,7 @@ def check_shard_option(shard_option): @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() def test_shard_option(): # for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]: for shard_option in [ShardOption.SHARD_LAST_AXIS]: diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py index db463a4e9..75ae0416e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn import torch.nn.functional as F @@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -33,7 +28,7 @@ class LinearSplitModel(nn.Module): return softmax_node -def check_split_handler(rank, softmax_dim, model_cls, world_size, port): +def check_split_handler(rank, world_size, port, softmax_dim, model_cls): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = model_cls(softmax_dim=softmax_dim).cuda() @@ -176,13 +171,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port): @parameterize('softmax_dim', [0, 1, 2, 3]) @parameterize('model_cls', [LinearSplitModel]) def test_split_handler(softmax_dim, model_cls): - world_size = 4 - run_func = partial(check_split_handler, - softmax_dim=softmax_dim, - model_cls=model_cls, - world_size=world_size, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_split_handler, 4, softmax_dim=softmax_dim, model_cls=model_cls) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py index db59ea60e..f860c629b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -47,7 +42,7 @@ class LinearSplitModel(nn.Module): return split_node -def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port): +def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = model_cls(split_size=split_size, split_dim=split_dim).cuda() @@ -258,14 +253,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port @parameterize('split_dim', [0, 1, 2]) @parameterize('model_cls', [ConvSplitModel, LinearSplitModel]) def test_split_handler(split_size, split_dim, model_cls): - world_size = 4 - run_func = partial(check_split_handler, - split_size=split_size, - split_dim=split_dim, - model_cls=model_cls, - world_size=world_size, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_split_handler, 4, split_size=split_size, split_dim=split_dim, model_cls=model_cls) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py index add51d73f..c11291eca 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -14,9 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -36,7 +31,7 @@ class LinearSumModel(nn.Module): return sum_node -def check_sum_handler(rank, sum_dims, keepdim, world_size, port): +def check_sum_handler(rank, world_size, port, sum_dims, keepdim): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda() @@ -228,9 +223,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port): @parameterize('sum_dims', [(0, 2), 1]) @parameterize('keepdim', [False, True]) def test_sum_handler(sum_dims, keepdim): - world_size = 4 - run_func = partial(check_sum_handler, sum_dims=sum_dims, keepdim=keepdim, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_sum_handler, 4, sum_dims=sum_dims, keepdim=keepdim) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py index f54b208c3..5b6ac051a 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py @@ -7,7 +7,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, run_on_environment_flag class TensorConstructorModel(nn.Module): @@ -22,6 +22,7 @@ class TensorConstructorModel(nn.Module): @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() def test_where_handler(): model = TensorConstructorModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py index bd8808973..f4e6dafdf 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py @@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, run_on_environment_flag class ReLuModel(nn.Module): @@ -24,6 +24,7 @@ class ReLuModel(nn.Module): @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() def test_elementwise_handler(): model = ReLuModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py index 300e8f94e..fbb194d8e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -15,9 +12,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -255,13 +251,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): @parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)]) @parameterize('model_cls', [ConvViewModel, LinearViewModel]) def test_view_handler(tgt_shape, model_cls): - world_size = 4 - run_func = partial(check_view_handler, - tgt_shape=tgt_shape, - model_cls=model_cls, - world_size=world_size, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_view_handler, 4, tgt_shape=tgt_shape, model_cls=model_cls) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py index c150ebd90..bd7635ac1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py @@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import clear_cache_before_run class ConvModel(nn.Module): @@ -21,6 +22,7 @@ class ConvModel(nn.Module): @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@clear_cache_before_run() def test_where_handler(): model = ConvModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index fb47baab9..0d93e4e40 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -10,10 +10,11 @@ from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, run_on_environment_flag @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() def test_cost_graph(): physical_mesh_id = torch.arange(0, 8) mesh_shape = (2, 4) diff --git a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py index 9a2240d62..d07145e48 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py +++ b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py @@ -8,7 +8,7 @@ import colossalai from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port +from colossalai.testing import free_port if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py index cb250d640..15610e2b5 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py @@ -9,7 +9,7 @@ from colossalai.autochunk.utils import flat_list from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port +from colossalai.testing import free_port if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py index 17a5abf4c..9e4cb7ee9 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py @@ -1,10 +1,8 @@ -from functools import partial from typing import Dict, List, Tuple import pytest import torch import torch.fx -import torch.multiprocessing as mp try: from fastfold.model.nn.evoformer import EvoformerBlock @@ -15,6 +13,7 @@ except: from test_autochunk_alphafold_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn def get_model(): @@ -66,18 +65,19 @@ def get_chunk_target() -> Dict: not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", ) -@pytest.mark.parametrize("max_memory", [None, 20, 24]) -@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) +@clear_cache_before_run() +@parameterize("max_memory", [None, 20, 24]) +@parameterize("data_args", [(32, 64)]) def test_evoformer_block(data_args, max_memory): - run_func = partial( + spawn( run_test, + 1, data_args=data_args, max_memory=max_memory, get_model=get_model, get_data=get_data, get_chunk_target=get_chunk_target, ) - mp.spawn(run_func, nprocs=1) if __name__ == "__main__": diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py index 5210c1c8d..6b47033e1 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py @@ -1,10 +1,8 @@ -from functools import partial from typing import List, Tuple import pytest import torch import torch.fx -import torch.multiprocessing as mp try: from fastfold.model.nn.evoformer import EvoformerStack @@ -15,6 +13,7 @@ except: from test_autochunk_alphafold_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn def get_model(): @@ -61,17 +60,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", ) -@pytest.mark.parametrize("max_memory", [None, 20, 24]) -@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) +@clear_cache_before_run() +@parameterize("max_memory", [None, 20, 24]) +@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) def test_evoformer_stack(data_args, max_memory): - run_func = partial( + spawn( run_test, + 1, data_args=data_args, max_memory=max_memory, get_model=get_model, get_data=get_data, ) - mp.spawn(run_func, nprocs=1) if __name__ == "__main__": diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py index ad955479e..b4c577c18 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py @@ -1,10 +1,8 @@ -from functools import partial from typing import Dict, List, Tuple import pytest import torch import torch.fx -import torch.multiprocessing as mp try: from fastfold.model.nn.evoformer import ExtraMSABlock @@ -14,6 +12,7 @@ except: from test_autochunk_alphafold_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn def get_model(): @@ -57,17 +56,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", ) -@pytest.mark.parametrize("max_memory", [None, 20, 24]) -@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) +@clear_cache_before_run() +@parameterize("max_memory", [None, 20, 24]) +@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) def test_extramsa_block(data_args, max_memory): - run_func = partial( + spawn( run_test, + 1, data_args=data_args, max_memory=max_memory, get_model=get_model, get_data=get_data, ) - mp.spawn(run_func, nprocs=1) if __name__ == "__main__": diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py index 529250fe8..e245f10d4 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py @@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port +from colossalai.testing import free_port if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index 16c5b10ff..ff0d4a1b5 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -1,9 +1,7 @@ -from functools import partial from typing import List, Tuple import pytest import torch -import torch.multiprocessing as mp try: from diffusers import UNet2DModel @@ -16,6 +14,7 @@ except: from test_autochunk_diffuser_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn BATCH_SIZE = 1 HEIGHT = 448 @@ -37,17 +36,18 @@ def get_data(shape: tuple) -> Tuple[List, List]: not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", ) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("shape", [LATENTS_SHAPE]) -@pytest.mark.parametrize("max_memory", [None, 150, 300]) +@clear_cache_before_run() +@parameterize("model", MODELS) +@parameterize("shape", [LATENTS_SHAPE]) +@parameterize("max_memory", [None, 150, 300]) def test_evoformer_block(model, shape, max_memory): - run_func = partial( + spawn( run_test, + 1, max_memory=max_memory, model=model, data=get_data(shape), ) - mp.spawn(run_func, nprocs=1) if __name__ == "__main__": diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py index 018a2557a..384706639 100644 --- a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py +++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py @@ -1,9 +1,7 @@ -from functools import partial from typing import List, Tuple import pytest import torch -import torch.multiprocessing as mp try: from transformers import GPT2Config, GPT2Model @@ -16,6 +14,7 @@ except: from test_autochunk_transformer_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn BATCH_SIZE = 1 SEQ_LENGTH = 512 @@ -35,18 +34,19 @@ def get_data(shape: tuple) -> Tuple[List, List]: not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", ) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)]) -@pytest.mark.parametrize("max_memory", [None, 6, 8]) +@clear_cache_before_run() +@parameterize("model", MODELS) +@parameterize("shape", [(BATCH_SIZE, SEQ_LENGTH)]) +@parameterize("max_memory", [None, 6, 8]) def test_autochunk_gpt(model, shape, max_memory): - run_func = partial( + spawn( run_test, + 1, data=get_data(shape), max_memory=max_memory, model=model, config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4), ) - mp.spawn(run_func, nprocs=1) if __name__ == "__main__": diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py index bc5eda7ed..faba138cd 100644 --- a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py +++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py @@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port +from colossalai.testing import free_port if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py index 2b7cbf139..a98aa0e03 100644 --- a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py @@ -1,9 +1,7 @@ -from functools import partial from typing import List, Tuple import pytest import torch -import torch.multiprocessing as mp try: from timm.models.vision_transformer import vit_large_patch16_384 as vit @@ -16,6 +14,7 @@ except: from test_autochunk_vit_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn def get_data() -> Tuple[List, List]: @@ -28,16 +27,17 @@ def get_data() -> Tuple[List, List]: not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", ) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("max_memory", [None, 32, 40]) +@clear_cache_before_run() +@parameterize("model", MODELS) +@parameterize("max_memory", [None, 32, 40]) def test_evoformer_block(model, max_memory): - run_func = partial( + spawn( run_test, + 1, max_memory=max_memory, model=model, data=get_data(), ) - mp.spawn(run_func, nprocs=1) if __name__ == "__main__": diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py index 035dd5979..317606fc4 100644 --- a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py @@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port +from colossalai.testing import free_port if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen diff --git a/tests/test_booster/test_accelerator.py b/tests/test_booster/test_accelerator.py index 6958a87e2..895c494d0 100644 --- a/tests/test_booster/test_accelerator.py +++ b/tests/test_booster/test_accelerator.py @@ -1,27 +1,14 @@ -from functools import partial - -import torch.multiprocessing as mp import torch.nn as nn from colossalai.booster.accelerator import Accelerator -from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.testing import clear_cache_before_run, parameterize +@clear_cache_before_run() @parameterize('device', ['cpu', 'cuda']) -def run_accelerator(device): +def test_accelerator(device): acceleartor = Accelerator(device) model = nn.Linear(8, 8) model = acceleartor.configure_model(model) assert next(model.parameters()).device.type == device del model, acceleartor - - -def run_dist(rank): - run_accelerator() - - -@rerun_if_address_is_in_use() -def test_accelerator(): - world_size = 1 - run_func = partial(run_dist) - mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_booster/test_mixed_precision/test_fp16_torch.py b/tests/test_booster/test_mixed_precision/test_fp16_torch.py index bacf29014..963387da2 100644 --- a/tests/test_booster/test_mixed_precision/test_fp16_torch.py +++ b/tests/test_booster/test_mixed_precision/test_fp16_torch.py @@ -1,13 +1,9 @@ -from functools import partial - import torch -import torch.multiprocessing as mp from torch.optim import Adam import colossalai from colossalai.booster.mixed_precision import FP16TorchMixedPrecision -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -41,6 +37,4 @@ def run_torch_amp(rank, world_size, port): @rerun_if_address_is_in_use() def test_torch_ddp_plugin(): - world_size = 1 - run_func = partial(run_torch_amp, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_torch_amp, 1) diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 169983a76..a3c63fd09 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -1,17 +1,12 @@ -from functools import partial - -import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin from colossalai.nn.optimizer import HybridAdam from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -119,9 +114,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True): @rerun_if_address_is_in_use() def test_gemini_plugin(early_stop: bool = True): - world_size = 2 - run_func = partial(run_dist, world_size=world_size, port=free_port(), early_stop=early_stop) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 2, early_stop=early_stop) if __name__ == '__main__': diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 71e8582cc..c225a1a06 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -1,8 +1,5 @@ -from functools import partial - import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import SGD @@ -10,8 +7,7 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import TorchDDPPlugin from colossalai.interface import OptimizerWrapper -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -103,6 +99,4 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_torch_ddp_plugin(): - world_size = 2 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 2) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index dfbb16af4..0f78184f7 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -6,6 +6,7 @@ from torch.optim import Adam from torchvision.models import resnet18 from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.testing import clear_cache_before_run, parameterize # ======== # Note: @@ -15,7 +16,8 @@ from colossalai.checkpoint_io import GeneralCheckpointIO # ======== -@pytest.mark.parametrize('use_safetensors', [True, False]) +@clear_cache_before_run() +@parameterize('use_safetensors', [True, False]) def test_unsharded_checkpoint(use_safetensors: bool): # create a model and optimizer model = resnet18() diff --git a/tests/test_cluster/test_device_mesh_manager.py b/tests/test_cluster/test_device_mesh_manager.py index b79814735..b42ef1fe0 100644 --- a/tests/test_cluster/test_device_mesh_manager.py +++ b/tests/test_cluster/test_device_mesh_manager.py @@ -1,14 +1,9 @@ -from functools import partial - import torch -import torch.multiprocessing as mp from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer import ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port +from colossalai.testing import spawn def check_device_mesh_manager(rank, world_size, port): @@ -31,9 +26,7 @@ def check_device_mesh_manager(rank, world_size, port): def test_device_mesh_manager(): - world_size = 4 - run_func = partial(check_device_mesh_manager, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_device_mesh_manager, 4) if __name__ == '__main__': diff --git a/tests/test_comm/test_boardcast_send_recv_v2.py b/tests/test_comm/test_boardcast_send_recv_v2.py index 1520d6054..253f6f21c 100644 --- a/tests/test_comm/test_boardcast_send_recv_v2.py +++ b/tests/test_comm/test_boardcast_send_recv_v2.py @@ -1,17 +1,12 @@ -from functools import partial -from typing import List - import pytest import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from colossalai.communication.p2p_v2 import _send_object, _recv_object, init_process_group + +from colossalai.communication.p2p_v2 import _recv_object, _send_object from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch -from colossalai.utils import free_port, get_current_device -from colossalai.testing import rerun_if_address_is_in_use from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn disable_existing_loggers() world_size = 4 @@ -45,9 +40,7 @@ def check_layer(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_object_list_p2p(): - disable_existing_loggers() - run_func = partial(check_layer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer, world_size) if __name__ == '__main__': diff --git a/tests/test_comm/test_comm.py b/tests/test_comm/test_comm.py index 07cb67730..747596bd2 100644 --- a/tests/test_comm/test_comm.py +++ b/tests/test_comm/test_comm.py @@ -1,15 +1,13 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp + from colossalai.communication import all_gather, all_reduce, reduce_scatter from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch -from colossalai.utils import free_port, get_current_device -from colossalai.testing import rerun_if_address_is_in_use +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) @@ -66,9 +64,7 @@ def check_layer(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_comm(): - world_size = 4 - run_func = partial(check_layer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer, 4) if __name__ == '__main__': diff --git a/tests/test_comm/test_object_list_p2p.py b/tests/test_comm/test_object_list_p2p.py index 701e3e8ad..e9d7630c1 100644 --- a/tests/test_comm/test_object_list_p2p.py +++ b/tests/test_comm/test_object_list_p2p.py @@ -1,15 +1,18 @@ -from functools import partial - import pytest import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from colossalai.communication.p2p import send_forward, recv_forward, send_backward, recv_backward, send_forward_recv_backward, send_backward_recv_forward + +from colossalai.communication.p2p import ( + recv_backward, + recv_forward, + send_backward, + send_backward_recv_forward, + send_forward, + send_forward_recv_backward, +) from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch -from colossalai.utils import free_port, get_current_device -from colossalai.testing import rerun_if_address_is_in_use +from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict(pipeline=2)) torch.manual_seed(123) @@ -96,9 +99,7 @@ def check_layer(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_object_list_p2p(): - world_size = 2 - run_func = partial(check_layer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer, 2) if __name__ == '__main__': diff --git a/tests/test_comm/test_object_list_p2p_v2.py b/tests/test_comm/test_object_list_p2p_v2.py index c639ac9f8..cae38385b 100644 --- a/tests/test_comm/test_object_list_p2p_v2.py +++ b/tests/test_comm/test_object_list_p2p_v2.py @@ -1,16 +1,12 @@ -from functools import partial - import pytest import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from colossalai.communication.p2p_v2 import send_forward, recv_forward, send_backward, recv_backward, init_process_group -from colossalai.context import ParallelMode, Initializer_Pipeline + +from colossalai.communication.p2p_v2 import recv_backward, recv_forward, send_backward, send_forward +from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch -from colossalai.utils import free_port, get_current_device -from colossalai.testing import rerun_if_address_is_in_use from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn disable_existing_loggers() @@ -121,10 +117,7 @@ def check_layer(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_object_list_p2p(): - disable_existing_loggers() - run_func = partial(check_layer, world_size=world_size, port=free_port()) - disable_existing_loggers() - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer, world_size) if __name__ == '__main__': diff --git a/tests/test_context/test_hybrid_parallel.py b/tests/test_context/test_hybrid_parallel.py index f311b1d2e..9f26a5af5 100644 --- a/tests/test_context/test_hybrid_parallel.py +++ b/tests/test_context/test_hybrid_parallel.py @@ -1,19 +1,17 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial from pathlib import Path + import pytest import torch -import torch.multiprocessing as mp from colossalai import launch +from colossalai.context import reset_seeds from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.utils import free_port -from colossalai.context import reset_seeds from colossalai.global_variables import tensor_parallel_env as tp_env -from colossalai.testing import rerun_if_address_is_in_use +from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py')) @@ -134,9 +132,14 @@ def init_context(config_path, rank, world_size, backend, port, host): torch.cuda.empty_cache() -def run_dist(rank, world_size, backend, port_list, host): - for config_path, port in zip(CONFIG_PATH_LIST, port_list): - init_context(config_path=config_path, rank=rank, world_size=world_size, backend=backend, port=port, host=host) +def run_dist(rank, world_size, port, backend, port_list, host): + for config_path, current_port in zip(CONFIG_PATH_LIST, port_list): + init_context(config_path=config_path, + rank=rank, + world_size=world_size, + backend=backend, + port=current_port, + host=host) reset_seeds() @@ -156,8 +159,7 @@ def test_context(): port_list.append(port) break - test_fn = partial(run_dist, world_size=world_size, backend='gloo', port_list=port_list, host='localhost') - mp.spawn(test_fn, nprocs=world_size) + spawn(run_dist, world_size, backend='gloo', port_list=port_list, host='localhost') if __name__ == '__main__': diff --git a/tests/test_data/test_data_parallel_sampler.py b/tests/test_data/test_data_parallel_sampler.py index 54fa44bdc..2ad3fd696 100644 --- a/tests/test_data/test_data_parallel_sampler.py +++ b/tests/test_data/test_data_parallel_sampler.py @@ -2,20 +2,18 @@ # -*- encoding: utf-8 -*- import os -from functools import partial from pathlib import Path import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp +from torchvision import datasets, transforms import colossalai -from torchvision import transforms, datasets -from colossalai.context import ParallelMode, Config +from colossalai.context import Config, ParallelMode from colossalai.core import global_context as gpc -from colossalai.utils import get_dataloader, free_port -from colossalai.testing import rerun_if_address_is_in_use +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_dataloader CONFIG = Config(dict( parallel=dict( @@ -58,9 +56,7 @@ def run_data_sampler(rank, world_size, port): @pytest.mark.cpu @rerun_if_address_is_in_use() def test_data_sampler(): - world_size = 4 - test_func = partial(run_data_sampler, world_size=world_size, port=free_port()) - mp.spawn(test_func, nprocs=world_size) + spawn(run_data_sampler, 4) if __name__ == '__main__': diff --git a/tests/test_data/test_deterministic_dataloader.py b/tests/test_data/test_deterministic_dataloader.py index 4d76e7f13..239e79dff 100644 --- a/tests/test_data/test_deterministic_dataloader.py +++ b/tests/test_data/test_deterministic_dataloader.py @@ -2,21 +2,18 @@ # -*- encoding: utf-8 -*- import os -from functools import partial from pathlib import Path import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp -from torchvision import transforms, datasets +from torchvision import datasets, transforms import colossalai -from colossalai.context import ParallelMode, Config +from colossalai.context import Config, ParallelMode from colossalai.core import global_context as gpc -from colossalai.utils import get_dataloader, free_port -from colossalai.testing import rerun_if_address_is_in_use -from torchvision import transforms +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_dataloader CONFIG = Config( dict( @@ -70,9 +67,7 @@ def run_data_sampler(rank, world_size, port): @pytest.mark.cpu @rerun_if_address_is_in_use() def test_data_sampler(): - world_size = 4 - test_func = partial(run_data_sampler, world_size=world_size, port=free_port()) - mp.spawn(test_func, nprocs=world_size) + spawn(run_data_sampler, 4) if __name__ == '__main__': diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py index 3c2390c92..4d63592f1 100644 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py @@ -1,25 +1,22 @@ import os - -from functools import partial from pathlib import Path -import colossalai import pytest import torch -import torch.multiprocessing as mp +from torchvision import transforms +from torchvision.datasets import CIFAR10 + +import colossalai from colossalai.amp import AMP_TYPE -from colossalai.trainer import Trainer, hooks from colossalai.context import ParallelMode -from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus -from colossalai.utils import free_port from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn import CrossEntropyLoss from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.utils import get_dataloader from colossalai.pipeline.pipelinable import PipelinableContext -from torchvision.datasets import CIFAR10 -from torchvision import transforms +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.trainer import Trainer, hooks +from colossalai.utils import get_dataloader BATCH_SIZE = 4 NUM_EPOCHS = 60 @@ -96,9 +93,7 @@ def run_trainer(rank, world_size, port): @skip_if_not_enough_gpus(min_gpus=8) @rerun_if_address_is_in_use() def test_hybrid_parallel(): - world_size = 8 - run_func = partial(run_trainer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_trainer, 8) if __name__ == '__main__': diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py index 2bafe0f7e..67d2ba5f5 100644 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py @@ -1,111 +1,104 @@ -import os - -from functools import partial -from pathlib import Path - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.amp import AMP_TYPE -from colossalai.trainer import Trainer, hooks -from colossalai.context import ParallelMode -from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus -from colossalai.utils import free_port -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.nn import CrossEntropyLoss -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.utils import get_dataloader -from colossalai.pipeline.pipelinable import PipelinableContext -from colossalai.logging import disable_existing_loggers -from torchvision.datasets import CIFAR10 -from torchvision import transforms - -from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2 - -disable_existing_loggers() -BATCH_SIZE = 4 -NUM_EPOCHS = 10 -WARMUP_EPOCHS = 5 -CONFIG = dict(NUM_MICRO_BATCHES=2, - parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')), - fp16=dict(mode=AMP_TYPE.NAIVE), - gradient_accumulation=2) - - -def run_trainer(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - disable_existing_loggers() - # get logger - logger = get_dist_logger() - - pipelinable = PipelinableContext() - try: - from titans.model.vit import vit_tiny_patch4_32 - except ImportError: - logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') - logger.warning('please install titan from https://github.com/hpcaitech/Titans') - return - with pipelinable: - model = vit_tiny_patch4_32() - pipelinable.to_layer_list() - pipelinable.policy = "uniform" - model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) - - # craete dataloaders - root = Path(os.environ['DATA']) - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4, pad_if_needed=True), - transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) - train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) - - # create loss function - criterion = CrossEntropyLoss(label_smoothing=0.1) - - # create optimizer - optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0) - - # create lr scheduler - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) - - # intiailize - engine, train_dataloader, *_ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - - engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES) - - logger = get_dist_logger() - - trainer = Trainer(engine=engine, logger=logger) - - hook_list = [ - hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), - ] - - trainer.fit(train_dataloader=train_dataloader, - max_steps=2, - epochs=NUM_EPOCHS, - hooks=hook_list, - display_progress=True) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_hybrid_parallel(): - world_size = 2 - run_func = partial(run_trainer, world_size=world_size, port=free_port()) - disable_existing_loggers() - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_hybrid_parallel() +import os +from pathlib import Path + +import pytest +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR10 + +import colossalai +from colossalai.amp import AMP_TYPE +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2 +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.pipeline.pipelinable import PipelinableContext +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.trainer import Trainer, hooks +from colossalai.utils import get_dataloader + +disable_existing_loggers() +BATCH_SIZE = 4 +NUM_EPOCHS = 10 +WARMUP_EPOCHS = 5 +CONFIG = dict(NUM_MICRO_BATCHES=2, + parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')), + fp16=dict(mode=AMP_TYPE.NAIVE), + gradient_accumulation=2) + + +def run_trainer(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + disable_existing_loggers() + # get logger + logger = get_dist_logger() + + pipelinable = PipelinableContext() + try: + from titans.model.vit import vit_tiny_patch4_32 + except ImportError: + logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') + logger.warning('please install titan from https://github.com/hpcaitech/Titans') + return + with pipelinable: + model = vit_tiny_patch4_32() + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + + # craete dataloaders + root = Path(os.environ['DATA']) + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) + + # intiailize + engine, train_dataloader, *_ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + + engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES) + + logger = get_dist_logger() + + trainer = Trainer(engine=engine, logger=logger) + + hook_list = [ + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), + ] + + trainer.fit(train_dataloader=train_dataloader, + max_steps=2, + epochs=NUM_EPOCHS, + hooks=hook_list, + display_progress=True) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_hybrid_parallel(): + spawn(run_trainer, 2) + disable_existing_loggers() + + +if __name__ == '__main__': + test_hybrid_parallel() diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py index 2ad20f6be..39efcd41a 100644 --- a/tests/test_ddp/test_ddp_ignore_params.py +++ b/tests/test_ddp/test_ddp_ignore_params.py @@ -1,19 +1,16 @@ import os import random -from functools import partial from typing import Callable, Type import numpy as np import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import colossalai from colossalai.nn.parallel import ColoDDP from colossalai.tensor import ProcessGroup -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, ZeroDDP from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration @@ -88,8 +85,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [2]) @rerun_if_address_is_in_use() def test_ddp_ignore_params(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py index bd4742ff2..54f89f972 100644 --- a/tests/test_ddp/test_ddp_state_dict.py +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -1,16 +1,12 @@ -import copy from collections import OrderedDict -from functools import partial import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.nn.parallel import ColoDDP from colossalai.tensor import ColoParameter, ProcessGroup -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs @@ -64,8 +60,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_state_dict(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_ddp/test_reducer.py b/tests/test_ddp/test_reducer.py index 5b302d99f..e8d3a112c 100644 --- a/tests/test_ddp/test_reducer.py +++ b/tests/test_ddp/test_reducer.py @@ -1,15 +1,15 @@ +from functools import partial + import pytest -import colossalai import torch -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from functools import partial -from colossalai.nn.parallel.reducer import Reducer import torch.distributed as dist from torch.distributed.distributed_c10d import _get_default_group +import colossalai +from colossalai.nn.parallel.reducer import Reducer +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device + REDUCE_CNT = 0 @@ -40,8 +40,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_reducer(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_device/test_alpha_beta.py b/tests/test_device/test_alpha_beta.py index 99abacd13..ab933ed57 100644 --- a/tests/test_device/test_alpha_beta.py +++ b/tests/test_device/test_alpha_beta.py @@ -1,16 +1,12 @@ -from functools import partial - import pytest -import torch.multiprocessing as mp from colossalai.device import AlphaBetaProfiler from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -def check_alpha_beta(rank, physical_devices, world_size, port): +def check_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') profiler = AlphaBetaProfiler(physical_devices) @@ -24,9 +20,7 @@ def check_alpha_beta(rank, physical_devices, world_size, port): @parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): - world_size = 4 - run_func = partial(check_alpha_beta, physical_devices=physical_devices, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_alpha_beta, 4, physical_devices=physical_devices) if __name__ == '__main__': diff --git a/tests/test_device/test_extract_alpha_beta.py b/tests/test_device/test_extract_alpha_beta.py index e32bebdd9..52604b9c6 100644 --- a/tests/test_device/test_extract_alpha_beta.py +++ b/tests/test_device/test_extract_alpha_beta.py @@ -1,16 +1,12 @@ -from functools import partial - import pytest -import torch.multiprocessing as mp from colossalai.device import AlphaBetaProfiler from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -def check_extract_alpha_beta(rank, physical_devices, world_size, port): +def check_extract_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') profiler = AlphaBetaProfiler(physical_devices) @@ -27,12 +23,7 @@ def check_extract_alpha_beta(rank, physical_devices, world_size, port): @parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): - world_size = 4 - run_func = partial(check_extract_alpha_beta, - physical_devices=physical_devices, - world_size=world_size, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_extract_alpha_beta, 4, physical_devices=physical_devices) if __name__ == '__main__': diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index 3172897fb..2b7060c48 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -1,15 +1,12 @@ -import torch -from functools import partial import pytest +import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.distributed import ReduceOp from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_layer(rank, world_size, port): @@ -40,9 +37,7 @@ def check_layer(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_logical_pg(): - world_size = 4 - run_func = partial(check_layer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer, 4) if __name__ == '__main__': diff --git a/tests/test_device/test_search_logical_device_mesh.py b/tests/test_device/test_search_logical_device_mesh.py index 591eafb2a..b22a76eab 100644 --- a/tests/test_device/test_search_logical_device_mesh.py +++ b/tests/test_device/test_search_logical_device_mesh.py @@ -1,16 +1,12 @@ -from functools import partial - import pytest -import torch.multiprocessing as mp from colossalai.device import AlphaBetaProfiler from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -def check_alpha_beta(rank, physical_devices, world_size, port): +def check_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') profiler = AlphaBetaProfiler(physical_devices) @@ -27,9 +23,7 @@ def check_alpha_beta(rank, physical_devices, world_size, port): @parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): - world_size = 4 - run_func = partial(check_alpha_beta, physical_devices=physical_devices, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_alpha_beta, 4, physical_devices=physical_devices) if __name__ == '__main__': diff --git a/tests/test_engine/test_engine.py b/tests/test_engine/test_engine.py index fb5bd1e16..62493cf37 100644 --- a/tests/test_engine/test_engine.py +++ b/tests/test_engine/test_engine.py @@ -1,13 +1,10 @@ -from functools import partial +import pytest import colossalai -import pytest -import torch.multiprocessing as mp from colossalai.amp import AMP_TYPE from colossalai.core import global_context as gpc -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.testing import parameterize, rerun_if_address_is_in_use CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), @@ -58,9 +55,7 @@ def run_engine(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_engine(): - world_size = 2 - run_func = partial(run_engine, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_engine, 2) if __name__ == '__main__': diff --git a/tests/test_engine/test_gradient_accumluation.py b/tests/test_engine/test_gradient_accumluation.py index 7f5ee47be..7783827c7 100644 --- a/tests/test_engine/test_gradient_accumluation.py +++ b/tests/test_engine/test_gradient_accumluation.py @@ -1,22 +1,20 @@ import os -from functools import partial from pathlib import Path -import colossalai -from colossalai.testing.utils import rerun_if_address_is_in_use import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.utils import free_port, get_dataloader -from colossalai.testing import rerun_if_address_is_in_use from torch.optim import Adam from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.models import resnet18 +import colossalai +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_dataloader + # Config BATCH_SIZE = 2 NUM_CLASSES = 10 @@ -90,9 +88,7 @@ def run_no_pipeline(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_engine(): - world_size = 4 - func = partial(run_no_pipeline, world_size=world_size, port=free_port()) - mp.spawn(func, nprocs=world_size) + spawn(run_no_pipeline, 4) if __name__ == '__main__': diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index 83df1bb5e..ab483f7e4 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -1,15 +1,13 @@ import pytest import torch -import torch.multiprocessing as mp import torch.nn.functional as F -from torch.fx import GraphModule from torch.utils.checkpoint import checkpoint import colossalai from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn try: from colossalai.fx.codegen import ActivationCheckpointCodeGen @@ -65,9 +63,9 @@ class MyModule(torch.nn.Module): return y1 + y2 + y3 + y4 + y5 + y6 -def _run_act_ckpt_codegen(rank): +def _run_act_ckpt_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model and run forward model = MyModule() @@ -118,13 +116,14 @@ def _run_act_ckpt_codegen(rank): @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@rerun_if_address_is_in_use() def test_act_ckpt_codegen(): - mp.spawn(_run_act_ckpt_codegen, nprocs=1) + spawn(_run_act_ckpt_codegen, 1) -def _run_act_ckpt_python_code_torch11(rank): +def _run_act_ckpt_python_code_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model and run forward model = MyModule() @@ -174,8 +173,9 @@ def _run_act_ckpt_python_code_torch11(rank): @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") +@rerun_if_address_is_in_use() def test_act_ckpt_python_code_torch11(): - mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1) + spawn(_run_act_ckpt_python_code_torch11, 1) if __name__ == '__main__': diff --git a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py index 6b3a49d18..9064023d4 100644 --- a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py @@ -1,15 +1,11 @@ import pytest import torch -import torch.multiprocessing as mp -import torch.nn.functional as F -from torch.fx import GraphModule -from torch.utils.checkpoint import checkpoint import colossalai from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn try: from colossalai.fx.codegen import ActivationCheckpointCodeGen @@ -35,9 +31,9 @@ class MyModule(torch.nn.Module): return self.linear6(self.linear5(self.linear4(self.linear3(self.linear2(self.linear1(x)))))) -def _run_act_ckpt_codegen(rank): +def _run_act_ckpt_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model and run forward model = MyModule() @@ -89,12 +85,12 @@ def _run_act_ckpt_codegen(rank): @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') def test_act_ckpt_codegen(): - mp.spawn(_run_act_ckpt_codegen, nprocs=1) + spawn(_run_act_ckpt_codegen, 1) -def _run_act_ckpt_python_code_torch11(rank): +def _run_act_ckpt_python_code_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model and run forward model = MyModule() @@ -146,8 +142,9 @@ def _run_act_ckpt_python_code_torch11(rank): @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") +@rerun_if_address_is_in_use() def test_act_ckpt_python_code_torch11(): - mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1) + spawn(_run_act_ckpt_python_code_torch11, 1) if __name__ == '__main__': diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py index 5d090066c..96e88eb92 100644 --- a/tests/test_fx/test_codegen/test_offload_codegen.py +++ b/tests/test_fx/test_codegen/test_offload_codegen.py @@ -2,15 +2,13 @@ import copy import pytest import torch -import torch.multiprocessing as mp -import torch.nn.functional as F from torch.fx import GraphModule import colossalai from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn try: from colossalai.fx.codegen import ActivationCheckpointCodeGen @@ -66,9 +64,9 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one" -def _run_offload_codegen(rank): +def _run_offload_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model and input model = MyNet().cuda() @@ -116,13 +114,14 @@ def _run_offload_codegen(rank): @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@rerun_if_address_is_in_use() def test_act_ckpt_codegen(): - mp.spawn(_run_offload_codegen, nprocs=1) + spawn(_run_offload_codegen, 1) -def _run_offload_codegen_torch11(rank): +def _run_offload_codegen_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model and input model = MyNet().cuda() @@ -171,8 +170,9 @@ def _run_offload_codegen_torch11(rank): @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented") +@rerun_if_address_is_in_use() def test_act_ckpt_python_code_torch11(): - mp.spawn(_run_offload_codegen_torch11, nprocs=1) + spawn(_run_offload_codegen_torch11, 1) if __name__ == "__main__": diff --git a/tests/test_fx/test_coloproxy.py b/tests/test_fx/test_coloproxy.py index 2bb6cf864..96cf5198d 100644 --- a/tests/test_fx/test_coloproxy.py +++ b/tests/test_fx/test_coloproxy.py @@ -1,9 +1,11 @@ +import pytest import torch import torch.nn as nn +from torch.fx import GraphModule + from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer -from torch.fx import GraphModule -import pytest +from colossalai.testing import clear_cache_before_run class Conv1D(nn.Module): @@ -23,6 +25,7 @@ class Conv1D(nn.Module): return x +@clear_cache_before_run() def test_coloproxy(): tracer = ColoTracer() diff --git a/tests/test_fx/test_comm_size_compute.py b/tests/test_fx/test_comm_size_compute.py index 8825bbb46..d3daadd71 100644 --- a/tests/test_fx/test_comm_size_compute.py +++ b/tests/test_fx/test_comm_size_compute.py @@ -1,13 +1,11 @@ -import colossalai -import colossalai.nn as col_nn -import pytest import torch -import torch.nn as nn +from torch.fx import symbolic_trace + from colossalai.fx._compatibility import is_compatible_with_meta -from colossalai.fx.passes.adding_split_node_pass import (split_with_split_nodes_pass, uniform_split_pass) +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.utils import get_comm_size -from torch.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run is_compatible = is_compatible_with_meta() if is_compatible: @@ -35,6 +33,7 @@ class MLP(torch.nn.Module): return x +@clear_cache_before_run() def test_comm_size_compute(): model = MLP(MODEL_DIM) input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta') diff --git a/tests/test_fx/test_complete_workflow.py b/tests/test_fx/test_complete_workflow.py deleted file mode 100644 index a21a351f8..000000000 --- a/tests/test_fx/test_complete_workflow.py +++ /dev/null @@ -1,87 +0,0 @@ -from functools import partial - -import pytest -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import torch.nn as nn - -import colossalai -from colossalai.fx import ColoTracer -from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass -from colossalai.tensor import ProcessGroup -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.utils.model.lazy_init_context import LazyInitContext - - -class MLP(torch.nn.Module): - - def __init__(self, dim: int): - super().__init__() - self.linear1 = torch.nn.Linear(dim, dim) - self.linear2 = torch.nn.Linear(dim, dim) - self.dropout = torch.nn.Dropout(0) - self.relu = torch.nn.ReLU() - - def forward(self, x): - x = self.linear1(x) - x = self.dropout(x) - x = self.relu(x) - x = self.linear2(x) - return x - - -def run_workflow(world_size, dev): - # initailization - with LazyInitContext() as ctx: - model = MLP(16) - - for param in model.parameters(): - assert param.is_meta - - # tracing - tracer = ColoTracer() - graph = tracer.trace(model) - gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) - - # annotate - annotated_gm = transformer_mlp_pass(gm, process_group=ProcessGroup(tp_degree=world_size)) - annotated_gm.recompile() - - # materialization and sharding - ctx.lazy_init_parameters(annotated_gm, device=dev) - for param in model.parameters(): - assert not param.is_meta - - # # check sharding - assert list(model.linear1.weight.shape) == [16 // world_size, 16] - assert list(model.linear1.bias.shape) == [16 // world_size] - assert list(model.linear2.weight.shape) == [16, 16 // world_size] - - # test forward to make sure that IR transform will produce the same results - # like how ColoTensor would do it normally - data = torch.rand(4, 16, device=dev) - non_fx_out = model(data) - fx_out = annotated_gm(data) - assert torch.equal(non_fx_out, fx_out), f'{non_fx_out} vs {fx_out}' - - -def run_dist(rank, world_size, dev, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_workflow(world_size, dev) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@pytest.mark.parametrize('dev', ['cuda', 'cpu']) -@rerun_if_address_is_in_use() -def test_complete_workflow(world_size, dev): - if dev == 'cpu' and world_size > 1: - return - run_func = partial(run_dist, world_size=world_size, dev=dev, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_complete_workflow(1, 'cuda') diff --git a/tests/test_fx/test_graph_manipulation.py b/tests/test_fx/test_graph_manipulation.py index fb33e58a7..175b69dd9 100644 --- a/tests/test_fx/test_graph_manipulation.py +++ b/tests/test_fx/test_graph_manipulation.py @@ -1,9 +1,11 @@ -import colossalai import torch -from colossalai.fx.passes.utils import get_leaf, get_top, assign_bfs_level_to_nodes -from colossalai.fx import ColoTracer from torch.fx import GraphModule + +import colossalai +from colossalai.fx import ColoTracer from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata +from colossalai.fx.passes.utils import assign_bfs_level_to_nodes, get_leaf, get_top +from colossalai.testing import clear_cache_before_run class MLP(torch.nn.Module): @@ -25,6 +27,7 @@ class MLP(torch.nn.Module): return l4, l5 +@clear_cache_before_run() def test_graph_manipulation(): model = MLP(4) tracer = ColoTracer() diff --git a/tests/test_fx/test_meta/test_aten.py b/tests/test_fx/test_meta/test_aten.py index 209ded89c..e490522db 100644 --- a/tests/test_fx/test_meta/test_aten.py +++ b/tests/test_fx/test_meta/test_aten.py @@ -3,7 +3,9 @@ from typing import Any, Callable, Union import pytest import torch import torch.nn as nn + from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.testing import clear_cache_before_run if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor @@ -71,6 +73,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() def test_meta_aten(): for (aten_op, requires_backward), v in registered_meta.items(): for f, x in v: diff --git a/tests/test_fx/test_meta/test_backward.py b/tests/test_fx/test_meta/test_backward.py index 351c02c57..7aed6fd45 100644 --- a/tests/test_fx/test_meta/test_backward.py +++ b/tests/test_fx/test_meta/test_backward.py @@ -2,11 +2,14 @@ import pytest import timm.models as tmm import torch import torchvision.models as tm + from colossalai.fx._compatibility import is_compatible_with_meta if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor +from colossalai.testing import clear_cache_before_run + tm_models = [ tm.vgg11, tm.resnet18, @@ -28,6 +31,7 @@ tmm_models = [ @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() def test_torchvision_models(): for m in tm_models: model = m() @@ -36,6 +40,7 @@ def test_torchvision_models(): @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() def test_timm_models(): for m in tmm_models: model = m() diff --git a/tests/test_fx/test_meta/test_meta_trace.py b/tests/test_fx/test_meta/test_meta_trace.py index 404b6d27d..61614f8a6 100644 --- a/tests/test_fx/test_meta/test_meta_trace.py +++ b/tests/test_fx/test_meta/test_meta_trace.py @@ -2,11 +2,14 @@ import pytest import timm.models as tmm import torch import torchvision.models as tm + from colossalai.fx._compatibility import is_compatible_with_meta if is_compatible_with_meta(): from colossalai.fx import meta_trace +from colossalai.testing import clear_cache_before_run + tm_models = [ tm.vgg11, tm.resnet18, @@ -28,6 +31,7 @@ tmm_models = [ @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() def test_torchvision_models_trace(): for m in tm_models: model = m() @@ -36,6 +40,7 @@ def test_torchvision_models_trace(): @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() def test_timm_models_trace(): for m in tmm_models: model = m() diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 6fac180d8..a12512696 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -1,7 +1,9 @@ import torch +from torch.fx import symbolic_trace + from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata -from torch.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor @@ -18,6 +20,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): assert meta_info_spec.numel == orig_tensor.numel() +@clear_cache_before_run() def test_meta_info_prop(): model = torch.nn.Linear(DIM_IN, DIM_OUT) input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') diff --git a/tests/test_fx/test_parallel_1d.py b/tests/test_fx/test_parallel_1d.py index 8963ba29c..1044be7db 100644 --- a/tests/test_fx/test_parallel_1d.py +++ b/tests/test_fx/test_parallel_1d.py @@ -1,18 +1,15 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch -import torch.multiprocessing as mp -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers -from colossalai.initialize import launch -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use from torch.fx import symbolic_trace + +from colossalai.core import global_context as gpc from colossalai.fx.passes import column_shard_linear_pass +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn class MLP(torch.nn.Module): @@ -52,11 +49,10 @@ def check_layer(rank, world_size, port): @pytest.mark.dist +@clear_cache_before_run() @rerun_if_address_is_in_use() def test_1d(): - world_size = 2 - run_func = partial(check_layer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer, 2) if __name__ == '__main__': diff --git a/tests/test_fx/test_pipeline_passes.py b/tests/test_fx/test_pipeline_passes.py index de8a9402b..1078dac9d 100644 --- a/tests/test_fx/test_pipeline_passes.py +++ b/tests/test_fx/test_pipeline_passes.py @@ -1,12 +1,17 @@ +import pytest import torch import torch.nn as nn -import colossalai -import colossalai.nn as col_nn from torch.fx import symbolic_trace -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass, \ - uniform_split_pass, balanced_split_pass_v2 -import pytest +import colossalai +import colossalai.nn as col_nn +from colossalai.fx.passes.adding_split_node_pass import ( + balanced_split_pass, + balanced_split_pass_v2, + split_with_split_nodes_pass, + uniform_split_pass, +) +from colossalai.testing import clear_cache_before_run MODEL_DIM = 16 BATCH_SIZE = 8 @@ -39,6 +44,7 @@ def pipeline_pass_test_helper(model, data, pass_func): assert output.equal(origin_output) +@clear_cache_before_run() def test_pipeline_passes(): model = MLP(MODEL_DIM) data = torch.rand(BATCH_SIZE, MODEL_DIM) diff --git a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py index c71796018..b5a6bbe8b 100644 --- a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py +++ b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py @@ -9,7 +9,7 @@ from torch.fx import symbolic_trace from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, run_on_environment_flag if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor @@ -126,6 +126,7 @@ def run_gpt_forward(gm: torch.fx.GraphModule): @run_on_environment_flag(name='FX_PROFILER') +@clear_cache_before_run() def test_meta_info_prop(): for m in [ tm.alexnet, tm.resnet18, tm.resnet34, tm.resnet50, tm.resnet101, tm.resnet152, tm.densenet121, @@ -155,6 +156,7 @@ def test_meta_info_prop(): @run_on_environment_flag(name='FX_PROFILER') +@clear_cache_before_run() def test_gpt_meta_info_prop(): for m in [gpt2_medium]: model = m().cuda() diff --git a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py index a834951bb..632ab8c09 100644 --- a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py +++ b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py @@ -4,6 +4,7 @@ from torch.fx import GraphModule from torch.utils.checkpoint import checkpoint from colossalai.fx import ColoTracer +from colossalai.testing import clear_cache_before_run class MLP(torch.nn.Module): @@ -35,6 +36,7 @@ class MyModule(torch.nn.Module): return x +@clear_cache_before_run() def test_activation_checkpoint_annotation(): module = MyModule() diff --git a/tests/test_fx/test_tracer/test_bias_addition_module.py b/tests/test_fx/test_tracer/test_bias_addition_module.py index afa30a217..2f88d8c78 100644 --- a/tests/test_fx/test_tracer/test_bias_addition_module.py +++ b/tests/test_fx/test_tracer/test_bias_addition_module.py @@ -1,6 +1,7 @@ import torch from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import clear_cache_before_run class LinearModel(torch.nn.Module): @@ -32,6 +33,7 @@ class ConvModel(torch.nn.Module): return x +@clear_cache_before_run() def test_linear_module(): model = LinearModel(3, 6) tracer = ColoTracer() @@ -68,6 +70,7 @@ def test_linear_module(): assert add_node._meta_data.shape == (3, 6) +@clear_cache_before_run() def test_conv_module(): model = ConvModel(3, 6, 2) tracer = ColoTracer() diff --git a/tests/test_fx/test_tracer/test_control_flow.py b/tests/test_fx/test_tracer/test_control_flow.py index ed842cff2..820729dad 100644 --- a/tests/test_fx/test_tracer/test_control_flow.py +++ b/tests/test_fx/test_tracer/test_control_flow.py @@ -1,7 +1,9 @@ import torch import torch.nn as nn from torch.fx import GraphModule + from colossalai.fx import ColoTracer as Tracer +from colossalai.testing import clear_cache_before_run class ControlFlowModel(nn.Module): @@ -21,6 +23,7 @@ class ControlFlowModel(nn.Module): return x1 - y1 +@clear_cache_before_run() def test_control_flow(): model = ControlFlowModel() tracer = Tracer() diff --git a/tests/test_fx/test_tracer/test_functional_conv.py b/tests/test_fx/test_tracer/test_functional_conv.py index 95670b85f..a552e9052 100644 --- a/tests/test_fx/test_tracer/test_functional_conv.py +++ b/tests/test_fx/test_tracer/test_functional_conv.py @@ -1,8 +1,11 @@ import torch from torch.nn import functional as F + from colossalai.fx.tracer.meta_patch import patched_function +from colossalai.testing import clear_cache_before_run +@clear_cache_before_run() def test_conv(): # test F.conv_1d data_1d = torch.rand(3, 16, 10) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index 31ba2290e..f4d681221 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -3,6 +3,7 @@ import torch from hf_tracer_utils import trace_model_and_compare_output from packaging import version +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo BATCH_SIZE = 2 @@ -10,6 +11,7 @@ SEQ_LENGTH = 16 @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() def test_albert(): sub_registry = model_zoo.get_sub_registry('transformers_albert') diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index 8db6817c6..a833bb30c 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -3,10 +3,12 @@ import torch from hf_tracer_utils import trace_model_and_compare_output from packaging import version +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() def test_bert(): sub_registry = model_zoo.get_sub_registry('transformers_bert') diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py index 92ece357b..0cbea82e0 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -2,6 +2,7 @@ import pytest import torch from colossalai.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run from colossalai.testing.random import seed_all from tests.kit.model_zoo import model_zoo @@ -40,6 +41,7 @@ def trace_and_compare(model_cls, data, output_fn): @pytest.mark.skip(reason='cannot pass this test yet') +@clear_cache_before_run() def test_diffusers(): seed_all(9091, cuda_deterministic=True) @@ -52,6 +54,7 @@ def test_diffusers(): print(f"{name:40s} √") +@clear_cache_before_run() def test_torch_diffusers(): seed_all(65535, cuda_deterministic=True) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index 796c17e39..67107469d 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -3,10 +3,12 @@ import torch from hf_tracer_utils import trace_model_and_compare_output from packaging import version +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() def test_gpt(): sub_registry = model_zoo.get_sub_registry('transformers_gpt') diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index e7bfa6070..369545b03 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -3,10 +3,12 @@ import torch from hf_tracer_utils import trace_model_and_compare_output from packaging import version +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() def test_opt(): sub_registry = model_zoo.get_sub_registry('transformers_opt') diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 5f7e4f81c..811cf3b21 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -3,10 +3,12 @@ import torch from hf_tracer_utils import trace_model_and_compare_output from packaging import version +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() def test_t5(): sub_registry = model_zoo.get_sub_registry('transformers_t5') diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py index 94a93e16f..ef778e218 100644 --- a/tests/test_fx/test_tracer/test_patched_module.py +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -1,5 +1,7 @@ import torch + from colossalai.fx.tracer.meta_patch import patched_module +from colossalai.testing import clear_cache_before_run def _run(data, module, patch_fn): @@ -31,6 +33,7 @@ def _assert_output_shape(data, module, patch_fn, expect_exception, output_shape) assert output.shape == output_shape +@clear_cache_before_run() def test_linear(): # test linear patch can produce the meta output with correct shape data = torch.rand(2, 4, device='meta') @@ -42,6 +45,7 @@ def test_linear(): _assert_output_shape(data, module, patched_module.torch_nn_linear, True, None) +@clear_cache_before_run() def test_rnn(): # test rnn patch can produce the meta output with correct shape data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) @@ -58,6 +62,7 @@ def test_rnn(): _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, True, None) +@clear_cache_before_run() def test_embedding(): data = torch.rand(2, 4, device='meta') @@ -134,6 +139,7 @@ def test_embedding(): output_shape=None) +@clear_cache_before_run() def test_conv1d(): # test conv 1d data = torch.rand(2, 3, 4) @@ -212,6 +218,7 @@ def test_conv2d(): output_shape=materialized_output.shape) +@clear_cache_before_run() def test_conv3d(): # test conv 3d data = torch.rand(2, 3, 4, 4, 4) @@ -253,6 +260,7 @@ def test_conv3d(): output_shape=materialized_output.shape) +@clear_cache_before_run() def test_conv_transpose1d(): # test conv transpose1d data = torch.rand(2, 3, 4) @@ -276,6 +284,7 @@ def test_conv_transpose1d(): output_shape=materialized_output.shape) +@clear_cache_before_run() def test_conv_transpose2d(): # test conv transpose2d data = torch.rand(2, 3, 4, 4) @@ -299,6 +308,7 @@ def test_conv_transpose2d(): output_shape=materialized_output.shape) +@clear_cache_before_run() def test_conv_transpose3d(): # test conv transpose2d data = torch.rand(2, 3, 4, 4, 4) @@ -322,6 +332,7 @@ def test_conv_transpose3d(): output_shape=materialized_output.shape) +@clear_cache_before_run() def test_pool1d(): combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d], [torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]] @@ -349,6 +360,7 @@ def test_pool1d(): _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) +@clear_cache_before_run() def test_pool2d(): combinations = [[torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d], [torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d]] @@ -379,6 +391,7 @@ def test_pool2d(): _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) +@clear_cache_before_run() def test_pool3d(): combinations = [[torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d], [torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d]] @@ -410,6 +423,7 @@ def test_pool3d(): # adapative pooling is different from other pooling, so test it individually +@clear_cache_before_run() def test_adaptive_pooling_1d(): pooler = torch.nn.AdaptiveAvgPool1d(output_size=3) patch_func = patched_module.torch_nn_adapative_pooling_1d @@ -434,6 +448,7 @@ def test_adaptive_pooling_1d(): _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) +@clear_cache_before_run() def test_adaptive_pooling_2d(): pooler = torch.nn.AdaptiveAvgPool2d(output_size=3) patch_func = patched_module.torch_nn_adapative_pooling_2d @@ -458,6 +473,7 @@ def test_adaptive_pooling_2d(): output_shape=output.shape) +@clear_cache_before_run() def test_adaptive_pooling_3d(): pooler = torch.nn.AdaptiveAvgPool3d(output_size=3) patch_func = patched_module.torch_nn_adapative_pooling_3d diff --git a/tests/test_fx/test_tracer/test_patched_op.py b/tests/test_fx/test_tracer/test_patched_op.py index 4406f02db..e0c5f560c 100644 --- a/tests/test_fx/test_tracer/test_patched_op.py +++ b/tests/test_fx/test_tracer/test_patched_op.py @@ -1,6 +1,9 @@ +from functools import partial + import torch + from colossalai.fx.tracer.meta_patch import patched_function -from functools import partial +from colossalai.testing import clear_cache_before_run def _run(data, patch_fn): @@ -22,6 +25,7 @@ def _assert_output_shape(data, patch_fn, expect_exception, output_shape): assert output.shape == output_shape +@clear_cache_before_run() def test_repeat_interleave(): patch_fn = patched_function.torch_repeat_interleave @@ -63,6 +67,7 @@ def test_repeat_interleave(): output_shape=materialized_output.shape) +@clear_cache_before_run() def test_torch_max(): data = torch.rand(4, 3) out = torch.max(data) diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index b175d8b10..aa14f514c 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -3,6 +3,7 @@ import torch from packaging import version from colossalai._analyzer.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo @@ -43,6 +44,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() def test_timm_models(): torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py index 66f4be5a6..eafcaca10 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py @@ -3,12 +3,14 @@ import torch from packaging import version from torchaudio_utils import trace_and_compare +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo # We cannot handle the tensors constructed with constant during forward, such as ``torch.empty(0).to(device=Proxy.device)`` # TODO: We could handle this case by hijacking torch.Tensor.to function. @pytest.mark.skip +@clear_cache_before_run() def test_torchaudio_models(): torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index 40f83d47a..df02568c0 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -2,6 +2,7 @@ import pytest import torch from colossalai._analyzer.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo BATCH = 2 @@ -47,6 +48,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' +@clear_cache_before_run() def test_torchrec_deepfm_models(): deepfm_models = model_zoo.get_sub_registry('deepfm') torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 6d4b6ab81..9776452be 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -2,6 +2,7 @@ import pytest import torch from colossalai._analyzer.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo BATCH = 2 @@ -47,6 +48,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' +@clear_cache_before_run() def test_torchrec_dlrm_models(): torch.backends.cudnn.deterministic = True dlrm_models = model_zoo.get_sub_registry('dlrm') diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index 8dbbf9f5a..bd259475a 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -1,9 +1,11 @@ import torch from colossalai._analyzer.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo +@clear_cache_before_run() def test_torchvision_models(): torch.backends.cudnn.deterministic = True tv_sub_registry = model_zoo.get_sub_registry('torchvision') diff --git a/tests/test_layers/test_1d/test_1d.py b/tests/test_layers/test_1d/test_1d.py index 897590f0d..891512542 100644 --- a/tests/test_layers/test_1d/test_1d.py +++ b/tests/test_layers/test_1d/test_1d.py @@ -1,18 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from checks_1d.check_layer_1d import * from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),) @@ -40,9 +36,7 @@ def check_layer(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_1d(): - world_size = 4 - run_func = partial(check_layer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer, 4) if __name__ == '__main__': diff --git a/tests/test_layers/test_2d/test_2d.py b/tests/test_layers/test_2d/test_2d.py index da235d0cf..bcea5ce7b 100644 --- a/tests/test_layers/test_2d/test_2d.py +++ b/tests/test_layers/test_2d/test_2d.py @@ -1,22 +1,27 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch -import torch.multiprocessing as mp +from checks_2d.check_layer_2d import ( + check_classifier_given_embed_weight, + check_classifier_no_given_weight, + check_embed, + check_layernorm, + check_linear, + check_loss, + check_patch_embed, + check_vocab_parallel_classifier_given_embed_weight, + check_vocab_parallel_classifier_no_given_weight, + check_vocab_parallel_embed, + check_vocab_parallel_loss, +) +from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB + from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight, - check_embed, check_layernorm, check_linear, check_loss, check_patch_embed, - check_vocab_parallel_classifier_given_embed_weight, - check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed, - check_vocab_parallel_loss) -from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB +from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')),) @@ -57,9 +62,7 @@ def check_layer_and_operation(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_2d(): - world_size = 4 - run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer_and_operation, 4) if __name__ == '__main__': diff --git a/tests/test_layers/test_2p5d/test_2p5d.py b/tests/test_layers/test_2p5d/test_2p5d.py index 365e2d934..373d834d0 100644 --- a/tests/test_layers/test_2p5d/test_2p5d.py +++ b/tests/test_layers/test_2p5d/test_2p5d.py @@ -1,15 +1,12 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp +from checks_2p5d.check_layer_2p5d import * +from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB + from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from checks_2p5d.check_layer_2p5d import * -from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB +from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict( pipeline=dict(size=1), @@ -53,9 +50,7 @@ def check_layer_and_operation(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_2p5d(): - world_size = 4 - run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer_and_operation, 4) if __name__ == '__main__': diff --git a/tests/test_layers/test_3d/test_3d.py b/tests/test_layers/test_3d/test_3d.py index 29a8b3aea..fde71a4a0 100644 --- a/tests/test_layers/test_3d/test_3d.py +++ b/tests/test_layers/test_3d/test_3d.py @@ -1,19 +1,24 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch -import torch.multiprocessing as mp +from checks_3d.check_layer_3d import ( + check_classifier_no_given_weight, + check_embed, + check_layernorm, + check_linear, + check_loss, + check_patch_embed, + check_vocab_parallel_classifier_given_embed_weight, + check_vocab_parallel_classifier_no_given_weight, + check_vocab_parallel_embed, + check_vocab_parallel_loss, +) + from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus -from checks_3d.check_layer_3d import (check_classifier_no_given_weight, check_embed, check_layernorm, check_linear, - check_loss, check_patch_embed, check_vocab_parallel_classifier_given_embed_weight, - check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed, - check_vocab_parallel_loss) +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn CONFIG = dict( parallel=dict( @@ -52,9 +57,7 @@ def check_layer_and_operation(rank, world_size, port): @skip_if_not_enough_gpus(min_gpus=8) @rerun_if_address_is_in_use() def test_3d(): - world_size = 8 - run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer_and_operation, 8) if __name__ == '__main__': diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index cff9072c7..22d4f02a4 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -1,20 +1,21 @@ -import pytest -from functools import partial - -import numpy as np import random +from typing import List +import numpy as np +import pytest import torch -import torch.multiprocessing as mp import colossalai -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \ - ColoTensor, ColoTensorSpec -from colossalai.nn.parallel.layers import CachedParamMgr, CachedEmbeddingBag, ParallelCachedEmbeddingBag, EvictionStrategy, \ - ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig -from typing import List +from colossalai.nn.parallel.layers import ( + CachedEmbeddingBag, + CachedParamMgr, + EvictionStrategy, + ParallelCachedEmbeddingBag, + ParallelCachedEmbeddingBagTablewise, + TablewiseEmbeddingBagConfig, +) +from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn NUM_EMBED, EMBED_DIM = 10, 8 BATCH_SIZE = 8 @@ -44,6 +45,7 @@ def synthesize_1d_sparse_feature( @pytest.mark.skip +@clear_cache_before_run() def test_cachemgr(): model = torch.nn.EmbeddingBag(10000, 128) # 10 chunks, 5 in cuda @@ -72,6 +74,7 @@ def test_cachemgr(): assert mgr.cuda_available_chunk_num == 5 +@clear_cache_before_run() def test_reorder_with_freq(): num_embed = 100 chunk_size = 1 @@ -102,7 +105,8 @@ def test_reorder_with_freq(): f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" -@pytest.mark.parametrize('use_LFU', [True, False]) +@clear_cache_before_run() +@parameterize('use_LFU', [True, False]) def test_freq_aware_embed(use_LFU: bool): device = torch.device('cuda', 0) evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET @@ -148,7 +152,8 @@ def test_freq_aware_embed(use_LFU: bool): f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" -@pytest.mark.parametrize('init_freq', [True, False]) +@clear_cache_before_run() +@parameterize('init_freq', [True, False]) def test_lfu_strategy(init_freq: bool): # minimal test to check behavior Bag = CachedEmbeddingBag(5, @@ -248,7 +253,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): input0 [1,2,3] [6,7] [] input1 [] [9] [13,15] input2 [1,5] [6,8] [11] - ↑ ↑ ↑ + ↑ ↑ ↑ rank 0 rank 0 rank 1 in KJT format ''' @@ -363,8 +368,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_parallel_freq_aware_embed(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_layers/test_sequence/test_sequence.py b/tests/test_layers/test_sequence/test_sequence.py index 3862c4ccd..aac192d7e 100644 --- a/tests/test_layers/test_sequence/test_sequence.py +++ b/tests/test_layers/test_sequence/test_sequence.py @@ -1,14 +1,11 @@ -import colossalai -import colossalai.nn as col_nn +import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp -import pytest -from colossalai.core import global_context as gpc +import colossalai from colossalai.context import ParallelMode -from colossalai.testing import rerun_if_address_is_in_use -from functools import partial +from colossalai.core import global_context as gpc +from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence'))) @@ -121,8 +118,8 @@ def check_ring_av(rank, world_size): 'attention output cannot match' -def run_test(rank, world_size): - colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=29500) +def run_test(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=port) # check_ring_qk(rank, world_size) check_ring_av(rank, world_size) @@ -134,9 +131,7 @@ def run_test(rank, world_size): @pytest.mark.dist @rerun_if_address_is_in_use() def test_sequence(): - world_size = 4 - run_func = partial(run_test, world_size=world_size) - mp.spawn(run_func, nprocs=world_size) + spawn(run_test, 4) if __name__ == '__main__': diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index e7b9a5527..e7002a75f 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -1,16 +1,15 @@ -from functools import partial import pytest import torch -import torch.nn as nn -import torch.multiprocessing as mp import torch.distributed as dist +import torch.nn as nn + import colossalai -from colossalai.utils import free_port, get_current_device -from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, Experts from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.utils.moe import sync_moe_model_param from colossalai.engine.gradient_handler import MoeGradientHandler -from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use +from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.utils.moe import sync_moe_model_param BATCH_SIZE = 4 DIM = 16 @@ -65,9 +64,7 @@ def run_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_grad_handler(): - world_size = 4 - run_func = partial(run_test, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_test, 4) if __name__ == '__main__': diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 62f924164..ad9a172b7 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -1,15 +1,14 @@ -from functools import partial import pytest import torch import torch.nn as nn -import torch.multiprocessing as mp + import colossalai from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import free_port, get_current_device -from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.testing import rerun_if_address_is_in_use +from colossalai.core import global_context as gpc +from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, Top2Router +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device BATCH_SIZE = 16 NUM_EXPERTS = 4 @@ -90,15 +89,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f @pytest.mark.parametrize("router", [Top1Router, Top2Router]) @rerun_if_address_is_in_use() def test_moe_kernel(rs, hidden_size, data_type, router): - world_size = 4 - run_func = partial(run_routing, - world_size=world_size, - port=free_port(), - rs=rs, - hidden_size=hidden_size, - data_type=data_type, - router=router) - mp.spawn(run_func, nprocs=world_size) + spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, router=router) if __name__ == '__main__': diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index d2cff44ad..8a0283ba7 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -1,19 +1,16 @@ import os -from functools import partial import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import colossalai from colossalai.context import MOE_CONTEXT from colossalai.nn.layer.moe import load_moe_model, save_moe_model -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_tensor.common_utils import debug_print from tests.test_zero.test_legacy.common import CONFIG @@ -46,8 +43,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use() def test_moe_checkpoint(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_dist) if __name__ == '__main__': diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py index 4826d87ac..555338fcf 100644 --- a/tests/test_moe/test_moe_colo_init.py +++ b/tests/test_moe/test_moe_colo_init.py @@ -1,15 +1,12 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import colossalai from colossalai.context import MOE_CONTEXT from colossalai.tensor import ColoParameter -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext from tests.test_moe.test_moe_zero_init import MoeModel from tests.test_tensor.common_utils import debug_print @@ -52,8 +49,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() def test_moe_colo_init(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 3126f59e2..6dc3f5f18 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -1,21 +1,20 @@ -from functools import partial import pytest -import torch.nn as nn -import torch.multiprocessing as mp import torch.distributed as dist +import torch.nn as nn + import colossalai -from colossalai.utils import free_port, get_current_device -from colossalai.nn.layer.moe import Experts from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.nn.layer.moe import Experts +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.utils.moe import sync_moe_model_param -from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use D_MODEL = 4 D_FF = 8 CONFIG = dict() -def run_test(rank, port): +def run_test(rank, world_size, port): world_size = 4 colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') expert_module = nn.Linear @@ -62,9 +61,7 @@ def run_test(rank, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_moe_initialization(): - world_size = 4 - run_func = partial(run_test, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_test, 4) if __name__ == '__main__': diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 18b50eb5c..79722f9f4 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn import colossalai @@ -10,8 +7,8 @@ from colossalai.context import MOE_CONTEXT from colossalai.logging import get_dist_logger from colossalai.nn import CheckpointModule from colossalai.nn.layer import MoeModule -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from tests.test_zero.test_legacy.common import CONFIG @@ -104,8 +101,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use() def test_moe_zero_init(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index 49c452938..ec37967f1 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -1,15 +1,11 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.context import MOE_CONTEXT from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.nn import MoeLoss -from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.legacy.sharded_model import ShardedModelV2 @@ -67,8 +63,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_moe_zero_model(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index b43e52bb4..efc6e9dda 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.amp import convert_to_apex_amp @@ -10,8 +7,8 @@ from colossalai.context import MOE_CONTEXT from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.nn import MoeLoss from colossalai.nn.optimizer import CPUAdam -from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device +from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.legacy.sharded_model import ShardedModelV2 @@ -116,8 +113,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_moe_zero_optim(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_ops/test_addmm_tp.py b/tests/test_ops/test_addmm_tp.py index 5182868b5..ecd3721b9 100644 --- a/tests/test_ops/test_addmm_tp.py +++ b/tests/test_ops/test_addmm_tp.py @@ -1,14 +1,11 @@ -import colossalai -import torch import pytest +import torch import torch.nn as nn -import torch.multiprocessing as mp -from colossalai.tensor import ColoTensor, ProcessGroup -from colossalai.tensor import ColoTensorSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from functools import partial -from tests.test_tensor.common_utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal class Conv1D(nn.Module): @@ -69,8 +66,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_addmm_1d(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_ops/test_embedding_bag_tp.py b/tests/test_ops/test_embedding_bag_tp.py index c7a1604e5..d3d3dcf7e 100644 --- a/tests/test_ops/test_embedding_bag_tp.py +++ b/tests/test_ops/test_embedding_bag_tp.py @@ -1,14 +1,11 @@ +import pytest +import torch from torch.nn import functional as F -from functools import partial import colossalai -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup -from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_tensor.common_utils import split_param_col_tp1d, tensor_equal, tensor_shard_equal def run_with_spec(spec_init_func): @@ -39,8 +36,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_embedding_bag_1d(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_ops/test_embedding_tp.py b/tests/test_ops/test_embedding_tp.py index 541dc5c09..c0b376e2c 100644 --- a/tests/test_ops/test_embedding_tp.py +++ b/tests/test_ops/test_embedding_tp.py @@ -1,14 +1,11 @@ +import pytest +import torch from torch.nn import functional as F -from functools import partial import colossalai -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor -from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal def run_with_spec(spec_init_func, pg: ProcessGroup): @@ -40,8 +37,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_embedding_1d(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_ops/test_linear_tp.py b/tests/test_ops/test_linear_tp.py index 603e98564..c88adfdd9 100644 --- a/tests/test_ops/test_linear_tp.py +++ b/tests/test_ops/test_linear_tp.py @@ -1,14 +1,11 @@ -from functools import partial - -import colossalai import pytest import torch -import torch.multiprocessing as mp import torch.nn.functional as F -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor -from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal def run_with_spec(spec_init_func, split_bias): @@ -44,8 +41,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_linear_1d(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_ops/test_loss_func.py b/tests/test_ops/test_loss_func.py index 9210242a0..fc55c7f77 100644 --- a/tests/test_ops/test_loss_func.py +++ b/tests/test_ops/test_loss_func.py @@ -1,52 +1,48 @@ -import torch -import pytest -import colossalai -import torch.nn.functional as F -import torch.multiprocessing as mp -from functools import partial -from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec -from colossalai.utils import get_current_device -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern - - -def check_cross_entropy(): - input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True) - input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True) - with torch.no_grad(): - input_ct.copy_(input_t) - - target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device()) - - world_size = torch.distributed.get_world_size() - pg = ProcessGroup(tp_degree=world_size) - input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg)) - input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()])) - input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D)) - - output = F.cross_entropy(input_t, target) - output_colo = F.cross_entropy(input_shard, target) - assert torch.allclose(output_colo, output) - - output.backward() - output_colo.backward() - - assert torch.allclose(input_t.grad, input_ct.grad) - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_cross_entropy() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@rerun_if_address_is_in_use() -def test_loss_func(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_loss_func(1) +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + + +def check_cross_entropy(): + input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True) + input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True) + with torch.no_grad(): + input_ct.copy_(input_t) + + target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device()) + + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg)) + input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()])) + input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D)) + + output = F.cross_entropy(input_t, target) + output_colo = F.cross_entropy(input_shard, target) + assert torch.allclose(output_colo, output) + + output.backward() + output_colo.backward() + + assert torch.allclose(input_t.grad, input_ct.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_cross_entropy() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_loss_func(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_loss_func(1) diff --git a/tests/test_ops/test_op.py b/tests/test_ops/test_op.py index 8d3cf50ff..4176d3b64 100644 --- a/tests/test_ops/test_op.py +++ b/tests/test_ops/test_op.py @@ -1,14 +1,12 @@ -import torch import pytest -import colossalai +import torch import torch.nn.functional as F -import torch.multiprocessing as mp -from functools import partial -from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec, ShardSpec -from colossalai.utils import get_current_device from torch.nn import Parameter -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device def _run_layer_norm(): @@ -66,8 +64,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [2]) @rerun_if_address_is_in_use() def test_element_wise_ops(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) def run_dist2(rank, world_size, port): @@ -79,8 +76,7 @@ def run_dist2(rank, world_size, port): @pytest.mark.parametrize('world_size', [1]) @rerun_if_address_is_in_use() def test_ln(world_size): - run_func = partial(run_dist2, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist2, world_size) def check_all(): diff --git a/tests/test_ops/test_view.py b/tests/test_ops/test_view.py index fc6fc2d3c..a9f203320 100644 --- a/tests/test_ops/test_view.py +++ b/tests/test_ops/test_view.py @@ -1,100 +1,97 @@ -from functools import partial - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device -from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor, ShardSpec -from colossalai.tensor.distspec import DistPlacementPattern -from tests.test_tensor.common_utils import split_param_row_tp1d, split_param_col_tp1d, debug_print - - -def exam_view_core(pg): - # the case of replicated ColoTensors - x = torch.randn(4, 4).cuda() - x_colo = ColoTensor(x, ColoTensorSpec(pg)) - - y = x.view(2, -1, 2) - y_colo = x_colo.view(2, -1, 2) - - assert torch.all(y == y_colo) - assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE - # the perfect case of col-sliced ColoTensors - split_param_col_tp1d(x_colo, pg) - - z = x.view(torch.Size((2, 1, 2, -1))) - z_colo = x_colo.view(torch.Size((2, 1, 2, -1))) - if dist.get_rank() == 0: - z = z[:, :, :, 0:2] - else: - z = z[:, :, :, 2:] - assert torch.all(z == z_colo) - assert z_colo.dist_spec == x_colo.dist_spec - # the perfect case of row-sliced ColoTensors - split_param_row_tp1d(x_colo, pg) - - z = x.view(torch.Size((-1, 2, 2))) - z_colo = x_colo.view(torch.Size((-1, 2, 2))) - if dist.get_rank() == 0: - z = z[0:2, :, :] - else: - z = z[2:, :, :] - assert torch.all(z == z_colo) - assert z_colo.dist_spec == x_colo.dist_spec - # the normal case of row-sliced ColoTensors - z = x.view(-1, 2, 2, 2) - z_colo = x_colo.view(-1, 2, 2, 2) - assert torch.all(z == z_colo) - assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE - - -def exam_view_autograd(pg): - x = torch.randn(8, 2, device=get_current_device(), requires_grad=True) - y = torch.randn(8, 2, device=get_current_device(), requires_grad=True) - with torch.no_grad(): - y.copy_(x) - y = ColoTensor(y, ColoTensorSpec(pg)) - y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()])) - - xx = x.view(2, 2, -1) - yy_slice = y_slice.view(2, 2, -1) - yy = yy_slice.to_replicate() - grad = torch.randn(2, 2, 4, device=get_current_device()) - - xx.backward(grad) - yy.backward(grad) - assert torch.all(x.grad == y.grad) - - -def exam_view_errors(pg): - x = torch.randn(8, 2, device=get_current_device()) - x = ColoTensor(x, ColoTensorSpec(pg)) - split_param_row_tp1d(x, pg) - - x.view('a', 'b', 'c') - x.view(8, -1) - x.view([-2, -2, -2]) - x.view((-1, -1, -1)) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) - exam_view_core(pg) - exam_view_autograd(pg) - # exam_view_errors(pg) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) -@rerun_if_address_is_in_use() -def test_view(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_view(2) +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec +from colossalai.tensor.distspec import DistPlacementPattern +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from tests.test_tensor.common_utils import debug_print, split_param_col_tp1d, split_param_row_tp1d + + +def exam_view_core(pg): + # the case of replicated ColoTensors + x = torch.randn(4, 4).cuda() + x_colo = ColoTensor(x, ColoTensorSpec(pg)) + + y = x.view(2, -1, 2) + y_colo = x_colo.view(2, -1, 2) + + assert torch.all(y == y_colo) + assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE + # the perfect case of col-sliced ColoTensors + split_param_col_tp1d(x_colo, pg) + + z = x.view(torch.Size((2, 1, 2, -1))) + z_colo = x_colo.view(torch.Size((2, 1, 2, -1))) + if dist.get_rank() == 0: + z = z[:, :, :, 0:2] + else: + z = z[:, :, :, 2:] + assert torch.all(z == z_colo) + assert z_colo.dist_spec == x_colo.dist_spec + # the perfect case of row-sliced ColoTensors + split_param_row_tp1d(x_colo, pg) + + z = x.view(torch.Size((-1, 2, 2))) + z_colo = x_colo.view(torch.Size((-1, 2, 2))) + if dist.get_rank() == 0: + z = z[0:2, :, :] + else: + z = z[2:, :, :] + assert torch.all(z == z_colo) + assert z_colo.dist_spec == x_colo.dist_spec + # the normal case of row-sliced ColoTensors + z = x.view(-1, 2, 2, 2) + z_colo = x_colo.view(-1, 2, 2, 2) + assert torch.all(z == z_colo) + assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE + + +def exam_view_autograd(pg): + x = torch.randn(8, 2, device=get_current_device(), requires_grad=True) + y = torch.randn(8, 2, device=get_current_device(), requires_grad=True) + with torch.no_grad(): + y.copy_(x) + y = ColoTensor(y, ColoTensorSpec(pg)) + y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()])) + + xx = x.view(2, 2, -1) + yy_slice = y_slice.view(2, 2, -1) + yy = yy_slice.to_replicate() + grad = torch.randn(2, 2, 4, device=get_current_device()) + + xx.backward(grad) + yy.backward(grad) + assert torch.all(x.grad == y.grad) + + +def exam_view_errors(pg): + x = torch.randn(8, 2, device=get_current_device()) + x = ColoTensor(x, ColoTensorSpec(pg)) + split_param_row_tp1d(x, pg) + + x.view('a', 'b', 'c') + x.view(8, -1) + x.view([-2, -2, -2]) + x.view((-1, -1, -1)) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) + exam_view_core(pg) + exam_view_autograd(pg) + # exam_view_errors(pg) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_view(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_view(2) diff --git a/tests/test_optimizer/test_cpu_adam.py b/tests/test_optimizer/test_cpu_adam.py index ea1c044f5..8b3ecf851 100644 --- a/tests/test_optimizer/test_cpu_adam.py +++ b/tests/test_optimizer/test_cpu_adam.py @@ -2,7 +2,7 @@ import math import torch -from colossalai.testing import parameterize +from colossalai.testing import clear_cache_before_run, parameterize def torch_adam_update( @@ -46,6 +46,7 @@ def assertTrue(condition, msg): assert condition, msg +@clear_cache_before_run() @parameterize('adamw', [True, False]) @parameterize('step', [1, 2]) @parameterize('p_dtype', [torch.float, torch.half]) diff --git a/tests/test_optimizer/test_fused_adam.py b/tests/test_optimizer/test_fused_adam.py index f7227c2d5..114d5293d 100644 --- a/tests/test_optimizer/test_fused_adam.py +++ b/tests/test_optimizer/test_fused_adam.py @@ -1,10 +1,10 @@ import torch import torch.nn as nn -from torch.optim.adam import Adam from torch.optim import AdamW +from torch.optim.adam import Adam from colossalai.nn.optimizer.fused_adam import FusedAdam -from colossalai.testing import parameterize +from colossalai.testing import clear_cache_before_run, parameterize class FC(nn.Module): @@ -17,6 +17,7 @@ class FC(nn.Module): return self.fc(x) +@clear_cache_before_run() @parameterize('adamw', [False, True]) @parameterize('p_dtype', [torch.float, torch.half]) @parameterize('g_dtype', [torch.float, torch.half]) diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py index 8ff6618ae..4afa13349 100644 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ b/tests/test_optimizer/test_fused_adam_kernel.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn from numpy import dtype -from colossalai.testing import parameterize +from colossalai.testing import clear_cache_before_run, parameterize from colossalai.utils import multi_tensor_applier @@ -41,6 +41,7 @@ def torch_adam_update( param.addcdiv_(exp_avg, denom, value=-step_size) +@clear_cache_before_run() @parameterize('adamw', [False, True]) @parameterize('step', [1, 2]) @parameterize('p_dtype', [torch.float, torch.half]) diff --git a/tests/test_optimizer/test_hybrid_adam.py b/tests/test_optimizer/test_hybrid_adam.py index 2576d8ffe..d075149df 100644 --- a/tests/test_optimizer/test_hybrid_adam.py +++ b/tests/test_optimizer/test_hybrid_adam.py @@ -4,11 +4,12 @@ from torch.optim import AdamW from torch.optim.adam import Adam from colossalai.nn.optimizer.hybrid_adam import HybridAdam -from colossalai.testing import parameterize +from colossalai.testing import clear_cache_before_run, parameterize RE = 3 +@clear_cache_before_run() @parameterize('adamw', [False, True]) @parameterize('device', ['cpu', 'cuda:0']) @parameterize('p_dtype', [torch.float]) diff --git a/tests/test_optimizer/test_nvme.py b/tests/test_optimizer/test_nvme.py index 243f785ad..5d794ac2d 100644 --- a/tests/test_optimizer/test_nvme.py +++ b/tests/test_optimizer/test_nvme.py @@ -1,7 +1,9 @@ import pytest import torch -from tests.components_to_test.registry import non_distributed_component_funcs + from colossalai.nn.optimizer import CPUAdam, HybridAdam +from colossalai.testing import clear_cache_before_run, parameterize +from tests.components_to_test.registry import non_distributed_component_funcs def move_some_params_to_cuda(model, torch_model): @@ -16,9 +18,10 @@ def check_params_equal(model, torch_model): assert torch.allclose(p, torch_p, atol=1e-3), f'diff: {torch.abs(p - torch_p)}' -@pytest.mark.parametrize('nvme_offload_fraction', [0.0, 0.5, 1.0]) -@pytest.mark.parametrize('nvme_offload_dir', ['./offload', None]) -@pytest.mark.parametrize('adam_cls', [CPUAdam, HybridAdam]) +@clear_cache_before_run() +@parameterize('nvme_offload_fraction', [0.0, 0.5, 1.0]) +@parameterize('nvme_offload_dir', ['./offload', None]) +@parameterize('adam_cls', [CPUAdam, HybridAdam]) def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): get_components_func = non_distributed_component_funcs.get_callable('simple_net') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py index 7ce2cd433..dab474a4e 100644 --- a/tests/test_pipeline/rpc_test_utils.py +++ b/tests/test_pipeline/rpc_test_utils.py @@ -6,13 +6,14 @@ import torch import torch.distributed as dist import torch.distributed.rpc as rpc import torch.multiprocessing as mp -from colossalai import launch -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.pipeline_process_group import ppg from torch import nn from torch._C._distributed_rpc import _is_current_rpc_agent_set from torch.optim import SGD, Adam, Optimizer, RMSprop +from colossalai import launch +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.pipeline_process_group import ppg + rpc_is_initialized = _is_current_rpc_agent_set @@ -20,7 +21,9 @@ def color_debug(text, prefix=' ', color='blue'): color = color.upper() print(getattr(Back, color), prefix, Style.RESET_ALL, text) + class MLP(nn.Module): + def __init__(self, dim: int, layers: int): super().__init__() self.layers = torch.nn.ModuleList() @@ -32,8 +35,10 @@ class MLP(nn.Module): for layer in self.layers: x = layer(x) return x.sum() - + + class DAG_MLP(nn.Module): + def __init__(self, dim: int, layers: int): super().__init__() self.layers = torch.nn.ModuleList() @@ -48,6 +53,7 @@ class DAG_MLP(nn.Module): y = self.dag_layer(y) return x.sum(), y.sum() + class RpcTestModel(nn.Module): def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: diff --git a/tests/test_pipeline/test_middleware_1f1b.py b/tests/test_pipeline/test_middleware_1f1b.py index c4dc617b1..5b3aad703 100644 --- a/tests/test_pipeline/test_middleware_1f1b.py +++ b/tests/test_pipeline/test_middleware_1f1b.py @@ -1,27 +1,27 @@ -import torch -import pytest import os -import torch.multiprocessing as mp -import torch.distributed.rpc as rpc +from functools import partial -from torch import nn +import pytest +import torch +import torch.distributed.rpc as rpc +from rpc_test_utils import DAG_MLP, MLP from torch._C._distributed_rpc import _is_current_rpc_agent_set + from colossalai import launch +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.middleware.adaptor import get_fx_topology from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass -from colossalai.fx import ColoTracer -from colossalai.pipeline.middleware.adaptor import get_fx_topology -from rpc_test_utils import MLP, DAG_MLP -from functools import partial -from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn # global variable for model created batch_size = 16 dim = 10 rpc_is_initialized = _is_current_rpc_agent_set + def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): model.eval() tracer = ColoTracer() @@ -34,13 +34,15 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): for submodule in split_submodules: if isinstance(submodule, torch.fx.GraphModule): setattr(submodule, '_topo', topo) - return split_submodules[pp_rank+1] + return split_submodules[pp_rank + 1] + def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): torch.manual_seed(1024) partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) return partition + def run_master(model_cls, world_size, forward_only): torch.manual_seed(100) @@ -50,23 +52,27 @@ def run_master(model_cls, world_size, forward_only): chunk = 1 num_microbatches = 8 use_checkpoint = 'store_true' - + if model_cls == MLP: + def data_gen(): x = torch.zeros((batch_size, dim)) kwargs = dict(x=x) return kwargs + model = model_cls(dim, stage_num * 3) if forward_only: labels = None else: labels = 1 elif model_cls == DAG_MLP: + def data_gen(): x = torch.zeros((batch_size, dim)) y = torch.zeros((batch_size, dim)) kwargs = dict(x=x, y=y) return kwargs + model = model_cls(dim, stage_num * 3) if forward_only: labels = None @@ -74,15 +80,17 @@ def run_master(model_cls, world_size, forward_only): labels = 1 else: pass - + data_kwargs = data_gen() - - engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs), - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=chunk, - checkpoint=use_checkpoint,) + + engine = OneFOneBPipelineEngine( + partition_fn=partial(partition, model, data_kwargs), + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint, + ) if not forward_only: engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3) @@ -90,13 +98,14 @@ def run_master(model_cls, world_size, forward_only): input_x = torch.randn((batch_size, dim), device=device) input_y = torch.randn((batch_size, dim), device=device) logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only) - -def run_worker(rank, model_cls, world_size, forward_only, master_func): + + +def run_worker(rank, world_size, port, model_cls, forward_only, master_func): master_addr = 'localhost' master_port = 29020 os.environ['MASTER_ADDR'] = master_addr os.environ['MASTER_PORT'] = str(master_port) - + disable_existing_loggers() launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False) @@ -113,7 +122,8 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func): # barrier here if rpc_is_initialized(): rpc.shutdown() - + + @pytest.mark.skip("skip due to CI torch version 1.11") @parameterize('model_cls', [MLP, DAG_MLP]) @parameterize('forward_only', [True, False]) @@ -122,7 +132,14 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func): def test_pp_middleware_fwd(model_cls, forward_only): world_size = 4 master_func = run_master - mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size) + spawn( + run_worker, + world_size, + model_cls=model_cls, + forward_only=forward_only, + master_func=master_func, + ) + if __name__ == "__main__": - test_pp_middleware_fwd() \ No newline at end of file + test_pp_middleware_fwd() diff --git a/tests/test_pipeline/test_pipelinable.py b/tests/test_pipeline/test_pipelinable.py index c99a88550..627cb5ac6 100644 --- a/tests/test_pipeline/test_pipelinable.py +++ b/tests/test_pipeline/test_pipelinable.py @@ -1,9 +1,7 @@ import torch -import torch.multiprocessing as mp from colossalai.pipeline.pipelinable import PipelinableContext - -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use, rerun_on_exception, spawn NUM_CHUNKS = 1 PIPELINE_SIZE = 2 @@ -27,7 +25,7 @@ class MLP(torch.nn.Module): return x -def run_pipelinable(rank): +def run_pipelinable(rank, world_size, port): pipelinable = PipelinableContext() with pipelinable: model = MLP() @@ -50,9 +48,9 @@ def run_pipelinable(rank): assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_pipelinable(): - mp.spawn(run_pipelinable, nprocs=1) + spawn(run_pipelinable, 1) if __name__ == '__main__': diff --git a/tests/test_pipeline/test_pipeline_process_group.py b/tests/test_pipeline/test_pipeline_process_group.py index c67e4175d..2a00e3ac5 100644 --- a/tests/test_pipeline/test_pipeline_process_group.py +++ b/tests/test_pipeline/test_pipeline_process_group.py @@ -1,13 +1,12 @@ import os import torch.distributed.rpc as rpc -import torch.multiprocessing as mp -import pytest +from rpc_test_utils import pg_parse_args, rpc_is_initialized -from colossalai.pipeline.pipeline_process_group import ppg from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from rpc_test_utils import pg_parse_args, rpc_is_initialized +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.testing import spawn def run_worker(rank, args): @@ -40,4 +39,4 @@ def run_worker(rank, args): if __name__ == "__main__": args = pg_parse_args() world_size = args.world_size - mp.spawn(run_worker, args=(args,), nprocs=world_size) \ No newline at end of file + spawn(run_worker, world_size, args=args) diff --git a/tests/test_tensor/core/test_dist_spec_mgr.py b/tests/test_tensor/core/test_dist_spec_mgr.py index e02f4e797..89476a35b 100644 --- a/tests/test_tensor/core/test_dist_spec_mgr.py +++ b/tests/test_tensor/core/test_dist_spec_mgr.py @@ -1,13 +1,12 @@ import math + +import pytest import torch import torch.distributed as dist -import pytest + import colossalai -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.tensor import DistSpecManager, ProcessGroup, ShardSpec, ReplicaSpec -from functools import partial +from colossalai.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn def run(): @@ -58,8 +57,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_dist_spec_mgr(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/core/test_tensor.py b/tests/test_tensor/core/test_tensor.py index b48d9e9a2..64d198b35 100644 --- a/tests/test_tensor/core/test_tensor.py +++ b/tests/test_tensor/core/test_tensor.py @@ -1,17 +1,11 @@ -import torch import pytest -from colossalai.tensor import ColoTensor +import torch from numpy import allclose import colossalai -from colossalai.utils import free_port -from colossalai.tensor import ColoTensorSpec from colossalai.core import global_context as gpc -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.tensor import distspec, ColoTensor, ProcessGroup, ShardSpec, ReplicaSpec -from functools import partial +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ReplicaSpec, ShardSpec, distspec +from colossalai.testing import rerun_if_address_is_in_use, spawn def _run_tensor_indexing(): @@ -152,8 +146,7 @@ def run_dist_tests(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_dist_cases(world_size): - run_func = partial(run_dist_tests, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist_tests, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/model/test_gpt2.py b/tests/test_tensor/model/test_gpt2.py index 0d6a3fe26..337bfa840 100644 --- a/tests/test_tensor/model/test_gpt2.py +++ b/tests/test_tensor/model/test_gpt2.py @@ -1,15 +1,11 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs @@ -145,8 +141,7 @@ def run_dist(rank, world_size, port, use_ddp): @pytest.mark.parametrize('use_ddp', [False, True]) @rerun_if_address_is_in_use() def test_gpt(world_size, use_ddp): - run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size, use_ddp=use_ddp) if __name__ == '__main__': diff --git a/tests/test_tensor/model/test_model.py b/tests/test_tensor/model/test_model.py index 83abc641c..79d70e53c 100644 --- a/tests/test_tensor/model/test_model.py +++ b/tests/test_tensor/model/test_model.py @@ -1,15 +1,11 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.tensor import ColoTensor, ProcessGroup from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs @@ -313,8 +309,7 @@ def run_model_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_model(world_size): - run_func = partial(run_model_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_model_dist, world_size) def run_pretrain_load_dist(rank, world_size, port): @@ -329,8 +324,7 @@ def run_pretrain_load_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_pretrain_load(world_size): - run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_pretrain_load_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/model/test_module_spec.py b/tests/test_tensor/model/test_module_spec.py index 739bf2b0a..b50851e5e 100644 --- a/tests/test_tensor/model/test_module_spec.py +++ b/tests/test_tensor/model/test_module_spec.py @@ -1,9 +1,7 @@ from copy import deepcopy -from functools import partial import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.nn.parallel.layers import check_colo_module, init_colo_module @@ -17,8 +15,7 @@ from colossalai.tensor import ( ShardSpec, distspec, ) -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs @@ -207,8 +204,7 @@ def run_dist_check(rank, world_size, port): @pytest.mark.skip("for higher testing speed") @rerun_if_address_is_in_use() def test_module_linear_1d(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) @pytest.mark.dist @@ -216,8 +212,7 @@ def test_module_linear_1d(world_size): @pytest.mark.skip("for higher testing speed") @rerun_if_address_is_in_use() def test_module_model(world_size): - run_func = partial(run_dist_model, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist_model, world_size) @pytest.mark.dist @@ -225,8 +220,7 @@ def test_module_model(world_size): @pytest.mark.skip("for higher testing speed") @rerun_if_address_is_in_use() def test_module_check(world_size): - run_func = partial(run_dist_check, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist_check, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_colo_checkpoint_tools.py b/tests/test_tensor/test_colo_checkpoint_tools.py index aa333d552..a53a3f37a 100644 --- a/tests/test_tensor/test_colo_checkpoint_tools.py +++ b/tests/test_tensor/test_colo_checkpoint_tools.py @@ -1,47 +1,41 @@ -import torch -import pytest -from functools import partial - -import torch.multiprocessing as mp -import torch.distributed as dist - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, ColoTensorSpec -from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor -from tests.test_tensor.common_utils import tensor_shard_equal - - -def run_dist(rank, world_size, port, dp_degree, tp_degree): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree) - x = torch.randn(4, 4) - param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg)) - spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D) - param.set_tensor_spec(*spec) - - gather_tensor(param) - if dist.get_rank() == 0: - assert torch.all(x == param) - else: - assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) - dist.barrier() - - scatter_tensor(param, spec[0]) - assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) - assert param.requires_grad is True - dist.barrier() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [4]) -@rerun_if_address_is_in_use() -def test_checkpoint(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port(), dp_degree=2, tp_degree=world_size // 2) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_checkpoint(world_size=4) +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor +from tests.test_tensor.common_utils import tensor_shard_equal + + +def run_dist(rank, world_size, port, dp_degree, tp_degree): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree) + x = torch.randn(4, 4) + param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg)) + spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D) + param.set_tensor_spec(*spec) + + gather_tensor(param) + if dist.get_rank() == 0: + assert torch.all(x == param) + else: + assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) + dist.barrier() + + scatter_tensor(param, spec[0]) + assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) + assert param.requires_grad is True + dist.barrier() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_checkpoint(world_size): + spawn(run_dist, world_size, dp_degree=2, tp_degree=world_size // 2) + + +if __name__ == '__main__': + test_checkpoint(world_size=4) diff --git a/tests/test_tensor/test_comm_spec_apply.py b/tests/test_tensor/test_comm_spec_apply.py index 46eee61f1..2c68633aa 100644 --- a/tests/test_tensor/test_comm_spec_apply.py +++ b/tests/test_tensor/test_comm_spec_apply.py @@ -1,10 +1,5 @@ -from functools import partial - import pytest import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from torch.distributed import ReduceOp from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh @@ -12,8 +7,7 @@ from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_all_gather(device_mesh, rank): @@ -218,8 +212,7 @@ def check_comm(rank, world_size, port): @rerun_if_address_is_in_use() def test_comm_spec(): world_size = 4 - run_func = partial(check_comm, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_comm, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_context.py b/tests/test_tensor/test_context.py index 047371f45..45def034b 100644 --- a/tests/test_tensor/test_context.py +++ b/tests/test_tensor/test_context.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.tensor import ( @@ -14,8 +11,7 @@ from colossalai.tensor import ( ReplicaSpec, ShardSpec, ) -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs @@ -61,8 +57,7 @@ def run_colo_init_context(rank: int, world_size: int, port: int): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_colo_init_context(world_size): - run_func = partial(run_colo_init_context, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_colo_init_context, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index 547a96b26..d1f5b9299 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -1,9 +1,6 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.distributed import ReduceOp from colossalai.core import global_context as gpc @@ -12,8 +9,7 @@ from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_all_gather(process_groups_dict, rank): @@ -182,8 +178,7 @@ def check_comm(rank, world_size, port): @rerun_if_address_is_in_use() def test_comm_spec(): world_size = 4 - run_func = partial(check_comm, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_comm, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index a99ac6e41..3ca369acb 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -1,7 +1,4 @@ -from functools import partial - import torch -import torch.multiprocessing as mp from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch @@ -9,7 +6,7 @@ from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn class TestModel(torch.nn.Module): @@ -92,10 +89,10 @@ def check_dtensor(rank, world_size, port): raise ValueError(f'rank {rank} is not in the device mesh') +@rerun_if_address_is_in_use() def test_dtensor(): world_size = 4 - run_func = partial(check_dtensor, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_dtensor, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index 70cf8726d..5f56decb5 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -1,9 +1,7 @@ import math -from functools import partial import pytest import torch -import torch.multiprocessing as mp from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch @@ -12,8 +10,7 @@ from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout_converter import LayoutConverter from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn entire_shape = torch.Size((64, 32, 16)) layout_converter = LayoutConverter() @@ -192,14 +189,9 @@ def check_layout_converting_apply(rank, world_size, port): @rerun_if_address_is_in_use() def test_layout_converter(): world_size = 4 - run_func = partial(check_one_step_transform, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - run_func = partial(check_layout_converting, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - run_func = partial(check_layout_converting_apply, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_one_step_transform, world_size) + spawn(check_layout_converting, world_size) + spawn(check_layout_converting_apply, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_mix_gather.py b/tests/test_tensor/test_mix_gather.py index c1ab30601..9122808eb 100644 --- a/tests/test_tensor/test_mix_gather.py +++ b/tests/test_tensor/test_mix_gather.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh @@ -11,7 +8,7 @@ from colossalai.logging import disable_existing_loggers from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.utils import mix_gather_simulator -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_mix_gather_S0S1(device_mesh, rank): @@ -323,10 +320,10 @@ def check_comm(rank, world_size, port): @pytest.mark.skip(reason="Skip because the check functions assume 8 GPUS but CI only have 4 GPUs") +@rerun_if_address_is_in_use() def test_mix_gather(): world_size = 8 - run_func = partial(check_comm, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_comm, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_parameter.py b/tests/test_tensor/test_parameter.py index 7c3c4b213..9c3f05da1 100644 --- a/tests/test_tensor/test_parameter.py +++ b/tests/test_tensor/test_parameter.py @@ -1,9 +1,10 @@ -from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup -import torch import pytest +import torch from common_utils import tensor_equal + import colossalai -from colossalai.utils import free_port +from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import free_port @pytest.mark.skip diff --git a/tests/test_tensor/test_shape_consistency_apply.py b/tests/test_tensor/test_shape_consistency_apply.py index 4c838bc83..b57952df4 100644 --- a/tests/test_tensor/test_shape_consistency_apply.py +++ b/tests/test_tensor/test_shape_consistency_apply.py @@ -1,16 +1,12 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_apply(rank, world_size, port): @@ -73,8 +69,7 @@ def check_apply(rank, world_size, port): @rerun_if_address_is_in_use() def test_apply(): world_size = 4 - run_func = partial(check_apply, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_apply, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py index 85008c67a..d66d4fec1 100644 --- a/tests/test_tensor/test_sharded_linear.py +++ b/tests/test_tensor/test_sharded_linear.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn.functional as F import colossalai @@ -10,8 +7,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.nn._ops._utils import gather_forward_split_backward from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn def run_dist(rank, world_size, port): @@ -229,8 +225,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [4]) @rerun_if_address_is_in_use() def test_sharded_mlp(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py index 94e39e5d1..c636d9442 100644 --- a/tests/test_tensor/test_tp_with_zero.py +++ b/tests/test_tensor/test_tp_with_zero.py @@ -1,15 +1,11 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.amp import convert_to_apex_amp from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP from colossalai.zero.gemini import search_chunk_configuration @@ -140,8 +136,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_gpt(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_trainer/test_pipeline/test_p2p.py b/tests/test_trainer/test_pipeline/test_p2p.py index 72820c6a1..cb7a193d2 100644 --- a/tests/test_trainer/test_pipeline/test_p2p.py +++ b/tests/test_trainer/test_pipeline/test_p2p.py @@ -1,21 +1,26 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp -from colossalai.communication import (recv_backward, recv_forward, recv_obj_meta, send_backward, - send_backward_recv_forward, send_forward, send_forward_recv_backward, - send_obj_meta) + +from colossalai.communication import ( + recv_backward, + recv_forward, + recv_obj_meta, + send_backward, + send_backward_recv_forward, + send_forward, + send_forward_recv_backward, + send_obj_meta, +) from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import get_dist_logger -from colossalai.utils import free_port, get_current_device -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device BATCH_SIZE = 4 SEQ_LENGTH = 2 @@ -93,11 +98,10 @@ def run_check(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_p2p(): world_size = 4 - run_func = partial(run_check, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_check, world_size) if __name__ == '__main__': diff --git a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py index 48f729658..6d7bf6b3d 100644 --- a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py +++ b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py @@ -1,34 +1,26 @@ # referenced from Megatron and used to testify communication import os -import os.path as osp -from functools import partial from pathlib import Path -import colossalai import pytest import torch import torch.nn as nn -import torch.multiprocessing as mp -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode -from colossalai.initialize import launch -from colossalai.utils import free_port, get_dataloader, print_rank_0 -from colossalai.testing import rerun_on_exception from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.models import resnet18 +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_dataloader, print_rank_0 BATCH_SIZE = 8 -CONFIG=dict( - NUM_MICRO_BATCHES=2, - parallel = dict( - pipeline=dict(size=2), - tensor=dict(size=1, mode=None) - ) -) +CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=dict(size=2), tensor=dict(size=1, mode=None))) + def run_schedule(rank, world_size, port): launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -85,11 +77,10 @@ def run_schedule(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_pipeline_schedule(): world_size = 2 - run_func = partial(run_schedule, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_schedule, world_size) if __name__ == '__main__': diff --git a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py index b01343329..753f82222 100644 --- a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py +++ b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -1,15 +1,13 @@ -from functools import partial - -import colossalai import pytest import torch -import torch.multiprocessing as mp + +import colossalai from colossalai.amp.amp_type import AMP_TYPE from colossalai.logging import get_dist_logger +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.trainer import Trainer -from colossalai.utils import MultiTimer, free_port +from colossalai.utils import MultiTimer from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.testing import parameterize, rerun_if_address_is_in_use BATCH_SIZE = 4 IMG_SIZE = 32 @@ -54,8 +52,7 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_trainer_no_pipeline(): world_size = 4 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_trainer/test_trainer_with_pipe_schedule.py index 3698526a8..bb63d51a0 100644 --- a/tests/test_trainer/test_trainer_with_pipe_schedule.py +++ b/tests/test_trainer/test_trainer_with_pipe_schedule.py @@ -1,23 +1,21 @@ import os -from functools import partial from pathlib import Path -import colossalai import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.engine.schedule import PipelineSchedule -from colossalai.logging import get_dist_logger -from colossalai.trainer import Trainer -from colossalai.utils import MultiTimer, free_port, get_dataloader from torch.optim import Adam from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.models import resnet18 -from colossalai.testing import rerun_if_address_is_in_use + +import colossalai +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.trainer import Trainer +from colossalai.utils import MultiTimer, get_dataloader BATCH_SIZE = 4 IMG_SIZE = 32 @@ -91,8 +89,7 @@ def run_trainer_with_pipeline(rank, world_size, port): @rerun_if_address_is_in_use() def test_trainer_with_pipeline(): world_size = 4 - run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_trainer_with_pipeline, world_size) if __name__ == '__main__': diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index 3ac75fb00..59a8acd4b 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -4,8 +4,10 @@ import pytest import torch import torch.nn.functional as F + from colossalai.context.parallel_mode import ParallelMode -from colossalai.context.random import add_seed, seed, set_mode, reset_seeds +from colossalai.context.random import add_seed, reset_seeds, seed, set_mode +from colossalai.testing import clear_cache_before_run, parameterize from colossalai.utils.activation_checkpoint import checkpoint @@ -39,8 +41,9 @@ def forward_inplace(x, weight): @pytest.mark.gpu -@pytest.mark.parametrize("use_reentrant", [True, False]) -@pytest.mark.parametrize("cpu_offload", [True, False]) +@clear_cache_before_run() +@parameterize("use_reentrant", [True, False]) +@parameterize("cpu_offload", [True, False]) def test_activation_checkpointing(cpu_offload, use_reentrant): # as seed manager is singleton diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py index 8a0fea9ae..335be6135 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py @@ -1,80 +1,77 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pprint -from functools import partial - -import colossalai.nn as col_nn -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint -from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus - - -def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform - - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - depth = len(model) - start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] - layers = [] - for i in range(depth): - if start <= i < end: - layers.append(model[i]) - else: - layers.append(nn.Identity()) - return nn.Sequential(*tuple(layers)) - - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) - - -def check_checkpoint_1d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),) - - disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - - m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) - sd1 = m1.state_dict() - if gpc.get_global_rank() == 0: - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") - save_checkpoint("test.pt", 0, m1) - - m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) - if is_using_pp(): - m2 = build_pipeline(m2) - - load_checkpoint("test.pt", m2) - sd2 = m2.state_dict() - if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: - sd2 = gather_pipeline_parallel_state_dict(sd2) - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") - - if gpc.get_global_rank() == 0: - for k, v in sd1.items(): - assert k in sd2 - check_equal(v, sd2[k].to(torch.device("cpu"))) - - -@pytest.mark.dist -@pytest.mark.skip("takes too long") -@skip_if_not_enough_gpus(min_gpus=8) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") -def test_checkpoint_1d(): - world_size = 8 - run_func = partial(check_checkpoint_1d, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == "__main__": - test_checkpoint_1d() +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint + +import pytest +import torch +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.utils import is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_1d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_checkpoint_1d(): + spawn(check_checkpoint_1d, 8) + + +if __name__ == "__main__": + test_checkpoint_1d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py index 26314290d..175d9ef6c 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py @@ -1,80 +1,77 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pprint -from functools import partial - -import colossalai.nn as col_nn -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, get_current_device, is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint -from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus - - -def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform - - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - depth = len(model) - start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] - layers = [] - for i in range(depth): - if start <= i < end: - layers.append(model[i]) - else: - layers.append(nn.Identity()) - return nn.Sequential(*tuple(layers)) - - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) - - -def check_checkpoint_2d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),) - - disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - - m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) - sd1 = m1.state_dict() - if gpc.get_global_rank() == 0: - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") - save_checkpoint("test.pt", 0, m1) - - m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) - if is_using_pp(): - m2 = build_pipeline(m2) - - load_checkpoint("test.pt", m2) - sd2 = m2.state_dict() - if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: - sd2 = gather_pipeline_parallel_state_dict(sd2) - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") - - if gpc.get_global_rank() == 0: - for k, v in sd1.items(): - assert k in sd2 - check_equal(v, sd2[k].to(torch.device("cpu"))) - - -@pytest.mark.dist -@pytest.mark.skip("takes too long") -@skip_if_not_enough_gpus(min_gpus=8) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") -def test_checkpoint_2d(): - world_size = 8 - run_func = partial(check_checkpoint_2d, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == "__main__": - test_checkpoint_2d() +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint + +import pytest +import torch +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.utils import is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_2d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_checkpoint_2d(): + spawn(check_checkpoint_2d, 8) + + +if __name__ == "__main__": + test_checkpoint_2d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py index 3dbd340fd..33cb3a65d 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py @@ -1,80 +1,77 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pprint -from functools import partial - -import colossalai.nn as col_nn -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, get_current_device, is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint -from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus - - -def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform - - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - depth = len(model) - start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] - layers = [] - for i in range(depth): - if start <= i < end: - layers.append(model[i]) - else: - layers.append(nn.Identity()) - return nn.Sequential(*tuple(layers)) - - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) - - -def check_checkpoint_2p5d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),) - - disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - - m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) - sd1 = m1.state_dict() - if gpc.get_global_rank() == 0: - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") - save_checkpoint("test.pt", 0, m1) - - m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) - if is_using_pp(): - m2 = build_pipeline(m2) - - load_checkpoint("test.pt", m2) - sd2 = m2.state_dict() - if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: - sd2 = gather_pipeline_parallel_state_dict(sd2) - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") - - if gpc.get_global_rank() == 0: - for k, v in sd1.items(): - assert k in sd2 - check_equal(v, sd2[k].to(torch.device("cpu"))) - - -@pytest.mark.dist -@pytest.mark.skip("takes too long") -@skip_if_not_enough_gpus(min_gpus=8) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") -def test_checkpoint_2p5d(): - world_size = 8 - run_func = partial(check_checkpoint_2p5d, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == "__main__": - test_checkpoint_2p5d() +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint + +import pytest +import torch +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.utils import is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_2p5d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_checkpoint_2p5d(): + spawn(check_checkpoint_2p5d, 8) + + +if __name__ == "__main__": + test_checkpoint_2p5d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py index 38f650547..73ac2dd5f 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py @@ -1,80 +1,77 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pprint -from functools import partial - -import colossalai.nn as col_nn -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, get_current_device, is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint -from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus - - -def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform - - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - depth = len(model) - start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] - layers = [] - for i in range(depth): - if start <= i < end: - layers.append(model[i]) - else: - layers.append(nn.Identity()) - return nn.Sequential(*tuple(layers)) - - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) - - -def check_checkpoint_3d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),) - - disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - - m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) - sd1 = m1.state_dict() - if gpc.get_global_rank() == 0: - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") - save_checkpoint("test.pt", 0, m1) - - m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) - if is_using_pp(): - m2 = build_pipeline(m2) - - load_checkpoint("test.pt", m2) - sd2 = m2.state_dict() - if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: - sd2 = gather_pipeline_parallel_state_dict(sd2) - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") - - if gpc.get_global_rank() == 0: - for k, v in sd1.items(): - assert k in sd2 - check_equal(v, sd2[k].to(torch.device("cpu"))) - - -@pytest.mark.dist -@pytest.mark.skip("takes too long") -@skip_if_not_enough_gpus(min_gpus=8) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") -def test_checkpoint_3d(): - world_size = 8 - run_func = partial(check_checkpoint_3d, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == "__main__": - test_checkpoint_3d() +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint + +import pytest +import torch +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.utils import is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_3d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_checkpoint_3d(): + spawn(check_checkpoint_3d, 8) + + +if __name__ == "__main__": + test_checkpoint_3d() diff --git a/tests/test_utils/test_checkpoint_io/test_load.py b/tests/test_utils/test_checkpoint_io/test_load.py index 780c13dc5..b1a741515 100644 --- a/tests/test_utils/test_checkpoint_io/test_load.py +++ b/tests/test_utils/test_checkpoint_io/test_load.py @@ -3,20 +3,19 @@ from functools import partial from tempfile import TemporaryDirectory from typing import Dict -import colossalai import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.utils.checkpoint_io.io import load, save -from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta) from torch import Tensor from torch.nn import Module from torch.optim import Adam, Optimizer +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.checkpoint_io.io import load, save +from colossalai.utils.checkpoint_io.meta import ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta + def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: assert set(a.keys()) == set(b.keys()) @@ -120,14 +119,13 @@ def test_save_global_load_global(max_shard_size_gb: float): check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict()) -def run_dist(rank, world_size, port, func): +def run_dist(rank, world_size, port, test_fn): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - func() + test_fn() def launch_dist(fn, world_size: int): - proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) - mp.spawn(proc_fn, nprocs=world_size) + spawn(run_dist, world_size, test_fn=fn) def save_dist(dir_name: str, zero: bool): diff --git a/tests/test_utils/test_checkpoint_io/test_merge.py b/tests/test_utils/test_checkpoint_io/test_merge.py index 04e454dcb..255c74adf 100644 --- a/tests/test_utils/test_checkpoint_io/test_merge.py +++ b/tests/test_utils/test_checkpoint_io/test_merge.py @@ -1,18 +1,18 @@ -from colossalai.utils.checkpoint_io.meta import ParamDistMeta -from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME -from colossalai.utils.checkpoint_io.io import save, merge -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from tempfile import TemporaryDirectory -from torch.optim import Adam -from functools import partial -import torch import os +from functools import partial +from tempfile import TemporaryDirectory + import pytest -import colossalai -import torch.nn as nn +import torch import torch.distributed as dist -import torch.multiprocessing as mp +import torch.nn as nn +from torch.optim import Adam + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME +from colossalai.utils.checkpoint_io.io import merge, save +from colossalai.utils.checkpoint_io.meta import ParamDistMeta class DummyModel(nn.Module): @@ -52,7 +52,7 @@ def test_merge_global(): assert len(os.listdir(output_dir)) == 0 -def run_dist(rank, world_size, port, func): +def run_dist(rank, world_size, port, test_fn): colossalai.launch(config={'parallel': { 'tensor': { 'mode': '1d', @@ -64,7 +64,7 @@ def run_dist(rank, world_size, port, func): host='localhost', port=port, backend='nccl') - func() + test_fn() def run_save_dist(dir_name: str, zero: bool): @@ -100,8 +100,7 @@ def test_merge_tp_dp(zero: bool): with TemporaryDirectory() as dir_name: fn = partial(run_save_dist, dir_name, zero) world_size = 4 - proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) - mp.spawn(proc_fn, nprocs=world_size) + spawn(run_dist, world_size, test_fn=fn) with TemporaryDirectory() as output_dir: merge(dir_name, output_dir) assert len(os.listdir(output_dir)) == 5 diff --git a/tests/test_utils/test_checkpoint_io/test_redist.py b/tests/test_utils/test_checkpoint_io/test_redist.py index 6e76f3167..144715bdf 100644 --- a/tests/test_utils/test_checkpoint_io/test_redist.py +++ b/tests/test_utils/test_checkpoint_io/test_redist.py @@ -2,19 +2,23 @@ import os from functools import partial from tempfile import TemporaryDirectory -import colossalai import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from torch.optim import Adam + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME from colossalai.utils.checkpoint_io.io import redist, save -from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, - RedistMeta) -from torch.optim import Adam +from colossalai.utils.checkpoint_io.meta import ( + ParamDistMeta, + ParamRedistMeta, + PipelineRedistMeta, + RankRedistMeta, + RedistMeta, +) class DummyModel(nn.Module): @@ -105,7 +109,7 @@ def test_global_to_dist(): check_checkpoint_shape(output_dir) -def run_dist(rank, world_size, port, func): +def run_dist(rank, world_size, port, test_fn): colossalai.launch(config={'parallel': { 'tensor': { 'mode': '1d', @@ -117,7 +121,7 @@ def run_dist(rank, world_size, port, func): host='localhost', port=port, backend='nccl') - func() + test_fn() def run_save_dist(dir_name: str, zero: bool): @@ -133,8 +137,7 @@ def test_dist_to_dist(zero: bool): with TemporaryDirectory() as dir_name: fn = partial(run_save_dist, dir_name, zero) world_size = 4 - proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) - mp.spawn(proc_fn, nprocs=world_size) + spawn(run_dist, world_size, test_fn=fn) with TemporaryDirectory() as output_dir: redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) if not zero: diff --git a/tests/test_utils/test_checkpoint_io/test_save.py b/tests/test_utils/test_checkpoint_io/test_save.py index 5ff9d0aa2..e35e566f6 100644 --- a/tests/test_utils/test_checkpoint_io/test_save.py +++ b/tests/test_utils/test_checkpoint_io/test_save.py @@ -3,21 +3,24 @@ from functools import partial from tempfile import TemporaryDirectory from typing import Dict -import colossalai import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.utils.checkpoint_io.constant import (GLOBAL_META_FILE_NAME, META_CKPT_FILE_NAME, MODEL_CKPT_FILE_NAME, - OTHER_CKPT_FILE_NAME) -from colossalai.utils.checkpoint_io.io import save -from colossalai.utils.checkpoint_io.meta import ParamDistMeta from torch import Tensor from torch.optim import Adam +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.checkpoint_io.constant import ( + GLOBAL_META_FILE_NAME, + META_CKPT_FILE_NAME, + MODEL_CKPT_FILE_NAME, + OTHER_CKPT_FILE_NAME, +) +from colossalai.utils.checkpoint_io.io import save +from colossalai.utils.checkpoint_io.meta import ParamDistMeta + def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: assert set(a.keys()) == set(b.keys()) @@ -104,9 +107,9 @@ def test_save_global_shard(): }) -def run_dist(rank, world_size, port, func): +def run_dist(rank, world_size, port, test_fn): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - func() + test_fn() def run_save_dist(dir_name): @@ -124,8 +127,7 @@ def test_save_dist(): with TemporaryDirectory() as dir_name: fn = partial(run_save_dist, dir_name) world_size = 2 - proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) - mp.spawn(proc_fn, nprocs=world_size) + spawn(run_dist, world_size, test_fn=fn) assert len(os.listdir(dir_name)) == 8 global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) assert len(global_meta['meta']) == 2 diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index 7c2ad9078..89760a545 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -1,20 +1,17 @@ import os import shutil from copy import deepcopy -from functools import partial import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR import colossalai from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext @@ -202,13 +199,7 @@ def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): # @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) @rerun_if_address_is_in_use() def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None): - run_func = partial(run_dist, - world_size=world_size, - port=free_port(), - use_ddp=use_ddp, - use_mp_reload=use_mp_reload, - test_scheduler=test_scheduler) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size, use_ddp=use_ddp, use_mp_reload=use_mp_reload, test_scheduler=test_scheduler) if __name__ == '__main__': diff --git a/tests/test_utils/test_commons.py b/tests/test_utils/test_commons.py index 6bfa6f33c..2633d7da2 100644 --- a/tests/test_utils/test_commons.py +++ b/tests/test_utils/test_commons.py @@ -1,15 +1,13 @@ import torch -import torch.multiprocessing as mp import colossalai -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline from colossalai.zero.legacy.sharded_param import ShardedTensor -def run_tensor_move(rank): - colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') +def run_tensor_move(rank, world_size, port): + colossalai.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl') src_t = torch.ones(2, 3).cuda() tgt_t = torch.zeros(2, 3) @@ -36,7 +34,7 @@ def run_tensor_move(rank): @rerun_if_address_is_in_use() def test_tensor_move(): - mp.spawn(run_tensor_move, nprocs=1) + spawn(run_tensor_move, 1) if __name__ == '__main__': diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 441cbbb22..7a28b0157 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -5,6 +5,7 @@ import torch from einops import rearrange from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN +from colossalai.testing import clear_cache_before_run, parameterize if HAS_MEM_EFF_ATTN: from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention @@ -22,7 +23,8 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") -@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +@clear_cache_before_run() +@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): D = H * D_HEAD @@ -42,7 +44,8 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") -@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +@clear_cache_before_run() +@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): D = H * D_HEAD @@ -65,7 +68,8 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") -@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +@clear_cache_before_run() +@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): D = H * D_HEAD @@ -84,7 +88,8 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") -@pytest.mark.parametrize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)]) +@clear_cache_before_run() +@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)]) def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16): D = H * D_HEAD diff --git a/tests/test_utils/test_lazy_init/test_distribute.py b/tests/test_utils/test_lazy_init/test_distribute.py index 1e32814ab..2c15ca84e 100644 --- a/tests/test_utils/test_lazy_init/test_distribute.py +++ b/tests/test_utils/test_lazy_init/test_distribute.py @@ -1,17 +1,14 @@ -from functools import partial from typing import Optional import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn import colossalai from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.common import print_rank_0 try: @@ -105,9 +102,7 @@ def run_dist(rank, world_size, port) -> None: @pytest.mark.dist @rerun_if_address_is_in_use() def test_dist_lazy_init(): - world_size = 4 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 4) if __name__ == '__main__': diff --git a/tests/test_utils/test_memory.py b/tests/test_utils/test_memory.py index 46a5aeba5..c88c2f8ec 100644 --- a/tests/test_utils/test_memory.py +++ b/tests/test_utils/test_memory.py @@ -1,12 +1,9 @@ import pytest import colossalai +from colossalai.testing import spawn from colossalai.utils.cuda import get_current_device -from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity -from colossalai.utils import free_port - -from functools import partial -import torch.multiprocessing as mp +from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): @@ -24,8 +21,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [3, 4]) def test_memory_utils(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_utils/test_norm_gradient_clipping.py b/tests/test_utils/test_norm_gradient_clipping.py index 259286663..c0d678026 100644 --- a/tests/test_utils/test_norm_gradient_clipping.py +++ b/tests/test_utils/test_norm_gradient_clipping.py @@ -1,16 +1,15 @@ -from colossalai.tensor import distspec, ColoTensorSpec, ProcessGroup -from colossalai.tensor.colo_parameter import ColoParameter -import colossalai import pytest import torch -import torch.multiprocessing as mp -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, get_current_device +from torch.nn.parameter import Parameter from torch.nn.utils import clip_grad_norm_ -from functools import partial -from colossalai.testing import parameterize, rerun_if_address_is_in_use + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor import ColoTensorSpec, ProcessGroup, distspec +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.utils.common import clip_grad_norm -from torch.nn.parameter import Parameter def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): @@ -71,8 +70,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_zero_clip_grad(world_size: int): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_utils/test_zero_gradient_clippling.py b/tests/test_utils/test_zero_gradient_clippling.py index 920656726..e99cf388e 100644 --- a/tests/test_utils/test_zero_gradient_clippling.py +++ b/tests/test_utils/test_zero_gradient_clippling.py @@ -1,21 +1,19 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import copy from functools import partial import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import checkpoint, clip_grad_norm_fp32 from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 @@ -106,8 +104,7 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_zero_clip_grad(): world_size = 4 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_chunk_mgrv2.py b/tests/test_zero/test_gemini/test_chunk_mgrv2.py index ba0945551..7ea063877 100644 --- a/tests/test_zero/test_gemini/test_chunk_mgrv2.py +++ b/tests/test_zero/test_gemini/test_chunk_mgrv2.py @@ -1,13 +1,9 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.zero.gemini.chunk import ChunkManager from tests.test_tensor.common_utils import debug_print @@ -64,8 +60,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [2]) @rerun_if_address_is_in_use() def test_chunk_manager(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index 5f9ba5d3a..16764aa6b 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -1,15 +1,12 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import colossalai from colossalai.tensor import ColoParameter from colossalai.tensor import ProcessGroup as ColoProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.chunk import Chunk @@ -117,8 +114,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 2, 4]) @rerun_if_address_is_in_use() def test_chunk_function(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 8cfacd018..697595bc3 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close @@ -10,8 +7,7 @@ import colossalai from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.tensor import ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration @@ -103,8 +99,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_gpt(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_gemini_use_rmt.py b/tests/test_zero/test_gemini/test_gemini_use_rmt.py index 9d5419e94..dd580976d 100644 --- a/tests/test_zero/test_gemini/test_gemini_use_rmt.py +++ b/tests/test_zero/test_gemini/test_gemini_use_rmt.py @@ -1,14 +1,10 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.tensor import ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.zero import ColoInitContext, ZeroDDP from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.zero.gemini.gemini_mgr import GeminiManager from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer @@ -98,8 +94,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_gemini_use_rmt(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_get_torch_model.py b/tests/test_zero/test_gemini/test_get_torch_model.py index c014ced97..b3e3b2b22 100644 --- a/tests/test_zero/test_gemini/test_get_torch_model.py +++ b/tests/test_zero/test_gemini/test_get_torch_model.py @@ -1,14 +1,9 @@ -import os -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.tensor import ColoParameter -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, GeminiDDP from colossalai.zero.gemini.utils import get_static_torch_model @@ -50,8 +45,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_convert_torch_module(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 65f252c55..38b6e474e 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -1,25 +1,20 @@ -from functools import partial -from time import time - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import debug_print, set_seed +from tests.test_tensor.common_utils import set_seed def check_param(model: ZeroDDP, torch_model: torch.nn.Module): @@ -105,8 +100,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_grad_clip(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 12392d6e5..790a0611c 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -1,18 +1,15 @@ -from functools import partial from typing import Callable import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx, zero_model_wrapper from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration @@ -128,8 +125,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_inference(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 7364e59d1..8ce20c16e 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -1,18 +1,13 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor import ColoParameter, ColoTensor -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration @@ -157,8 +152,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_optim(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py index 9a3e93493..0e6f283aa 100644 --- a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py +++ b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py @@ -3,12 +3,14 @@ from copy import deepcopy import numpy as np import torch +from colossalai.testing import clear_cache_before_run from colossalai.zero import ColoInitContext from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs +@clear_cache_before_run() def test_runtime_mem_tracer(): test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model', 'albert'] diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py index 71cdf9a18..35b3b93ad 100644 --- a/tests/test_zero/test_gemini/test_search.py +++ b/tests/test_zero/test_gemini/test_search.py @@ -1,14 +1,10 @@ -from functools import partial - import pytest import torch -import torch.distributed as dist -import torch.multiprocessing as mp import colossalai from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs @@ -115,8 +111,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_search(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 7e759808d..66e05f3ed 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -1,14 +1,9 @@ -from functools import partial - import pytest import torch -import torch.distributed as dist -import torch.multiprocessing as mp from torch.testing import assert_close import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, ZeroDDP from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration @@ -105,8 +100,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_zero_ddp(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index 996dc4eb8..a8af176c5 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -1,14 +1,10 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import colossalai from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration @@ -83,8 +79,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_zero_optim(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_found_inf.py b/tests/test_zero/test_legacy/test_found_inf.py index 03a1a609b..e90158e0a 100644 --- a/tests/test_zero/test_legacy/test_found_inf.py +++ b/tests/test_zero/test_legacy/test_found_inf.py @@ -1,15 +1,11 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from common import CONFIG from test_sharded_optim_v2 import _run_step import colossalai from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy @@ -64,8 +60,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_found_inf(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_gemini_manager.py b/tests/test_zero/test_legacy/test_gemini_manager.py index aee943253..0e956f7cc 100644 --- a/tests/test_zero/test_legacy/test_gemini_manager.py +++ b/tests/test_zero/test_legacy/test_gemini_manager.py @@ -1,10 +1,12 @@ import pytest import torch +from colossalai.testing import clear_cache_before_run from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState @pytest.mark.dist +@clear_cache_before_run() def test_gemini_manager(): # reset the manager, in case that there exists memory information left manager = StatefulTensor.GST_MGR diff --git a/tests/test_zero/test_legacy/test_init_context.py b/tests/test_zero/test_legacy/test_init_context.py index 0eb8842de..844938271 100644 --- a/tests/test_zero/test_legacy/test_init_context.py +++ b/tests/test_zero/test_legacy/test_init_context.py @@ -1,17 +1,13 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from common import CONFIG import colossalai from colossalai.logging import get_dist_logger -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.utils.memory import colo_device_memory_used from colossalai.zero.gemini.memory_tracer.utils import colo_model_mem_usage @@ -70,8 +66,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_zero_init_context(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_param_op.py b/tests/test_zero/test_legacy/test_param_op.py index 9ebacdb70..b91371b98 100644 --- a/tests/test_zero/test_legacy/test_param_op.py +++ b/tests/test_zero/test_legacy/test_param_op.py @@ -2,6 +2,7 @@ import copy import torch +from colossalai.testing import clear_cache_before_run from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr from tests.components_to_test.registry import non_distributed_component_funcs @@ -49,6 +50,7 @@ def run_model(model, inputs, label, criterion, use_param_hook=False): return hookwrapper.hook_triggered_times +@clear_cache_before_run() def test_base_param_hook(): test_models = ['repeated_computed_layers', 'resnet18', 'hanging_param_model', 'inline_op_model'] # test_models = ['bert'] diff --git a/tests/test_zero/test_legacy/test_shard_model_v2.py b/tests/test_zero/test_legacy/test_shard_model_v2.py index 884444adf..93d624aa2 100644 --- a/tests/test_zero/test_legacy/test_shard_model_v2.py +++ b/tests/test_zero/test_legacy/test_shard_model_v2.py @@ -1,17 +1,13 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from common import CONFIG, check_grads_padding, run_fwd_bwd from torch.nn.parallel import DistributedDataParallel as DDP import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy from colossalai.zero.legacy.sharded_model import ShardedModelV2 @@ -61,8 +57,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_shard_model_v2(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_shard_param.py b/tests/test_zero/test_legacy/test_shard_param.py index b76648321..4ba43edce 100644 --- a/tests/test_zero/test_legacy/test_shard_param.py +++ b/tests/test_zero/test_legacy/test_shard_param.py @@ -1,14 +1,11 @@ from copy import deepcopy -from functools import partial import pytest import torch -import torch.multiprocessing as mp from common import CONFIG, allclose import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.legacy.sharded_param import ShardedTensor @@ -39,8 +36,7 @@ def _run_shard_tensor(rank, world_size, port): @pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_shard_tensor(world_size): - run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_shard_tensor, world_size) def _run_shard_param_v2(rank, world_size, port): @@ -87,8 +83,7 @@ def _run_shard_param_v2(rank, world_size, port): @pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_shard_param_v2(world_size): - run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_shard_param_v2, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py b/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py index d257a0285..1ca144662 100644 --- a/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py +++ b/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py @@ -1,14 +1,10 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.nn.optimizer import HybridAdam from colossalai.tensor import ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import TensorShardStrategy @@ -86,8 +82,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_sharded_optim_state_dist(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_sharded_optim_v2.py b/tests/test_zero/test_legacy/test_sharded_optim_v2.py index 3eea13d5d..c6f77995e 100644 --- a/tests/test_zero/test_legacy/test_sharded_optim_v2.py +++ b/tests/test_zero/test_legacy/test_sharded_optim_v2.py @@ -1,17 +1,13 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from common import CONFIG, check_sharded_model_params from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import CPUAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy @@ -107,8 +103,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_sharded_optim_v2(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py b/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py index 05512f59a..61d850d06 100644 --- a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py +++ b/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py @@ -1,19 +1,15 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from torchvision.models import resnet50 import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import TensorShardStrategy @@ -84,9 +80,7 @@ def test_sharded_optim_with_sync_bn(): wanted if we are doing predictions. """ - world_size = 2 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 2) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_state_dict.py b/tests/test_zero/test_legacy/test_state_dict.py index 40d2820d8..5f76fff3e 100644 --- a/tests/test_zero/test_legacy/test_state_dict.py +++ b/tests/test_zero/test_legacy/test_state_dict.py @@ -5,12 +5,10 @@ from functools import partial import pytest import torch -import torch.multiprocessing as mp from common import CONFIG import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.legacy.sharded_model import ShardedModelV2 @@ -50,8 +48,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_zero_state_dict(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_tensor_utils.py b/tests/test_zero/test_legacy/test_tensor_utils.py index 311448170..238bc3fe1 100644 --- a/tests/test_zero/test_legacy/test_tensor_utils.py +++ b/tests/test_zero/test_legacy/test_tensor_utils.py @@ -1,12 +1,8 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor from colossalai.zero.legacy.gemini.tensor_utils import ( @@ -91,8 +87,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use() def test_zero_tensor_utils(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_zero_engine.py b/tests/test_zero/test_legacy/test_zero_engine.py index 1e7f53358..dc8847ce5 100644 --- a/tests/test_zero/test_legacy/test_zero_engine.py +++ b/tests/test_zero/test_legacy/test_zero_engine.py @@ -1,19 +1,15 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from common import MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy from colossalai.zero.low_level._utils import has_inf_or_nan @@ -96,16 +92,14 @@ def run_dist(rank, world_size, port, parallel_config): @pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use() def test_mp_engine(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size, parallel_config=MP_PARALLEL_CONFIG) @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_zero_engine(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG) if __name__ == '__main__': diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index 504df202e..2ae1f3a99 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -1,16 +1,14 @@ import copy -from functools import partial import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai +from colossalai.testing import spawn from colossalai.testing.random import seed_all -from colossalai.utils import free_port from colossalai.zero import LowLevelZeroOptimizer @@ -158,9 +156,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist def test_grad_accumulation(): - world_size = 2 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 2) if __name__ == '__main__': diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index ed76e0171..4086af9d8 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -1,17 +1,14 @@ import copy -from functools import partial import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai -from colossalai.testing import rerun_if_address_is_in_use +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all -from colossalai.utils import free_port from colossalai.zero import LowLevelZeroOptimizer @@ -179,9 +176,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_zero_1_2(): - world_size = 2 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 2) if __name__ == '__main__': diff --git a/tests/test_zero/test_low_level/test_zero_init.py b/tests/test_zero/test_low_level/test_zero_init.py index 803d0021d..aeeaff5b5 100644 --- a/tests/test_zero/test_low_level/test_zero_init.py +++ b/tests/test_zero/test_low_level/test_zero_init.py @@ -1,14 +1,12 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import torch.nn as nn import colossalai from colossalai.tensor import ProcessGroup -from colossalai.utils import free_port, get_current_device +from colossalai.testing import spawn +from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer @@ -51,9 +49,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist def test_zero_init(): - world_size = 4 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 4) if __name__ == '__main__': diff --git a/tests/test_zero/test_low_level/test_zero_tp.py b/tests/test_zero/test_low_level/test_zero_tp.py index bb7495583..f0804f4bb 100644 --- a/tests/test_zero/test_low_level/test_zero_tp.py +++ b/tests/test_zero/test_low_level/test_zero_tp.py @@ -1,16 +1,13 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai from colossalai.tensor import ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal @@ -89,9 +86,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_zero_with_tp(): - world_size = 4 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 4) if __name__ == '__main__':