[hotfix] skip auto checkpointing tests (#3029)

* [hotfix] skip auto checkpointing tests

* fix test name issue
pull/3032/head
YuliangLiu0306 2023-03-07 15:50:00 +08:00 committed by GitHub
parent 8fedc8766a
commit 4269196c79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 9 deletions

View File

@ -1,16 +1,17 @@
import copy import copy
import colossalai
import pytest import pytest
import torch import torch
import torch.fx import torch.fx
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
import colossalai
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.algorithms import solver_rotor # from colossalai.fx.passes.algorithms import solver_rotor
from colossalai.fx.passes.algorithms.operation import Sequence # from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.utils import free_port
@ -67,6 +68,7 @@ def _run_C_solver_consistency_test(rank=0):
gpc.destroy() gpc.destroy()
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0") @pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0")
def test_C_solver_consistency(): def test_C_solver_consistency():
mp.spawn(_run_C_solver_consistency_test, nprocs=1) mp.spawn(_run_C_solver_consistency_test, nprocs=1)

View File

@ -13,7 +13,7 @@ from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor # from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.utils import free_port
@ -28,7 +28,8 @@ except:
from colossalai.fx.codegen import python_code_with_activation_checkpoint from colossalai.fx.codegen import python_code_with_activation_checkpoint
with_codegen = False with_codegen = False
SOLVERS = [chen_greedy, solver_rotor] # SOLVERS = [chen_greedy, solver_rotor]
SOLVERS = []
def _is_activation_checkpoint_available(gm: GraphModule): def _is_activation_checkpoint_available(gm: GraphModule):

View File

@ -1,11 +1,12 @@
import pytest import pytest
import torch import torch
import torchvision.models as tm import torchvision.models as tm
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import linearize, solver_rotor # from colossalai.fx.passes.algorithms import linearize, solver_rotor
from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss) # from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
if is_compatible_with_meta(): if is_compatible_with_meta():
@ -21,6 +22,7 @@ except:
@pytest.mark.skip(reason='TODO: modify the logger') @pytest.mark.skip(reason='TODO: modify the logger')
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
def test_linearize(): def test_linearize():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
@ -79,6 +81,7 @@ def test_linearize():
del node_list del node_list
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skip(reason="torch11 meta tensor not implemented") @pytest.mark.skip(reason="torch11 meta tensor not implemented")
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") @pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
def test_linearize_torch11(): def test_linearize_torch11():

View File

@ -4,7 +4,7 @@ from functools import reduce
from colossalai.tensor.d_tensor.sharding_spec import ALLGATHER_COST, SHARD_COST, STEP_PENALTY, ShardingSpec from colossalai.tensor.d_tensor.sharding_spec import ALLGATHER_COST, SHARD_COST, STEP_PENALTY, ShardingSpec
def test_sharding_spec(): def test_dtensor_sharding_spec():
dims = 4 dims = 4
dim_partition_dict_0 = {0: [0, 1]} dim_partition_dict_0 = {0: [0, 1]}
# DistSpec: # DistSpec:
@ -31,4 +31,4 @@ def test_sharding_spec():
if __name__ == '__main__': if __name__ == '__main__':
test_sharding_spec() test_dtensor_sharding_spec()