[refactor] refactor ColoTensor's unit tests (#1340)

pull/1342/head
HELSON 2022-07-19 15:46:24 +08:00 committed by GitHub
parent f92c100ddd
commit bf5066fba7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 11 additions and 13 deletions

View File

@ -1 +0,0 @@
from ._util import *

View File

@ -0,0 +1 @@
from ._utils import *

View File

@ -1,7 +1,7 @@
import pytest
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, set_seed
import torch
from torch.nn.parallel import DistributedDataParallel as DDP

View File

@ -1,7 +1,5 @@
import pytest
from functools import partial
from _utils import tensor_shard_equal, set_seed
import torch
import torch.multiprocessing as mp
@ -15,7 +13,8 @@ from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.nn.optimizer import ColossalaiOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs
from _utils import split_param_row_tp1d, split_param_col_tp1d
from tests.test_tensor.common_utils import tensor_shard_equal, check_equal, set_seed, \
split_param_row_tp1d, split_param_col_tp1d
def run_1d_hybrid_tp(model_name):
@ -264,7 +263,6 @@ def run_1d_row_tp(model_name: str):
def _run_pretrain_load():
from _utils import check_equal
from transformers import BertForMaskedLM
set_seed(1)
model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')

View File

@ -7,7 +7,7 @@ import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ShardSpec, ColoTensorSpec
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, set_seed
import colossalai
from colossalai.utils.cuda import get_current_device

View File

@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensorSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from functools import partial
from _utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d
from tests.test_tensor.common_utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d
class Conv1D(nn.Module):

View File

@ -8,7 +8,7 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d
def run_with_spec(spec_init_func):

View File

@ -8,7 +8,7 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
def run_with_spec(spec_init_func, pg: ProcessGroup):

View File

@ -8,7 +8,7 @@ import torch.nn.functional as F
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
def run_with_spec(spec_init_func, split_bias):

View File

@ -1,7 +1,7 @@
from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup
import torch
import pytest
from _utils import tensor_equal
from common_utils import tensor_equal
import colossalai
from colossalai.utils import free_port

View File

@ -8,7 +8,7 @@ from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini import ChunkManager
from functools import partial
from _utils import tensor_equal, set_seed, tensor_shard_equal
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ZeroDDP