mirror of https://github.com/hpcaitech/ColossalAI
[test] reorganize zero/gemini tests (#3445)
parent
72cb4dd433
commit
933048ad3e
|
@ -14,7 +14,7 @@ from colossalai.utils import free_port, get_current_device
|
|||
from colossalai.zero import ColoInitContext
|
||||
from tests.test_moe.test_moe_zero_init import MoeModel
|
||||
from tests.test_tensor.common_utils import debug_print
|
||||
from tests.test_zero.common import CONFIG
|
||||
from tests.test_zero.test_legacy.common import CONFIG
|
||||
|
||||
|
||||
def exam_moe_checkpoint():
|
||||
|
|
|
@ -13,7 +13,7 @@ from colossalai.utils import free_port, get_current_device
|
|||
from colossalai.zero import ColoInitContext
|
||||
from tests.test_moe.test_moe_zero_init import MoeModel
|
||||
from tests.test_tensor.common_utils import debug_print
|
||||
from tests.test_zero.common import CONFIG
|
||||
from tests.test_zero.test_legacy.common import CONFIG
|
||||
|
||||
|
||||
@parameterize("init_device_type", ['cpu', 'cuda'])
|
||||
|
|
|
@ -14,7 +14,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
|||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.zero.legacy.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
|
||||
from tests.test_zero.common import CONFIG
|
||||
from tests.test_zero.test_legacy.common import CONFIG
|
||||
|
||||
|
||||
class MoeModel(nn.Module):
|
||||
|
|
|
@ -17,7 +17,7 @@ from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
|
|||
from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_moe.test_moe_zero_init import MoeModel
|
||||
from tests.test_zero.common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||
from tests.test_zero.test_legacy.common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||
|
||||
|
||||
@parameterize("enable_autocast", [False])
|
||||
|
|
|
@ -20,7 +20,7 @@ from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2
|
|||
from colossalai.zero.low_level._utils import has_inf_or_nan
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_moe.test_moe_zero_init import MoeModel
|
||||
from tests.test_zero.common import CONFIG, check_sharded_model_params
|
||||
from tests.test_zero.test_legacy.common import CONFIG, check_sharded_model_params
|
||||
|
||||
|
||||
def _run_step(model, optimizer, data, label, criterion, grad_handler):
|
||||
|
|
|
@ -4,6 +4,7 @@ import pytest
|
|||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from common import CONFIG
|
||||
from test_sharded_optim_v2 import _run_step
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
@ -16,7 +17,6 @@ from colossalai.zero.legacy.sharded_model import ShardedModelV2
|
|||
from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2
|
||||
from colossalai.zero.low_level._utils import has_inf_or_nan
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_zero.test_sharded_optim_v2 import _run_step
|
||||
|
||||
|
||||
@parameterize("cpu_offload", [True, False])
|
|
@ -4,6 +4,7 @@ from functools import partial
|
|||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from common import CONFIG, allclose
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
|
@ -12,7 +13,6 @@ from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor
|
|||
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
|
||||
from colossalai.zero.legacy.sharded_param import ShardedTensor
|
||||
from colossalai.zero.legacy.sharded_param.sharded_param import ShardedParamV2
|
||||
from tests.test_zero.common import CONFIG, allclose
|
||||
|
||||
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
|
@ -1,7 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
Loading…
Reference in New Issue