mirror of https://github.com/hpcaitech/ColossalAI
[test] refactor tests with spawn (#3452)
* [test] added spawn decorator * polish code * polish code * polish code * polish code * polish code * polish codepull/3343/head
parent
62f4e2eb07
commit
80eba05b0a
|
@ -8,10 +8,10 @@ jobs:
|
|||
detect:
|
||||
name: Detect file change
|
||||
if: |
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' &&
|
||||
contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
|
||||
github.event.pull_request.draft == false &&
|
||||
github.base_ref == 'main' &&
|
||||
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' &&
|
||||
contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
|
||||
outputs:
|
||||
changedExtenisonFiles: ${{ steps.find-extension-change.outputs.all_changed_files }}
|
||||
anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }}
|
||||
|
@ -27,10 +27,10 @@ jobs:
|
|||
- name: Locate base commit
|
||||
id: locate-base-sha
|
||||
run: |
|
||||
curBranch=$(git rev-parse --abbrev-ref HEAD)
|
||||
commonCommit=$(git merge-base origin/main $curBranch)
|
||||
echo $commonCommit
|
||||
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
|
||||
curBranch=$(git rev-parse --abbrev-ref HEAD)
|
||||
commonCommit=$(git merge-base origin/main $curBranch)
|
||||
echo $commonCommit
|
||||
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Find the changed extension-related files
|
||||
id: find-extension-change
|
||||
|
@ -63,7 +63,6 @@ jobs:
|
|||
echo "$file was changed"
|
||||
done
|
||||
|
||||
|
||||
build:
|
||||
name: Build and Test Colossal-AI
|
||||
needs: detect
|
||||
|
@ -124,7 +123,7 @@ jobs:
|
|||
- name: Execute Unit Testing
|
||||
if: needs.detect.outputs.anyLibraryFileChanged == 'true'
|
||||
run: |
|
||||
PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/
|
||||
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/
|
||||
env:
|
||||
DATA: /data/scratch/cifar-10
|
||||
NCCL_SHM_DISABLE: 1
|
||||
|
|
|
@ -1,19 +1,16 @@
|
|||
import os
|
||||
import tempfile
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from coati.models.gpt import GPTActor
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
|
||||
|
||||
|
@ -90,8 +87,7 @@ def run_dist(rank, world_size, port, strategy):
|
|||
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint(world_size, strategy):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size, strategy=strategy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from coati.experience_maker import NaiveExperienceMaker
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.gpt import GPTActor, GPTCritic
|
||||
|
@ -13,8 +11,7 @@ from coati.replay_buffer import NaiveReplayBuffer
|
|||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
|
||||
|
||||
|
@ -114,8 +111,7 @@ def run_dist(rank, world_size, port, strategy):
|
|||
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_data(world_size, strategy):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size, strategy=strategy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -10,7 +10,8 @@ from colossalai.context import Config
|
|||
from colossalai.context.random import reset_seeds
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.utils import MultiTimer, free_port
|
||||
from colossalai.testing import free_port
|
||||
from colossalai.utils import MultiTimer
|
||||
|
||||
from .models import MLP
|
||||
|
||||
|
|
|
@ -1,7 +1,17 @@
|
|||
from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group
|
||||
from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use, skip_if_not_enough_gpus
|
||||
from .comparison import assert_close, assert_close_loose, assert_equal, assert_equal_in_group, assert_not_equal
|
||||
from .pytest_wrapper import run_on_environment_flag
|
||||
from .utils import (
|
||||
clear_cache_before_run,
|
||||
free_port,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
rerun_on_exception,
|
||||
skip_if_not_enough_gpus,
|
||||
spawn,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
|
||||
'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus'
|
||||
'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn',
|
||||
'clear_cache_before_run', 'run_on_environment_flag'
|
||||
]
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
import gc
|
||||
import random
|
||||
import re
|
||||
import torch
|
||||
from typing import Callable, List, Any
|
||||
import socket
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, List
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from packaging import version
|
||||
|
||||
|
||||
|
@ -43,7 +48,7 @@ def parameterize(argument: str, values: List[Any]) -> Callable:
|
|||
# > davis: hello
|
||||
# > davis: bye
|
||||
# > davis: stop
|
||||
|
||||
|
||||
Args:
|
||||
argument (str): the name of the argument to parameterize
|
||||
values (List[Any]): a list of values to iterate for this argument
|
||||
|
@ -85,13 +90,13 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
|
|||
def test_method():
|
||||
print('hey')
|
||||
raise RuntimeError('Address already in use')
|
||||
|
||||
|
||||
# rerun for infinite times if Runtime error occurs
|
||||
@rerun_on_exception(exception_type=RuntimeError, max_try=None)
|
||||
def test_method():
|
||||
print('hey')
|
||||
raise RuntimeError('Address already in use')
|
||||
|
||||
|
||||
# rerun only the exception message is matched with pattern
|
||||
# for infinite times if Runtime error occurs
|
||||
@rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$")
|
||||
|
@ -101,10 +106,10 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
|
|||
|
||||
Args:
|
||||
exception_type (Exception, Optional): The type of exception to detect for rerun
|
||||
pattern (str, Optional): The pattern to match the exception message.
|
||||
pattern (str, Optional): The pattern to match the exception message.
|
||||
If the pattern is not None and matches the exception message,
|
||||
the exception will be detected for rerun
|
||||
max_try (int, Optional): Maximum reruns for this function. The default value is 5.
|
||||
max_try (int, Optional): Maximum reruns for this function. The default value is 5.
|
||||
If max_try is None, it will rerun foreven if exception keeps occurings
|
||||
"""
|
||||
|
||||
|
@ -202,3 +207,72 @@ def skip_if_not_enough_gpus(min_gpus: int):
|
|||
return _execute_by_gpu_num
|
||||
|
||||
return _wrap_func
|
||||
|
||||
|
||||
def free_port() -> int:
|
||||
"""Get a free port on localhost.
|
||||
|
||||
Returns:
|
||||
int: A free port on localhost.
|
||||
"""
|
||||
while True:
|
||||
port = random.randint(20000, 65000)
|
||||
try:
|
||||
with socket.socket() as sock:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(("localhost", port))
|
||||
return port
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
|
||||
def spawn(func, nprocs=1, **kwargs):
|
||||
"""
|
||||
This function is used to spawn processes for testing.
|
||||
|
||||
Usage:
|
||||
# must contians arguments rank, world_size, port
|
||||
def do_something(rank, world_size, port):
|
||||
...
|
||||
|
||||
spawn(do_something, nprocs=8)
|
||||
|
||||
# can also pass other arguments
|
||||
def do_something(rank, world_size, port, arg1, arg2):
|
||||
...
|
||||
|
||||
spawn(do_something, nprocs=8, arg1=1, arg2=2)
|
||||
|
||||
Args:
|
||||
func (Callable): The function to be spawned.
|
||||
nprocs (int, optional): The number of processes to spawn. Defaults to 1.
|
||||
"""
|
||||
port = free_port()
|
||||
wrapped_func = partial(func, world_size=nprocs, port=port, **kwargs)
|
||||
mp.spawn(wrapped_func, nprocs=nprocs)
|
||||
|
||||
|
||||
def clear_cache_before_run():
|
||||
"""
|
||||
This function is a wrapper to clear CUDA and python cache before executing the function.
|
||||
|
||||
Usage:
|
||||
@clear_cache_before_run()
|
||||
def test_something():
|
||||
...
|
||||
"""
|
||||
|
||||
def _wrap_func(f):
|
||||
|
||||
def _clear_cache(*args, **kwargs):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_max_memory_cached()
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
f(*args, **kwargs)
|
||||
|
||||
return _clear_cache
|
||||
|
||||
return _wrap_func
|
||||
|
|
|
@ -7,7 +7,6 @@ from .common import (
|
|||
count_zeros_fp32,
|
||||
disposable,
|
||||
ensure_path_exists,
|
||||
free_port,
|
||||
is_ddp_ignored,
|
||||
is_dp_rank_0,
|
||||
is_model_parallel_parameter,
|
||||
|
@ -37,7 +36,6 @@ from .timer import MultiTimer, Timer
|
|||
|
||||
__all__ = [
|
||||
'checkpoint',
|
||||
'free_port',
|
||||
'print_rank_0',
|
||||
'sync_model_param',
|
||||
'is_ddp_ignored',
|
||||
|
|
|
@ -50,23 +50,6 @@ def ensure_path_exists(filename: str):
|
|||
Path(dirpath).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def free_port() -> int:
|
||||
"""Get a free port on localhost.
|
||||
|
||||
Returns:
|
||||
int: A free port on localhost.
|
||||
"""
|
||||
while True:
|
||||
port = random.randint(20000, 65000)
|
||||
try:
|
||||
with socket.socket() as sock:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(("localhost", port))
|
||||
return port
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
|
||||
def sync_model_param(model, parallel_mode):
|
||||
r"""Make sure data parameters are consistent during Data Parallel Mode.
|
||||
|
||||
|
|
|
@ -4,3 +4,4 @@ packaging
|
|||
tensornvme
|
||||
psutil
|
||||
transformers
|
||||
pytest
|
||||
|
|
|
@ -56,12 +56,12 @@ Let's see an example. A ColoTensor is initialized and sharded on 8 GPUs using tp
|
|||
```python
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.utils import free_port, print_rank_0
|
||||
from colossalai.utils import print_rank_0
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import spawn
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -83,8 +83,7 @@ def run_dist_tests(rank, world_size, port):
|
|||
print_rank_0(f"shape {t1.shape}, {t1.data}")
|
||||
|
||||
def test_dist_cases(world_size):
|
||||
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist_tests, world_size)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dist_cases(4)
|
||||
|
|
|
@ -57,12 +57,12 @@ ColoTensor 包含额外的属性[ColoTensorSpec](https://colossalai.readthedocs.
|
|||
```python
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.utils import free_port, print_rank_0
|
||||
from colossalai.utils import print_rank_0
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import spawn
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -84,8 +84,7 @@ def run_dist_tests(rank, world_size, port):
|
|||
print_rank_0(f"shape {t1.shape}, {t1.data}")
|
||||
|
||||
def test_dist_cases(world_size):
|
||||
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist_tests, world_size)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dist_cases(4)
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
import os
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from vit import get_training_components
|
||||
|
||||
|
@ -15,8 +13,7 @@ from colossalai.context.parallel_mode import ParallelMode
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
|
||||
|
@ -156,8 +153,7 @@ def run_dist(rank, world_size, port, use_ddp):
|
|||
@pytest.mark.parametrize('use_ddp', [False, True])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_vit(world_size, use_ddp):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size, use_ddp=use_ddp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,20 +1,20 @@
|
|||
import time
|
||||
import pytest
|
||||
import argparse
|
||||
from functools import partial
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from model_zoo import GPTLMLoss, get_gpt2_components
|
||||
from torch.utils._pytree import tree_map
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
|
||||
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
|
||||
from colossalai.auto_parallel.offload.solver import NOT_NVML
|
||||
from model_zoo import get_gpt2_components, GPTLMLoss
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -24,6 +24,7 @@ def parse_args():
|
|||
parser.add_argument('--memory_budget', type=float, default=16)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
|
||||
def train_gpt(args):
|
||||
memory_budget = args.memory_budget * 1024 * 1024 * 1024
|
||||
|
@ -33,13 +34,16 @@ def train_gpt(args):
|
|||
|
||||
# build model
|
||||
model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size)
|
||||
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device())
|
||||
label = torch.randint(low=0, high=128, size=(
|
||||
64,
|
||||
8,
|
||||
), device=get_current_device())
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
start_time = time.time()
|
||||
model = model_builder()
|
||||
model.train()
|
||||
param_size = parameter_size(model) / 1024 ** 2 / 2
|
||||
param_size = parameter_size(model) / 1024**2 / 2
|
||||
init_time = time.time() - start_time
|
||||
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
|
||||
|
||||
|
@ -74,21 +78,20 @@ def train_gpt(args):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
exec_time = sum(sorted(time_list)[:5]) / 5
|
||||
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
|
||||
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
|
||||
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2
|
||||
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2
|
||||
print(f'solver_type: {solver_type} | model_type: {model_type}')
|
||||
print(
|
||||
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
|
||||
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
|
||||
)
|
||||
print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
|
||||
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|')
|
||||
print(time_list)
|
||||
|
||||
|
||||
def run(rank, world_size, port, args):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
train_gpt(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
run_func = partial(run, world_size=1, port=free_port(), args=args)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
spawn(run, 1, args=args)
|
||||
|
|
|
@ -1,18 +1,13 @@
|
|||
from functools import partial
|
||||
from time import time
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from gpt_modules import GPT2LMHeadModel, GPTLMLoss
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch_from_torch
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
|
||||
|
|
|
@ -1,19 +1,14 @@
|
|||
import time
|
||||
from argparse import ArgumentParser
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
from bench_utils import bench, data_gen_resnet
|
||||
|
||||
import colossalai
|
||||
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
||||
from colossalai.fx import metainfo_trace, symbolic_trace
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def _benchmark(rank, world_size, port):
|
||||
|
@ -50,9 +45,7 @@ def _benchmark(rank, world_size, port):
|
|||
|
||||
|
||||
def auto_activation_checkpoint_batchsize_benchmark():
|
||||
world_size = 1
|
||||
run_func_module = partial(_benchmark, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_benchmark, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -4,14 +4,13 @@ from functools import partial
|
|||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium
|
||||
|
||||
import colossalai
|
||||
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
||||
from colossalai.fx import metainfo_trace, symbolic_trace
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def _benchmark(rank, world_size, port, args):
|
||||
|
@ -77,8 +76,7 @@ def _benchmark(rank, world_size, port, args):
|
|||
|
||||
def auto_activation_checkpoint_benchmark(args):
|
||||
world_size = 1
|
||||
run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args)
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_benchmark, world_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -12,3 +12,4 @@ contexttimer
|
|||
einops
|
||||
triton==2.0.0.dev20221202
|
||||
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
|
||||
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
import copy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp
|
||||
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
|
@ -87,10 +84,9 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_naive_amp():
|
||||
world_size = 1
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
import copy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp
|
||||
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
|
@ -87,10 +84,9 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_torch_amp():
|
||||
world_size = 1
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
from packaging import version
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from colossalai.testing.utils import parameterize
|
||||
from colossalai.testing.utils import clear_cache_before_run, parameterize
|
||||
|
||||
try:
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
|
@ -81,6 +81,7 @@ class AddmmModel(torch.nn.Module):
|
|||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@clear_cache_before_run()
|
||||
@parameterize("bias", [True, False])
|
||||
@parameterize("bias_addition_split", [True, False])
|
||||
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
try:
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
except:
|
||||
|
@ -62,9 +64,10 @@ class AModel(torch.nn.Module):
|
|||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
||||
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
||||
@clear_cache_before_run()
|
||||
@parameterize("bias", [True, False])
|
||||
@parameterize("bias_addition_split", [True, False])
|
||||
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
||||
def test_mod_dir(bias, bias_addition_split, shape):
|
||||
model = AModel(bias=bias)
|
||||
x = torch.rand(shape)
|
||||
|
@ -75,4 +78,4 @@ def test_mod_dir(bias, bias_addition_split, shape):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mod_dir(True, True, (3, 3, 3))
|
||||
test_mod_dir(bias=True, bias_addition_split=True, shape=(3, 3, 3))
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import pytest
|
||||
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
try:
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
|
@ -42,6 +44,7 @@ class MyModule(nn.Module):
|
|||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@clear_cache_before_run()
|
||||
def test_nested_ckpt():
|
||||
model = MyModule()
|
||||
x = torch.rand(10, 10)
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
import torchvision.models as tm
|
||||
from packaging import version
|
||||
|
||||
from colossalai.testing.utils import parameterize
|
||||
from colossalai.testing.utils import clear_cache_before_run, parameterize
|
||||
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
|
||||
|
||||
try:
|
||||
|
@ -32,6 +32,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
|
|||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@clear_cache_before_run()
|
||||
@parameterize('m', tm_models)
|
||||
def test_torchvision_shape_prop(m):
|
||||
with MetaTensorMode():
|
||||
|
@ -46,6 +47,7 @@ def test_torchvision_shape_prop(m):
|
|||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@clear_cache_before_run()
|
||||
@parameterize('m', tmm_models)
|
||||
def test_timm_shape_prop(m):
|
||||
with MetaTensorMode():
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
import torchvision.models as tm
|
||||
from packaging import version
|
||||
|
||||
from colossalai.testing.utils import parameterize
|
||||
from colossalai.testing.utils import clear_cache_before_run, parameterize
|
||||
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
|
||||
|
||||
try:
|
||||
|
@ -19,6 +19,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
|
|||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@clear_cache_before_run()
|
||||
@parameterize('m', tm_models)
|
||||
def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
|
||||
with MetaTensorMode():
|
||||
|
@ -33,6 +34,7 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
|
|||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@clear_cache_before_run()
|
||||
@parameterize('m', tmm_models)
|
||||
def test_timm_profile(m, verbose=False, bias_addition_split=False):
|
||||
with MetaTensorMode():
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from typing import Any, Callable, Union
|
||||
import pytest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
try:
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
except:
|
||||
|
@ -72,6 +74,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
|
|||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@clear_cache_before_run()
|
||||
def test_meta_aten():
|
||||
for (aten_op, requires_backward), v in registered_meta.items():
|
||||
for f, x in v:
|
||||
|
|
|
@ -4,6 +4,7 @@ import torch.nn.functional as F
|
|||
import torchvision.models as tm
|
||||
from packaging import version
|
||||
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
|
||||
|
||||
try:
|
||||
|
@ -39,7 +40,8 @@ odd_cases = [
|
|||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.parametrize('func, args, kwargs', odd_cases)
|
||||
@clear_cache_before_run()
|
||||
@parameterize('func, args, kwargs', odd_cases)
|
||||
def test_flop_count_function(func, args, kwargs):
|
||||
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
|
||||
assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}'
|
||||
|
|
|
@ -3,6 +3,8 @@ import torch
|
|||
import torchvision.models as tm
|
||||
from packaging import version
|
||||
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
try:
|
||||
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
|
||||
except:
|
||||
|
@ -30,7 +32,8 @@ def run_and_compare(model):
|
|||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tm_models + tmm_models)
|
||||
@clear_cache_before_run()
|
||||
@parameterize('m', tm_models + tmm_models)
|
||||
def test_meta_mode_shape(m):
|
||||
run_and_compare(m())
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ import copy
|
|||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
|
||||
import colossalai
|
||||
|
@ -13,7 +12,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
|
|||
# from colossalai.fx.passes.algorithms import solver_rotor
|
||||
# from colossalai.fx.passes.algorithms.operation import Sequence
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
@ -26,8 +25,8 @@ except:
|
|||
withcodegen = False
|
||||
|
||||
|
||||
def _run_C_solver_consistency_test(rank=0):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
def _run_C_solver_consistency_test(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
|
||||
model = M()
|
||||
|
@ -70,8 +69,9 @@ def _run_C_solver_consistency_test(rank=0):
|
|||
|
||||
@pytest.mark.skip("TODO(lyl): refactor all tests.")
|
||||
@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_C_solver_consistency():
|
||||
mp.spawn(_run_C_solver_consistency_test, nprocs=1)
|
||||
spawn(_run_C_solver_consistency_test, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -4,7 +4,6 @@ from typing import Callable
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
from torch.fx import GraphModule
|
||||
|
||||
|
@ -15,7 +14,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
|
|||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
# from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
@ -68,8 +67,8 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
|
|||
assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
|
||||
|
||||
|
||||
def _run_ckpt_solver(rank):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
def _run_ckpt_solver(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
MODEL_LIST = [tm.densenet121]
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
@ -98,12 +97,13 @@ def _run_ckpt_solver(rank):
|
|||
|
||||
@pytest.mark.skip("TODO(super-dainiu): refactor all tests.")
|
||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ckpt_solver():
|
||||
mp.spawn(_run_ckpt_solver, nprocs=1)
|
||||
spawn(_run_ckpt_solver, 1)
|
||||
|
||||
|
||||
def _run_ckpt_solver_torch11(rank):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
def _run_ckpt_solver_torch11(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
MODEL_LIST = [tm.densenet121]
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
@ -131,8 +131,9 @@ def _run_ckpt_solver_torch11(rank):
|
|||
|
||||
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
||||
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ckpt_solver_torch11():
|
||||
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
|
||||
spawn(_run_ckpt_solver_torch11, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -8,6 +8,7 @@ from colossalai.fx.graph_module import ColoGraphModule
|
|||
# from colossalai.fx.passes.algorithms import linearize, solver_rotor
|
||||
# from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
|
@ -24,6 +25,7 @@ except:
|
|||
@pytest.mark.skip(reason='TODO: modify the logger')
|
||||
@pytest.mark.skip("TODO(lyl): refactor all tests.")
|
||||
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
|
||||
@clear_cache_before_run()
|
||||
def test_linearize():
|
||||
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
||||
tracer = ColoTracer()
|
||||
|
@ -84,6 +86,7 @@ def test_linearize():
|
|||
@pytest.mark.skip("TODO(lyl): refactor all tests.")
|
||||
@pytest.mark.skip(reason="torch11 meta tensor not implemented")
|
||||
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
|
||||
@clear_cache_before_run()
|
||||
def test_linearize_torch11():
|
||||
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
||||
tracer = ColoTracer()
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import time
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
import colossalai
|
||||
|
@ -12,8 +10,8 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
|
|||
from colossalai.auto_parallel.offload.solver import NOT_NVML
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
|
||||
from tests.test_auto_parallel.test_offload.model_utils import *
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
|
@ -140,9 +138,9 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
@pytest.mark.skip("this test failed")
|
||||
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_perf():
|
||||
run_func = partial(run_dist, world_size=1, port=free_port())
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
spawn(run_dist, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -3,20 +3,20 @@ import torch.fx
|
|||
from torch.fx import GraphModule
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.auto_parallel.offload.region_manager import RegionManager
|
||||
from colossalai.auto_parallel.offload.solver import NOT_NVML, SolverFactory
|
||||
from colossalai.fx import ColoTracer, is_compatible_with_meta
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.auto_parallel.offload.region_manager import RegionManager
|
||||
from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
from tests.test_auto_parallel.test_offload.model_utils import *
|
||||
|
||||
|
||||
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
|
||||
@clear_cache_before_run()
|
||||
@parameterize('model_name', ['gpt2_', 'bert_'])
|
||||
@parameterize('memory_budget', [4000])
|
||||
@parameterize('solver_name', ['syn', 'asyn'])
|
||||
def solver_test(model_name: str,
|
||||
memory_budget: float,
|
||||
solver_name: str):
|
||||
def solver_test(model_name: str, memory_budget: float, solver_name: str):
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, data_gen = get_components_func()
|
||||
|
@ -52,11 +52,16 @@ def solver_test(model_name: str,
|
|||
for region in region_list:
|
||||
need_offload = region.need_offload
|
||||
to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None
|
||||
print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
|
||||
print(
|
||||
f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}'
|
||||
)
|
||||
for region in region_list.__reversed__():
|
||||
need_offload = region.need_offload
|
||||
to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None
|
||||
print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
|
||||
print(
|
||||
f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}'
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
solver_test()
|
||||
solver_test()
|
||||
|
|
|
@ -6,6 +6,7 @@ from colossalai.device.device_mesh import DeviceMesh
|
|||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
|
@ -26,6 +27,7 @@ def insert_narrow(gm, x_node):
|
|||
return gm
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def test_node_args_converting_pass():
|
||||
model = TestModule()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
|
|
@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
|||
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
|
@ -36,6 +37,7 @@ def recover_narrow(gm, narrow_node):
|
|||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@clear_cache_before_run()
|
||||
def test_size_value_converting_pass():
|
||||
model = TestModule()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
|
|
@ -2,7 +2,6 @@ from functools import partial
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
|
@ -13,9 +12,7 @@ except:
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
|
||||
|
||||
class LinearModel(torch.nn.Module):
|
||||
|
@ -86,11 +83,8 @@ def check_conv_module(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bias_addition_module():
|
||||
world_size = 4
|
||||
run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_linear, nprocs=world_size)
|
||||
run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_conv, nprocs=world_size)
|
||||
spawn(check_linear_module, 4)
|
||||
spawn(check_conv_module, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
from functools import partial
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
@ -17,9 +15,7 @@ except:
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
|
||||
HIDDEN_SIZE = 16
|
||||
|
||||
|
@ -65,9 +61,7 @@ def check_act_ckpt(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mlp_layer():
|
||||
world_size = 4
|
||||
run_func = partial(check_act_ckpt, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_act_ckpt, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import copy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
try:
|
||||
|
@ -15,9 +13,7 @@ except:
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
@ -102,9 +98,7 @@ def check_compatibility_with_ddp(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_compatibility_with_ddp():
|
||||
world_size = 4
|
||||
run_func = partial(check_compatibility_with_ddp, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_compatibility_with_ddp, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
import copy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
try:
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
|
@ -17,10 +14,9 @@ from colossalai.initialize import launch
|
|||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor.process_group import ProcessGroup
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
@ -110,9 +106,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_auto_parallel_with_gemini():
|
||||
world_size = 4
|
||||
run_func = partial(check_auto_parallel_with_gemini, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_auto_parallel_with_gemini, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -10,8 +10,7 @@ from colossalai._analyzer.fx.passes import shape_prop_pass
|
|||
# from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, run_on_environment_flag
|
||||
|
||||
NUM_REPEAT_BLOCKS = 4
|
||||
BATCH_SIZE = 1
|
||||
|
@ -81,6 +80,7 @@ class NonRepeatModel(nn.Module):
|
|||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
@parameterize('model_cls', [RepeatModel, NonRepeatModel])
|
||||
def test_repeat_blocks(model_cls):
|
||||
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
import copy
|
||||
import random
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import transformers
|
||||
from torch.fx import GraphModule
|
||||
|
||||
|
@ -30,9 +28,8 @@ from colossalai.device.device_mesh import DeviceMesh
|
|||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.shape_consistency import to_global
|
||||
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
|
||||
|
||||
BATCH_SIZE = 1
|
||||
|
@ -190,9 +187,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
|
|||
@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mlp_layer(model_cls):
|
||||
world_size = 4
|
||||
run_func = partial(check_attention_layer, model_cls=model_cls, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_attention_layer, 4, model_cls=model_cls)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from torch.fx import GraphModule
|
||||
|
||||
|
@ -7,10 +6,10 @@ from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
|||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
|
||||
|
||||
|
@ -20,6 +19,7 @@ HIDDEN_DIM = 384
|
|||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
|
||||
def test_self_attention_block(model_cls):
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM)
|
||||
|
|
|
@ -6,6 +6,8 @@ from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
|||
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class LinearModel(nn.Module):
|
||||
|
@ -26,6 +28,7 @@ class LinearModel(nn.Module):
|
|||
|
||||
|
||||
@pytest.mark.skip('meta tensor has some bugs in 1.11')
|
||||
@clear_cache_before_run()
|
||||
def test_liveness_analysis():
|
||||
model = LinearModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
|
|
@ -1,23 +1,14 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import meta_register
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
|
||||
from colossalai.testing.utils import clear_cache_before_run, parameterize
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('func', [
|
||||
torch.nn.functional.softmax,
|
||||
torch.nn.functional.relu,
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@ -10,8 +7,7 @@ from colossalai.fx import ColoGraphModule, ColoTracer
|
|||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -62,9 +58,7 @@ def _binary_elementwise_mem_test(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_binary_elementwise_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_binary_elementwise_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_binary_elementwise_mem_test, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,17 +1,12 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -25,7 +20,7 @@ class ConvFunctionModule(nn.Module):
|
|||
return nn.functional.conv2d(input, self.conv_weight)
|
||||
|
||||
|
||||
def _conv_module_mem_test(rank, bias, world_size, port):
|
||||
def _conv_module_mem_test(rank, world_size, port, bias):
|
||||
"""This function is for conv memory test
|
||||
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
|
||||
|
||||
|
@ -62,9 +57,7 @@ def _conv_module_mem_test(rank, bias, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_conv_meta_concrete_info_match(bias=False):
|
||||
world_size = 4
|
||||
run_func_module = partial(_conv_module_mem_test, bias=bias, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_conv_module_mem_test, 4, bias=bias)
|
||||
|
||||
|
||||
def _conv_function_mem_test(rank, world_size, port):
|
||||
|
@ -103,9 +96,7 @@ def _conv_function_mem_test(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_conv_function_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_conv_function_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_conv_function_mem_test, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,33 +1,16 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
|
||||
from colossalai.testing.utils import clear_cache_before_run
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||
from colossalai.auto_parallel.meta_profiler import meta_register
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
@clear_cache_before_run()
|
||||
def test_embedding_meta_info():
|
||||
meta_func = meta_register.get(torch.nn.Embedding)
|
||||
|
||||
|
|
|
@ -1,24 +1,14 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||
|
||||
|
||||
class MyModule(nn.Module):
|
||||
|
||||
|
@ -63,9 +53,7 @@ def _linear_module_mem_test(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_module_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_linear_module_mem_test, 4)
|
||||
|
||||
|
||||
def _linear_function_mem_test(rank, world_size, port):
|
||||
|
@ -101,9 +89,7 @@ def _linear_function_mem_test(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_function_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_linear_function_mem_test, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,26 +1,8 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
|
||||
from colossalai.testing.utils import clear_cache_before_run, parameterize
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
|
@ -28,6 +10,7 @@ if torch.__version__ >= '1.12.0':
|
|||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
@clear_cache_before_run()
|
||||
@parameterize(
|
||||
'tensor_shapes',
|
||||
[
|
||||
|
|
|
@ -1,29 +1,17 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||
from colossalai.auto_parallel.meta_profiler import meta_register
|
||||
|
||||
|
||||
def _batchnorm_module_mem_test(rank, world_size, port):
|
||||
|
@ -62,9 +50,7 @@ def _batchnorm_module_mem_test(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_batchnorm_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_batchnorm_module_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_batchnorm_module_mem_test, 4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations')
|
||||
|
|
|
@ -1,17 +1,12 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -51,9 +46,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_adaptiveavgpool_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_adaptiveavgpool_module_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_adaptiveavgpool_module_mem_test, 4)
|
||||
|
||||
|
||||
def _maxpool_module_mem_test(rank, world_size, port):
|
||||
|
@ -92,9 +85,7 @@ def _maxpool_module_mem_test(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_maxpool_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_maxpool_module_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_maxpool_module_mem_test, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,26 +1,9 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
|
||||
from colossalai.testing.utils import clear_cache_before_run
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
|
@ -37,6 +20,7 @@ class SplitModule(nn.Module):
|
|||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
@clear_cache_before_run()
|
||||
def test_tensor_meta_info():
|
||||
"""test tensor related meta information
|
||||
We will just use torch.Tensor.split for the test
|
||||
|
|
|
@ -1,24 +1,8 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
|
||||
from colossalai.testing.utils import clear_cache_before_run
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
|
@ -26,6 +10,7 @@ if torch.__version__ >= '1.12.0':
|
|||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
@clear_cache_before_run()
|
||||
def test_where_meta_info():
|
||||
meta_func = meta_register.get(torch.where)
|
||||
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
|
||||
|
@ -11,9 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh
|
|||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -45,7 +40,7 @@ class AddBMMTorchFunctionModule(nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
|
||||
def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = module(using_kwargs).cuda()
|
||||
|
@ -249,14 +244,13 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por
|
|||
@parameterize('using_kwargs', [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_2d_device_mesh(module, bias_shape, using_kwargs):
|
||||
world_size = 4
|
||||
run_func = partial(check_2d_device_mesh,
|
||||
module=module,
|
||||
bias_shape=bias_shape,
|
||||
world_size=world_size,
|
||||
using_kwargs=using_kwargs,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(
|
||||
check_2d_device_mesh,
|
||||
4,
|
||||
module=module,
|
||||
bias_shape=bias_shape,
|
||||
using_kwargs=using_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip("skip due to bias cases not ready")
|
||||
|
@ -267,14 +261,13 @@ def test_2d_device_mesh(module, bias_shape, using_kwargs):
|
|||
@parameterize('using_kwargs', [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_1d_device_mesh(module, bias_shape, using_kwargs):
|
||||
world_size = 4
|
||||
run_func = partial(check_1d_device_mesh,
|
||||
module=module,
|
||||
bias_shape=bias_shape,
|
||||
using_kwargs=using_kwargs,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(
|
||||
check_1d_device_mesh,
|
||||
4,
|
||||
module=module,
|
||||
bias_shape=bias_shape,
|
||||
using_kwargs=using_kwargs,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -17,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -45,7 +40,7 @@ class AddmmModel_with_param(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port):
|
||||
def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
if model_cls == AddmmModel:
|
||||
|
@ -189,13 +184,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
|
|||
@parameterize('model_cls', [AddmmModel, AddmmModel_with_param])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_addmm_handler(input_shape, model_cls):
|
||||
world_size = 4
|
||||
run_func_function = partial(check_addmm_function_handler,
|
||||
input_shape=input_shape,
|
||||
model_cls=model_cls,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_function, nprocs=world_size)
|
||||
spawn(check_addmm_function_handler, 4, input_shape=input_shape, model_cls=model_cls)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -114,9 +109,7 @@ def check_bn_module_handler(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bn_module_handler():
|
||||
world_size = 4
|
||||
run_func = partial(check_bn_module_handler, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_bn_module_handler, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,9 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -19,9 +15,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
WEIGHT_SHAPE = (32, 16)
|
||||
|
@ -168,9 +162,7 @@ def check_linear_module_handler(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_handler():
|
||||
world_size = 4
|
||||
run_func_module = partial(check_linear_module_handler, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(check_linear_module_handler)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
|
@ -18,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -35,7 +29,7 @@ class LinearModule(torch.nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def check_linear_module_handler(rank, bias, world_size, port):
|
||||
def check_linear_module_handler(rank, world_size, port, bias):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = LinearModule(16, 32, bias=bias).cuda()
|
||||
|
@ -157,9 +151,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_handler(bias=True):
|
||||
world_size = 4
|
||||
run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(check_linear_module_handler, bias=bias)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size, port):
|
||||
def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
|
@ -149,7 +144,7 @@ class BEOpModelWithIntConst(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port):
|
||||
def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
|
@ -236,13 +231,12 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_binary_elementwise_handler_with_tensor(op, other_dim):
|
||||
world_size = 4
|
||||
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor,
|
||||
op=op,
|
||||
other_dim=other_dim,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_tensor, nprocs=world_size)
|
||||
spawn(
|
||||
check_binary_elementwise_handler_with_tensor,
|
||||
4,
|
||||
op=op,
|
||||
other_dim=other_dim,
|
||||
)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
|
@ -252,14 +246,13 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_binary_elementwise_handler_with_int(op, model_cls, other_dim):
|
||||
world_size = 4
|
||||
run_func_int = partial(check_binary_elementwise_handler_with_int,
|
||||
op=op,
|
||||
model_cls=model_cls,
|
||||
other_dim=other_dim,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_int, nprocs=world_size)
|
||||
spawn(
|
||||
check_binary_elementwise_handler_with_int,
|
||||
4,
|
||||
op=op,
|
||||
model_cls=model_cls,
|
||||
other_dim=other_dim,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -207,11 +202,8 @@ def check_1d_device_mesh(rank, module, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bmm_handler(module):
|
||||
world_size = 4
|
||||
run_func_2d = partial(check_2d_device_mesh, module=module, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_2d, nprocs=world_size)
|
||||
run_func_1d = partial(check_1d_device_mesh, module=module, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_1d, nprocs=world_size)
|
||||
spawn(check_2d_device_mesh, 4, module=module)
|
||||
spawn(check_1d_device_mesh, 4, module=module)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
def check_conv_module_handler(rank, bias, world_size, port):
|
||||
def check_conv_module_handler(rank, world_size, port, bias):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda()
|
||||
|
@ -155,7 +150,7 @@ class ConvModel(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def check_conv_function_handler(rank, bias, world_size, port):
|
||||
def check_conv_function_handler(rank, world_size, port, bias):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = ConvModel().cuda()
|
||||
|
@ -302,9 +297,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
|
|||
# @parameterize('bias', [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_conv_module_handler(bias=False):
|
||||
world_size = 4
|
||||
run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_conv_module_handler, 4, bias=bias)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
|
@ -314,9 +307,7 @@ def test_conv_module_handler(bias=False):
|
|||
# @parameterize('bias', [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_conv_function_handler(bias=False):
|
||||
world_size = 4
|
||||
run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_conv_function_handler, 4, bias=bias)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHan
|
|||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
|
||||
|
||||
|
||||
class ReshapeModel(nn.Module):
|
||||
|
@ -23,6 +23,7 @@ class ReshapeModel(nn.Module):
|
|||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
def test_reshape_handler():
|
||||
model = ReshapeModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -16,9 +13,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
NUM_EMBEDDINGS = 16
|
||||
|
@ -272,18 +268,14 @@ def check_embedding_function_handler(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_embedding_module_handler():
|
||||
world_size = 4
|
||||
run_func = partial(check_embedding_module_handler, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_embedding_module_handler, 4)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_embedding_function_handler():
|
||||
world_size = 4
|
||||
run_func = partial(check_embedding_function_handler, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_embedding_function_handler, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
|||
from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class GetattrModel(nn.Module):
|
||||
|
@ -22,6 +23,7 @@ class GetattrModel(nn.Module):
|
|||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@clear_cache_before_run()
|
||||
def test_getattr_handler():
|
||||
model = GetattrModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
|
|
@ -2,7 +2,6 @@ from functools import partial
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -14,12 +13,10 @@ from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import Li
|
|||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -103,12 +100,7 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
|
|||
# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))])
|
||||
@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))])
|
||||
def test_getitem_from_tensor_handler(getitem_index):
|
||||
world_size = 4
|
||||
run_func = partial(check_getitem_from_tensor_handler,
|
||||
getitem_index=getitem_index,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_getitem_from_tensor_handler, 4)
|
||||
|
||||
|
||||
class GetItemFromTupleModel(nn.Module):
|
||||
|
@ -123,6 +115,7 @@ class GetItemFromTupleModel(nn.Module):
|
|||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
def test_getitem_from_tuple_handler():
|
||||
model = GetItemFromTupleModel()
|
||||
tracer = ColoTracer()
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -11,12 +8,10 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
|||
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -104,9 +99,7 @@ def check_ln_module_handler(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ln_module_handler():
|
||||
world_size = 4
|
||||
run_func = partial(check_ln_module_handler, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_ln_module_handler, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -18,14 +15,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
def check_linear_module_handler(rank, bias, input_shape, world_size, port):
|
||||
def check_linear_module_handler(rank, world_size, port, bias, input_shape):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
|
||||
|
@ -172,7 +168,7 @@ class LinearModel(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def check_linear_function_handler(rank, bias, input_shape, world_size, port):
|
||||
def check_linear_function_handler(rank, world_size, port, bias, input_shape):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = LinearModel().cuda()
|
||||
|
@ -313,19 +309,18 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_handler(input_shape, bias=False):
|
||||
world_size = 4
|
||||
run_func_module = partial(check_linear_module_handler,
|
||||
bias=bias,
|
||||
input_shape=input_shape,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
run_func_function = partial(check_linear_function_handler,
|
||||
bias=bias,
|
||||
input_shape=input_shape,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_function, nprocs=world_size)
|
||||
spawn(
|
||||
check_linear_module_handler,
|
||||
4,
|
||||
bias=bias,
|
||||
input_shape=input_shape,
|
||||
)
|
||||
spawn(
|
||||
check_linear_function_handler,
|
||||
4,
|
||||
bias=bias,
|
||||
input_shape=input_shape,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -18,7 +18,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
|||
StrategiesVector,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing.utils import parameterize
|
||||
from colossalai.testing.utils import clear_cache_before_run, parameterize
|
||||
|
||||
|
||||
class MatMulModule(nn.Module):
|
||||
|
@ -28,6 +28,7 @@ class MatMulModule(nn.Module):
|
|||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
@clear_cache_before_run()
|
||||
@parameterize(
|
||||
'tensor_shapes',
|
||||
[
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -8,11 +7,11 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
|||
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
def test_norm_pool_handler():
|
||||
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
|||
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
|
||||
class OutputModel(nn.Module):
|
||||
|
@ -23,7 +23,7 @@ class OutputModel(nn.Module):
|
|||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@parameterize('output_option', ['distributed', 'replicated'])
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_output_handler(output_option):
|
||||
model = OutputModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
|
|
@ -2,7 +2,6 @@ from functools import partial
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -15,9 +14,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -55,7 +53,7 @@ class LinearReshapeModel(nn.Module):
|
|||
return permute_node
|
||||
|
||||
|
||||
def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, port):
|
||||
def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
if call_function == torch.permute:
|
||||
|
@ -328,14 +326,13 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
|
|||
@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))])
|
||||
@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel])
|
||||
def test_view_handler(call_function, reshape_dims, model_cls):
|
||||
world_size = 4
|
||||
run_func = partial(check_view_handler,
|
||||
call_function=call_function,
|
||||
reshape_dims=reshape_dims,
|
||||
model_cls=model_cls,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(
|
||||
check_view_handler,
|
||||
4,
|
||||
call_function=call_function,
|
||||
reshape_dims=reshape_dims,
|
||||
model_cls=model_cls,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
|||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
|
||||
class PlaceholderModel(nn.Module):
|
||||
|
@ -22,7 +22,7 @@ class PlaceholderModel(nn.Module):
|
|||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@parameterize('placeholder_option', ['distributed', 'replicated'])
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_placeholder_handler(placeholder_option):
|
||||
model = PlaceholderModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -9,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHan
|
|||
from colossalai.auto_parallel.tensor_shard.options import ShardOption
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
|
||||
|
||||
|
||||
class LinearModel(nn.Module):
|
||||
|
@ -108,6 +107,7 @@ def check_shard_option(shard_option):
|
|||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
def test_shard_option():
|
||||
# for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]:
|
||||
for shard_option in [ShardOption.SHARD_LAST_AXIS]:
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -33,7 +28,7 @@ class LinearSplitModel(nn.Module):
|
|||
return softmax_node
|
||||
|
||||
|
||||
def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
|
||||
def check_split_handler(rank, world_size, port, softmax_dim, model_cls):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = model_cls(softmax_dim=softmax_dim).cuda()
|
||||
|
@ -176,13 +171,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
|
|||
@parameterize('softmax_dim', [0, 1, 2, 3])
|
||||
@parameterize('model_cls', [LinearSplitModel])
|
||||
def test_split_handler(softmax_dim, model_cls):
|
||||
world_size = 4
|
||||
run_func = partial(check_split_handler,
|
||||
softmax_dim=softmax_dim,
|
||||
model_cls=model_cls,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_split_handler, 4, softmax_dim=softmax_dim, model_cls=model_cls)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -47,7 +42,7 @@ class LinearSplitModel(nn.Module):
|
|||
return split_node
|
||||
|
||||
|
||||
def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port):
|
||||
def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = model_cls(split_size=split_size, split_dim=split_dim).cuda()
|
||||
|
@ -258,14 +253,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
|
|||
@parameterize('split_dim', [0, 1, 2])
|
||||
@parameterize('model_cls', [ConvSplitModel, LinearSplitModel])
|
||||
def test_split_handler(split_size, split_dim, model_cls):
|
||||
world_size = 4
|
||||
run_func = partial(check_split_handler,
|
||||
split_size=split_size,
|
||||
split_dim=split_dim,
|
||||
model_cls=model_cls,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_split_handler, 4, split_size=split_size, split_dim=split_dim, model_cls=model_cls)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -14,9 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -36,7 +31,7 @@ class LinearSumModel(nn.Module):
|
|||
return sum_node
|
||||
|
||||
|
||||
def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
|
||||
def check_sum_handler(rank, world_size, port, sum_dims, keepdim):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda()
|
||||
|
@ -228,9 +223,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
|
|||
@parameterize('sum_dims', [(0, 2), 1])
|
||||
@parameterize('keepdim', [False, True])
|
||||
def test_sum_handler(sum_dims, keepdim):
|
||||
world_size = 4
|
||||
run_func = partial(check_sum_handler, sum_dims=sum_dims, keepdim=keepdim, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_sum_handler, 4, sum_dims=sum_dims, keepdim=keepdim)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -7,7 +7,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
|||
from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
|
||||
|
||||
|
||||
class TensorConstructorModel(nn.Module):
|
||||
|
@ -22,6 +22,7 @@ class TensorConstructorModel(nn.Module):
|
|||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
def test_where_handler():
|
||||
model = TensorConstructorModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
|
|||
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
|
||||
|
||||
|
||||
class ReLuModel(nn.Module):
|
||||
|
@ -24,6 +24,7 @@ class ReLuModel(nn.Module):
|
|||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
def test_elementwise_handler():
|
||||
model = ReLuModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
|
@ -15,9 +12,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
|
@ -255,13 +251,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
|
|||
@parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)])
|
||||
@parameterize('model_cls', [ConvViewModel, LinearViewModel])
|
||||
def test_view_handler(tgt_shape, model_cls):
|
||||
world_size = 4
|
||||
run_func = partial(check_view_handler,
|
||||
tgt_shape=tgt_shape,
|
||||
model_cls=model_cls,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_view_handler, 4, tgt_shape=tgt_shape, model_cls=model_cls)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
|||
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
@ -21,6 +22,7 @@ class ConvModel(nn.Module):
|
|||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@clear_cache_before_run()
|
||||
def test_where_handler():
|
||||
model = ConvModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
|
|
@ -10,10 +10,11 @@ from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
|||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
def test_cost_graph():
|
||||
physical_mesh_id = torch.arange(0, 8)
|
||||
mesh_shape = (2, 4)
|
||||
|
|
|
@ -8,7 +8,7 @@ import colossalai
|
|||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import free_port
|
||||
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.autochunk.utils import flat_list
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import free_port
|
||||
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
from functools import partial
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from fastfold.model.nn.evoformer import EvoformerBlock
|
||||
|
@ -15,6 +13,7 @@ except:
|
|||
from test_autochunk_alphafold_utils import run_test
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, spawn
|
||||
|
||||
|
||||
def get_model():
|
||||
|
@ -66,18 +65,19 @@ def get_chunk_target() -> Dict:
|
|||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 24])
|
||||
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
|
||||
@clear_cache_before_run()
|
||||
@parameterize("max_memory", [None, 20, 24])
|
||||
@parameterize("data_args", [(32, 64)])
|
||||
def test_evoformer_block(data_args, max_memory):
|
||||
run_func = partial(
|
||||
spawn(
|
||||
run_test,
|
||||
1,
|
||||
data_args=data_args,
|
||||
max_memory=max_memory,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
get_chunk_target=get_chunk_target,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from fastfold.model.nn.evoformer import EvoformerStack
|
||||
|
@ -15,6 +13,7 @@ except:
|
|||
from test_autochunk_alphafold_utils import run_test
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, spawn
|
||||
|
||||
|
||||
def get_model():
|
||||
|
@ -61,17 +60,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
|||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 24])
|
||||
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
|
||||
@clear_cache_before_run()
|
||||
@parameterize("max_memory", [None, 20, 24])
|
||||
@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len)
|
||||
def test_evoformer_stack(data_args, max_memory):
|
||||
run_func = partial(
|
||||
spawn(
|
||||
run_test,
|
||||
1,
|
||||
data_args=data_args,
|
||||
max_memory=max_memory,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
from functools import partial
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from fastfold.model.nn.evoformer import ExtraMSABlock
|
||||
|
@ -14,6 +12,7 @@ except:
|
|||
from test_autochunk_alphafold_utils import run_test
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, spawn
|
||||
|
||||
|
||||
def get_model():
|
||||
|
@ -57,17 +56,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
|||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 24])
|
||||
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
|
||||
@clear_cache_before_run()
|
||||
@parameterize("max_memory", [None, 20, 24])
|
||||
@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len)
|
||||
def test_extramsa_block(data_args, max_memory):
|
||||
run_func = partial(
|
||||
spawn(
|
||||
run_test,
|
||||
1,
|
||||
data_args=data_args,
|
||||
max_memory=max_memory,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import free_port
|
||||
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from diffusers import UNet2DModel
|
||||
|
@ -16,6 +14,7 @@ except:
|
|||
from test_autochunk_diffuser_utils import run_test
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, spawn
|
||||
|
||||
BATCH_SIZE = 1
|
||||
HEIGHT = 448
|
||||
|
@ -37,17 +36,18 @@ def get_data(shape: tuple) -> Tuple[List, List]:
|
|||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
|
||||
@pytest.mark.parametrize("max_memory", [None, 150, 300])
|
||||
@clear_cache_before_run()
|
||||
@parameterize("model", MODELS)
|
||||
@parameterize("shape", [LATENTS_SHAPE])
|
||||
@parameterize("max_memory", [None, 150, 300])
|
||||
def test_evoformer_block(model, shape, max_memory):
|
||||
run_func = partial(
|
||||
spawn(
|
||||
run_test,
|
||||
1,
|
||||
max_memory=max_memory,
|
||||
model=model,
|
||||
data=get_data(shape),
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
|
@ -16,6 +14,7 @@ except:
|
|||
from test_autochunk_transformer_utils import run_test
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, spawn
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGTH = 512
|
||||
|
@ -35,18 +34,19 @@ def get_data(shape: tuple) -> Tuple[List, List]:
|
|||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
|
||||
@pytest.mark.parametrize("max_memory", [None, 6, 8])
|
||||
@clear_cache_before_run()
|
||||
@parameterize("model", MODELS)
|
||||
@parameterize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
|
||||
@parameterize("max_memory", [None, 6, 8])
|
||||
def test_autochunk_gpt(model, shape, max_memory):
|
||||
run_func = partial(
|
||||
spawn(
|
||||
run_test,
|
||||
1,
|
||||
data=get_data(shape),
|
||||
max_memory=max_memory,
|
||||
model=model,
|
||||
config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4),
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import free_port
|
||||
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from timm.models.vision_transformer import vit_large_patch16_384 as vit
|
||||
|
@ -16,6 +14,7 @@ except:
|
|||
from test_autochunk_vit_utils import run_test
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, spawn
|
||||
|
||||
|
||||
def get_data() -> Tuple[List, List]:
|
||||
|
@ -28,16 +27,17 @@ def get_data() -> Tuple[List, List]:
|
|||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_memory", [None, 32, 40])
|
||||
@clear_cache_before_run()
|
||||
@parameterize("model", MODELS)
|
||||
@parameterize("max_memory", [None, 32, 40])
|
||||
def test_evoformer_block(model, max_memory):
|
||||
run_func = partial(
|
||||
spawn(
|
||||
run_test,
|
||||
1,
|
||||
max_memory=max_memory,
|
||||
model=model,
|
||||
data=get_data(),
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import free_port
|
||||
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
|
|
|
@ -1,27 +1,14 @@
|
|||
from functools import partial
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.booster.accelerator import Accelerator
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('device', ['cpu', 'cuda'])
|
||||
def run_accelerator(device):
|
||||
def test_accelerator(device):
|
||||
acceleartor = Accelerator(device)
|
||||
model = nn.Linear(8, 8)
|
||||
model = acceleartor.configure_model(model)
|
||||
assert next(model.parameters()).device.type == device
|
||||
del model, acceleartor
|
||||
|
||||
|
||||
def run_dist(rank):
|
||||
run_accelerator()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_accelerator():
|
||||
world_size = 1
|
||||
run_func = partial(run_dist)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
|
|
@ -1,13 +1,9 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.optim import Adam
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.mixed_precision import FP16TorchMixedPrecision
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
|
@ -41,6 +37,4 @@ def run_torch_amp(rank, world_size, port):
|
|||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_torch_ddp_plugin():
|
||||
world_size = 1
|
||||
run_func = partial(run_torch_amp, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_torch_amp, 1)
|
||||
|
|
|
@ -1,17 +1,12 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
|
@ -119,9 +114,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
|
|||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_plugin(early_stop: bool = True):
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), early_stop=early_stop)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, 2, early_stop=early_stop)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import SGD
|
||||
|
||||
|
@ -10,8 +7,7 @@ import colossalai
|
|||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
|
@ -103,6 +99,4 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_torch_ddp_plugin():
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, 2)
|
||||
|
|
|
@ -6,6 +6,7 @@ from torch.optim import Adam
|
|||
from torchvision.models import resnet18
|
||||
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
# ========
|
||||
# Note:
|
||||
|
@ -15,7 +16,8 @@ from colossalai.checkpoint_io import GeneralCheckpointIO
|
|||
# ========
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_safetensors', [True, False])
|
||||
@clear_cache_before_run()
|
||||
@parameterize('use_safetensors', [True, False])
|
||||
def test_unsharded_checkpoint(use_safetensors: bool):
|
||||
# create a model and optimizer
|
||||
model = resnet18()
|
||||
|
|
|
@ -1,14 +1,9 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def check_device_mesh_manager(rank, world_size, port):
|
||||
|
@ -31,9 +26,7 @@ def check_device_mesh_manager(rank, world_size, port):
|
|||
|
||||
|
||||
def test_device_mesh_manager():
|
||||
world_size = 4
|
||||
run_func = partial(check_device_mesh_manager, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_device_mesh_manager, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,17 +1,12 @@
|
|||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.communication.p2p_v2 import _send_object, _recv_object, init_process_group
|
||||
|
||||
from colossalai.communication.p2p_v2 import _recv_object, _send_object
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
disable_existing_loggers()
|
||||
world_size = 4
|
||||
|
@ -45,9 +40,7 @@ def check_layer(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_object_list_p2p():
|
||||
disable_existing_loggers()
|
||||
run_func = partial(check_layer, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_layer, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.communication import all_gather, all_reduce, reduce_scatter
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
|
||||
|
||||
|
@ -66,9 +64,7 @@ def check_layer(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_comm():
|
||||
world_size = 4
|
||||
run_func = partial(check_layer, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_layer, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.communication.p2p import send_forward, recv_forward, send_backward, recv_backward, send_forward_recv_backward, send_backward_recv_forward
|
||||
|
||||
from colossalai.communication.p2p import (
|
||||
recv_backward,
|
||||
recv_forward,
|
||||
send_backward,
|
||||
send_backward_recv_forward,
|
||||
send_forward,
|
||||
send_forward_recv_backward,
|
||||
)
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(parallel=dict(pipeline=2))
|
||||
torch.manual_seed(123)
|
||||
|
@ -96,9 +99,7 @@ def check_layer(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_object_list_p2p():
|
||||
world_size = 2
|
||||
run_func = partial(check_layer, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_layer, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue