[refactory] add nn.parallel module (#1068)

pull/1069/head
Jiarui Fang 2022-06-06 15:34:41 +08:00 committed by GitHub
parent 6754f1b77f
commit 49832b2344
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 44 additions and 46 deletions

View File

@ -5,6 +5,3 @@ from .metric import *
from .model import *
from .optimizer import *
from ._ops import *
from .modules import ColoLinear, ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module

View File

@ -1,3 +0,0 @@
from .colo_module import ColoModule
from .linear import ColoLinear
from .embedding import ColoEmbedding

View File

@ -0,0 +1,3 @@
from .data_parallel import ColoDDP, ColoDDPV2
__all__ = ['ColoDDP', 'ColoDDPV2']

View File

@ -7,8 +7,6 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.chunk import ChunkManager, TensorState
from colossalai.tensor.param_op_hook import use_param_op_hooks
__all__ = ['ColoDDP', 'ColoDDPV2']
def free_storage(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""

View File

@ -0,0 +1,15 @@
from .colo_module import ColoModule
from .linear import ColoLinear
from .embedding import ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
__all__ = [
'ColoModule',
'register_colo_module',
'is_colo_module',
'get_colo_module',
'init_colo_module',
'check_colo_module',
'ColoLinear',
'ColoEmbedding',
]

View File

@ -1,6 +1,6 @@
from typing import Dict
from colossalai.tensor import ColoParameter, ParallelAction, TensorSpec
from .modules import ColoModule
from . import ColoModule
import torch
_COLOSSAL_MODULES: Dict[type, ColoModule] = {}

View File

@ -11,8 +11,6 @@ from .memory import (report_memory_usage, colo_device_memory_used, colo_set_proc
colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity)
from .timer import MultiTimer, Timer
from .tensor_detector import TensorDetector
from .model.utils import InsertPostInitMethodToModuleSubClasses
from .model.colo_init_context import ColoInitContext
__all__ = [
'checkpoint',
@ -52,6 +50,4 @@ __all__ = [
'disposable',
'colo_set_cpu_memory_capacity',
'colo_get_cpu_memory_capacity',
'InsertPostInitMethodToModuleSubClasses',
'ColoInitContext',
]

View File

@ -2,7 +2,7 @@ from .utils import InsertPostInitMethodToModuleSubClasses
import torch
from colossalai.tensor import ColoTensor, ColoParameter
from colossalai.nn import register_colo_module, init_colo_module, \
from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding
from torch import nn

View File

@ -1,10 +1,7 @@
import torch
import functools
import inspect
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.utils.model.utils import _substitute_init_recursively, InsertPostInitMethodToModuleSubClasses, call_to_str
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses, call_to_str
from colossalai.builder.pipeline import partition_uniform, partition_balanced
from colossalai.core import global_context as gpc
from colossalai.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoTensor

View File

@ -1,9 +1,12 @@
import contextlib
import functools
from typing import Optional
from contextlib import AbstractContextManager
import torch
import torch.nn as nn
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.context.singleton_meta import SingletonMeta
@ -12,8 +15,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2
from contextlib import AbstractContextManager
from colossalai.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
class ZeroContextConfig(object):

View File

@ -2,7 +2,7 @@ import torch
import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from colossalai.nn.parallel import ColoDDPV2
from colossalai.nn.parallel.data_parallel import ColoDDPV2
from typing import Dict
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.core import global_context as gpc

View File

@ -1,10 +1,7 @@
import pytest
from colossalai.utils import ColoInitContext
from colossalai.utils.model.colo_init_context import ColoInitContext
from numpy import allclose, require
import torch
from colossalai.tensor import ColoTensor
from copy import deepcopy
from colossalai.utils.cuda import get_current_device

View File

@ -5,14 +5,14 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
from colossalai.core import global_context as gpc
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDP
from colossalai.nn.parallel.data_parallel import ColoDDP
def init_1d_row_spec(model):

View File

@ -1,13 +1,14 @@
from colossalai.utils import free_port, ColoInitContext, get_current_device
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.tensor import ComputePattern, ParallelAction
from functools import partial
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.nn import init_colo_module
from colossalai.nn.parallel import ColoDDP
from colossalai.nn.parallel.layers import init_colo_module
from colossalai.nn.parallel.data_parallel import ColoDDP
import colossalai
import torch

View File

@ -5,11 +5,11 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec, ComputePattern, \
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
ParallelAction, ColoTensor, DistSpecManager
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc

View File

@ -6,7 +6,7 @@ import torch
import torch.multiprocessing as mp
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.nn import init_colo_module, check_colo_module
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed
import colossalai

View File

@ -5,14 +5,14 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec, ColoParameter, ChunkManager
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager
from colossalai.core import global_context as gpc
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
from _utils import tensor_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDP, ColoDDPV2
from colossalai.nn.parallel import ColoDDPV2
from colossalai.testing import parameterize

View File

@ -6,11 +6,11 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager
from colossalai.core import global_context as gpc
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
from _utils import tensor_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDPV2

View File

@ -1,13 +1,8 @@
import os.path as osp
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.utils.model.pipelinable import PipelinableContext
from functools import partial
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
NUM_CHUNKS = 1