mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] remove ParallelAction, use ComputeSpec instread (#1166)
parent
177c374401
commit
f4ef224358
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.context import ParallelMode
|
||||
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
|
||||
|
@ -29,13 +29,13 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
|||
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
||||
alpha: Number) -> ColoTensor:
|
||||
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
||||
parallel_action = mat2.spec.parallel_action
|
||||
parallel_action = mat2.spec.compute_spec
|
||||
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
|
||||
mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D)
|
||||
|
||||
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
|
||||
output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
# TODO(jiaruifang) addam is special case
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
|
||||
from colossalai.context import ParallelMode
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
@ -28,7 +28,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
|||
sparse=sparse)
|
||||
output_spec = TensorSpec(
|
||||
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
return output.to_replicate()
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch.nn.functional as F
|
|||
from typing import Optional
|
||||
from torch import Tensor
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
|
@ -34,7 +34,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
|
|||
padding_idx=padding_idx)
|
||||
output_spec = TensorSpec(
|
||||
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
return output.to_replicate()
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
|
||||
|
||||
|
@ -32,7 +32,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
||||
# All-Gather(Output)
|
||||
# Input:B
|
||||
parallel_action = weight.spec.parallel_action
|
||||
parallel_action = weight.spec.compute_spec
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D)
|
||||
|
||||
|
@ -41,7 +41,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||
spec=TensorSpec(
|
||||
distspec.shard(weight.spec.get_process_group(), [-1],
|
||||
[weight.spec.get_process_group_size()]),
|
||||
ParallelAction(ComputePattern.TP1D)))
|
||||
ComputeSpec(ComputePattern.TP1D)))
|
||||
return output.to_replicate()
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Dict
|
||||
from colossalai.tensor import ColoParameter, ParallelAction, TensorSpec
|
||||
from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec
|
||||
from . import ColoModule
|
||||
import torch
|
||||
|
||||
|
@ -39,7 +39,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
|||
if not isinstance(param, ColoParameter):
|
||||
raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.')
|
||||
if param.has_spec():
|
||||
cur_compute_pattern = param.spec.parallel_action.compute_pattern
|
||||
cur_compute_pattern = param.spec.compute_spec.compute_pattern
|
||||
if compute_pattern is None:
|
||||
compute_pattern = cur_compute_pattern
|
||||
else:
|
||||
|
@ -79,11 +79,11 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
|||
check_colo_module(submodule, recursive=True)
|
||||
|
||||
|
||||
def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, recursive=True, mode='default'):
|
||||
def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recursive=True, mode='default'):
|
||||
compute_pattern = parallel_action.compute_pattern
|
||||
if is_colo_module(module):
|
||||
# for each param
|
||||
# set DistSpec and ParallelAction
|
||||
# set DistSpec and ComputeSpec
|
||||
colo_module = get_colo_module(module)
|
||||
colo_module.register(compute_pattern)
|
||||
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
from .spec import ComputePattern, ParallelAction, TensorSpec
|
||||
|
||||
from .tensor_spec import TensorSpec
|
||||
from .compute_spec import ComputeSpec, ComputePattern
|
||||
from .colo_tensor import ColoTensor
|
||||
from .colo_parameter import ColoParameter
|
||||
from .utils import convert_parameter, named_params_with_colotensor
|
||||
from . import distspec
|
||||
from .dist_spec_mgr import DistSpecManager
|
||||
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
||||
from .chunk import ChunkManager, TensorState
|
||||
from . import distspec
|
||||
|
||||
__all__ = [
|
||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ParallelAction', 'named_params_with_colotensor',
|
||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor',
|
||||
'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState'
|
||||
]
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import torch
|
||||
|
||||
from typing import Optional
|
||||
from copy import copy
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor.const import TensorType
|
||||
import torch
|
||||
from colossalai.tensor import TensorSpec, distspec
|
||||
from copy import copy
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def filter_args(func, *args):
|
||||
|
|
|
@ -66,7 +66,7 @@ class ColoTensor(torch.Tensor):
|
|||
self._tensor_spec = spec
|
||||
|
||||
def has_spec(self) -> bool:
|
||||
return self._tensor_spec.parallel_action is not None
|
||||
return self._tensor_spec.compute_spec is not None
|
||||
|
||||
def is_model_data(self) -> bool:
|
||||
return self._type == TensorType.MODEL
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class ComputePattern(Enum):
|
||||
TP1D = 0
|
||||
TP2D = 1
|
||||
TP2P5D = 2
|
||||
TP3D = 3
|
||||
|
||||
|
||||
class ComputeSpec(object):
|
||||
"""ComputeSpec
|
||||
The Specification for compuattion pattern
|
||||
Args:
|
||||
compute_pattern (ComputePattern): an Enum instance for compute pattern.
|
||||
"""
|
||||
|
||||
def __init__(self, compute_pattern: ComputePattern) -> None:
|
||||
assert isinstance(compute_pattern, ComputePattern)
|
||||
self.compute_pattern = compute_pattern
|
||||
|
||||
def __repr__(self):
|
||||
return f'compute pattern: {self.compute_pattern}'
|
|
@ -1,24 +1,7 @@
|
|||
import torch.distributed as dist
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||
|
||||
|
||||
class ComputePattern(Enum):
|
||||
TP1D = 0
|
||||
TP2D = 1
|
||||
TP2P5D = 2
|
||||
TP3D = 3
|
||||
|
||||
|
||||
class ParallelAction(object):
|
||||
|
||||
def __init__(self, compute_pattern: ComputePattern) -> None:
|
||||
assert isinstance(compute_pattern, ComputePattern)
|
||||
self.compute_pattern = compute_pattern
|
||||
|
||||
def __repr__(self):
|
||||
return f'compute pattern: {self.compute_pattern}'
|
||||
from .compute_spec import ComputeSpec, ComputePattern
|
||||
|
||||
|
||||
class TensorSpec(object):
|
||||
|
@ -26,12 +9,12 @@ class TensorSpec(object):
|
|||
The specification of the ColoTensor.
|
||||
Args:
|
||||
dist_spec (_DistSpec): descriping the layout among processes.
|
||||
parallel_action (Optional[ParallelAction], optional): actions conducted on the tensor after initialization if it's a model data tensor.
|
||||
parallel_action (Optional[ComputeSpec], optional): actions conducted on the tensor after initialization if it's a model data tensor.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, dist_spec: _DistSpec, parallel_action: Optional[ParallelAction] = None):
|
||||
self.parallel_action = parallel_action
|
||||
def __init__(self, dist_spec: _DistSpec, compute_spec: Optional[ComputeSpec] = None):
|
||||
self.compute_spec = compute_spec
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
def get_process_group(self):
|
||||
|
@ -58,7 +41,7 @@ class TensorSpec(object):
|
|||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
|
||||
def has_compute_pattern(self, compute_pattern: ComputePattern):
|
||||
return self.parallel_action.compute_pattern == compute_pattern
|
||||
return self.compute_spec.compute_pattern == compute_pattern
|
||||
|
||||
def __repr__(self):
|
||||
return f'parallel action: {self.parallel_action}, dist_spec: {self.dist_spec}'
|
||||
return f'parallel action: {self.compute_spec}, dist_spec: {self.dist_spec}'
|
|
@ -14,4 +14,4 @@ RUN git clone https://github.com/hpcaitech/ColossalAI.git \
|
|||
&& pip install -v --no-cache-dir .
|
||||
|
||||
# install titans
|
||||
RUN pip install -no-cache-dir titans
|
||||
RUN pip install --no-cache-dir titans
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
import torch.multiprocessing as mp
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
|
@ -41,7 +41,7 @@ class Conv1D(nn.Module):
|
|||
def init_1d_row(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@ -49,7 +49,7 @@ def init_1d_row(weight, bias):
|
|||
def init_1d_col(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
bias.set_spec(spec)
|
||||
|
|
|
@ -11,14 +11,14 @@ import torch.multiprocessing as mp
|
|||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_col(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
|
|
@ -11,14 +11,14 @@ import torch.multiprocessing as mp
|
|||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_row(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@ -26,7 +26,7 @@ def init_1d_row(weight):
|
|||
def init_1d_col(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
|||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec
|
||||
from colossalai.core import global_context as gpc
|
||||
from functools import partial
|
||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
|
@ -18,7 +18,7 @@ from colossalai.nn.parallel.data_parallel import ColoDDP
|
|||
def init_1d_row_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
|
@ -28,7 +28,7 @@ def init_1d_row_spec(model):
|
|||
def init_1d_col_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.tensor import ComputePattern, ParallelAction
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec
|
||||
|
||||
from functools import partial
|
||||
from colossalai.core import global_context as gpc
|
||||
|
@ -46,7 +46,7 @@ def run_hybrid_device(use_ddp, mode):
|
|||
|
||||
print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}')
|
||||
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
|
||||
parallel_action = ParallelAction(ComputePattern.TP1D)
|
||||
parallel_action = ComputeSpec(ComputePattern.TP1D)
|
||||
init_colo_module(model, parallel_action, recursive=True, mode=mode)
|
||||
|
||||
# use cpu gloo to handle embedding
|
||||
|
@ -63,6 +63,7 @@ def run_hybrid_device(use_ddp, mode):
|
|||
out.sum().backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_ddp, mode):
|
||||
if use_ddp and world_size == 1:
|
||||
return
|
||||
|
@ -71,6 +72,7 @@ def run_dist(rank, world_size, port, use_ddp, mode):
|
|||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_hybrid_device(use_ddp, mode)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@pytest.mark.parametrize('use_ddp', [False, True])
|
||||
|
@ -78,7 +80,7 @@ def run_dist(rank, world_size, port, use_ddp, mode):
|
|||
@rerun_if_address_is_in_use()
|
||||
# Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP)
|
||||
def _test_hybrid_device(world_size, use_ddp, mode):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp ,mode=mode)
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp, mode=mode)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
|
@ -12,14 +12,14 @@ import torch.nn.functional as F
|
|||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_row(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@ -27,7 +27,7 @@ def init_1d_row(weight, bias):
|
|||
def init_1d_col(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
bias.set_spec(spec)
|
||||
|
|
|
@ -10,7 +10,7 @@ from colossalai.utils.cuda import get_current_device
|
|||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
|
||||
ParallelAction, ColoTensor, DistSpecManager
|
||||
ComputeSpec, ColoTensor, DistSpecManager
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
|
@ -21,7 +21,7 @@ from _utils import tensor_equal, tensor_shard_equal, set_seed
|
|||
def init_1d_row_linear(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@ -29,7 +29,7 @@ def init_1d_row_linear(weight):
|
|||
def init_1d_col_linear(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@ -37,7 +37,7 @@ def init_1d_col_linear(weight):
|
|||
def init_1d_row_embedding(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@ -45,7 +45,7 @@ def init_1d_row_embedding(weight):
|
|||
def init_1d_col_embedding(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from functools import partial
|
|||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec
|
||||
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
|
||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
|
||||
|
@ -40,7 +40,7 @@ def run_model_with_spec(mode, model_name):
|
|||
for p1, p2 in zip(model.parameters(), model_seq.parameters()):
|
||||
p2.data.copy_(p1.data)
|
||||
|
||||
parallel_action = ParallelAction(ComputePattern.TP1D)
|
||||
parallel_action = ComputeSpec(ComputePattern.TP1D)
|
||||
# Not all layers in Bert can be mod by 4.
|
||||
# e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
|
||||
if 'bert' == model_name:
|
||||
|
@ -114,7 +114,7 @@ def run_linear_with_spec(mode):
|
|||
|
||||
model_handy = copy(model)
|
||||
|
||||
parallel_action = ParallelAction(ComputePattern.TP1D)
|
||||
parallel_action = ComputeSpec(ComputePattern.TP1D)
|
||||
init_colo_module(model, parallel_action, recursive=True, mode=mode)
|
||||
|
||||
x = torch.rand(2, 4).cuda()
|
||||
|
@ -148,7 +148,7 @@ def run_check_shared_param():
|
|||
model = BertForMaskedLM(config)
|
||||
|
||||
model = model.cuda()
|
||||
parallel_action = ParallelAction(ComputePattern.TP1D)
|
||||
parallel_action = ComputeSpec(ComputePattern.TP1D)
|
||||
# model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
|
||||
assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2
|
||||
# They are all Linear, so both row is allowed. This should pass check.
|
||||
|
@ -156,7 +156,7 @@ def run_check_shared_param():
|
|||
# This should be detected by check because you can not set weight as row while set bias as col.
|
||||
col_spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
model.cls.predictions.bias.set_spec(col_spec)
|
||||
try:
|
||||
check_colo_module(model.cls.predictions.decoder, recursive=False)
|
||||
|
|
|
@ -19,7 +19,7 @@ from colossalai.zero import ZeroOptimizer
|
|||
from colossalai.testing import parameterize
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
|
@ -47,7 +47,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
|||
def init_1d_row_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
|
@ -57,7 +57,7 @@ def init_1d_row_spec(model):
|
|||
def init_1d_col_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
|
|
Loading…
Reference in New Issue