Browse Source

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
pull/3343/head
Frank Lee 2 years ago committed by GitHub
parent
commit
80eba05b0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 19
      .github/workflows/build_on_pr.yml
  2. 8
      applications/Chat/tests/test_checkpoint.py
  3. 8
      applications/Chat/tests/test_data.py
  4. 3
      colossalai/cli/benchmark/benchmark.py
  5. 16
      colossalai/testing/__init__.py
  6. 88
      colossalai/testing/utils.py
  7. 2
      colossalai/utils/__init__.py
  8. 17
      colossalai/utils/common.py
  9. 1
      docs/requirements-doc-test.txt
  10. 7
      docs/source/en/basics/colotensor_concept.md
  11. 7
      docs/source/zh-Hans/basics/colotensor_concept.md
  12. 8
      examples/images/vit/test_vit.py
  13. 39
      examples/language/gpt/experiments/auto_offload/train_gpt_offload.py
  14. 7
      examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py
  15. 11
      examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py
  16. 6
      examples/tutorial/auto_parallel/auto_ckpt_solver_test.py
  17. 1
      requirements/requirements-test.txt
  18. 10
      tests/test_amp/test_naive_fp16.py
  19. 10
      tests/test_amp/test_torch_fp16.py
  20. 3
      tests/test_analyzer/test_fx/test_bias_addition.py
  21. 11
      tests/test_analyzer/test_fx/test_mod_dir.py
  22. 5
      tests/test_analyzer/test_fx/test_nested_ckpt.py
  23. 4
      tests/test_analyzer/test_fx/test_shape_prop.py
  24. 4
      tests/test_analyzer/test_fx/test_symbolic_profile.py
  25. 5
      tests/test_analyzer/test_subclasses/test_aten.py
  26. 4
      tests/test_analyzer/test_subclasses/test_flop_tensor.py
  27. 5
      tests/test_analyzer/test_subclasses/test_meta_mode.py
  28. 10
      tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py
  29. 17
      tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py
  30. 3
      tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py
  31. 10
      tests/test_auto_parallel/test_offload/test_perf.py
  32. 23
      tests/test_auto_parallel/test_offload/test_solver.py
  33. 2
      tests/test_auto_parallel/test_pass/test_node_converting_pass.py
  34. 2
      tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py
  35. 12
      tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py
  36. 12
      tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
  37. 10
      tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py
  38. 14
      tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
  39. 4
      tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py
  40. 9
      tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
  41. 6
      tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
  42. 3
      tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py
  43. 15
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py
  44. 10
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py
  45. 17
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py
  46. 27
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py
  47. 20
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py
  48. 25
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py
  49. 22
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py
  50. 15
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py
  51. 22
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py
  52. 23
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py
  53. 39
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py
  54. 17
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py
  55. 11
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py
  56. 12
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py
  57. 16
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py
  58. 39
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py
  59. 14
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
  60. 19
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py
  61. 3
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py
  62. 14
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py
  63. 2
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py
  64. 13
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py
  65. 11
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py
  66. 35
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
  67. 3
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py
  68. 5
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py
  69. 4
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py
  70. 21
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py
  71. 4
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py
  72. 4
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py
  73. 17
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py
  74. 18
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py
  75. 13
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py
  76. 3
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py
  77. 3
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py
  78. 14
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
  79. 2
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py
  80. 3
      tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
  81. 2
      tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py
  82. 2
      tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py
  83. 12
      tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py
  84. 12
      tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py
  85. 12
      tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py
  86. 2
      tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py
  87. 14
      tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py
  88. 14
      tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py
  89. 2
      tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py
  90. 12
      tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py
  91. 2
      tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py
  92. 19
      tests/test_booster/test_accelerator.py
  93. 10
      tests/test_booster/test_mixed_precision/test_fp16_torch.py
  94. 11
      tests/test_booster/test_plugin/test_gemini_plugin.py
  95. 10
      tests/test_booster/test_plugin/test_torch_ddp_plugin.py
  96. 4
      tests/test_checkpoint_io/test_general_checkpoint_io.py
  97. 11
      tests/test_cluster/test_device_mesh_manager.py
  98. 15
      tests/test_comm/test_boardcast_send_recv_v2.py
  99. 12
      tests/test_comm/test_comm.py
  100. 21
      tests/test_comm/test_object_list_p2p.py
  101. Some files were not shown because too many files have changed in this diff Show More

19
.github/workflows/build_on_pr.yml

@ -8,10 +8,10 @@ jobs:
detect: detect:
name: Detect file change name: Detect file change
if: | if: |
github.event.pull_request.draft == false && github.event.pull_request.draft == false &&
github.base_ref == 'main' && github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' &&
contains( github.event.pull_request.labels.*.name, 'Run Build and Test') contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
outputs: outputs:
changedExtenisonFiles: ${{ steps.find-extension-change.outputs.all_changed_files }} changedExtenisonFiles: ${{ steps.find-extension-change.outputs.all_changed_files }}
anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }} anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }}
@ -27,10 +27,10 @@ jobs:
- name: Locate base commit - name: Locate base commit
id: locate-base-sha id: locate-base-sha
run: | run: |
curBranch=$(git rev-parse --abbrev-ref HEAD) curBranch=$(git rev-parse --abbrev-ref HEAD)
commonCommit=$(git merge-base origin/main $curBranch) commonCommit=$(git merge-base origin/main $curBranch)
echo $commonCommit echo $commonCommit
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
- name: Find the changed extension-related files - name: Find the changed extension-related files
id: find-extension-change id: find-extension-change
@ -63,7 +63,6 @@ jobs:
echo "$file was changed" echo "$file was changed"
done done
build: build:
name: Build and Test Colossal-AI name: Build and Test Colossal-AI
needs: detect needs: detect
@ -124,7 +123,7 @@ jobs:
- name: Execute Unit Testing - name: Execute Unit Testing
if: needs.detect.outputs.anyLibraryFileChanged == 'true' if: needs.detect.outputs.anyLibraryFileChanged == 'true'
run: | run: |
PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/ CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/
env: env:
DATA: /data/scratch/cifar-10 DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1 NCCL_SHM_DISABLE: 1

8
applications/Chat/tests/test_checkpoint.py

@ -1,19 +1,16 @@
import os import os
import tempfile import tempfile
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
from coati.models.gpt import GPTActor from coati.models.gpt import GPTActor
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
@ -90,8 +87,7 @@ def run_dist(rank, world_size, port, strategy):
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini']) @pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_checkpoint(world_size, strategy): def test_checkpoint(world_size, strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy) spawn(run_dist, world_size, strategy=strategy)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

8
applications/Chat/tests/test_data.py

@ -1,11 +1,9 @@
import os import os
from copy import deepcopy from copy import deepcopy
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
from coati.experience_maker import NaiveExperienceMaker from coati.experience_maker import NaiveExperienceMaker
from coati.models.base import RewardModel from coati.models.base import RewardModel
from coati.models.gpt import GPTActor, GPTCritic from coati.models.gpt import GPTActor, GPTCritic
@ -13,8 +11,7 @@ from coati.replay_buffer import NaiveReplayBuffer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
@ -114,8 +111,7 @@ def run_dist(rank, world_size, port, strategy):
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai']) @pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_data(world_size, strategy): def test_data(world_size, strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy) spawn(run_dist, world_size, strategy=strategy)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

3
colossalai/cli/benchmark/benchmark.py

@ -10,7 +10,8 @@ from colossalai.context import Config
from colossalai.context.random import reset_seeds from colossalai.context.random import reset_seeds
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import MultiTimer, free_port from colossalai.testing import free_port
from colossalai.utils import MultiTimer
from .models import MLP from .models import MLP

16
colossalai/testing/__init__.py

@ -1,7 +1,17 @@
from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group from .comparison import assert_close, assert_close_loose, assert_equal, assert_equal_in_group, assert_not_equal
from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use, skip_if_not_enough_gpus from .pytest_wrapper import run_on_environment_flag
from .utils import (
clear_cache_before_run,
free_port,
parameterize,
rerun_if_address_is_in_use,
rerun_on_exception,
skip_if_not_enough_gpus,
spawn,
)
__all__ = [ __all__ = [
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus' 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn',
'clear_cache_before_run', 'run_on_environment_flag'
] ]

88
colossalai/testing/utils.py

