[tensor] a shorter shard and replicate spec (#1245)

pull/1241/head
Jiarui Fang 2022-07-11 15:51:48 +08:00 committed by GitHub
parent 2699dfbbfd
commit 9bcd2fd4af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 91 additions and 98 deletions

View File

@ -1,6 +1,5 @@
import torch
from torch.fx.node import map_arg
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern, ShardSpec
def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter:
@ -25,7 +24,7 @@ def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
spec = ColoTensorSpec(pg, distspec.shard([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = ColoTensorSpec(pg, ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
setattr(weight, "fx_attr", spec)
return weight

View File

@ -1,7 +1,7 @@
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor
from colossalai.tensor import distspec, ColoTensorSpec
from colossalai.tensor import distspec, ColoTensorSpec, ShardSpec, ReplicaSpec
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
from ._utils import reduce_input, reduce_grad
@ -11,7 +11,8 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
mat1 = mat1.redistribute(distspec.shard([-1], [mat2.get_tp_world_size()]))
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
# Output:P
partial_output = torch.mm(mat1, mat2)
@ -20,7 +21,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# input
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor + alpha * output
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(distspec.replicate()))
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(ReplicaSpec()))
return output
@ -28,11 +29,11 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
compute_spec = mat2.compute_spec
mat1 = mat1.redistribute(distspec.replicate())
mat1 = mat1.redistribute(ReplicaSpec())
mat1 = reduce_grad(mat1, mat1.get_process_group())
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
output_spec = ColoTensorSpec(input_tensor.get_process_group(), distspec.shard([-1], [mat2.get_tp_world_size()]),
output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)

View File

@ -1,7 +1,7 @@
import torch.nn.functional as F
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
@ -14,7 +14,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
output_parallel = F.embedding(input_tensor,
weight,
@ -23,7 +24,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
output_spec = ColoTensorSpec(weight.get_process_group(), distspec.shard([-1], [weight.get_tp_world_size()]),
output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
@ -46,7 +47,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Find index in this shard and mask those not here
# Reduce all
pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = weight.get_process_group().tp_local_rank()
@ -74,7 +76,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, weight.get_process_group())
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), distspec.replicate()))
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec()))
return output

View File

@ -2,7 +2,7 @@ import torch.nn.functional as F
from typing import Optional
from torch import Tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, ShardSpec, ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor
@ -20,7 +20,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
output_parallel = F.embedding_bag(input_tensor,
weight,
@ -33,8 +33,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx)
output_spec = ColoTensorSpec(pg, distspec.shard([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if weight.compute_spec.output_replicate:

View File

@ -1,7 +1,7 @@
from typing import List, Optional
import torch.nn.functional as F
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec
from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec, ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor
@ -16,7 +16,7 @@ def colo_layernorm(
assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
bias = convert_to_colo_tensor(bias, weight.get_process_group())
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))

View File

@ -3,8 +3,7 @@ from typing import Optional
from ._utils import GeneralTensor, convert_to_colo_tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
@ -12,7 +11,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Reduce(Output) + bias = res
# Input:S[1]
pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(distspec.shard([-1], [weight.get_tp_world_size()]))
input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]))
# Output:P
partial_output = F.linear(input_tensor, weight)
@ -24,7 +23,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
output = output + bias
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, distspec.replicate()))
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec()))
return output
@ -33,13 +32,15 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Gather(Output)
# Input:B
compute_spec = weight.compute_spec
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
output_parallel = F.linear(input_parallel, weight, bias)
output = ColoTensor.from_torch_tensor(output_parallel,
spec=ColoTensorSpec(weight.get_process_group(),
distspec.shard([-1], [weight.get_tp_world_size()]),
ShardSpec([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D)))
if compute_spec.output_replicate:
return output.to_replicate()

View File

@ -7,16 +7,6 @@ class ColoModule(object):
def __init__(self):
self._shard_params: List[str] = []
# Example:
# {ComputePattern.TP1D:
# 'default':
# 'weight':
# distspec.shard(xxxxx)
# 'bias':
# distspec.shard(xxxxx)
# 'row': ...
# 'col': ...
# }
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
def _register_shard_params(self, params: List[str]):

View File

@ -1,7 +1,5 @@
from .colo_module import ColoModule
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
class ColoEmbedding(ColoModule):
@ -21,7 +19,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard([0], [pg.tp_world_size()]),
'weight': ShardSpec([0], [pg.tp_world_size()]),
},
mode='row',
)
@ -30,7 +28,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard([-1], [pg.tp_world_size()]),
'weight': ShardSpec([-1], [pg.tp_world_size()]),
},
mode='col',
)

View File

