[Tensor] remove ParallelAction, use ComputeSpec instread (#1166)

pull/1168/head
Jiarui Fang 2022-06-23 17:34:59 +08:00 committed by GitHub
parent 177c374401
commit f4ef224358
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 87 additions and 77 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor
from colossalai.tensor import distspec from colossalai.tensor import distspec
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from ._utils import GeneralTensor, Number, convert_to_colo_tensor from ._utils import GeneralTensor, Number, convert_to_colo_tensor
@ -29,13 +29,13 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor: alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1] # mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
parallel_action = mat2.spec.parallel_action parallel_action = mat2.spec.compute_spec
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group())) mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D) mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D)
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha) output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]), output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
# TODO(jiaruifang) addam is special case # TODO(jiaruifang) addam is special case

View File

@ -3,7 +3,7 @@ from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input from colossalai.nn.layer.parallel_1d._utils import reduce_input
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from ._utils import GeneralTensor, convert_to_colo_tensor from ._utils import GeneralTensor, convert_to_colo_tensor
@ -28,7 +28,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse=sparse) sparse=sparse)
output_spec = TensorSpec( output_spec = TensorSpec(
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
return output.to_replicate() return output.to_replicate()

View File

@ -2,7 +2,7 @@ import torch.nn.functional as F
from typing import Optional from typing import Optional
from torch import Tensor from torch import Tensor
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
from ._utils import GeneralTensor, convert_to_colo_tensor from ._utils import GeneralTensor, convert_to_colo_tensor
@ -34,7 +34,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
padding_idx=padding_idx) padding_idx=padding_idx)
output_spec = TensorSpec( output_spec = TensorSpec(
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
return output.to_replicate() return output.to_replicate()

View File

@ -3,7 +3,7 @@ from typing import Optional
from ._utils import GeneralTensor, convert_to_colo_tensor from ._utils import GeneralTensor, convert_to_colo_tensor
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
@ -32,7 +32,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output) # All-Gather(Output)
# Input:B # Input:B
parallel_action = weight.spec.parallel_action parallel_action = weight.spec.compute_spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D) input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D)
@ -41,7 +41,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
spec=TensorSpec( spec=TensorSpec(
distspec.shard(weight.spec.get_process_group(), [-1], distspec.shard(weight.spec.get_process_group(), [-1],
[weight.spec.get_process_group_size()]), [weight.spec.get_process_group_size()]),
ParallelAction(ComputePattern.TP1D))) ComputeSpec(ComputePattern.TP1D)))
return output.to_replicate() return output.to_replicate()

View File

@ -1,5 +1,5 @@
from typing import Dict from typing import Dict
from colossalai.tensor import ColoParameter, ParallelAction, TensorSpec from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec
from . import ColoModule from . import ColoModule
import torch import torch
@ -39,7 +39,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
if not isinstance(param, ColoParameter): if not isinstance(param, ColoParameter):
raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.') raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.')
if param.has_spec(): if param.has_spec():
cur_compute_pattern = param.spec.parallel_action.compute_pattern cur_compute_pattern = param.spec.compute_spec.compute_pattern
if compute_pattern is None: if compute_pattern is None:
compute_pattern = cur_compute_pattern compute_pattern = cur_compute_pattern
else: else:
@ -79,11 +79,11 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
check_colo_module(submodule, recursive=True) check_colo_module(submodule, recursive=True)
def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, recursive=True, mode='default'): def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recursive=True, mode='default'):
compute_pattern = parallel_action.compute_pattern compute_pattern = parallel_action.compute_pattern
if is_colo_module(module): if is_colo_module(module):
# for each param # for each param
# set DistSpec and ParallelAction # set DistSpec and ComputeSpec
colo_module = get_colo_module(module) colo_module = get_colo_module(module)
colo_module.register(compute_pattern) colo_module.register(compute_pattern)
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode): if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):

View File

@ -1,14 +1,14 @@
from .spec import ComputePattern, ParallelAction, TensorSpec from .tensor_spec import TensorSpec
from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor from .colo_tensor import ColoTensor
from .colo_parameter import ColoParameter from .colo_parameter import ColoParameter
from .utils import convert_parameter, named_params_with_colotensor from .utils import convert_parameter, named_params_with_colotensor
from . import distspec
from .dist_spec_mgr import DistSpecManager from .dist_spec_mgr import DistSpecManager
from .param_op_hook import ParamOpHook, ParamOpHookManager from .param_op_hook import ParamOpHook, ParamOpHookManager
from .chunk import ChunkManager, TensorState from .chunk import ChunkManager, TensorState
from . import distspec
__all__ = [ __all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ParallelAction', 'named_params_with_colotensor', 'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor',
'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState' 'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState'
] ]

View File

@ -1,10 +1,12 @@
import torch
from typing import Optional
from copy import copy
from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType from colossalai.tensor.const import TensorType
import torch
from colossalai.tensor import TensorSpec, distspec from colossalai.tensor import TensorSpec, distspec
from copy import copy
from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.tensor.param_op_hook import ParamOpHookManager
from typing import Optional
def filter_args(func, *args): def filter_args(func, *args):

View File

@ -66,7 +66,7 @@ class ColoTensor(torch.Tensor):
self._tensor_spec = spec self._tensor_spec = spec
def has_spec(self) -> bool: def has_spec(self) -> bool:
return self._tensor_spec.parallel_action is not None return self._tensor_spec.compute_spec is not None
def is_model_data(self) -> bool: def is_model_data(self) -> bool:
return self._type == TensorType.MODEL return self._type == TensorType.MODEL

View File

@ -0,0 +1,23 @@
from enum import Enum
class ComputePattern(Enum):
TP1D = 0
TP2D = 1
TP2P5D = 2
TP3D = 3
class ComputeSpec(object):
"""ComputeSpec
The Specification for compuattion pattern
Args:
compute_pattern (ComputePattern): an Enum instance for compute pattern.
"""
def __init__(self, compute_pattern: ComputePattern) -> None:
assert isinstance(compute_pattern, ComputePattern)
self.compute_pattern = compute_pattern
def __repr__(self):
return f'compute pattern: {self.compute_pattern}'

View File

@ -1,24 +1,7 @@
import torch.distributed as dist import torch.distributed as dist
from enum import Enum from typing import Optional
from typing import List, Optional
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from .compute_spec import ComputeSpec, ComputePattern
class ComputePattern(Enum):
TP1D = 0
TP2D = 1
TP2P5D = 2
TP3D = 3
class ParallelAction(object):
def __init__(self, compute_pattern: ComputePattern) -> None:
assert isinstance(compute_pattern, ComputePattern)
self.compute_pattern = compute_pattern
def __repr__(self):
return f'compute pattern: {self.compute_pattern}'
class TensorSpec(object): class TensorSpec(object):
@ -26,12 +9,12 @@ class TensorSpec(object):
The specification of the ColoTensor. The specification of the ColoTensor.
Args: Args:
dist_spec (_DistSpec): descriping the layout among processes. dist_spec (_DistSpec): descriping the layout among processes.
parallel_action (Optional[ParallelAction], optional): actions conducted on the tensor after initialization if it's a model data tensor. parallel_action (Optional[ComputeSpec], optional): actions conducted on the tensor after initialization if it's a model data tensor.
Defaults to None. Defaults to None.
""" """
def __init__(self, dist_spec: _DistSpec, parallel_action: Optional[ParallelAction] = None): def __init__(self, dist_spec: _DistSpec, compute_spec: Optional[ComputeSpec] = None):
self.parallel_action = parallel_action self.compute_spec = compute_spec
self.dist_spec = dist_spec self.dist_spec = dist_spec
def get_process_group(self): def get_process_group(self):
@ -58,7 +41,7 @@ class TensorSpec(object):
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0 and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def has_compute_pattern(self, compute_pattern: ComputePattern): def has_compute_pattern(self, compute_pattern: ComputePattern):
return self.parallel_action.compute_pattern == compute_pattern return self.compute_spec.compute_pattern == compute_pattern
def __repr__(self): def __repr__(self):
return f'parallel action: {self.parallel_action}, dist_spec: {self.dist_spec}' return f'parallel action: {self.compute_spec}, dist_spec: {self.dist_spec}'

View File

@ -14,4 +14,4 @@ RUN git clone https://github.com/hpcaitech/ColossalAI.git \
&& pip install -v --no-cache-dir . && pip install -v --no-cache-dir .
# install titans # install titans
RUN pip install -no-cache-dir titans RUN pip install --no-cache-dir titans

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor
from colossalai.tensor import distspec from colossalai.tensor import distspec
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
@ -41,7 +41,7 @@ class Conv1D(nn.Module):
def init_1d_row(weight, bias): def init_1d_row(weight, bias):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
@ -49,7 +49,7 @@ def init_1d_row(weight, bias):
def init_1d_col(weight, bias): def init_1d_col(weight, bias):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
bias.set_spec(spec) bias.set_spec(spec)

View File

@ -11,14 +11,14 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
from _utils import tensor_equal, tensor_shard_equal from _utils import tensor_equal, tensor_shard_equal
def init_1d_col(weight): def init_1d_col(weight):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)

View File

@ -11,14 +11,14 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
from _utils import tensor_equal, tensor_shard_equal from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight): def init_1d_row(weight):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
@ -26,7 +26,7 @@ def init_1d_row(weight):
def init_1d_col(weight): def init_1d_col(weight):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)

View File

@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from functools import partial from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed from _utils import tensor_equal, tensor_shard_equal, set_seed
@ -18,7 +18,7 @@ from colossalai.nn.parallel.data_parallel import ColoDDP
def init_1d_row_spec(model): def init_1d_row_spec(model):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n: if 'weight' in n and 'ln' not in n:
@ -28,7 +28,7 @@ def init_1d_row_spec(model):
def init_1d_col_spec(model): def init_1d_col_spec(model):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n): if 'ln' not in n and ('weight' in n or 'bias' in n):

View File

@ -1,7 +1,7 @@
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ComputePattern, ParallelAction from colossalai.tensor import ComputePattern, ComputeSpec
from functools import partial from functools import partial
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
@ -46,7 +46,7 @@ def run_hybrid_device(use_ddp, mode):
print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}') print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}')
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}') #print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
parallel_action = ParallelAction(ComputePattern.TP1D) parallel_action = ComputeSpec(ComputePattern.TP1D)
init_colo_module(model, parallel_action, recursive=True, mode=mode) init_colo_module(model, parallel_action, recursive=True, mode=mode)
# use cpu gloo to handle embedding # use cpu gloo to handle embedding
@ -63,6 +63,7 @@ def run_hybrid_device(use_ddp, mode):
out.sum().backward() out.sum().backward()
optimizer.step() optimizer.step()
def run_dist(rank, world_size, port, use_ddp, mode): def run_dist(rank, world_size, port, use_ddp, mode):
if use_ddp and world_size == 1: if use_ddp and world_size == 1:
return return
@ -71,6 +72,7 @@ def run_dist(rank, world_size, port, use_ddp, mode):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_hybrid_device(use_ddp, mode) run_hybrid_device(use_ddp, mode)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('use_ddp', [False, True]) @pytest.mark.parametrize('use_ddp', [False, True])
@ -78,7 +80,7 @@ def run_dist(rank, world_size, port, use_ddp, mode):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
# Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP) # Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP)
def _test_hybrid_device(world_size, use_ddp, mode): def _test_hybrid_device(world_size, use_ddp, mode):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp ,mode=mode) run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp, mode=mode)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -12,14 +12,14 @@ import torch.nn.functional as F
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
from _utils import tensor_equal, tensor_shard_equal from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, bias): def init_1d_row(weight, bias):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
@ -27,7 +27,7 @@ def init_1d_row(weight, bias):
def init_1d_col(weight, bias): def init_1d_col(weight, bias):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
bias.set_spec(spec) bias.set_spec(spec)