@ -1,8 +1,13 @@
import gc
import random
import re import re
import torch import socket
from typing import Callable, List, Any
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from typing import Any, Callable, List
import torch
import torch.multiprocessing as mp
from packaging import version from packaging import version
@ -43,7 +48,7 @@ def parameterize(argument: str, values: List[Any]) -> Callable:
# > davis: hello # > davis: hello
# > davis: bye # > davis: bye
# > davis: stop # > davis: stop
Args: Args:
argument (str): the name of the argument to parameterize argument (str): the name of the argument to parameterize
values (List[Any]): a list of values to iterate for this argument values (List[Any]): a list of values to iterate for this argument
@ -85,13 +90,13 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
def test_method(): def test_method():
print('hey') print('hey')
raise RuntimeError('Address already in use') raise RuntimeError('Address already in use')
# rerun for infinite times if Runtime error occurs # rerun for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, max_try=None) @rerun_on_exception(exception_type=RuntimeError, max_try=None)
def test_method(): def test_method():
print('hey') print('hey')
raise RuntimeError('Address already in use') raise RuntimeError('Address already in use')
# rerun only the exception message is matched with pattern # rerun only the exception message is matched with pattern
# for infinite times if Runtime error occurs # for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$") @rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$")
@ -101,10 +106,10 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
Args: Args:
exception_type (Exception, Optional): The type of exception to detect for rerun exception_type (Exception, Optional): The type of exception to detect for rerun
pattern (str, Optional): The pattern to match the exception message. pattern (str, Optional): The pattern to match the exception message.
If the pattern is not None and matches the exception message, If the pattern is not None and matches the exception message,
the exception will be detected for rerun the exception will be detected for rerun
max_try (int, Optional): Maximum reruns for this function. The default value is 5. max_try (int, Optional): Maximum reruns for this function. The default value is 5.
If max_try is None, it will rerun foreven if exception keeps occurings If max_try is None, it will rerun foreven if exception keeps occurings
""" """
@ -202,3 +207,72 @@ def skip_if_not_enough_gpus(min_gpus: int):
return _execute_by_gpu_num return _execute_by_gpu_num
return _wrap_func return _wrap_func
def free_port() -> int:
"""Get a free port on localhost.
Returns:
int: A free port on localhost.
"""
while True:
port = random.randint(20000, 65000)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError:
continue
def spawn(func, nprocs=1, **kwargs):
"""
This function is used to spawn processes for testing.
Usage:
# must contians arguments rank, world_size, port
def do_something(rank, world_size, port):
...
spawn(do_something, nprocs=8)
# can also pass other arguments
def do_something(rank, world_size, port, arg1, arg2):
...
spawn(do_something, nprocs=8, arg1=1, arg2=2)
Args:
func (Callable): The function to be spawned.
nprocs (int, optional): The number of processes to spawn. Defaults to 1.
"""
port = free_port()
wrapped_func = partial(func, world_size=nprocs, port=port, **kwargs)
mp.spawn(wrapped_func, nprocs=nprocs)
def clear_cache_before_run():
"""
This function is a wrapper to clear CUDA and python cache before executing the function.
Usage:
@clear_cache_before_run()
def test_something():
...
"""
def _wrap_func(f):
def _clear_cache(*args, **kwargs):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_max_memory_cached()
torch.cuda.synchronize()
gc.collect()
f(*args, **kwargs)
return _clear_cache
return _wrap_func

2
colossalai/utils/__init__.py

@ -7,7 +7,6 @@ from .common import (
count_zeros_fp32, count_zeros_fp32,
disposable, disposable,
ensure_path_exists, ensure_path_exists,
free_port,
is_ddp_ignored, is_ddp_ignored,
is_dp_rank_0, is_dp_rank_0,
is_model_parallel_parameter, is_model_parallel_parameter,
@ -37,7 +36,6 @@ from .timer import MultiTimer, Timer
__all__ = [ __all__ = [
'checkpoint', 'checkpoint',
'free_port',
'print_rank_0', 'print_rank_0',
'sync_model_param', 'sync_model_param',
'is_ddp_ignored', 'is_ddp_ignored',

17
colossalai/utils/common.py

@ -50,23 +50,6 @@ def ensure_path_exists(filename: str):
Path(dirpath).mkdir(parents=True, exist_ok=True) Path(dirpath).mkdir(parents=True, exist_ok=True)
def free_port() -> int:
"""Get a free port on localhost.
Returns:
int: A free port on localhost.
"""
while True:
port = random.randint(20000, 65000)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError:
continue
def sync_model_param(model, parallel_mode): def sync_model_param(model, parallel_mode):
r"""Make sure data parameters are consistent during Data Parallel Mode. r"""Make sure data parameters are consistent during Data Parallel Mode.

1
docs/requirements-doc-test.txt

@ -4,3 +4,4 @@ packaging
tensornvme tensornvme
psutil psutil
transformers transformers
pytest

7
docs/source/en/basics/colotensor_concept.md

@ -56,12 +56,12 @@ Let's see an example. A ColoTensor is initialized and sharded on 8 GPUs using tp
```python ```python
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port, print_rank_0 from colossalai.utils import print_rank_0
from functools import partial from functools import partial
import colossalai import colossalai
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern
from colossalai.utils import free_port from colossalai.testing import spawn
import torch import torch
@ -83,8 +83,7 @@ def run_dist_tests(rank, world_size, port):
print_rank_0(f"shape {t1.shape}, {t1.data}") print_rank_0(f"shape {t1.shape}, {t1.data}")
def test_dist_cases(world_size): def test_dist_cases(world_size):
run_func = partial(run_dist_tests, world_size=world_size, port=free_port()) spawn(run_dist_tests, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_dist_cases(4) test_dist_cases(4)

7
docs/source/zh-Hans/basics/colotensor_concept.md

@ -57,12 +57,12 @@ ColoTensor 包含额外的属性[ColoTensorSpec](https://colossalai.readthedocs.
```python ```python
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port, print_rank_0 from colossalai.utils import print_rank_0
from functools import partial from functools import partial
import colossalai import colossalai
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern
from colossalai.utils import free_port from colossalai.testing import spawn
import torch import torch
@ -84,8 +84,7 @@ def run_dist_tests(rank, world_size, port):
print_rank_0(f"shape {t1.shape}, {t1.data}") print_rank_0(f"shape {t1.shape}, {t1.data}")
def test_dist_cases(world_size): def test_dist_cases(world_size):
run_func = partial(run_dist_tests, world_size=world_size, port=free_port()) spawn(run_dist_tests, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_dist_cases(4) test_dist_cases(4)

8
examples/images/vit/test_vit.py

@ -1,11 +1,9 @@
import os import os
import random import random
from functools import partial
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from vit import get_training_components from vit import get_training_components
@ -15,8 +13,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext from colossalai.zero import ColoInitContext
@ -156,8 +153,7 @@ def run_dist(rank, world_size, port, use_ddp):
@pytest.mark.parametrize('use_ddp', [False, True]) @pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_vit(world_size, use_ddp): def test_vit(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) spawn(run_dist, world_size, use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

39
examples/language/gpt/experiments/auto_offload/train_gpt_offload.py

@ -1,20 +1,20 @@
import time
import pytest
import argparse import argparse
from functools import partial import time
import pytest
import torch import torch
from model_zoo import GPTLMLoss, get_gpt2_components
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.fx.profiler import parameter_size
from colossalai.utils import free_port, get_current_device
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.auto_parallel.offload.solver import NOT_NVML
from model_zoo import get_gpt2_components, GPTLMLoss from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import spawn
from colossalai.utils import get_current_device
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -24,6 +24,7 @@ def parse_args():
parser.add_argument('--memory_budget', type=float, default=16) parser.add_argument('--memory_budget', type=float, default=16)
return parser.parse_args() return parser.parse_args()
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
def train_gpt(args): def train_gpt(args):
memory_budget = args.memory_budget * 1024 * 1024 * 1024 memory_budget = args.memory_budget * 1024 * 1024 * 1024
@ -33,13 +34,16 @@ def train_gpt(args):
# build model # build model
model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size) model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size)
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device()) label = torch.randint(low=0, high=128, size=(
64,
8,
), device=get_current_device())
criterion = GPTLMLoss() criterion = GPTLMLoss()
start_time = time.time() start_time = time.time()
model = model_builder() model = model_builder()
model.train() model.train()
param_size = parameter_size(model) / 1024 ** 2 / 2 param_size = parameter_size(model) / 1024**2 / 2
init_time = time.time() - start_time init_time = time.time() - start_time
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
@ -74,21 +78,20 @@ def train_gpt(args):
torch.cuda.synchronize() torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5 exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2
print(f'solver_type: {solver_type} | model_type: {model_type}') print(f'solver_type: {solver_type} | model_type: {model_type}')
print( print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|')
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
)
print(time_list) print(time_list)
def run(rank, world_size, port, args): def run(rank, world_size, port, args):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
train_gpt(args) train_gpt(args)
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
run_func = partial(run, world_size=1, port=free_port(), args=args) spawn(run, 1, args=args)
mp.spawn(run_func, nprocs=1)

7
examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py

@ -1,18 +1,13 @@
from functools import partial from functools import partial
from time import time from time import time
from typing import Dict, Optional, Tuple, Union
import psutil import psutil
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers import transformers
from gpt_modules import GPT2LMHeadModel, GPTLMLoss from gpt_modules import GPT2LMHeadModel, GPTLMLoss
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch_from_torch from colossalai.initialize import launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger

11
examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py

@ -1,19 +1,14 @@
import time
from argparse import ArgumentParser
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import torch import torch
import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
from bench_utils import bench, data_gen_resnet from bench_utils import bench, data_gen_resnet
import colossalai import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port from colossalai.testing import spawn
def _benchmark(rank, world_size, port): def _benchmark(rank, world_size, port):
@ -50,9 +45,7 @@ def _benchmark(rank, world_size, port):
def auto_activation_checkpoint_batchsize_benchmark(): def auto_activation_checkpoint_batchsize_benchmark():
world_size = 1 spawn(_benchmark, 1)
run_func_module = partial(_benchmark, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == "__main__": if __name__ == "__main__":

6
examples/tutorial/auto_parallel/auto_ckpt_solver_test.py

@ -4,14 +4,13 @@ from functools import partial
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium
import colossalai import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port from colossalai.testing import spawn
def _benchmark(rank, world_size, port, args): def _benchmark(rank, world_size, port, args):
@ -77,8 +76,7 @@ def _benchmark(rank, world_size, port, args):
def auto_activation_checkpoint_benchmark(args): def auto_activation_checkpoint_benchmark(args):
world_size = 1 world_size = 1
run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args) spawn(_benchmark, world_size, args=args)
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == "__main__": if __name__ == "__main__":

1
requirements/requirements-test.txt

@ -12,3 +12,4 @@ contexttimer
einops einops
triton==2.0.0.dev20221202 triton==2.0.0.dev20221202
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611

10
tests/test_amp/test_naive_fp16.py

@ -1,14 +1,11 @@
import copy import copy
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
@ -87,10 +84,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_naive_amp(): def test_naive_amp():
world_size = 1 spawn(run_dist, 1)
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

10
tests/test_amp/test_torch_fp16.py

@ -1,14 +1,11 @@
import copy import copy
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
@ -87,10 +84,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_torch_amp(): def test_torch_amp():
world_size = 1 spawn(run_dist, 1)
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

3
tests/test_analyzer/test_fx/test_bias_addition.py

@ -3,7 +3,7 @@ import torch
from packaging import version from packaging import version
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from colossalai.testing.utils import parameterize from colossalai.testing.utils import clear_cache_before_run, parameterize
try: try:
from colossalai._analyzer.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace
@ -81,6 +81,7 @@ class AddmmModel(torch.nn.Module):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize("bias", [True, False]) @parameterize("bias", [True, False])
@parameterize("bias_addition_split", [True, False]) @parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)]) @parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])