@ -1,5 +1,5 @@
from .colo_module import ColoModule
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
class ColoLinear(ColoModule):
@ -19,7 +19,7 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard([-1], [pg.tp_world_size()]),
'weight': ShardSpec([-1], [pg.tp_world_size()]),
'bias': None
},
mode='row',
@ -29,8 +29,8 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard([0], [pg.tp_world_size()]),
'bias': distspec.shard([0], [pg.tp_world_size()])
'weight': ShardSpec([0], [pg.tp_world_size()]),
'bias': ShardSpec([0], [pg.tp_world_size()])
},
mode='col',
)

View File

@ -1,5 +1,8 @@
from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec
from .distspec import shard as ShardSpec
from .distspec import replicate as ReplicaSpec
from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor
from .colo_parameter import ColoParameter
@ -11,5 +14,5 @@ from . import distspec
__all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup',
'ColoTensorSpec', 'TensorSpec'
'ColoTensorSpec', 'TensorSpec', 'ShardSpec', 'ReplicaSpec'
]

View File

@ -5,7 +5,7 @@ import torch
from functools import lru_cache
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor import distspec, ProcessGroup
from colossalai.tensor import ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from typing import Optional, Set, Callable
@ -51,21 +51,21 @@ class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(distspec.replicate()).
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways.
1. directly init.
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size),
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
2. use static method from_torch_tensor
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
"""
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
@ -85,7 +85,7 @@ class ColoTensor(torch.Tensor):
# If not set spec, use a DP process group and replicate dist spec
if spec is None:
self.has_initialized = False
self.dist_spec = distspec.replicate()
self.dist_spec = ReplicaSpec()
self.compute_spec = None
self.process_group = ProcessGroup()
else:
@ -194,13 +194,14 @@ class ColoTensor(torch.Tensor):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self._redistribute(dist_spec=distspec.replicate())
self._redistribute(dist_spec=ReplicaSpec())
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
converting dist spec of the tensor to REPLICATE
"""
return self.redistribute(distspec.replicate())
return self.redistribute(ReplicaSpec())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
@ -234,7 +235,7 @@ class ColoTensor(torch.Tensor):
"""
if self.is_replicate():
return super().view(*args)
replicated_t = self.redistribute(dist_spec=distspec.replicate())
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return replicated_t.view(*args)
def size_global(self, args: Optional[int] = None):

View File

@ -1,6 +1,6 @@
from .utils import InsertPostInitMethodToModuleSubClasses
import torch
from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup
from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup, ReplicaSpec
from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding

View File

@ -4,7 +4,7 @@ import pytest
import torch.nn as nn
import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor import distspec
from colossalai.tensor import ShardSpec
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
@ -37,13 +37,13 @@ class Conv1D(nn.Module):
def init_1d_row(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
def init_1d_col(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
bias.set_tensor_spec(*spec)

View File

@ -4,10 +4,9 @@ import torch.distributed as dist
import pytest
import colossalai
import torch.multiprocessing as mp
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import DistSpecManager, distspec, ProcessGroup
from colossalai.tensor import DistSpecManager, ProcessGroup, ShardSpec, ReplicaSpec
from functools import partial
@ -18,10 +17,10 @@ def run():
depth = int(math.sqrt(size))
assert depth == math.sqrt(size)
x = torch.rand(8, 8).cuda()
old_dist_spec = distspec.replicate()
row_spec = distspec.shard([0], [size])
col_spec = distspec.shard([-1], [size])
mat_spec = distspec.shard([0, 1], [depth, depth])
old_dist_spec = ReplicaSpec()
row_spec = ShardSpec([0], [size])
col_spec = ShardSpec([-1], [size])
mat_spec = ShardSpec([0, 1], [depth, depth])
row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec, group)
assert torch.equal(x.chunk(size, 0)[rank], row_shard)
assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec, group))
@ -40,8 +39,8 @@ def check_mem():
x = torch.rand(32, 32).cuda()
orig_mem = x.numel() * x.element_size()
assert torch.cuda.memory_allocated() == orig_mem
old_dist_spec = distspec.replicate()
row_spec = distspec.shard([0], [size])
old_dist_spec = ReplicaSpec()
row_spec = ShardSpec([0], [size])
x.data = DistSpecManager._shard_as(x, old_dist_spec, row_spec, pg)
assert x.size(0) == 32 // size and x.size(1) == 32
assert torch.cuda.memory_allocated() == orig_mem // size

View File

@ -1,5 +1,5 @@
import torch
from colossalai.tensor import distspec, ColoParameter
from colossalai.tensor import ShardSpec, ColoParameter
from torch.nn import functional as F
from functools import partial
@ -14,7 +14,7 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_col(weight, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)

View File

@ -1,5 +1,5 @@
import torch
from colossalai.tensor import ColoTensor, distspec
from colossalai.tensor import ColoTensor, ShardSpec
from torch.nn import functional as F
from functools import partial
@ -14,13 +14,13 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
def init_1d_col(weight, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)

View File

@ -12,7 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
@ -20,7 +20,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_spec(model, pg: ProcessGroup):
tensor_spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
@ -28,7 +28,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
def init_1d_col_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n):

View File

@ -1,5 +1,5 @@
import torch
from colossalai.tensor import ColoTensor, distspec
from colossalai.tensor import ColoTensor, ShardSpec
from functools import partial
@ -15,13 +15,13 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
def init_1d_col(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
bias.set_tensor_spec(*spec)

View File

@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
from colossalai.utils import get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import distspec, ComputeSpec, ComputePattern
from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern
def check_cross_entropy():
@ -22,7 +22,7 @@ def check_cross_entropy():
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
input_shard = input_t_colo.redistribute(distspec.shard([-1], [pg.tp_world_size()]))
input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
output = F.cross_entropy(input_t, target)

View File

@ -11,7 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, ColoTensorSpec, ComputePattern, \
from colossalai.tensor import ShardSpec, ColoTensorSpec, ComputePattern, \
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
from colossalai.nn.optimizer import ColoOptimizer
@ -19,28 +19,28 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)
def init_1d_col_linear(weight, pg):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)
def init_1d_row_embedding(weight, pg):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)
def init_1d_col_embedding(weight, pg):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)

View File

@ -5,7 +5,7 @@ from functools import partial
import torch
import torch.multiprocessing as mp
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, ShardSpec, ReplicaSpec
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed
@ -13,7 +13,7 @@ import colossalai
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, ProcessGroup
from colossalai.tensor import distspec, ProcessGroup, ReplicaSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
@ -159,7 +159,7 @@ def run_check_shared_param():
# They are all Linear, so both row is allowed. This should pass check.
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
# This should be detected by check because you can not set weight as row while set bias as col.
col_spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
col_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
# TODO(jiaruifang) optimize this line
if not model.cls.predictions.bias.has_initialized:

View File

@ -4,7 +4,7 @@ import colossalai
import torch.nn.functional as F
import torch.multiprocessing as mp
from functools import partial
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec, ShardSpec
from colossalai.utils import get_current_device
from torch.nn import Parameter
from colossalai.testing import rerun_if_address_is_in_use
@ -47,7 +47,7 @@ def check_element_wise_ops():
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
t = torch.rand(2, 2)
x = ColoTensor(t, spec=ColoTensorSpec(pg, distspec.shard([0], [pg.tp_world_size()])))
x = ColoTensor(t, spec=ColoTensorSpec(pg, ShardSpec([0], [pg.tp_world_size()])))
check_spec_eq(x, x.cuda())
assert torch.equal(x.cuda(), t.cuda())

View File

@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import distspec, ColoTensor, ProcessGroup
from colossalai.tensor import distspec, ColoTensor, ProcessGroup, ShardSpec, ReplicaSpec
from functools import partial
@ -55,7 +55,7 @@ def _run_operand(world_size):
pg = ProcessGroup(tp_degree=world_size)
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
t.set_dist_spec(distspec.shard([0], [world_size]))
t.set_dist_spec(ShardSpec([0], [world_size]))
t_new = torch.zeros_like(t)
assert isinstance(t_new, ColoTensor)
assert t_new.is_sharded()
@ -69,7 +69,7 @@ def _run_view(world_size):
rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
t = ColoTensor.from_torch_tensor(
t_ref, ColoTensorSpec(pg, dist_attr=distspec.shard(dims=[0], num_partitions=[pg.tp_world_size()])))
t_ref, ColoTensorSpec(pg, dist_attr=ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])))
assert t.size_global()[0] == 4 * world_size
assert t.size_global(1) == 5
@ -82,7 +82,7 @@ def _run_view(world_size):
def _run_tensor_shard_init(world_size):
t_ref = torch.randn(4, 5)
pg = ProcessGroup(tp_degree=world_size)
shard_attr = distspec.shard(dims=[0], num_partitions=[pg.tp_world_size()])
shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_dist_spec(distspec.replicate())

View File

@ -17,7 +17,7 @@ from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
def check_param_equal(model, torch_model, pg: ProcessGroup):
@ -45,7 +45,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
def init_1d_row_spec(model, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
@ -53,7 +53,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
def init_1d_col_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n):

View File

@ -16,7 +16,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ShardSpec, ProcessGroup
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -81,7 +81,7 @@ class MLP(nn.Module):
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n: