[refactory] add nn.parallel module (#1068)

pull/1069/head
Jiarui Fang 2022-06-06 15:34:41 +08:00 committed by GitHub
parent 6754f1b77f
commit 49832b2344
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 44 additions and 46 deletions

View File

@ -5,6 +5,3 @@ from .metric import *
from .model import * from .model import *
from .optimizer import * from .optimizer import *
from ._ops import * from ._ops import *
from .modules import ColoLinear, ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module

View File

@ -1,3 +0,0 @@
from .colo_module import ColoModule
from .linear import ColoLinear
from .embedding import ColoEmbedding

View File

@ -0,0 +1,3 @@
from .data_parallel import ColoDDP, ColoDDPV2
__all__ = ['ColoDDP', 'ColoDDPV2']

View File

@ -7,8 +7,6 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.chunk import ChunkManager, TensorState from colossalai.tensor.chunk import ChunkManager, TensorState
from colossalai.tensor.param_op_hook import use_param_op_hooks from colossalai.tensor.param_op_hook import use_param_op_hooks
__all__ = ['ColoDDP', 'ColoDDPV2']
def free_storage(data: torch.Tensor) -> None: def free_storage(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor.""" """Free underlying storage of a Tensor."""

View File

@ -0,0 +1,15 @@
from .colo_module import ColoModule
from .linear import ColoLinear
from .embedding import ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
__all__ = [
'ColoModule',
'register_colo_module',
'is_colo_module',
'get_colo_module',
'init_colo_module',
'check_colo_module',
'ColoLinear',
'ColoEmbedding',
]

View File

@ -1,6 +1,6 @@
from typing import Dict from typing import Dict
from colossalai.tensor import ColoParameter, ParallelAction, TensorSpec from colossalai.tensor import ColoParameter, ParallelAction, TensorSpec
from .modules import ColoModule from . import ColoModule
import torch import torch
_COLOSSAL_MODULES: Dict[type, ColoModule] = {} _COLOSSAL_MODULES: Dict[type, ColoModule] = {}

View File

@ -11,8 +11,6 @@ from .memory import (report_memory_usage, colo_device_memory_used, colo_set_proc
colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity) colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity)
from .timer import MultiTimer, Timer from .timer import MultiTimer, Timer
from .tensor_detector import TensorDetector from .tensor_detector import TensorDetector
from .model.utils import InsertPostInitMethodToModuleSubClasses
from .model.colo_init_context import ColoInitContext
__all__ = [ __all__ = [
'checkpoint', 'checkpoint',
@ -52,6 +50,4 @@ __all__ = [
'disposable', 'disposable',
'colo_set_cpu_memory_capacity', 'colo_set_cpu_memory_capacity',
'colo_get_cpu_memory_capacity', 'colo_get_cpu_memory_capacity',
'InsertPostInitMethodToModuleSubClasses',
'ColoInitContext',
] ]

View File

@ -2,7 +2,7 @@ from .utils import InsertPostInitMethodToModuleSubClasses
import torch import torch
from colossalai.tensor import ColoTensor, ColoParameter from colossalai.tensor import ColoTensor, ColoParameter
from colossalai.nn import register_colo_module, init_colo_module, \ from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding ColoLinear, ColoEmbedding
from torch import nn from torch import nn

View File

@ -1,10 +1,7 @@
import torch import torch
import functools
import inspect import inspect
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses, call_to_str
from colossalai.utils.model.utils import _substitute_init_recursively, InsertPostInitMethodToModuleSubClasses, call_to_str
from colossalai.builder.pipeline import partition_uniform, partition_balanced from colossalai.builder.pipeline import partition_uniform, partition_balanced
from colossalai.core import global_context as gpc
from colossalai.nn.layer.utils import CheckpointModule from colossalai.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor

View File

@ -1,9 +1,12 @@
import contextlib import contextlib
import functools import functools
from typing import Optional from typing import Optional
from contextlib import AbstractContextManager
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
@ -12,8 +15,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param import ShardedParamV2
from contextlib import AbstractContextManager from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.utils import InsertPostInitMethodToModuleSubClasses
class ZeroContextConfig(object): class ZeroContextConfig(object):

View File

@ -2,7 +2,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from enum import Enum from enum import Enum
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.nn.parallel import ColoDDPV2 from colossalai.nn.parallel.data_parallel import ColoDDPV2
from typing import Dict from typing import Dict
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc

View File

@ -1,10 +1,7 @@
import pytest import pytest
from colossalai.utils import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from numpy import allclose, require
import torch import torch
from colossalai.tensor import ColoTensor
from copy import deepcopy
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device

View File

@ -5,14 +5,14 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from functools import partial from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed from _utils import tensor_equal, tensor_shard_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
def init_1d_row_spec(model): def init_1d_row_spec(model):

View File

@ -1,13 +1,14 @@
from colossalai.utils import free_port, ColoInitContext, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction from colossalai.tensor import ComputePattern, ParallelAction
from functools import partial from functools import partial
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.nn import init_colo_module from colossalai.nn.parallel.layers import init_colo_module
from colossalai.nn.parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
import colossalai import colossalai
import torch import torch

View File

@ -5,11 +5,11 @@ import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec, ComputePattern, \ from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
ParallelAction, ColoTensor, DistSpecManager ParallelAction, ColoTensor, DistSpecManager
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc

View File

@ -6,7 +6,7 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.nn import init_colo_module, check_colo_module from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed from _utils import tensor_equal, tensor_shard_equal, set_seed
import colossalai import colossalai

View File

@ -5,14 +5,14 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec, ColoParameter, ChunkManager from colossalai.tensor import ChunkManager
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from functools import partial from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed from _utils import tensor_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDP, ColoDDPV2 from colossalai.nn.parallel import ColoDDPV2
from colossalai.testing import parameterize from colossalai.testing import parameterize

View File

@ -6,11 +6,11 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager from colossalai.tensor import ChunkManager
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from functools import partial from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed from _utils import tensor_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDPV2 from colossalai.nn.parallel import ColoDDPV2

View File

@ -1,13 +1,8 @@
import os.path as osp
import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils.model.pipelinable import PipelinableContext from colossalai.utils.model.pipelinable import PipelinableContext
from functools import partial
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_on_exception
NUM_CHUNKS = 1 NUM_CHUNKS = 1