11
tests/test_analyzer/test_fx/test_mod_dir.py

@ -1,6 +1,8 @@
import pytest import pytest
import torch import torch
from colossalai.testing import clear_cache_before_run, parameterize
try: try:
from colossalai._analyzer.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace
except: except:
@ -62,9 +64,10 @@ class AModel(torch.nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize("bias", [True, False]) @clear_cache_before_run()
@pytest.mark.parametrize("bias_addition_split", [True, False]) @parameterize("bias", [True, False])
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)]) @parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
def test_mod_dir(bias, bias_addition_split, shape): def test_mod_dir(bias, bias_addition_split, shape):
model = AModel(bias=bias) model = AModel(bias=bias)
x = torch.rand(shape) x = torch.rand(shape)
@ -75,4 +78,4 @@ def test_mod_dir(bias, bias_addition_split, shape):
if __name__ == '__main__': if __name__ == '__main__':
test_mod_dir(True, True, (3, 3, 3)) test_mod_dir(bias=True, bias_addition_split=True, shape=(3, 3, 3))

5
tests/test_analyzer/test_fx/test_nested_ckpt.py

@ -1,7 +1,9 @@
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
import pytest
from colossalai.testing import clear_cache_before_run
try: try:
from colossalai._analyzer.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace
@ -42,6 +44,7 @@ class MyModule(nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@clear_cache_before_run()
def test_nested_ckpt(): def test_nested_ckpt():
model = MyModule() model = MyModule()
x = torch.rand(10, 10) x = torch.rand(10, 10)

4
tests/test_analyzer/test_fx/test_shape_prop.py

@ -3,7 +3,7 @@ import torch
import torchvision.models as tm import torchvision.models as tm
from packaging import version from packaging import version
from colossalai.testing.utils import parameterize from colossalai.testing.utils import clear_cache_before_run, parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try: try:
@ -32,6 +32,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tm_models) @parameterize('m', tm_models)
def test_torchvision_shape_prop(m): def test_torchvision_shape_prop(m):
with MetaTensorMode(): with MetaTensorMode():
@ -46,6 +47,7 @@ def test_torchvision_shape_prop(m):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tmm_models) @parameterize('m', tmm_models)
def test_timm_shape_prop(m): def test_timm_shape_prop(m):
with MetaTensorMode(): with MetaTensorMode():

4
tests/test_analyzer/test_fx/test_symbolic_profile.py

@ -3,7 +3,7 @@ import torch
import torchvision.models as tm import torchvision.models as tm
from packaging import version from packaging import version
from colossalai.testing.utils import parameterize from colossalai.testing.utils import clear_cache_before_run, parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try: try:
@ -19,6 +19,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tm_models) @parameterize('m', tm_models)
def test_torchvision_profile(m, verbose=False, bias_addition_split=False): def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode(): with MetaTensorMode():
@ -33,6 +34,7 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tmm_models) @parameterize('m', tmm_models)
def test_timm_profile(m, verbose=False, bias_addition_split=False): def test_timm_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode(): with MetaTensorMode():

5
tests/test_analyzer/test_subclasses/test_aten.py

@ -1,9 +1,11 @@
from typing import Any, Callable, Union from typing import Any, Callable, Union
import pytest
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.testing import clear_cache_before_run
try: try:
from colossalai._analyzer._subclasses import MetaTensor from colossalai._analyzer._subclasses import MetaTensor
except: except:
@ -72,6 +74,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@clear_cache_before_run()
def test_meta_aten(): def test_meta_aten():
for (aten_op, requires_backward), v in registered_meta.items(): for (aten_op, requires_backward), v in registered_meta.items():
for f, x in v: for f, x in v:

4
tests/test_analyzer/test_subclasses/test_flop_tensor.py

@ -4,6 +4,7 @@ import torch.nn.functional as F
import torchvision.models as tm import torchvision.models as tm
from packaging import version from packaging import version
from colossalai.testing import clear_cache_before_run, parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try: try:
@ -39,7 +40,8 @@ odd_cases = [
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('func, args, kwargs', odd_cases) @clear_cache_before_run()
@parameterize('func, args, kwargs', odd_cases)
def test_flop_count_function(func, args, kwargs): def test_flop_count_function(func, args, kwargs):
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True) rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}' assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}'

5
tests/test_analyzer/test_subclasses/test_meta_mode.py

@ -3,6 +3,8 @@ import torch
import torchvision.models as tm import torchvision.models as tm
from packaging import version from packaging import version
from colossalai.testing import clear_cache_before_run, parameterize
try: try:
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
except: except:
@ -30,7 +32,8 @@ def run_and_compare(model):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models + tmm_models) @clear_cache_before_run()
@parameterize('m', tm_models + tmm_models)
def test_meta_mode_shape(m): def test_meta_mode_shape(m):
run_and_compare(m()) run_and_compare(m())

10
tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py

@ -3,7 +3,6 @@ import copy
import pytest import pytest
import torch import torch
import torch.fx import torch.fx
import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
import colossalai import colossalai
@ -13,7 +12,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
# from colossalai.fx.passes.algorithms import solver_rotor # from colossalai.fx.passes.algorithms import solver_rotor
# from colossalai.fx.passes.algorithms.operation import Sequence # from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use, spawn
if is_compatible_with_meta(): if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor from colossalai.fx.profiler.tensor import MetaTensor
@ -26,8 +25,8 @@ except:
withcodegen = False withcodegen = False
def _run_C_solver_consistency_test(rank=0): def _run_C_solver_consistency_test(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]: for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
model = M() model = M()
@ -70,8 +69,9 @@ def _run_C_solver_consistency_test(rank=0):
@pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0") @pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0")
@rerun_if_address_is_in_use()
def test_C_solver_consistency(): def test_C_solver_consistency():
mp.spawn(_run_C_solver_consistency_test, nprocs=1) spawn(_run_C_solver_consistency_test, 1)
if __name__ == '__main__': if __name__ == '__main__':

17
tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py

