Browse Source

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
pull/3343/head
Frank Lee 2 years ago committed by GitHub
parent
commit
80eba05b0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 19
      .github/workflows/build_on_pr.yml
  2. 8
      applications/Chat/tests/test_checkpoint.py
  3. 8
      applications/Chat/tests/test_data.py
  4. 3
      colossalai/cli/benchmark/benchmark.py
  5. 16
      colossalai/testing/__init__.py
  6. 78
      colossalai/testing/utils.py
  7. 2
      colossalai/utils/__init__.py
  8. 17
      colossalai/utils/common.py
  9. 1
      docs/requirements-doc-test.txt
  10. 7
      docs/source/en/basics/colotensor_concept.md
  11. 7
      docs/source/zh-Hans/basics/colotensor_concept.md
  12. 8
      examples/images/vit/test_vit.py
  13. 39
      examples/language/gpt/experiments/auto_offload/train_gpt_offload.py
  14. 7
      examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py
  15. 11
      examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py
  16. 6
      examples/tutorial/auto_parallel/auto_ckpt_solver_test.py
  17. 1
      requirements/requirements-test.txt
  18. 10
      tests/test_amp/test_naive_fp16.py
  19. 10
      tests/test_amp/test_torch_fp16.py
  20. 3
      tests/test_analyzer/test_fx/test_bias_addition.py
  21. 11
      tests/test_analyzer/test_fx/test_mod_dir.py
  22. 5
      tests/test_analyzer/test_fx/test_nested_ckpt.py
  23. 4
      tests/test_analyzer/test_fx/test_shape_prop.py
  24. 4
      tests/test_analyzer/test_fx/test_symbolic_profile.py
  25. 5
      tests/test_analyzer/test_subclasses/test_aten.py
  26. 4
      tests/test_analyzer/test_subclasses/test_flop_tensor.py
  27. 5
      tests/test_analyzer/test_subclasses/test_meta_mode.py
  28. 10
      tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py
  29. 17
      tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py
  30. 3
      tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py
  31. 10
      tests/test_auto_parallel/test_offload/test_perf.py
  32. 21
      tests/test_auto_parallel/test_offload/test_solver.py
  33. 2
      tests/test_auto_parallel/test_pass/test_node_converting_pass.py
  34. 2
      tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py
  35. 12
      tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py
  36. 12
      tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
  37. 10
      tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py
  38. 14
      tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
  39. 4
      tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py
  40. 9
      tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
  41. 6
      tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
  42. 3
      tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py
  43. 15
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py
  44. 10
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py
  45. 17
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py
  46. 27
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py
  47. 20
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py
  48. 25
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py
  49. 22
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py
  50. 15
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py
  51. 22
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py
  52. 23
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py
  53. 39
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py
  54. 17
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py
  55. 11
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py
  56. 12
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py
  57. 16
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py
  58. 39
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py
  59. 14
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
  60. 19
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py
  61. 3
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py
  62. 14
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py
  63. 2
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py
  64. 13
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py
  65. 11
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py
  66. 35
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
  67. 3
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py
  68. 5
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py
  69. 4
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py
  70. 21
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py
  71. 4
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py
  72. 4
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py
  73. 17
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py
  74. 18
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py
  75. 13
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py
  76. 3
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py
  77. 3
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py
  78. 14
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
  79. 2
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py
  80. 3
      tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
  81. 2
      tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py
  82. 2
      tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py
  83. 12
      tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py
  84. 12
      tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py
  85. 12
      tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py
  86. 2
      tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py
  87. 14
      tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py
  88. 14
      tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py
  89. 2
      tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py
  90. 12
      tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py
  91. 2
      tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py
  92. 19
      tests/test_booster/test_accelerator.py
  93. 10
      tests/test_booster/test_mixed_precision/test_fp16_torch.py
  94. 11
      tests/test_booster/test_plugin/test_gemini_plugin.py
  95. 10
      tests/test_booster/test_plugin/test_torch_ddp_plugin.py
  96. 4
      tests/test_checkpoint_io/test_general_checkpoint_io.py
  97. 11
      tests/test_cluster/test_device_mesh_manager.py
  98. 15
      tests/test_comm/test_boardcast_send_recv_v2.py
  99. 12
      tests/test_comm/test_comm.py
  100. 21
      tests/test_comm/test_object_list_p2p.py
  101. Some files were not shown because too many files have changed in this diff Show More

19
.github/workflows/build_on_pr.yml

@ -8,10 +8,10 @@ jobs:
detect:
name: Detect file change
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' &&
contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' &&
contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
outputs:
changedExtenisonFiles: ${{ steps.find-extension-change.outputs.all_changed_files }}
anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }}
@ -27,10 +27,10 @@ jobs:
- name: Locate base commit
id: locate-base-sha
run: |
curBranch=$(git rev-parse --abbrev-ref HEAD)
commonCommit=$(git merge-base origin/main $curBranch)
echo $commonCommit
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
curBranch=$(git rev-parse --abbrev-ref HEAD)
commonCommit=$(git merge-base origin/main $curBranch)
echo $commonCommit
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
- name: Find the changed extension-related files
id: find-extension-change
@ -63,7 +63,6 @@ jobs:
echo "$file was changed"
done
build:
name: Build and Test Colossal-AI
needs: detect
@ -124,7 +123,7 @@ jobs:
- name: Execute Unit Testing
if: needs.detect.outputs.anyLibraryFileChanged == 'true'
run: |
PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/
env:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1

8
applications/Chat/tests/test_checkpoint.py

@ -1,19 +1,16 @@
import os
import tempfile
from contextlib import nullcontext
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from coati.models.gpt import GPTActor
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
@ -90,8 +87,7 @@ def run_dist(rank, world_size, port, strategy):
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size, strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size, strategy=strategy)
if __name__ == '__main__':

8
applications/Chat/tests/test_data.py

@ -1,11 +1,9 @@
import os
from copy import deepcopy
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from coati.experience_maker import NaiveExperienceMaker
from coati.models.base import RewardModel
from coati.models.gpt import GPTActor, GPTCritic
@ -13,8 +11,7 @@ from coati.replay_buffer import NaiveReplayBuffer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
@ -114,8 +111,7 @@ def run_dist(rank, world_size, port, strategy):
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
@rerun_if_address_is_in_use()
def test_data(world_size, strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size, strategy=strategy)
if __name__ == '__main__':

3
colossalai/cli/benchmark/benchmark.py

@ -10,7 +10,8 @@ from colossalai.context import Config
from colossalai.context.random import reset_seeds
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import MultiTimer, free_port
from colossalai.testing import free_port
from colossalai.utils import MultiTimer
from .models import MLP

16
colossalai/testing/__init__.py

@ -1,7 +1,17 @@
from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group
from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use, skip_if_not_enough_gpus
from .comparison import assert_close, assert_close_loose, assert_equal, assert_equal_in_group, assert_not_equal
from .pytest_wrapper import run_on_environment_flag
from .utils import (
clear_cache_before_run,
free_port,
parameterize,
rerun_if_address_is_in_use,
rerun_on_exception,
skip_if_not_enough_gpus,
spawn,
)
__all__ = [
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus'
'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn',
'clear_cache_before_run', 'run_on_environment_flag'
]

78
colossalai/testing/utils.py

@ -1,8 +1,13 @@
import gc
import random
import re
import torch
from typing import Callable, List, Any
import socket
from functools import partial
from inspect import signature
from typing import Any, Callable, List
import torch
import torch.multiprocessing as mp
from packaging import version
@ -202,3 +207,72 @@ def skip_if_not_enough_gpus(min_gpus: int):
return _execute_by_gpu_num
return _wrap_func
def free_port() -> int:
"""Get a free port on localhost.
Returns:
int: A free port on localhost.
"""
while True:
port = random.randint(20000, 65000)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError:
continue
def spawn(func, nprocs=1, **kwargs):
"""
This function is used to spawn processes for testing.
Usage:
# must contians arguments rank, world_size, port
def do_something(rank, world_size, port):
...
spawn(do_something, nprocs=8)
# can also pass other arguments
def do_something(rank, world_size, port, arg1, arg2):
...
spawn(do_something, nprocs=8, arg1=1, arg2=2)
Args:
func (Callable): The function to be spawned.
nprocs (int, optional): The number of processes to spawn. Defaults to 1.
"""
port = free_port()
wrapped_func = partial(func, world_size=nprocs, port=port, **kwargs)
mp.spawn(wrapped_func, nprocs=nprocs)
def clear_cache_before_run():
"""
This function is a wrapper to clear CUDA and python cache before executing the function.
Usage:
@clear_cache_before_run()
def test_something():
...
"""
def _wrap_func(f):
def _clear_cache(*args, **kwargs):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_max_memory_cached()
torch.cuda.synchronize()
gc.collect()
f(*args, **kwargs)
return _clear_cache
return _wrap_func

2
colossalai/utils/__init__.py

@ -7,7 +7,6 @@ from .common import (
count_zeros_fp32,
disposable,
ensure_path_exists,
free_port,
is_ddp_ignored,
is_dp_rank_0,
is_model_parallel_parameter,
@ -37,7 +36,6 @@ from .timer import MultiTimer, Timer
__all__ = [
'checkpoint',
'free_port',
'print_rank_0',
'sync_model_param',
'is_ddp_ignored',

17
colossalai/utils/common.py

@ -50,23 +50,6 @@ def ensure_path_exists(filename: str):
Path(dirpath).mkdir(parents=True, exist_ok=True)
def free_port() -> int:
"""Get a free port on localhost.
Returns:
int: A free port on localhost.
"""
while True:
port = random.randint(20000, 65000)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError:
continue
def sync_model_param(model, parallel_mode):
r"""Make sure data parameters are consistent during Data Parallel Mode.

1
docs/requirements-doc-test.txt

@ -4,3 +4,4 @@ packaging
tensornvme
psutil
transformers
pytest

7
docs/source/en/basics/colotensor_concept.md

@ -56,12 +56,12 @@ Let's see an example. A ColoTensor is initialized and sharded on 8 GPUs using tp
```python
import torch
import torch.multiprocessing as mp
from colossalai.utils import free_port, print_rank_0
from colossalai.utils import print_rank_0
from functools import partial
import colossalai
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern
from colossalai.utils import free_port
from colossalai.testing import spawn
import torch
@ -83,8 +83,7 @@ def run_dist_tests(rank, world_size, port):
print_rank_0(f"shape {t1.shape}, {t1.data}")
def test_dist_cases(world_size):
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist_tests, world_size)
if __name__ == '__main__':
test_dist_cases(4)

7
docs/source/zh-Hans/basics/colotensor_concept.md

@ -57,12 +57,12 @@ ColoTensor 包含额外的属性[ColoTensorSpec](https://colossalai.readthedocs.
```python
import torch
import torch.multiprocessing as mp
from colossalai.utils import free_port, print_rank_0
from colossalai.utils import print_rank_0
from functools import partial
import colossalai
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern
from colossalai.utils import free_port
from colossalai.testing import spawn
import torch
@ -84,8 +84,7 @@ def run_dist_tests(rank, world_size, port):
print_rank_0(f"shape {t1.shape}, {t1.data}")
def test_dist_cases(world_size):
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist_tests, world_size)
if __name__ == '__main__':
test_dist_cases(4)

8
examples/images/vit/test_vit.py

@ -1,11 +1,9 @@
import os
import random
from functools import partial
import numpy as np
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from vit import get_training_components
@ -15,8 +13,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
@ -156,8 +153,7 @@ def run_dist(rank, world_size, port, use_ddp):
@pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use()
def test_vit(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size, use_ddp=use_ddp)
if __name__ == '__main__':

39
examples/language/gpt/experiments/auto_offload/train_gpt_offload.py

@ -1,20 +1,20 @@
import time
import pytest
import argparse
from functools import partial
import time
import pytest
import torch
from model_zoo import GPTLMLoss, get_gpt2_components
from torch.utils._pytree import tree_map
import torch.multiprocessing as mp
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.fx.profiler import parameter_size
from colossalai.utils import free_port, get_current_device
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
from model_zoo import get_gpt2_components, GPTLMLoss
from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import spawn
from colossalai.utils import get_current_device
def parse_args():
parser = argparse.ArgumentParser()
@ -24,6 +24,7 @@ def parse_args():
parser.add_argument('--memory_budget', type=float, default=16)
return parser.parse_args()
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
def train_gpt(args):
memory_budget = args.memory_budget * 1024 * 1024 * 1024
@ -33,13 +34,16 @@ def train_gpt(args):
# build model
model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size)
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device())
label = torch.randint(low=0, high=128, size=(
64,
8,
), device=get_current_device())
criterion = GPTLMLoss()
start_time = time.time()
model = model_builder()
model.train()
param_size = parameter_size(model) / 1024 ** 2 / 2
param_size = parameter_size(model) / 1024**2 / 2
init_time = time.time() - start_time
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
@ -74,21 +78,20 @@ def train_gpt(args):
torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2
print(f'solver_type: {solver_type} | model_type: {model_type}')
print(
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
)
print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|')
print(time_list)
def run(rank, world_size, port, args):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
train_gpt(args)
if __name__ == '__main__':
args = parse_args()
run_func = partial(run, world_size=1, port=free_port(), args=args)
mp.spawn(run_func, nprocs=1)
spawn(run, 1, args=args)

7
examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py

@ -1,18 +1,13 @@
from functools import partial
from time import time
from typing import Dict, Optional, Tuple, Union
import psutil
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from gpt_modules import GPT2LMHeadModel, GPTLMLoss
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize
from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger

11
examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py

@ -1,19 +1,14 @@
import time
from argparse import ArgumentParser
from copy import deepcopy
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import bench, data_gen_resnet
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
from colossalai.testing import spawn
def _benchmark(rank, world_size, port):
@ -50,9 +45,7 @@ def _benchmark(rank, world_size, port):
def auto_activation_checkpoint_batchsize_benchmark():
world_size = 1
run_func_module = partial(_benchmark, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_benchmark, 1)
if __name__ == "__main__":

6
examples/tutorial/auto_parallel/auto_ckpt_solver_test.py

@ -4,14 +4,13 @@ from functools import partial
import matplotlib.pyplot as plt
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
from colossalai.testing import spawn
def _benchmark(rank, world_size, port, args):
@ -77,8 +76,7 @@ def _benchmark(rank, world_size, port, args):
def auto_activation_checkpoint_benchmark(args):
world_size = 1
run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args)
mp.spawn(run_func_module, nprocs=world_size)
spawn(_benchmark, world_size, args=args)
if __name__ == "__main__":

1
requirements/requirements-test.txt

@ -12,3 +12,4 @@ contexttimer
einops
triton==2.0.0.dev20221202
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611

10
tests/test_amp/test_naive_fp16.py

@ -1,14 +1,11 @@
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs
@ -87,10 +84,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_naive_amp():
world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, 1)
if __name__ == '__main__':

10
tests/test_amp/test_torch_fp16.py

@ -1,14 +1,11 @@
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs
@ -87,10 +84,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_torch_amp():
world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, 1)
if __name__ == '__main__':

3
tests/test_analyzer/test_fx/test_bias_addition.py

@ -3,7 +3,7 @@ import torch
from packaging import version
from torch.utils.checkpoint import checkpoint
from colossalai.testing.utils import parameterize
from colossalai.testing.utils import clear_cache_before_run, parameterize
try:
from colossalai._analyzer.fx import symbolic_trace
@ -81,6 +81,7 @@ class AddmmModel(torch.nn.Module):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize("bias", [True, False])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])

11
tests/test_analyzer/test_fx/test_mod_dir.py

@ -1,6 +1,8 @@
import pytest
import torch
from colossalai.testing import clear_cache_before_run, parameterize
try:
from colossalai._analyzer.fx import symbolic_trace
except:
@ -62,9 +64,10 @@ class AModel(torch.nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("bias_addition_split", [True, False])
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
@clear_cache_before_run()
@parameterize("bias", [True, False])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
def test_mod_dir(bias, bias_addition_split, shape):
model = AModel(bias=bias)
x = torch.rand(shape)
@ -75,4 +78,4 @@ def test_mod_dir(bias, bias_addition_split, shape):
if __name__ == '__main__':
test_mod_dir(True, True, (3, 3, 3))
test_mod_dir(bias=True, bias_addition_split=True, shape=(3, 3, 3))

5
tests/test_analyzer/test_fx/test_nested_ckpt.py

@ -1,7 +1,9 @@
import pytest
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import pytest
from colossalai.testing import clear_cache_before_run
try:
from colossalai._analyzer.fx import symbolic_trace
@ -42,6 +44,7 @@ class MyModule(nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@clear_cache_before_run()
def test_nested_ckpt():
model = MyModule()
x = torch.rand(10, 10)

4
tests/test_analyzer/test_fx/test_shape_prop.py

@ -3,7 +3,7 @@ import torch
import torchvision.models as tm
from packaging import version
from colossalai.testing.utils import parameterize
from colossalai.testing.utils import clear_cache_before_run, parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try:
@ -32,6 +32,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tm_models)
def test_torchvision_shape_prop(m):
with MetaTensorMode():
@ -46,6 +47,7 @@ def test_torchvision_shape_prop(m):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tmm_models)
def test_timm_shape_prop(m):
with MetaTensorMode():

4
tests/test_analyzer/test_fx/test_symbolic_profile.py

@ -3,7 +3,7 @@ import torch
import torchvision.models as tm
from packaging import version
from colossalai.testing.utils import parameterize
from colossalai.testing.utils import clear_cache_before_run, parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try:
@ -19,6 +19,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tm_models)
def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode():
@ -33,6 +34,7 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tmm_models)
def test_timm_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode():

5
tests/test_analyzer/test_subclasses/test_aten.py

@ -1,9 +1,11 @@
from typing import Any, Callable, Union
import pytest
import pytest
import torch
import torch.nn as nn
from colossalai.testing import clear_cache_before_run
try:
from colossalai._analyzer._subclasses import MetaTensor
except:
@ -72,6 +74,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@clear_cache_before_run()
def test_meta_aten():
for (aten_op, requires_backward), v in registered_meta.items():
for f, x in v:

4
tests/test_analyzer/test_subclasses/test_flop_tensor.py

@ -4,6 +4,7 @@ import torch.nn.functional as F
import torchvision.models as tm
from packaging import version
from colossalai.testing import clear_cache_before_run, parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try:
@ -39,7 +40,8 @@ odd_cases = [
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('func, args, kwargs', odd_cases)
@clear_cache_before_run()
@parameterize('func, args, kwargs', odd_cases)
def test_flop_count_function(func, args, kwargs):
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}'

5
tests/test_analyzer/test_subclasses/test_meta_mode.py

@ -3,6 +3,8 @@ import torch
import torchvision.models as tm
from packaging import version
from colossalai.testing import clear_cache_before_run, parameterize
try:
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
except:
@ -30,7 +32,8 @@ def run_and_compare(model):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models + tmm_models)
@clear_cache_before_run()
@parameterize('m', tm_models + tmm_models)
def test_meta_mode_shape(m):
run_and_compare(m())

10
tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py

@ -3,7 +3,6 @@ import copy
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
import torchvision.models as tm
import colossalai
@ -13,7 +12,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
# from colossalai.fx.passes.algorithms import solver_rotor
# from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
@ -26,8 +25,8 @@ except:
withcodegen = False
def _run_C_solver_consistency_test(rank=0):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
def _run_C_solver_consistency_test(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
model = M()
@ -70,8 +69,9 @@ def _run_C_solver_consistency_test(rank=0):
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0")
@rerun_if_address_is_in_use()
def test_C_solver_consistency():
mp.spawn(_run_C_solver_consistency_test, nprocs=1)
spawn(_run_C_solver_consistency_test, 1)
if __name__ == '__main__':

17
tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py

@ -4,7 +4,6 @@ from typing import Callable
import pytest
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from torch.fx import GraphModule
@ -15,7 +14,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule
# from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
@ -68,8 +67,8 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
def _run_ckpt_solver(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
def _run_ckpt_solver(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
@ -98,12 +97,13 @@ def _run_ckpt_solver(rank):
@pytest.mark.skip("TODO(super-dainiu): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
@rerun_if_address_is_in_use()
def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1)
spawn(_run_ckpt_solver, 1)
def _run_ckpt_solver_torch11(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
def _run_ckpt_solver_torch11(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
@ -131,8 +131,9 @@ def _run_ckpt_solver_torch11(rank):
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
@rerun_if_address_is_in_use()
def test_ckpt_solver_torch11():
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
spawn(_run_ckpt_solver_torch11, 1)
if __name__ == '__main__':

3
tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py

@ -8,6 +8,7 @@ from colossalai.fx.graph_module import ColoGraphModule
# from colossalai.fx.passes.algorithms import linearize, solver_rotor
# from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.testing import clear_cache_before_run
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
@ -24,6 +25,7 @@ except:
@pytest.mark.skip(reason='TODO: modify the logger')
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
@clear_cache_before_run()
def test_linearize():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
tracer = ColoTracer()
@ -84,6 +86,7 @@ def test_linearize():
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skip(reason="torch11 meta tensor not implemented")
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
@clear_cache_before_run()
def test_linearize_torch11():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
tracer = ColoTracer()

10
tests/test_auto_parallel/test_offload/test_perf.py

@ -1,9 +1,7 @@
import time
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.utils._pytree import tree_map
import colossalai
@ -12,8 +10,8 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize
from colossalai.utils import free_port, get_current_device
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
from tests.test_auto_parallel.test_offload.model_utils import *
from tests.test_tensor.common_utils import set_seed
@ -140,9 +138,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.skip("this test failed")
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
@rerun_if_address_is_in_use()
def test_perf():
run_func = partial(run_dist, world_size=1, port=free_port())
mp.spawn(run_func, nprocs=1)
spawn(run_dist, 1)
if __name__ == '__main__':

21
tests/test_auto_parallel/test_offload/test_solver.py

@ -3,20 +3,20 @@ import torch.fx
from torch.fx import GraphModule
from torch.utils._pytree import tree_map
from colossalai.auto_parallel.offload.region_manager import RegionManager
from colossalai.auto_parallel.offload.solver import NOT_NVML, SolverFactory
from colossalai.fx import ColoTracer, is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.auto_parallel.offload.region_manager import RegionManager
from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML
from colossalai.testing import parameterize
from colossalai.testing import clear_cache_before_run, parameterize
from tests.test_auto_parallel.test_offload.model_utils import *
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
@clear_cache_before_run()
@parameterize('model_name', ['gpt2_', 'bert_'])
@parameterize('memory_budget', [4000])
@parameterize('solver_name', ['syn', 'asyn'])
def solver_test(model_name: str,
memory_budget: float,
solver_name: str):
def solver_test(model_name: str, memory_budget: float, solver_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, data_gen = get_components_func()
@ -52,11 +52,16 @@ def solver_test(model_name: str,
for region in region_list:
need_offload = region.need_offload
to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None
print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
print(
f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}'
)
for region in region_list.__reversed__():
need_offload = region.need_offload
to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None
print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
print(
f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}'
)
if __name__ == '__main__':
solver_test()

2
tests/test_auto_parallel/test_pass/test_node_converting_pass.py

@ -6,6 +6,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import clear_cache_before_run
class TestModule(torch.nn.Module):
@ -26,6 +27,7 @@ def insert_narrow(gm, x_node):
return gm
@clear_cache_before_run()
def test_node_args_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)

2
tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py

@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import clear_cache_before_run
class TestModule(torch.nn.Module):
@ -36,6 +37,7 @@ def recover_narrow(gm, narrow_node):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_size_value_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)

12
tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py

@ -2,7 +2,6 @@ from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
try:
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
@ -13,9 +12,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
class LinearModel(torch.nn.Module):
@ -86,11 +83,8 @@ def check_conv_module(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bias_addition_module():
world_size = 4
run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port())
mp.spawn(run_func_linear, nprocs=world_size)
run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port())
mp.spawn(run_func_conv, nprocs=world_size)
spawn(check_linear_module, 4)
spawn(check_conv_module, 4)
if __name__ == '__main__':

12
tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py

@ -1,9 +1,7 @@
from functools import partial
from typing import Optional, Tuple, Union
from typing import Optional, Tuple
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from transformers.pytorch_utils import Conv1D
@ -17,9 +15,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
HIDDEN_SIZE = 16
@ -65,9 +61,7 @@ def check_act_ckpt(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_mlp_layer():
world_size = 4
run_func = partial(check_act_ckpt, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_act_ckpt, 4)
if __name__ == '__main__':

10
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py

@ -1,9 +1,7 @@
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
try:
@ -15,9 +13,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
class MLP(torch.nn.Module):
@ -102,9 +98,7 @@ def check_compatibility_with_ddp(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_compatibility_with_ddp():
world_size = 4
run_func = partial(check_compatibility_with_ddp, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_compatibility_with_ddp, 4)
if __name__ == '__main__':

14
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py

@ -1,10 +1,7 @@
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
try:
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
@ -17,10 +14,9 @@ from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.process_group import ProcessGroup
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port, get_current_device
from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.utils import get_current_device
from colossalai.zero import post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
class MLP(torch.nn.Module):
@ -110,9 +106,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_auto_parallel_with_gemini():
world_size = 4
run_func = partial(check_auto_parallel_with_gemini, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_auto_parallel_with_gemini, 4)
if __name__ == '__main__':

4
tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py

@ -10,8 +10,7 @@ from colossalai._analyzer.fx.passes import shape_prop_pass
# from colossalai.fx.tracer.tracer import ColoTracer
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.testing import parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, parameterize, run_on_environment_flag
NUM_REPEAT_BLOCKS = 4
BATCH_SIZE = 1
@ -81,6 +80,7 @@ class NonRepeatModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
@parameterize('model_cls', [RepeatModel, NonRepeatModel])
def test_repeat_blocks(model_cls):

9
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py

@ -1,12 +1,10 @@
import copy
import random
from functools import partial
from typing import Dict
import numpy as np
import pytest
import torch
import torch.multiprocessing as mp
import transformers
from torch.fx import GraphModule
@ -30,9 +28,8 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import to_global
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
BATCH_SIZE = 1
@ -190,9 +187,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model])
@rerun_if_address_is_in_use()
def test_mlp_layer(model_cls):
world_size = 4
run_func = partial(check_attention_layer, model_cls=model_cls, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_attention_layer, 4, model_cls=model_cls)
if __name__ == '__main__':

6
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py

@ -1,5 +1,4 @@
import torch
import torch.nn as nn
import transformers
from torch.fx import GraphModule
@ -7,10 +6,10 @@ from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing import parameterize
from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
@ -20,6 +19,7 @@ HIDDEN_DIM = 384
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM)

3
tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py

@ -6,6 +6,8 @@ from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import clear_cache_before_run
class LinearModel(nn.Module):
@ -26,6 +28,7 @@ class LinearModel(nn.Module):
@pytest.mark.skip('meta tensor has some bugs in 1.11')
@clear_cache_before_run()
def test_liveness_analysis():
model = LinearModel()
tracer = ColoTracer(bias_addition_split=True)

15
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py

@ -1,23 +1,14 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.meta_profiler import meta_register
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
from colossalai.testing.utils import clear_cache_before_run, parameterize
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize('func', [
torch.nn.functional.softmax,
torch.nn.functional.relu,

10
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py

@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh
@ -10,8 +7,7 @@ from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
@ -62,9 +58,7 @@ def _binary_elementwise_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_binary_elementwise_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_binary_elementwise_mem_test, 4)
if __name__ == '__main__':

17
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py

@ -1,17 +1,12 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
@ -25,7 +20,7 @@ class ConvFunctionModule(nn.Module):
return nn.functional.conv2d(input, self.conv_weight)
def _conv_module_mem_test(rank, bias, world_size, port):
def _conv_module_mem_test(rank, world_size, port, bias):
"""This function is for conv memory test
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
@ -62,9 +57,7 @@ def _conv_module_mem_test(rank, bias, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_conv_meta_concrete_info_match(bias=False):
world_size = 4
run_func_module = partial(_conv_module_mem_test, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_conv_module_mem_test, 4, bias=bias)
def _conv_function_mem_test(rank, world_size, port):
@ -103,9 +96,7 @@ def _conv_function_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_conv_function_concrete_info_match():
world_size = 4
run_func_module = partial(_conv_function_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_conv_function_mem_test, 4)
if __name__ == '__main__':

27
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py

@ -1,33 +1,16 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.testing.utils import clear_cache_before_run
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
from colossalai.auto_parallel.meta_profiler import meta_register
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_embedding_meta_info():
meta_func = meta_register.get(torch.nn.Embedding)

20
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py

@ -1,24 +1,14 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
class MyModule(nn.Module):
@ -63,9 +53,7 @@ def _linear_module_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_module_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_linear_module_mem_test, 4)
def _linear_function_mem_test(rank, world_size, port):
@ -101,9 +89,7 @@ def _linear_function_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_function_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_linear_function_mem_test, 4)
if __name__ == '__main__':

25
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py

@ -1,26 +1,8 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
from colossalai.testing.utils import clear_cache_before_run, parameterize
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
@ -28,6 +10,7 @@ if torch.__version__ >= '1.12.0':
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize(
'tensor_shapes',
[

22
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py

@ -1,29 +1,17 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
from colossalai.auto_parallel.meta_profiler import meta_register
def _batchnorm_module_mem_test(rank, world_size, port):
@ -62,9 +50,7 @@ def _batchnorm_module_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_batchnorm_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_batchnorm_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_batchnorm_module_mem_test, 4)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations')

15
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py

@ -1,17 +1,12 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
@ -51,9 +46,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_adaptiveavgpool_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_adaptiveavgpool_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_adaptiveavgpool_module_mem_test, 4)
def _maxpool_module_mem_test(rank, world_size, port):
@ -92,9 +85,7 @@ def _maxpool_module_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_maxpool_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_maxpool_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_maxpool_module_mem_test, 4)
if __name__ == '__main__':

22
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py

@ -1,26 +1,9 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.testing.utils import clear_cache_before_run
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
@ -37,6 +20,7 @@ class SplitModule(nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_tensor_meta_info():
"""test tensor related meta information
We will just use torch.Tensor.split for the test

23
tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py

@ -1,24 +1,8 @@
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
from colossalai.testing.utils import clear_cache_before_run
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
@ -26,6 +10,7 @@ if torch.__version__ >= '1.12.0':
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_where_meta_info():
meta_func = meta_register.get(torch.where)

39
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
@ -11,9 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -45,7 +40,7 @@ class AddBMMTorchFunctionModule(nn.Module):
return output
def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = module(using_kwargs).cuda()
@ -249,14 +244,13 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por
@parameterize('using_kwargs', [True, False])
@rerun_if_address_is_in_use()
def test_2d_device_mesh(module, bias_shape, using_kwargs):
world_size = 4
run_func = partial(check_2d_device_mesh,
module=module,
bias_shape=bias_shape,
world_size=world_size,
using_kwargs=using_kwargs,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(
check_2d_device_mesh,
4,
module=module,
bias_shape=bias_shape,
using_kwargs=using_kwargs,
)
@pytest.mark.skip("skip due to bias cases not ready")
@ -267,14 +261,13 @@ def test_2d_device_mesh(module, bias_shape, using_kwargs):
@parameterize('using_kwargs', [True, False])
@rerun_if_address_is_in_use()
def test_1d_device_mesh(module, bias_shape, using_kwargs):
world_size = 4
run_func = partial(check_1d_device_mesh,
module=module,
bias_shape=bias_shape,
using_kwargs=using_kwargs,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(
check_1d_device_mesh,
4,
module=module,
bias_shape=bias_shape,
using_kwargs=using_kwargs,
)
if __name__ == '__main__':

17
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -17,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -45,7 +40,7 @@ class AddmmModel_with_param(nn.Module):
return x
def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port):
def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if model_cls == AddmmModel:
@ -189,13 +184,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
@parameterize('model_cls', [AddmmModel, AddmmModel_with_param])
@rerun_if_address_is_in_use()
def test_addmm_handler(input_shape, model_cls):
world_size = 4
run_func_function = partial(check_addmm_function_handler,
input_shape=input_shape,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func_function, nprocs=world_size)
spawn(check_addmm_function_handler, 4, input_shape=input_shape, model_cls=model_cls)
if __name__ == '__main__':

11
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -114,9 +109,7 @@ def check_bn_module_handler(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bn_module_handler():
world_size = 4
run_func = partial(check_bn_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_bn_module_handler, 4)
if __name__ == '__main__':

12
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py

@ -1,9 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -19,9 +15,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
WEIGHT_SHAPE = (32, 16)
@ -168,9 +162,7 @@ def check_linear_module_handler(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler():
world_size = 4
run_func_module = partial(check_linear_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(check_linear_module_handler)
if __name__ == '__main__':

16
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py

@ -1,14 +1,10 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
@ -18,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -35,7 +29,7 @@ class LinearModule(torch.nn.Module):
return x
def check_linear_module_handler(rank, bias, world_size, port):
def check_linear_module_handler(rank, world_size, port, bias):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModule(16, 32, bias=bias).cuda()
@ -157,9 +151,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler(bias=True):
world_size = 4
run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(check_linear_module_handler, bias=bias)
if __name__ == '__main__':

39
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size, port):
def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@ -149,7 +144,7 @@ class BEOpModelWithIntConst(nn.Module):
return out
def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port):
def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@ -236,13 +231,12 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_handler_with_tensor(op, other_dim):
world_size = 4
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor,
op=op,
other_dim=other_dim,
world_size=world_size,
port=free_port())
mp.spawn(run_func_tensor, nprocs=world_size)
spawn(
check_binary_elementwise_handler_with_tensor,
4,
op=op,
other_dim=other_dim,
)
@run_on_environment_flag(name='AUTO_PARALLEL')
@ -252,14 +246,13 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_handler_with_int(op, model_cls, other_dim):
world_size = 4
run_func_int = partial(check_binary_elementwise_handler_with_int,
op=op,
model_cls=model_cls,
other_dim=other_dim,
world_size=world_size,
port=free_port())
mp.spawn(run_func_int, nprocs=world_size)
spawn(
check_binary_elementwise_handler_with_int,
4,
op=op,
model_cls=model_cls,
other_dim=other_dim,
)
if __name__ == '__main__':

14
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -207,11 +202,8 @@ def check_1d_device_mesh(rank, module, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bmm_handler(module):
world_size = 4
run_func_2d = partial(check_2d_device_mesh, module=module, world_size=world_size, port=free_port())
mp.spawn(run_func_2d, nprocs=world_size)
run_func_1d = partial(check_1d_device_mesh, module=module, world_size=world_size, port=free_port())
mp.spawn(run_func_1d, nprocs=world_size)
spawn(check_2d_device_mesh, 4, module=module)
spawn(check_1d_device_mesh, 4, module=module)
if __name__ == '__main__':

19
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_conv_module_handler(rank, bias, world_size, port):
def check_conv_module_handler(rank, world_size, port, bias):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda()
@ -155,7 +150,7 @@ class ConvModel(nn.Module):
return x
def check_conv_function_handler(rank, bias, world_size, port):
def check_conv_function_handler(rank, world_size, port, bias):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = ConvModel().cuda()
@ -302,9 +297,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
# @parameterize('bias', [True, False])
@rerun_if_address_is_in_use()
def test_conv_module_handler(bias=False):
world_size = 4
run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_conv_module_handler, 4, bias=bias)
@run_on_environment_flag(name='AUTO_PARALLEL')
@ -314,9 +307,7 @@ def test_conv_module_handler(bias=False):
# @parameterize('bias', [True, False])
@rerun_if_address_is_in_use()
def test_conv_function_handler(bias=False):
world_size = 4
run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_conv_function_handler, 4, bias=bias)
if __name__ == '__main__':

3
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py

@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHan
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class ReshapeModel(nn.Module):
@ -23,6 +23,7 @@ class ReshapeModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_reshape_handler():
model = ReshapeModel()
tracer = ColoTracer(bias_addition_split=True)

14
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -16,9 +13,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
NUM_EMBEDDINGS = 16
@ -272,18 +268,14 @@ def check_embedding_function_handler(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_embedding_module_handler():
world_size = 4
run_func = partial(check_embedding_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_embedding_module_handler, 4)
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_embedding_function_handler():
world_size = 4
run_func = partial(check_embedding_function_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_embedding_function_handler, 4)
if __name__ == '__main__':

2
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py

@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import clear_cache_before_run
class GetattrModel(nn.Module):
@ -22,6 +23,7 @@ class GetattrModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_getattr_handler():
model = GetattrModel()
tracer = ColoTracer(bias_addition_split=True)

13
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py

@ -2,7 +2,6 @@ from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -14,12 +13,10 @@ from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import Li
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -103,12 +100,7 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))])
@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))])
def test_getitem_from_tensor_handler(getitem_index):
world_size = 4
run_func = partial(check_getitem_from_tensor_handler,
getitem_index=getitem_index,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_getitem_from_tensor_handler, 4)
class GetItemFromTupleModel(nn.Module):
@ -123,6 +115,7 @@ class GetItemFromTupleModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_getitem_from_tuple_handler():
model = GetItemFromTupleModel()
tracer = ColoTracer()

11
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -11,12 +8,10 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -104,9 +99,7 @@ def check_ln_module_handler(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_ln_module_handler():
world_size = 4
run_func = partial(check_ln_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_ln_module_handler, 4)
if __name__ == '__main__':

35
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -18,14 +15,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_linear_module_handler(rank, bias, input_shape, world_size, port):
def check_linear_module_handler(rank, world_size, port, bias, input_shape):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
@ -172,7 +168,7 @@ class LinearModel(nn.Module):
return x
def check_linear_function_handler(rank, bias, input_shape, world_size, port):
def check_linear_function_handler(rank, world_size, port, bias, input_shape):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModel().cuda()
@ -313,19 +309,18 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler(input_shape, bias=False):
world_size = 4
run_func_module = partial(check_linear_module_handler,
bias=bias,
input_shape=input_shape,
world_size=world_size,
port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
run_func_function = partial(check_linear_function_handler,
bias=bias,
input_shape=input_shape,
world_size=world_size,
port=free_port())
mp.spawn(run_func_function, nprocs=world_size)
spawn(
check_linear_module_handler,
4,
bias=bias,
input_shape=input_shape,
)
spawn(
check_linear_function_handler,
4,
bias=bias,
input_shape=input_shape,
)
if __name__ == '__main__':

3
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py

@ -18,7 +18,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.utils import parameterize
from colossalai.testing.utils import clear_cache_before_run, parameterize
class MatMulModule(nn.Module):
@ -28,6 +28,7 @@ class MatMulModule(nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize(
'tensor_shapes',
[

5
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py

@ -1,4 +1,3 @@
import pytest
import torch
import torch.nn as nn
@ -8,11 +7,11 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer(bias_addition_split=True)

4
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py

@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import clear_cache_before_run, parameterize
class OutputModel(nn.Module):
@ -23,7 +23,7 @@ class OutputModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('output_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_output_handler(output_option):
model = OutputModel()
tracer = ColoTracer(bias_addition_split=True)

21
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py

@ -2,7 +2,6 @@ from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -15,9 +14,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -55,7 +53,7 @@ class LinearReshapeModel(nn.Module):
return permute_node
def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, port):
def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if call_function == torch.permute:
@ -328,14 +326,13 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))])
@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel])
def test_view_handler(call_function, reshape_dims, model_cls):
world_size = 4
run_func = partial(check_view_handler,
call_function=call_function,
reshape_dims=reshape_dims,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(
check_view_handler,
4,
call_function=call_function,
reshape_dims=reshape_dims,
model_cls=model_cls,
)
if __name__ == '__main__':

4
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py

@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import clear_cache_before_run, parameterize
class PlaceholderModel(nn.Module):
@ -22,7 +22,7 @@ class PlaceholderModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('placeholder_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_placeholder_handler(placeholder_option):
model = PlaceholderModel()
tracer = ColoTracer(bias_addition_split=True)

4
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py

@ -1,5 +1,4 @@
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -9,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHan
from colossalai.auto_parallel.tensor_shard.options import ShardOption
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class LinearModel(nn.Module):
@ -108,6 +107,7 @@ def check_shard_option(shard_option):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_shard_option():
# for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]:
for shard_option in [ShardOption.SHARD_LAST_AXIS]:

17
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -33,7 +28,7 @@ class LinearSplitModel(nn.Module):
return softmax_node
def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
def check_split_handler(rank, world_size, port, softmax_dim, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = model_cls(softmax_dim=softmax_dim).cuda()
@ -176,13 +171,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
@parameterize('softmax_dim', [0, 1, 2, 3])
@parameterize('model_cls', [LinearSplitModel])
def test_split_handler(softmax_dim, model_cls):
world_size = 4
run_func = partial(check_split_handler,
softmax_dim=softmax_dim,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_split_handler, 4, softmax_dim=softmax_dim, model_cls=model_cls)
if __name__ == '__main__':

18
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -47,7 +42,7 @@ class LinearSplitModel(nn.Module):
return split_node
def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port):
def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = model_cls(split_size=split_size, split_dim=split_dim).cuda()
@ -258,14 +253,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
@parameterize('split_dim', [0, 1, 2])
@parameterize('model_cls', [ConvSplitModel, LinearSplitModel])
def test_split_handler(split_size, split_dim, model_cls):
world_size = 4
run_func = partial(check_split_handler,
split_size=split_size,
split_dim=split_dim,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_split_handler, 4, split_size=split_size, split_dim=split_dim, model_cls=model_cls)
if __name__ == '__main__':

13
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -14,9 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -36,7 +31,7 @@ class LinearSumModel(nn.Module):
return sum_node
def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
def check_sum_handler(rank, world_size, port, sum_dims, keepdim):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda()
@ -228,9 +223,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
@parameterize('sum_dims', [(0, 2), 1])
@parameterize('keepdim', [False, True])
def test_sum_handler(sum_dims, keepdim):
world_size = 4
run_func = partial(check_sum_handler, sum_dims=sum_dims, keepdim=keepdim, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_sum_handler, 4, sum_dims=sum_dims, keepdim=keepdim)
if __name__ == '__main__':

3
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py

@ -7,7 +7,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class TensorConstructorModel(nn.Module):
@ -22,6 +22,7 @@ class TensorConstructorModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_where_handler():
model = TensorConstructorModel()
tracer = ColoTracer(bias_addition_split=True)

3
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py

@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class ReLuModel(nn.Module):
@ -24,6 +24,7 @@ class ReLuModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_elementwise_handler():
model = ReLuModel()
tracer = ColoTracer(bias_addition_split=True)

14
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py

@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@ -15,9 +12,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@ -255,13 +251,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
@parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)])
@parameterize('model_cls', [ConvViewModel, LinearViewModel])
def test_view_handler(tgt_shape, model_cls):
world_size = 4
run_func = partial(check_view_handler,
tgt_shape=tgt_shape,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_view_handler, 4, tgt_shape=tgt_shape, model_cls=model_cls)
if __name__ == '__main__':

2
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py

@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import clear_cache_before_run
class ConvModel(nn.Module):
@ -21,6 +22,7 @@ class ConvModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_where_handler():
model = ConvModel()
tracer = ColoTracer(bias_addition_split=True)

3
tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py

@ -10,10 +10,11 @@ from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)

2
tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py

@ -8,7 +8,7 @@ import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen

2
tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py

@ -9,7 +9,7 @@ from colossalai.autochunk.utils import flat_list
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen

12
tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py

@ -1,10 +1,8 @@
from functools import partial
from typing import Dict, List, Tuple
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import EvoformerBlock
@ -15,6 +13,7 @@ except:
from test_autochunk_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_model():
@ -66,18 +65,19 @@ def get_chunk_target() -> Dict:
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 24])
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
@clear_cache_before_run()
@parameterize("max_memory", [None, 20, 24])
@parameterize("data_args", [(32, 64)])
def test_evoformer_block(data_args, max_memory):
run_func = partial(
spawn(
run_test,
1,
data_args=data_args,
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
get_chunk_target=get_chunk_target,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":

12
tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py

@ -1,10 +1,8 @@
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import EvoformerStack
@ -15,6 +13,7 @@ except:
from test_autochunk_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_model():
@ -61,17 +60,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 24])
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
@clear_cache_before_run()
@parameterize("max_memory", [None, 20, 24])
@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len)
def test_evoformer_stack(data_args, max_memory):
run_func = partial(
spawn(
run_test,
1,
data_args=data_args,
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":

12
tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py

@ -1,10 +1,8 @@
from functools import partial
from typing import Dict, List, Tuple
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import ExtraMSABlock
@ -14,6 +12,7 @@ except:
from test_autochunk_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_model():
@ -57,17 +56,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 24])
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
@clear_cache_before_run()
@parameterize("max_memory", [None, 20, 24])
@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len)
def test_extramsa_block(data_args, max_memory):
run_func = partial(
spawn(
run_test,
1,
data_args=data_args,
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":

2
tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py

@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen

14
tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py

@ -1,9 +1,7 @@
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.multiprocessing as mp
try:
from diffusers import UNet2DModel
@ -16,6 +14,7 @@ except:
from test_autochunk_diffuser_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
BATCH_SIZE = 1
HEIGHT = 448
@ -37,17 +36,18 @@ def get_data(shape: tuple) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
@pytest.mark.parametrize("max_memory", [None, 150, 300])
@clear_cache_before_run()
@parameterize("model", MODELS)
@parameterize("shape", [LATENTS_SHAPE])
@parameterize("max_memory", [None, 150, 300])
def test_evoformer_block(model, shape, max_memory):
run_func = partial(
spawn(
run_test,
1,
max_memory=max_memory,
model=model,
data=get_data(shape),
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":

14
tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py

@ -1,9 +1,7 @@
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.multiprocessing as mp
try:
from transformers import GPT2Config, GPT2Model
@ -16,6 +14,7 @@ except:
from test_autochunk_transformer_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
BATCH_SIZE = 1
SEQ_LENGTH = 512
@ -35,18 +34,19 @@ def get_data(shape: tuple) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
@pytest.mark.parametrize("max_memory", [None, 6, 8])
@clear_cache_before_run()
@parameterize("model", MODELS)
@parameterize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
@parameterize("max_memory", [None, 6, 8])
def test_autochunk_gpt(model, shape, max_memory):
run_func = partial(
spawn(
run_test,
1,
data=get_data(shape),
max_memory=max_memory,
model=model,
config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4),
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":

2
tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py

@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen

12
tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py

@ -1,9 +1,7 @@
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.multiprocessing as mp
try:
from timm.models.vision_transformer import vit_large_patch16_384 as vit
@ -16,6 +14,7 @@ except:
from test_autochunk_vit_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_data() -> Tuple[List, List]:
@ -28,16 +27,17 @@ def get_data() -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_memory", [None, 32, 40])
@clear_cache_before_run()
@parameterize("model", MODELS)
@parameterize("max_memory", [None, 32, 40])
def test_evoformer_block(model, max_memory):
run_func = partial(
spawn(
run_test,
1,
max_memory=max_memory,
model=model,
data=get_data(),
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":

2
tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py

@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen

19
tests/test_booster/test_accelerator.py

@ -1,27 +1,14 @@
from functools import partial
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.booster.accelerator import Accelerator
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.testing import clear_cache_before_run, parameterize
@clear_cache_before_run()
@parameterize('device', ['cpu', 'cuda'])
def run_accelerator(device):
def test_accelerator(device):
acceleartor = Accelerator(device)
model = nn.Linear(8, 8)
model = acceleartor.configure_model(model)
assert next(model.parameters()).device.type == device
del model, acceleartor
def run_dist(rank):
run_accelerator()
@rerun_if_address_is_in_use()
def test_accelerator():
world_size = 1
run_func = partial(run_dist)
mp.spawn(run_func, nprocs=world_size)

10
tests/test_booster/test_mixed_precision/test_fp16_torch.py

@ -1,13 +1,9 @@
from functools import partial
import torch
import torch.multiprocessing as mp
from torch.optim import Adam
import colossalai
from colossalai.booster.mixed_precision import FP16TorchMixedPrecision
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@ -41,6 +37,4 @@ def run_torch_amp(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_torch_ddp_plugin():
world_size = 1
run_func = partial(run_torch_amp, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_torch_amp, 1)

11
tests/test_booster/test_plugin/test_gemini_plugin.py

@ -1,17 +1,12 @@
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@ -119,9 +114,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port(), early_stop=early_stop)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, 2, early_stop=early_stop)
if __name__ == '__main__':

10
tests/test_booster/test_plugin/test_torch_ddp_plugin.py

@ -1,8 +1,5 @@
from functools import partial
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
@ -10,8 +7,7 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.interface import OptimizerWrapper
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@ -103,6 +99,4 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_torch_ddp_plugin():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, 2)

4
tests/test_checkpoint_io/test_general_checkpoint_io.py

@ -6,6 +6,7 @@ from torch.optim import Adam
from torchvision.models import resnet18
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.testing import clear_cache_before_run, parameterize
# ========
# Note:
@ -15,7 +16,8 @@ from colossalai.checkpoint_io import GeneralCheckpointIO
# ========
@pytest.mark.parametrize('use_safetensors', [True, False])
@clear_cache_before_run()
@parameterize('use_safetensors', [True, False])
def test_unsharded_checkpoint(use_safetensors: bool):
# create a model and optimizer
model = resnet18()

11
tests/test_cluster/test_device_mesh_manager.py

@ -1,14 +1,9 @@
from functools import partial
import torch
import torch.multiprocessing as mp
from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.testing import spawn
def check_device_mesh_manager(rank, world_size, port):
@ -31,9 +26,7 @@ def check_device_mesh_manager(rank, world_size, port):
def test_device_mesh_manager():
world_size = 4
run_func = partial(check_device_mesh_manager, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_device_mesh_manager, 4)
if __name__ == '__main__':

15
tests/test_comm/test_boardcast_send_recv_v2.py

@ -1,17 +1,12 @@
from functools import partial
from typing import List
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.communication.p2p_v2 import _send_object, _recv_object, init_process_group
from colossalai.communication.p2p_v2 import _recv_object, _send_object
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
disable_existing_loggers()
world_size = 4
@ -45,9 +40,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_object_list_p2p():
disable_existing_loggers()
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_layer, world_size)
if __name__ == '__main__':

12
tests/test_comm/test_comm.py

@ -1,15 +1,13 @@
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.communication import all_gather, all_reduce, reduce_scatter
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
@ -66,9 +64,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_comm():
world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_layer, 4)
if __name__ == '__main__':

21
tests/test_comm/test_object_list_p2p.py

@ -1,15 +1,18 @@
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.communication.p2p import send_forward, recv_forward, send_backward, recv_backward, send_forward_recv_backward, send_backward_recv_forward
from colossalai.communication.p2p import (
recv_backward,
recv_forward,
send_backward,
send_backward_recv_forward,
send_forward,
send_forward_recv_backward,
)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(parallel=dict(pipeline=2))
torch.manual_seed(123)
@ -96,9 +99,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_object_list_p2p():
world_size = 2
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_layer, 2)
if __name__ == '__main__':

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save