View File

@ -10,7 +10,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \ from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
ParallelAction, ColoTensor, DistSpecManager ComputeSpec, ColoTensor, DistSpecManager
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColoOptimizer from colossalai.nn.optimizer import ColoOptimizer
@ -21,7 +21,7 @@ from _utils import tensor_equal, tensor_shard_equal, set_seed
def init_1d_row_linear(weight): def init_1d_row_linear(weight):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
@ -29,7 +29,7 @@ def init_1d_row_linear(weight):
def init_1d_col_linear(weight): def init_1d_col_linear(weight):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
@ -37,7 +37,7 @@ def init_1d_col_linear(weight):
def init_1d_row_embedding(weight): def init_1d_row_embedding(weight):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
@ -45,7 +45,7 @@ def init_1d_row_embedding(weight):
def init_1d_col_embedding(weight): def init_1d_col_embedding(weight):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)

View File

@ -5,7 +5,7 @@ from functools import partial
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed from _utils import tensor_equal, tensor_shard_equal, set_seed
@ -40,7 +40,7 @@ def run_model_with_spec(mode, model_name):
for p1, p2 in zip(model.parameters(), model_seq.parameters()): for p1, p2 in zip(model.parameters(), model_seq.parameters()):
p2.data.copy_(p1.data) p2.data.copy_(p1.data)
parallel_action = ParallelAction(ComputePattern.TP1D) parallel_action = ComputeSpec(ComputePattern.TP1D)
# Not all layers in Bert can be mod by 4. # Not all layers in Bert can be mod by 4.
# e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2. # e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
if 'bert' == model_name: if 'bert' == model_name:
@ -114,7 +114,7 @@ def run_linear_with_spec(mode):
model_handy = copy(model) model_handy = copy(model)
parallel_action = ParallelAction(ComputePattern.TP1D) parallel_action = ComputeSpec(ComputePattern.TP1D)
init_colo_module(model, parallel_action, recursive=True, mode=mode) init_colo_module(model, parallel_action, recursive=True, mode=mode)
x = torch.rand(2, 4).cuda() x = torch.rand(2, 4).cuda()
@ -148,7 +148,7 @@ def run_check_shared_param():
model = BertForMaskedLM(config) model = BertForMaskedLM(config)
model = model.cuda() model = model.cuda()
parallel_action = ParallelAction(ComputePattern.TP1D) parallel_action = ComputeSpec(ComputePattern.TP1D)
# model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec # model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2 assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2
# They are all Linear, so both row is allowed. This should pass check. # They are all Linear, so both row is allowed. This should pass check.
@ -156,7 +156,7 @@ def run_check_shared_param():
# This should be detected by check because you can not set weight as row while set bias as col. # This should be detected by check because you can not set weight as row while set bias as col.
col_spec = TensorSpec( col_spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
model.cls.predictions.bias.set_spec(col_spec) model.cls.predictions.bias.set_spec(col_spec)
try: try:
check_colo_module(model.cls.predictions.decoder, recursive=False) check_colo_module(model.cls.predictions.decoder, recursive=False)

View File

@ -19,7 +19,7 @@ from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec
def check_param_equal(model, torch_model): def check_param_equal(model, torch_model):
@ -47,7 +47,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
def init_1d_row_spec(model): def init_1d_row_spec(model):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n: if 'weight' in n and 'ln' not in n:
@ -57,7 +57,7 @@ def init_1d_row_spec(model):
def init_1d_col_spec(model): def init_1d_col_spec(model):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n): if 'ln' not in n and ('weight' in n or 'bias' in n):