@ -4,7 +4,6 @@ from typing import Callable
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
from torch.fx import GraphModule from torch.fx import GraphModule
@ -15,7 +14,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
# from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor # from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use, spawn
if is_compatible_with_meta(): if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor from colossalai.fx.profiler.tensor import MetaTensor
@ -68,8 +67,8 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
def _run_ckpt_solver(rank): def _run_ckpt_solver(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MODEL_LIST = [tm.densenet121] MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
@ -98,12 +97,13 @@ def _run_ckpt_solver(rank):
@pytest.mark.skip("TODO(super-dainiu): refactor all tests.") @pytest.mark.skip("TODO(super-dainiu): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
@rerun_if_address_is_in_use()
def test_ckpt_solver(): def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1) spawn(_run_ckpt_solver, 1)
def _run_ckpt_solver_torch11(rank): def _run_ckpt_solver_torch11(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MODEL_LIST = [tm.densenet121] MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
@ -131,8 +131,9 @@ def _run_ckpt_solver_torch11(rank):
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
@rerun_if_address_is_in_use()
def test_ckpt_solver_torch11(): def test_ckpt_solver_torch11():
mp.spawn(_run_ckpt_solver_torch11, nprocs=1) spawn(_run_ckpt_solver_torch11, 1)
if __name__ == '__main__': if __name__ == '__main__':

3
tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py

@ -8,6 +8,7 @@ from colossalai.fx.graph_module import ColoGraphModule
# from colossalai.fx.passes.algorithms import linearize, solver_rotor # from colossalai.fx.passes.algorithms import linearize, solver_rotor
# from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss) # from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.testing import clear_cache_before_run
if is_compatible_with_meta(): if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor from colossalai.fx.profiler.tensor import MetaTensor
@ -24,6 +25,7 @@ except:
@pytest.mark.skip(reason='TODO: modify the logger') @pytest.mark.skip(reason='TODO: modify the logger')
@pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
@clear_cache_before_run()
def test_linearize(): def test_linearize():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
tracer = ColoTracer() tracer = ColoTracer()
@ -84,6 +86,7 @@ def test_linearize():
@pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skip(reason="torch11 meta tensor not implemented") @pytest.mark.skip(reason="torch11 meta tensor not implemented")
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") @pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
@clear_cache_before_run()
def test_linearize_torch11(): def test_linearize_torch11():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
tracer = ColoTracer() tracer = ColoTracer()

10
tests/test_auto_parallel/test_offload/test_perf.py

@ -1,9 +1,7 @@
import time import time
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
import colossalai import colossalai
@ -12,8 +10,8 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.fx.profiler import parameter_size from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port, get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_auto_parallel.test_offload.model_utils import *
from tests.test_tensor.common_utils import set_seed from tests.test_tensor.common_utils import set_seed
@ -140,9 +138,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.skip("this test failed") @pytest.mark.skip("this test failed")
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
@rerun_if_address_is_in_use()
def test_perf(): def test_perf():
run_func = partial(run_dist, world_size=1, port=free_port()) spawn(run_dist, 1)
mp.spawn(run_func, nprocs=1)
if __name__ == '__main__': if __name__ == '__main__':

23
tests/test_auto_parallel/test_offload/test_solver.py

@ -3,20 +3,20 @@ import torch.fx
from torch.fx import GraphModule from torch.fx import GraphModule
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.auto_parallel.offload.region_manager import RegionManager
from colossalai.auto_parallel.offload.solver import NOT_NVML, SolverFactory
from colossalai.fx import ColoTracer, is_compatible_with_meta from colossalai.fx import ColoTracer, is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.auto_parallel.offload.region_manager import RegionManager from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML
from colossalai.testing import parameterize
from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_auto_parallel.test_offload.model_utils import *
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
@clear_cache_before_run()
@parameterize('model_name', ['gpt2_', 'bert_']) @parameterize('model_name', ['gpt2_', 'bert_'])
@parameterize('memory_budget', [4000]) @parameterize('memory_budget', [4000])
@parameterize('solver_name', ['syn', 'asyn']) @parameterize('solver_name', ['syn', 'asyn'])
def solver_test(model_name: str, def solver_test(model_name: str, memory_budget: float, solver_name: str):
memory_budget: float,
solver_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, data_gen = get_components_func() model_builder, data_gen = get_components_func()
@ -52,11 +52,16 @@ def solver_test(model_name: str,
for region in region_list: for region in region_list:
need_offload = region.need_offload need_offload = region.need_offload
to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None
print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}') print(
f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}'
)
for region in region_list.__reversed__(): for region in region_list.__reversed__():
need_offload = region.need_offload need_offload = region.need_offload
to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None
print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}') print(
f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}'
)
if __name__ == '__main__': if __name__ == '__main__':
solver_test() solver_test()

2
tests/test_auto_parallel/test_pass/test_node_converting_pass.py

@ -6,6 +6,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import clear_cache_before_run
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
@ -26,6 +27,7 @@ def insert_narrow(gm, x_node):
return gm return gm
@clear_cache_before_run()
def test_node_args_converting_pass(): def test_node_args_converting_pass():
model = TestModule() model = TestModule()
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)

2
tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py

@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import clear_cache_before_run
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
@ -36,6 +37,7 @@ def recover_narrow(gm, narrow_node):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_size_value_converting_pass(): def test_size_value_converting_pass():
model = TestModule() model = TestModule()
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)

12
tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py

