mirror of https://github.com/hpcaitech/ColossalAI
[refactor] refactor ColoTensor's unit tests (#1340)
parent
f92c100ddd
commit
bf5066fba7
|
@ -1 +0,0 @@
|
|||
from ._util import *
|
|
@ -0,0 +1 @@
|
|||
from ._utils import *
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
|
||||
from functools import partial
|
||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
@ -1,7 +1,5 @@
|
|||
import pytest
|
||||
from functools import partial
|
||||
from _utils import tensor_shard_equal, set_seed
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
@ -15,7 +13,8 @@ from colossalai.tensor import ColoTensor, ProcessGroup
|
|||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from _utils import split_param_row_tp1d, split_param_col_tp1d
|
||||
from tests.test_tensor.common_utils import tensor_shard_equal, check_equal, set_seed, \
|
||||
split_param_row_tp1d, split_param_col_tp1d
|
||||
|
||||
|
||||
def run_1d_hybrid_tp(model_name):
|
||||
|
@ -264,7 +263,6 @@ def run_1d_row_tp(model_name: str):
|
|||
|
||||
|
||||
def _run_pretrain_load():
|
||||
from _utils import check_equal
|
||||
from transformers import BertForMaskedLM
|
||||
set_seed(1)
|
||||
model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
|
@ -7,7 +7,7 @@ import torch.multiprocessing as mp
|
|||
|
||||
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ShardSpec, ColoTensorSpec
|
||||
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
|
||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils.cuda import get_current_device
|
|
@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensorSpec
|
|||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from functools import partial
|
||||
from _utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d
|
||||
from tests.test_tensor.common_utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
|
@ -8,7 +8,7 @@ import torch.multiprocessing as mp
|
|||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup
|
||||
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d
|
||||
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
|
@ -8,7 +8,7 @@ import torch.multiprocessing as mp
|
|||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
|
||||
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
|
||||
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func, pg: ProcessGroup):
|
|
@ -8,7 +8,7 @@ import torch.nn.functional as F
|
|||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
|
||||
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
|
||||
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func, split_bias):
|
|
@ -1,7 +1,7 @@
|
|||
from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
import torch
|
||||
import pytest
|
||||
from _utils import tensor_equal
|
||||
from common_utils import tensor_equal
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.utils import free_port
|
|||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.gemini import ChunkManager
|
||||
from functools import partial
|
||||
from _utils import tensor_equal, set_seed, tensor_shard_equal
|
||||
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
|
|
Loading…
Reference in New Issue