@ -2,7 +2,6 @@ from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
try: try:
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
@ -13,9 +12,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, rerun_if_address_is_in_use from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
class LinearModel(torch.nn.Module): class LinearModel(torch.nn.Module):
@ -86,11 +83,8 @@ def check_conv_module(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_bias_addition_module(): def test_bias_addition_module():
world_size = 4 spawn(check_linear_module, 4)
run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port()) spawn(check_conv_module, 4)
mp.spawn(run_func_linear, nprocs=world_size)
run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port())
mp.spawn(run_func_conv, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

12
tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py

@ -1,9 +1,7 @@
from functools import partial from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from transformers.pytorch_utils import Conv1D from transformers.pytorch_utils import Conv1D
@ -17,9 +15,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
HIDDEN_SIZE = 16 HIDDEN_SIZE = 16
@ -65,9 +61,7 @@ def check_act_ckpt(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_mlp_layer(): def test_mlp_layer():
world_size = 4 spawn(check_act_ckpt, 4)
run_func = partial(check_act_ckpt, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

10
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py

@ -1,9 +1,7 @@
import copy import copy
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
try: try:
@ -15,9 +13,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, rerun_if_address_is_in_use from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
class MLP(torch.nn.Module): class MLP(torch.nn.Module):
@ -102,9 +98,7 @@ def check_compatibility_with_ddp(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_compatibility_with_ddp(): def test_compatibility_with_ddp():
world_size = 4 spawn(check_compatibility_with_ddp, 4)
run_func = partial(check_compatibility_with_ddp, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

14
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py

@ -1,10 +1,7 @@
import copy import copy
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
try: try:
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
@ -17,10 +14,9 @@ from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.process_group import ProcessGroup from colossalai.tensor.process_group import ProcessGroup
from colossalai.testing import assert_close, rerun_if_address_is_in_use from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.utils import get_current_device
from colossalai.utils import free_port, get_current_device from colossalai.zero import post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
class MLP(torch.nn.Module): class MLP(torch.nn.Module):
@ -110,9 +106,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_auto_parallel_with_gemini(): def test_auto_parallel_with_gemini():
world_size = 4 spawn(check_auto_parallel_with_gemini, 4)
run_func = partial(check_auto_parallel_with_gemini, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

4
tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py

@ -10,8 +10,7 @@ from colossalai._analyzer.fx.passes import shape_prop_pass
# from colossalai.fx.tracer.tracer import ColoTracer # from colossalai.fx.tracer.tracer import ColoTracer
from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.testing import parameterize from colossalai.testing import clear_cache_before_run, parameterize, run_on_environment_flag
from colossalai.testing.pytest_wrapper import run_on_environment_flag
NUM_REPEAT_BLOCKS = 4 NUM_REPEAT_BLOCKS = 4
BATCH_SIZE = 1 BATCH_SIZE = 1
@ -81,6 +80,7 @@ class NonRepeatModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
@parameterize('model_cls', [RepeatModel, NonRepeatModel]) @parameterize('model_cls', [RepeatModel, NonRepeatModel])
def test_repeat_blocks(model_cls): def test_repeat_blocks(model_cls):

9
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py

@ -1,12 +1,10 @@
import copy import copy
import random import random
from functools import partial
from typing import Dict from typing import Dict
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import transformers import transformers
from torch.fx import GraphModule from torch.fx import GraphModule
@ -30,9 +28,8 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import to_global from colossalai.tensor.shape_consistency import to_global
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
BATCH_SIZE = 1 BATCH_SIZE = 1
@ -190,9 +187,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) @parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_mlp_layer(model_cls): def test_mlp_layer(model_cls):
world_size = 4 spawn(check_attention_layer, 4, model_cls=model_cls)
run_func = partial(check_attention_layer, model_cls=model_cls, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

6
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py

@ -1,5 +1,4 @@
import torch import torch
import torch.nn as nn
import transformers import transformers
from torch.fx import GraphModule from torch.fx import GraphModule
@ -7,10 +6,10 @@ from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing import parameterize from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
@ -20,6 +19,7 @@ HIDDEN_DIM = 384
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) @parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls): def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM)

3
tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py

@ -6,6 +6,8 @@ from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes import shape_prop_pass from colossalai._analyzer.fx.passes import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import clear_cache_before_run
class LinearModel(nn.Module): class LinearModel(nn.Module):
@ -26,6 +28,7 @@ class LinearModel(nn.Module):
@pytest.mark.skip('meta tensor has some bugs in 1.11') @pytest.mark.skip('meta tensor has some bugs in 1.11')
@clear_cache_before_run()
def test_liveness_analysis(): def test_liveness_analysis():
model = LinearModel() model = LinearModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)

15
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py

@ -1,23 +1,14 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.meta_profiler import meta_register from colossalai.auto_parallel.meta_profiler import meta_register
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.device.device_mesh import DeviceMesh from colossalai.testing.utils import clear_cache_before_run, parameterize
from colossalai.fx import ColoGraphModule, ColoTracer from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize('func', [ @parameterize('func', [
torch.nn.functional.softmax, torch.nn.functional.softmax,
torch.nn.functional.relu, torch.nn.functional.relu,

10
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py

@ -1,8 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
@ -10,8 +7,7 @@ from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
@ -62,9 +58,7 @@ def _binary_elementwise_mem_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_binary_elementwise_meta_concrete_info_match(): def test_binary_elementwise_meta_concrete_info_match():
world_size = 4 spawn(_binary_elementwise_mem_test, 4)
run_func_module = partial(_binary_elementwise_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

17
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py

@ -1,17 +1,12 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
@ -25,7 +20,7 @@ class ConvFunctionModule(nn.Module):
return nn.functional.conv2d(input, self.conv_weight) return nn.functional.conv2d(input, self.conv_weight)
def _conv_module_mem_test(rank, bias, world_size, port): def _conv_module_mem_test(rank, world_size, port, bias):
"""This function is for conv memory test """This function is for conv memory test
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
@ -62,9 +57,7 @@ def _conv_module_mem_test(rank, bias, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_conv_meta_concrete_info_match(bias=False): def test_conv_meta_concrete_info_match(bias=False):
world_size = 4 spawn(_conv_module_mem_test, 4, bias=bias)
run_func_module = partial(_conv_module_mem_test, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
def _conv_function_mem_test(rank, world_size, port): def _conv_function_mem_test(rank, world_size, port):
@ -103,9 +96,7 @@ def _conv_function_mem_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_conv_function_concrete_info_match(): def test_conv_function_concrete_info_match():
world_size = 4 spawn(_conv_function_mem_test, 4)
run_func_module = partial(_conv_function_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

27
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py

@ -1,33 +1,16 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.testing.utils import clear_cache_before_run
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0': if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register from colossalai.auto_parallel.meta_profiler import meta_register
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_embedding_meta_info(): def test_embedding_meta_info():
meta_func = meta_register.get(torch.nn.Embedding) meta_func = meta_register.get(torch.nn.Embedding)

20
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py

@ -1,24 +1,14 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
class MyModule(nn.Module): class MyModule(nn.Module):
@ -63,9 +53,7 @@ def _linear_module_mem_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_linear_module_meta_concrete_info_match(): def test_linear_module_meta_concrete_info_match():
world_size = 4 spawn(_linear_module_mem_test, 4)
run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
def _linear_function_mem_test(rank, world_size, port): def _linear_function_mem_test(rank, world_size, port):
@ -101,9 +89,7 @@ def _linear_function_mem_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_linear_function_meta_concrete_info_match(): def test_linear_function_meta_concrete_info_match():
world_size = 4 spawn(_linear_function_mem_test, 4)
run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

25
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py

@ -1,26 +1,8 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
from colossalai.testing.utils import clear_cache_before_run, parameterize
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0': if torch.__version__ >= '1.12.0':
@ -28,6 +10,7 @@ if torch.__version__ >= '1.12.0':
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize( @parameterize(
'tensor_shapes', 'tensor_shapes',
[ [

22
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py

@ -1,29 +1,17 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
if torch.__version__ >= '1.12.0': if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register from colossalai.auto_parallel.meta_profiler import meta_register
def _batchnorm_module_mem_test(rank, world_size, port): def _batchnorm_module_mem_test(rank, world_size, port):
@ -62,9 +50,7 @@ def _batchnorm_module_mem_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_batchnorm_meta_concrete_info_match(): def test_batchnorm_meta_concrete_info_match():
world_size = 4 spawn(_batchnorm_module_mem_test, 4)
run_func_module = partial(_batchnorm_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations') @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations')

15
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py

@ -1,17 +1,12 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
@ -51,9 +46,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_adaptiveavgpool_meta_concrete_info_match(): def test_adaptiveavgpool_meta_concrete_info_match():
world_size = 4 spawn(_adaptiveavgpool_module_mem_test, 4)
run_func_module = partial(_adaptiveavgpool_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
def _maxpool_module_mem_test(rank, world_size, port): def _maxpool_module_mem_test(rank, world_size, port):
@ -92,9 +85,7 @@ def _maxpool_module_mem_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_maxpool_meta_concrete_info_match(): def test_maxpool_meta_concrete_info_match():
world_size = 4 spawn(_maxpool_module_mem_test, 4)
run_func_module = partial(_maxpool_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

22
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py

@ -1,26 +1,9 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.testing.utils import clear_cache_before_run
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0': if torch.__version__ >= '1.12.0':
@ -37,6 +20,7 @@ class SplitModule(nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_tensor_meta_info(): def test_tensor_meta_info():
"""test tensor related meta information """test tensor related meta information
We will just use torch.Tensor.split for the test We will just use torch.Tensor.split for the test

23
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py

@ -1,24 +1,8 @@
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
from colossalai.testing.utils import clear_cache_before_run
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0': if torch.__version__ >= '1.12.0':
@ -26,6 +10,7 @@ if torch.__version__ >= '1.12.0':
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_where_meta_info(): def test_where_meta_info():
meta_func = meta_register.get(torch.where) meta_func = meta_register.get(torch.where)

39
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
@ -11,9 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -45,7 +40,7 @@ class AddBMMTorchFunctionModule(nn.Module):
return output return output
def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = module(using_kwargs).cuda() model = module(using_kwargs).cuda()
@ -249,14 +244,13 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por
@parameterize('using_kwargs', [True, False]) @parameterize('using_kwargs', [True, False])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_2d_device_mesh(module, bias_shape, using_kwargs): def test_2d_device_mesh(module, bias_shape, using_kwargs):
world_size = 4 spawn(
run_func = partial(check_2d_device_mesh, check_2d_device_mesh,
module=module, 4,
bias_shape=bias_shape, module=module,
world_size=world_size, bias_shape=bias_shape,
using_kwargs=using_kwargs, using_kwargs=using_kwargs,
port=free_port()) )
mp.spawn(run_func, nprocs=world_size)
@pytest.mark.skip("skip due to bias cases not ready") @pytest.mark.skip("skip due to bias cases not ready")
@ -267,14 +261,13 @@ def test_2d_device_mesh(module, bias_shape, using_kwargs):
@parameterize('using_kwargs', [True, False]) @parameterize('using_kwargs', [True, False])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_1d_device_mesh(module, bias_shape, using_kwargs): def test_1d_device_mesh(module, bias_shape, using_kwargs):
world_size = 4 spawn(
run_func = partial(check_1d_device_mesh, check_1d_device_mesh,
module=module, 4,
bias_shape=bias_shape, module=module,
using_kwargs=using_kwargs, bias_shape=bias_shape,
world_size=world_size, using_kwargs=using_kwargs,
port=free_port()) )
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

17
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -17,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -45,7 +40,7 @@ class AddmmModel_with_param(nn.Module):
return x return x
def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port): def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if model_cls == AddmmModel: if model_cls == AddmmModel:
@ -189,13 +184,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
@parameterize('model_cls', [AddmmModel, AddmmModel_with_param]) @parameterize('model_cls', [AddmmModel, AddmmModel_with_param])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_addmm_handler(input_shape, model_cls): def test_addmm_handler(input_shape, model_cls):
world_size = 4 spawn(check_addmm_function_handler, 4, input_shape=input_shape, model_cls=model_cls)
run_func_function = partial(check_addmm_function_handler,
input_shape=input_shape,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func_function, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

11
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -114,9 +109,7 @@ def check_bn_module_handler(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_bn_module_handler(): def test_bn_module_handler():
world_size = 4 spawn(check_bn_module_handler, 4)
run_func = partial(check_bn_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

12
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py

@ -1,9 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -19,9 +15,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
WEIGHT_SHAPE = (32, 16) WEIGHT_SHAPE = (32, 16)
@ -168,9 +162,7 @@ def check_linear_module_handler(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_linear_handler(): def test_linear_handler():
world_size = 4 spawn(check_linear_module_handler)
run_func_module = partial(check_linear_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

16
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py

@ -1,14 +1,10 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData, OperationData,
OperationDataType, OperationDataType,
@ -18,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -35,7 +29,7 @@ class LinearModule(torch.nn.Module):
return x return x
def check_linear_module_handler(rank, bias, world_size, port): def check_linear_module_handler(rank, world_size, port, bias):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModule(16, 32, bias=bias).cuda() model = LinearModule(16, 32, bias=bias).cuda()
@ -157,9 +151,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_linear_handler(bias=True): def test_linear_handler(bias=True):
world_size = 4 spawn(check_linear_module_handler, bias=bias)
run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

39
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size, port): def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@ -149,7 +144,7 @@ class BEOpModelWithIntConst(nn.Module):
return out return out
def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port): def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@ -236,13 +231,12 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_binary_elementwise_handler_with_tensor(op, other_dim): def test_binary_elementwise_handler_with_tensor(op, other_dim):
world_size = 4 spawn(
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor, check_binary_elementwise_handler_with_tensor,
op=op, 4,
other_dim=other_dim, op=op,
world_size=world_size, other_dim=other_dim,
port=free_port()) )
mp.spawn(run_func_tensor, nprocs=world_size)
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@ -252,14 +246,13 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_binary_elementwise_handler_with_int(op, model_cls, other_dim): def test_binary_elementwise_handler_with_int(op, model_cls, other_dim):
world_size = 4 spawn(
run_func_int = partial(check_binary_elementwise_handler_with_int, check_binary_elementwise_handler_with_int,
op=op, 4,
model_cls=model_cls, op=op,
other_dim=other_dim, model_cls=model_cls,
world_size=world_size, other_dim=other_dim,
port=free_port()) )
mp.spawn(run_func_int, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

14
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -207,11 +202,8 @@ def check_1d_device_mesh(rank, module, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_bmm_handler(module): def test_bmm_handler(module):
world_size = 4 spawn(check_2d_device_mesh, 4, module=module)
run_func_2d = partial(check_2d_device_mesh, module=module, world_size=world_size, port=free_port()) spawn(check_1d_device_mesh, 4, module=module)
mp.spawn(run_func_2d, nprocs=world_size)
run_func_1d = partial(check_1d_device_mesh, module=module, world_size=world_size, port=free_port())
mp.spawn(run_func_1d, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

19
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_conv_module_handler(rank, bias, world_size, port): def check_conv_module_handler(rank, world_size, port, bias):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda() model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda()
@ -155,7 +150,7 @@ class ConvModel(nn.Module):
return x return x
def check_conv_function_handler(rank, bias, world_size, port): def check_conv_function_handler(rank, world_size, port, bias):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = ConvModel().cuda() model = ConvModel().cuda()
@ -302,9 +297,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
# @parameterize('bias', [True, False]) # @parameterize('bias', [True, False])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_conv_module_handler(bias=False): def test_conv_module_handler(bias=False):
world_size = 4 spawn(check_conv_module_handler, 4, bias=bias)
run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@ -314,9 +307,7 @@ def test_conv_module_handler(bias=False):
# @parameterize('bias', [True, False]) # @parameterize('bias', [True, False])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_conv_function_handler(bias=False): def test_conv_function_handler(bias=False):
world_size = 4 spawn(check_conv_function_handler, 4, bias=bias)
run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

3
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py

@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHan
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class ReshapeModel(nn.Module): class ReshapeModel(nn.Module):
@ -23,6 +23,7 @@ class ReshapeModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_reshape_handler(): def test_reshape_handler():
model = ReshapeModel() model = ReshapeModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)

14
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -16,9 +13,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
NUM_EMBEDDINGS = 16 NUM_EMBEDDINGS = 16
@ -272,18 +268,14 @@ def check_embedding_function_handler(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_embedding_module_handler(): def test_embedding_module_handler():
world_size = 4 spawn(check_embedding_module_handler, 4)
run_func = partial(check_embedding_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_embedding_function_handler(): def test_embedding_function_handler():
world_size = 4 spawn(check_embedding_function_handler, 4)
run_func = partial(check_embedding_function_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

2
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py

@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import clear_cache_before_run
class GetattrModel(nn.Module): class GetattrModel(nn.Module):
@ -22,6 +23,7 @@ class GetattrModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_getattr_handler(): def test_getattr_handler():
model = GetattrModel() model = GetattrModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)

13
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py

@ -2,7 +2,6 @@ from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -14,12 +13,10 @@ from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import Li
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -103,12 +100,7 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))]) # @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))])
@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))]) @parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))])
def test_getitem_from_tensor_handler(getitem_index): def test_getitem_from_tensor_handler(getitem_index):
world_size = 4 spawn(check_getitem_from_tensor_handler, 4)
run_func = partial(check_getitem_from_tensor_handler,
getitem_index=getitem_index,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
class GetItemFromTupleModel(nn.Module): class GetItemFromTupleModel(nn.Module):
@ -123,6 +115,7 @@ class GetItemFromTupleModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_getitem_from_tuple_handler(): def test_getitem_from_tuple_handler():
model = GetItemFromTupleModel() model = GetItemFromTupleModel()
tracer = ColoTracer() tracer = ColoTracer()

11
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -11,12 +8,10 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -104,9 +99,7 @@ def check_ln_module_handler(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_ln_module_handler(): def test_ln_module_handler():
world_size = 4 spawn(check_ln_module_handler, 4)
run_func = partial(check_ln_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

35
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -18,14 +15,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize from colossalai.testing.utils import parameterize
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_linear_module_handler(rank, bias, input_shape, world_size, port): def check_linear_module_handler(rank, world_size, port, bias, input_shape):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda() model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
@ -172,7 +168,7 @@ class LinearModel(nn.Module):
return x return x
def check_linear_function_handler(rank, bias, input_shape, world_size, port): def check_linear_function_handler(rank, world_size, port, bias, input_shape):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModel().cuda() model = LinearModel().cuda()
@ -313,19 +309,18 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_linear_handler(input_shape, bias=False): def test_linear_handler(input_shape, bias=False):
world_size = 4 spawn(
run_func_module = partial(check_linear_module_handler, check_linear_module_handler,
bias=bias, 4,
input_shape=input_shape, bias=bias,
world_size=world_size, input_shape=input_shape,
port=free_port()) )
mp.spawn(run_func_module, nprocs=world_size) spawn(
run_func_function = partial(check_linear_function_handler, check_linear_function_handler,
bias=bias, 4,
input_shape=input_shape, bias=bias,
world_size=world_size, input_shape=input_shape,
port=free_port()) )
mp.spawn(run_func_function, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

3
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py

@ -18,7 +18,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector, StrategiesVector,
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.utils import parameterize from colossalai.testing.utils import clear_cache_before_run, parameterize
class MatMulModule(nn.Module): class MatMulModule(nn.Module):
@ -28,6 +28,7 @@ class MatMulModule(nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize( @parameterize(
'tensor_shapes', 'tensor_shapes',
[ [

5
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py

@ -1,4 +1,3 @@
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -8,11 +7,11 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.testing import clear_cache_before_run, run_on_environment_flag
from colossalai.testing.pytest_wrapper import run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_norm_pool_handler(): def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)

4
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py

@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import clear_cache_before_run, parameterize
class OutputModel(nn.Module): class OutputModel(nn.Module):
@ -23,7 +23,7 @@ class OutputModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('output_option', ['distributed', 'replicated']) @parameterize('output_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use() @clear_cache_before_run()
def test_output_handler(output_option): def test_output_handler(output_option):
model = OutputModel() model = OutputModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)

21
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py

@ -2,7 +2,6 @@ from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -15,9 +14,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -55,7 +53,7 @@ class LinearReshapeModel(nn.Module):
return permute_node return permute_node
def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, port): def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if call_function == torch.permute: if call_function == torch.permute:
@ -328,14 +326,13 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))]) @parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))])
@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel]) @parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel])
def test_view_handler(call_function, reshape_dims, model_cls): def test_view_handler(call_function, reshape_dims, model_cls):
world_size = 4 spawn(
run_func = partial(check_view_handler, check_view_handler,
call_function=call_function, 4,
reshape_dims=reshape_dims, call_function=call_function,
model_cls=model_cls, reshape_dims=reshape_dims,
world_size=world_size, model_cls=model_cls,
port=free_port()) )
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

4
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py

@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import clear_cache_before_run, parameterize
class PlaceholderModel(nn.Module): class PlaceholderModel(nn.Module):
@ -22,7 +22,7 @@ class PlaceholderModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('placeholder_option', ['distributed', 'replicated']) @parameterize('placeholder_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use() @clear_cache_before_run()
def test_placeholder_handler(placeholder_option): def test_placeholder_handler(placeholder_option):
model = PlaceholderModel() model = PlaceholderModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)

4
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py

@ -1,5 +1,4 @@
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -9,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHan
from colossalai.auto_parallel.tensor_shard.options import ShardOption from colossalai.auto_parallel.tensor_shard.options import ShardOption
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class LinearModel(nn.Module): class LinearModel(nn.Module):
@ -108,6 +107,7 @@ def check_shard_option(shard_option):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_shard_option(): def test_shard_option():
# for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]: # for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]:
for shard_option in [ShardOption.SHARD_LAST_AXIS]: for shard_option in [ShardOption.SHARD_LAST_AXIS]:

17
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -33,7 +28,7 @@ class LinearSplitModel(nn.Module):
return softmax_node return softmax_node
def check_split_handler(rank, softmax_dim, model_cls, world_size, port): def check_split_handler(rank, world_size, port, softmax_dim, model_cls):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = model_cls(softmax_dim=softmax_dim).cuda() model = model_cls(softmax_dim=softmax_dim).cuda()
@ -176,13 +171,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
@parameterize('softmax_dim', [0, 1, 2, 3]) @parameterize('softmax_dim', [0, 1, 2, 3])
@parameterize('model_cls', [LinearSplitModel]) @parameterize('model_cls', [LinearSplitModel])
def test_split_handler(softmax_dim, model_cls): def test_split_handler(softmax_dim, model_cls):
world_size = 4 spawn(check_split_handler, 4, softmax_dim=softmax_dim, model_cls=model_cls)
run_func = partial(check_split_handler,
softmax_dim=softmax_dim,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

18
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -47,7 +42,7 @@ class LinearSplitModel(nn.Module):
return split_node return split_node
def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port): def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = model_cls(split_size=split_size, split_dim=split_dim).cuda() model = model_cls(split_size=split_size, split_dim=split_dim).cuda()
@ -258,14 +253,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
@parameterize('split_dim', [0, 1, 2]) @parameterize('split_dim', [0, 1, 2])
@parameterize('model_cls', [ConvSplitModel, LinearSplitModel]) @parameterize('model_cls', [ConvSplitModel, LinearSplitModel])
def test_split_handler(split_size, split_dim, model_cls): def test_split_handler(split_size, split_dim, model_cls):
world_size = 4 spawn(check_split_handler, 4, split_size=split_size, split_dim=split_dim, model_cls=model_cls)
run_func = partial(check_split_handler,
split_size=split_size,
split_dim=split_dim,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

13
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -14,9 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -36,7 +31,7 @@ class LinearSumModel(nn.Module):
return sum_node return sum_node
def check_sum_handler(rank, sum_dims, keepdim, world_size, port): def check_sum_handler(rank, world_size, port, sum_dims, keepdim):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda() model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda()
@ -228,9 +223,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
@parameterize('sum_dims', [(0, 2), 1]) @parameterize('sum_dims', [(0, 2), 1])
@parameterize('keepdim', [False, True]) @parameterize('keepdim', [False, True])
def test_sum_handler(sum_dims, keepdim): def test_sum_handler(sum_dims, keepdim):
world_size = 4 spawn(check_sum_handler, 4, sum_dims=sum_dims, keepdim=keepdim)
run_func = partial(check_sum_handler, sum_dims=sum_dims, keepdim=keepdim, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

3
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py

@ -7,7 +7,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class TensorConstructorModel(nn.Module): class TensorConstructorModel(nn.Module):
@ -22,6 +22,7 @@ class TensorConstructorModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_where_handler(): def test_where_handler():
model = TensorConstructorModel() model = TensorConstructorModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)

3
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py

@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class ReLuModel(nn.Module): class ReLuModel(nn.Module):
@ -24,6 +24,7 @@ class ReLuModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_elementwise_handler(): def test_elementwise_handler():
model = ReLuModel() model = ReLuModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)

14
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -15,9 +12,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -255,13 +251,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
@parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)]) @parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)])
@parameterize('model_cls', [ConvViewModel, LinearViewModel]) @parameterize('model_cls', [ConvViewModel, LinearViewModel])
def test_view_handler(tgt_shape, model_cls): def test_view_handler(tgt_shape, model_cls):
world_size = 4 spawn(check_view_handler, 4, tgt_shape=tgt_shape, model_cls=model_cls)
run_func = partial(check_view_handler,
tgt_shape=tgt_shape,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

2
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py

@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import clear_cache_before_run
class ConvModel(nn.Module): class ConvModel(nn.Module):
@ -21,6 +22,7 @@ class ConvModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_where_handler(): def test_where_handler():
model = ConvModel() model = ConvModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)

3
tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py

@ -10,10 +10,11 @@ from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing import clear_cache_before_run, run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_cost_graph(): def test_cost_graph():
physical_mesh_id = torch.arange(0, 8) physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4) mesh_shape = (2, 4)

2
tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py

@ -8,7 +8,7 @@ import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE: if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen

2
tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py

@ -9,7 +9,7 @@ from colossalai.autochunk.utils import flat_list
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE: if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen

12
tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py

@ -1,10 +1,8 @@
from functools import partial
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import pytest import pytest
import torch import torch
import torch.fx import torch.fx
import torch.multiprocessing as mp
try: try:
from fastfold.model.nn.evoformer import EvoformerBlock from fastfold.model.nn.evoformer import EvoformerBlock
@ -15,6 +13,7 @@ except:
from test_autochunk_alphafold_utils import run_test from test_autochunk_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_model(): def get_model():
@ -66,18 +65,19 @@ def get_chunk_target() -> Dict:
not (AUTOCHUNK_AVAILABLE and HAS_REPO), not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",
) )
@pytest.mark.parametrize("max_memory", [None, 20, 24]) @clear_cache_before_run()
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) @parameterize("max_memory", [None, 20, 24])
@parameterize("data_args", [(32, 64)])
def test_evoformer_block(data_args, max_memory): def test_evoformer_block(data_args, max_memory):
run_func = partial( spawn(
run_test, run_test,
1,
data_args=data_args, data_args=data_args,
max_memory=max_memory, max_memory=max_memory,
get_model=get_model, get_model=get_model,
get_data=get_data, get_data=get_data,
get_chunk_target=get_chunk_target, get_chunk_target=get_chunk_target,
) )
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":

12
tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py

@ -1,10 +1,8 @@
from functools import partial
from typing import List, Tuple from typing import List, Tuple
import pytest import pytest
import torch import torch
import torch.fx import torch.fx
import torch.multiprocessing as mp
try: try:
from fastfold.model.nn.evoformer import EvoformerStack from fastfold.model.nn.evoformer import EvoformerStack
@ -15,6 +13,7 @@ except:
from test_autochunk_alphafold_utils import run_test from test_autochunk_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_model(): def get_model():
@ -61,17 +60,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO), not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",
) )
@pytest.mark.parametrize("max_memory", [None, 20, 24]) @clear_cache_before_run()
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) @parameterize("max_memory", [None, 20, 24])
@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len)
def test_evoformer_stack(data_args, max_memory): def test_evoformer_stack(data_args, max_memory):
run_func = partial( spawn(
run_test, run_test,
1,
data_args=data_args, data_args=data_args,
max_memory=max_memory, max_memory=max_memory,
get_model=get_model, get_model=get_model,
get_data=get_data, get_data=get_data,
) )
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":

12
tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py

@ -1,10 +1,8 @@
from functools import partial
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import pytest import pytest
import torch import torch
import torch.fx import torch.fx
import torch.multiprocessing as mp
try: try:
from fastfold.model.nn.evoformer import ExtraMSABlock from fastfold.model.nn.evoformer import ExtraMSABlock
@ -14,6 +12,7 @@ except:
from test_autochunk_alphafold_utils import run_test from test_autochunk_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_model(): def get_model():
@ -57,17 +56,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO), not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",
) )
@pytest.mark.parametrize("max_memory", [None, 20, 24]) @clear_cache_before_run()
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) @parameterize("max_memory", [None, 20, 24])
@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len)
def test_extramsa_block(data_args, max_memory): def test_extramsa_block(data_args, max_memory):
run_func = partial( spawn(
run_test, run_test,
1,
data_args=data_args, data_args=data_args,
max_memory=max_memory, max_memory=max_memory,
get_model=get_model, get_model=get_model,
get_data=get_data, get_data=get_data,
) )
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":

2
tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py

@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE: if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen

14
tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py

@ -1,9 +1,7 @@
from functools import partial
from typing import List, Tuple from typing import List, Tuple
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
try: try:
from diffusers import UNet2DModel from diffusers import UNet2DModel
@ -16,6 +14,7 @@ except:
from test_autochunk_diffuser_utils import run_test from test_autochunk_diffuser_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
BATCH_SIZE = 1 BATCH_SIZE = 1
HEIGHT = 448 HEIGHT = 448
@ -37,17 +36,18 @@ def get_data(shape: tuple) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO), not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",
) )
@pytest.mark.parametrize("model", MODELS) @clear_cache_before_run()
@pytest.mark.parametrize("shape", [LATENTS_SHAPE]) @parameterize("model", MODELS)
@pytest.mark.parametrize("max_memory", [None, 150, 300]) @parameterize("shape", [LATENTS_SHAPE])
@parameterize("max_memory", [None, 150, 300])
def test_evoformer_block(model, shape, max_memory): def test_evoformer_block(model, shape, max_memory):
run_func = partial( spawn(
run_test, run_test,
1,
max_memory=max_memory, max_memory=max_memory,
model=model, model=model,
data=get_data(shape), data=get_data(shape),
) )
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":

14
tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py

@ -1,9 +1,7 @@
from functools import partial
from typing import List, Tuple from typing import List, Tuple
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
try: try:
from transformers import GPT2Config, GPT2Model from transformers import GPT2Config, GPT2Model
@ -16,6 +14,7 @@ except:
from test_autochunk_transformer_utils import run_test from test_autochunk_transformer_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
BATCH_SIZE = 1 BATCH_SIZE = 1
SEQ_LENGTH = 512 SEQ_LENGTH = 512
@ -35,18 +34,19 @@ def get_data(shape: tuple) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO), not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",
) )
@pytest.mark.parametrize("model", MODELS) @clear_cache_before_run()
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)]) @parameterize("model", MODELS)
@pytest.mark.parametrize("max_memory", [None, 6, 8]) @parameterize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
@parameterize("max_memory", [None, 6, 8])
def test_autochunk_gpt(model, shape, max_memory): def test_autochunk_gpt(model, shape, max_memory):
run_func = partial( spawn(
run_test, run_test,
1,
data=get_data(shape), data=get_data(shape),
max_memory=max_memory, max_memory=max_memory,
model=model, model=model,
config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4), config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4),
) )
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":

2
tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py

@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE: if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen

12
tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py

@ -1,9 +1,7 @@
from functools import partial
from typing import List, Tuple from typing import List, Tuple
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
try: try:
from timm.models.vision_transformer import vit_large_patch16_384 as vit from timm.models.vision_transformer import vit_large_patch16_384 as vit
@ -16,6 +14,7 @@ except:
from test_autochunk_vit_utils import run_test from test_autochunk_vit_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_data() -> Tuple[List, List]: def get_data() -> Tuple[List, List]:
@ -28,16 +27,17 @@ def get_data() -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO), not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",
) )
@pytest.mark.parametrize("model", MODELS) @clear_cache_before_run()
@pytest.mark.parametrize("max_memory", [None, 32, 40]) @parameterize("model", MODELS)
@parameterize("max_memory", [None, 32, 40])
def test_evoformer_block(model, max_memory): def test_evoformer_block(model, max_memory):
run_func = partial( spawn(
run_test, run_test,
1,
max_memory=max_memory, max_memory=max_memory,
model=model, model=model,
data=get_data(), data=get_data(),
) )
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":

2
tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py

@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE: if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen

19
tests/test_booster/test_accelerator.py

@ -1,27 +1,14 @@
from functools import partial
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.booster.accelerator import Accelerator from colossalai.booster.accelerator import Accelerator
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import clear_cache_before_run, parameterize
@clear_cache_before_run()
@parameterize('device', ['cpu', 'cuda']) @parameterize('device', ['cpu', 'cuda'])
def run_accelerator(device): def test_accelerator(device):
acceleartor = Accelerator(device) acceleartor = Accelerator(device)
model = nn.Linear(8, 8) model = nn.Linear(8, 8)
model = acceleartor.configure_model(model) model = acceleartor.configure_model(model)
assert next(model.parameters()).device.type == device assert next(model.parameters()).device.type == device
del model, acceleartor del model, acceleartor
def run_dist(rank):
run_accelerator()
@rerun_if_address_is_in_use()
def test_accelerator():
world_size = 1
run_func = partial(run_dist)
mp.spawn(run_func, nprocs=world_size)

10
tests/test_booster/test_mixed_precision/test_fp16_torch.py

@ -1,13 +1,9 @@
from functools import partial
import torch import torch
import torch.multiprocessing as mp
from torch.optim import Adam from torch.optim import Adam
import colossalai import colossalai
from colossalai.booster.mixed_precision import FP16TorchMixedPrecision from colossalai.booster.mixed_precision import FP16TorchMixedPrecision
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
@ -41,6 +37,4 @@ def run_torch_amp(rank, world_size, port):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_torch_ddp_plugin(): def test_torch_ddp_plugin():
world_size = 1 spawn(run_torch_amp, 1)
run_func = partial(run_torch_amp, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

11
tests/test_booster/test_plugin/test_gemini_plugin.py

@ -1,17 +1,12 @@
from functools import partial
import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin from colossalai.booster.plugin import GeminiPlugin
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
@ -119,9 +114,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True): def test_gemini_plugin(early_stop: bool = True):
world_size = 2 spawn(run_dist, 2, early_stop=early_stop)
run_func = partial(run_dist, world_size=world_size, port=free_port(), early_stop=early_stop)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

10
tests/test_booster/test_plugin/test_torch_ddp_plugin.py

@ -1,8 +1,5 @@
from functools import partial
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD from torch.optim import SGD
@ -10,8 +7,7 @@ import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
@ -103,6 +99,4 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_torch_ddp_plugin(): def test_torch_ddp_plugin():
world_size = 2 spawn(run_dist, 2)
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

4
tests/test_checkpoint_io/test_general_checkpoint_io.py

@ -6,6 +6,7 @@ from torch.optim import Adam
from torchvision.models import resnet18 from torchvision.models import resnet18
from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.testing import clear_cache_before_run, parameterize
# ======== # ========
# Note: # Note:
@ -15,7 +16,8 @@ from colossalai.checkpoint_io import GeneralCheckpointIO
# ======== # ========
@pytest.mark.parametrize('use_safetensors', [True, False]) @clear_cache_before_run()
@parameterize('use_safetensors', [True, False])
def test_unsharded_checkpoint(use_safetensors: bool): def test_unsharded_checkpoint(use_safetensors: bool):
# create a model and optimizer # create a model and optimizer
model = resnet18() model = resnet18()

11
tests/test_cluster/test_device_mesh_manager.py

@ -1,14 +1,9 @@
from functools import partial
import torch import torch
import torch.multiprocessing as mp
from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer import ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port from colossalai.testing import spawn
def check_device_mesh_manager(rank, world_size, port): def check_device_mesh_manager(rank, world_size, port):
@ -31,9 +26,7 @@ def check_device_mesh_manager(rank, world_size, port):
def test_device_mesh_manager(): def test_device_mesh_manager():
world_size = 4 spawn(check_device_mesh_manager, 4)
run_func = partial(check_device_mesh_manager, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

15
tests/test_comm/test_boardcast_send_recv_v2.py

@ -1,17 +1,12 @@
from functools import partial
from typing import List
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp from colossalai.communication.p2p_v2 import _recv_object, _send_object
from colossalai.communication.p2p_v2 import _send_object, _recv_object, init_process_group
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
disable_existing_loggers() disable_existing_loggers()
world_size = 4 world_size = 4
@ -45,9 +40,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_object_list_p2p(): def test_object_list_p2p():
disable_existing_loggers() spawn(check_layer, world_size)
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

12
tests/test_comm/test_comm.py

@ -1,15 +1,13 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.communication import all_gather, all_reduce, reduce_scatter from colossalai.communication import all_gather, all_reduce, reduce_scatter
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import get_current_device
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
@ -66,9 +64,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_comm(): def test_comm():
world_size = 4 spawn(check_layer, 4)
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

21
tests/test_comm/test_object_list_p2p.py

@ -1,15 +1,18 @@
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp from colossalai.communication.p2p import (
from colossalai.communication.p2p import send_forward, recv_forward, send_backward, recv_backward, send_forward_recv_backward, send_backward_recv_forward recv_backward,
recv_forward,
send_backward,
send_backward_recv_forward,
send_forward,
send_forward_recv_backward,
)
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import rerun_if_address_is_in_use
CONFIG = dict(parallel=dict(pipeline=2)) CONFIG = dict(parallel=dict(pipeline=2))
torch.manual_seed(123) torch.manual_seed(123)
@ -96,9 +99,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_object_list_p2p(): def test_object_list_p2p():
world_size = 2 spawn(check_layer, 2)
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save