[tensor] a shorter shard and replicate spec (#1245)

pull/1241/head
Jiarui Fang 2022-07-11 15:51:48 +08:00 committed by GitHub
parent 2699dfbbfd
commit 9bcd2fd4af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 91 additions and 98 deletions

View File

@ -1,6 +1,5 @@
import torch import torch
from torch.fx.node import map_arg from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern, ShardSpec
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern
def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter: def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter:
@ -25,7 +24,7 @@ def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
spec = ColoTensorSpec(pg, distspec.shard([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = ColoTensorSpec(pg, ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor. # As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
setattr(weight, "fx_attr", spec) setattr(weight, "fx_attr", spec)
return weight return weight

View File

@ -1,7 +1,7 @@
import torch import torch
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor
from colossalai.tensor import distspec, ColoTensorSpec from colossalai.tensor import distspec, ColoTensorSpec, ShardSpec, ReplicaSpec
from ._utils import GeneralTensor, Number, convert_to_colo_tensor from ._utils import GeneralTensor, Number, convert_to_colo_tensor
from ._utils import reduce_input, reduce_grad from ._utils import reduce_input, reduce_grad
@ -11,7 +11,8 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# mat1:S[1] x mat2:S[0] = Output:P # mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res # beta * input + alpha * All-Reduce(Output) = res
mat1 = mat1.redistribute(distspec.shard([-1], [mat2.get_tp_world_size()])) mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
# Output:P # Output:P
partial_output = torch.mm(mat1, mat2) partial_output = torch.mm(mat1, mat2)
@ -20,7 +21,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# input # input
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op' assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor + alpha * output output = beta * input_tensor + alpha * output
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(distspec.replicate())) output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(ReplicaSpec()))
return output return output
@ -28,11 +29,11 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
alpha: Number) -> ColoTensor: alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1] # mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
compute_spec = mat2.compute_spec compute_spec = mat2.compute_spec
mat1 = mat1.redistribute(distspec.replicate()) mat1 = mat1.redistribute(ReplicaSpec())
mat1 = reduce_grad(mat1, mat1.get_process_group()) mat1 = reduce_grad(mat1, mat1.get_process_group())
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha) output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
output_spec = ColoTensorSpec(input_tensor.get_process_group(), distspec.shard([-1], [mat2.get_tp_world_size()]), output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)

View File

@ -1,7 +1,7 @@
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
@ -14,7 +14,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor: sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table # Gather splitted lookup table
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
output_parallel = F.embedding(input_tensor, output_parallel = F.embedding(input_tensor,
weight, weight,
@ -23,7 +24,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
norm_type=norm_type, norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq, scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse) sparse=sparse)
output_spec = ColoTensorSpec(weight.get_process_group(), distspec.shard([-1], [weight.get_tp_world_size()]), output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
@ -46,7 +47,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Find index in this shard and mask those not here # Find index in this shard and mask those not here
# Reduce all # Reduce all
pg = weight.get_process_group() pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) # tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = weight.get_process_group().tp_local_rank() tensor_parallel_rank = weight.get_process_group().tp_local_rank()
@ -74,7 +76,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
partial_output[input_mask, :] = 0. partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, weight.get_process_group()) output = reduce_input(partial_output, weight.get_process_group())
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), distspec.replicate())) output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec()))
return output return output

View File

@ -2,7 +2,7 @@ import torch.nn.functional as F
from typing import Optional from typing import Optional
from torch import Tensor from torch import Tensor
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, ShardSpec, ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor from ._utils import GeneralTensor, convert_to_colo_tensor
@ -20,7 +20,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table # Gather splitted lookup table
pg = weight.get_process_group() pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(distspec.replicate()) input_tensor = input_tensor.redistribute(ReplicaSpec())
output_parallel = F.embedding_bag(input_tensor, output_parallel = F.embedding_bag(input_tensor,
weight, weight,
@ -33,8 +33,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
per_sample_weights=per_sample_weights, per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset, include_last_offset=include_last_offset,
padding_idx=padding_idx) padding_idx=padding_idx)
output_spec = ColoTensorSpec(pg, distspec.shard([-1], [weight.get_tp_world_size()]), output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if weight.compute_spec.output_replicate: if weight.compute_spec.output_replicate:

View File

@ -1,7 +1,7 @@
from typing import List, Optional from typing import List, Optional
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec, ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor from ._utils import GeneralTensor, convert_to_colo_tensor
@ -16,7 +16,7 @@ def colo_layernorm(
assert isinstance(weight, ColoTensor) assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
bias = convert_to_colo_tensor(bias, weight.get_process_group()) bias = convert_to_colo_tensor(bias, weight.get_process_group())
input_tensor = input_tensor.redistribute(distspec.replicate()) input_tensor = input_tensor.redistribute(ReplicaSpec())
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps) output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group())) output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))

View File

@ -3,8 +3,7 @@ from typing import Optional
from ._utils import GeneralTensor, convert_to_colo_tensor from ._utils import GeneralTensor, convert_to_colo_tensor
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import reduce_input, reduce_grad from ._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
@ -12,7 +11,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Reduce(Output) + bias = res # All-Reduce(Output) + bias = res
# Input:S[1] # Input:S[1]
pg = weight.get_process_group() pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(distspec.shard([-1], [weight.get_tp_world_size()])) input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]))
# Output:P # Output:P
partial_output = F.linear(input_tensor, weight) partial_output = F.linear(input_tensor, weight)
@ -24,7 +23,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op' assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
output = output + bias output = output + bias
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, distspec.replicate())) output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec()))
return output return output
@ -33,13 +32,15 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Gather(Output) # All-Gather(Output)
# Input:B # Input:B
compute_spec = weight.compute_spec compute_spec = weight.compute_spec
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
input_parallel = reduce_grad(input_tensor, weight.get_process_group()) input_parallel = reduce_grad(input_tensor, weight.get_process_group())
output_parallel = F.linear(input_parallel, weight, bias) output_parallel = F.linear(input_parallel, weight, bias)
output = ColoTensor.from_torch_tensor(output_parallel, output = ColoTensor.from_torch_tensor(output_parallel,
spec=ColoTensorSpec(weight.get_process_group(), spec=ColoTensorSpec(weight.get_process_group(),
distspec.shard([-1], [weight.get_tp_world_size()]), ShardSpec([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))) ComputeSpec(ComputePattern.TP1D)))
if compute_spec.output_replicate: if compute_spec.output_replicate:
return output.to_replicate() return output.to_replicate()

View File

@ -7,16 +7,6 @@ class ColoModule(object):
def __init__(self): def __init__(self):
self._shard_params: List[str] = [] self._shard_params: List[str] = []
# Example:
# {ComputePattern.TP1D:
# 'default':
# 'weight':
# distspec.shard(xxxxx)
# 'bias':
# distspec.shard(xxxxx)
# 'row': ...
# 'col': ...
# }
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {} self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
def _register_shard_params(self, params: List[str]): def _register_shard_params(self, params: List[str]):

View File

@ -1,7 +1,5 @@
from .colo_module import ColoModule from .colo_module import ColoModule
from colossalai.tensor import ComputePattern, distspec, ProcessGroup from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
class ColoEmbedding(ColoModule): class ColoEmbedding(ColoModule):
@ -21,7 +19,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': distspec.shard([0], [pg.tp_world_size()]), 'weight': ShardSpec([0], [pg.tp_world_size()]),
}, },
mode='row', mode='row',
) )
@ -30,7 +28,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': distspec.shard([-1], [pg.tp_world_size()]), 'weight': ShardSpec([-1], [pg.tp_world_size()]),
}, },
mode='col', mode='col',
) )

View File

@ -1,5 +1,5 @@
from .colo_module import ColoModule from .colo_module import ColoModule
from colossalai.tensor import ComputePattern, distspec, ProcessGroup from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
class ColoLinear(ColoModule): class ColoLinear(ColoModule):
@ -19,7 +19,7 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': distspec.shard([-1], [pg.tp_world_size()]), 'weight': ShardSpec([-1], [pg.tp_world_size()]),
'bias': None 'bias': None
}, },
mode='row', mode='row',
@ -29,8 +29,8 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': distspec.shard([0], [pg.tp_world_size()]), 'weight': ShardSpec([0], [pg.tp_world_size()]),
'bias': distspec.shard([0], [pg.tp_world_size()]) 'bias': ShardSpec([0], [pg.tp_world_size()])
}, },
mode='col', mode='col',
) )

View File

@ -1,5 +1,8 @@
from .process_group import ProcessGroup from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec from .tensor_spec import ColoTensorSpec
from .distspec import shard as ShardSpec
from .distspec import replicate as ReplicaSpec
from .compute_spec import ComputeSpec, ComputePattern from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor from .colo_tensor import ColoTensor
from .colo_parameter import ColoParameter from .colo_parameter import ColoParameter
@ -11,5 +14,5 @@ from . import distspec
__all__ = [ __all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup',
'ColoTensorSpec', 'TensorSpec' 'ColoTensorSpec', 'TensorSpec', 'ShardSpec', 'ReplicaSpec'
] ]

View File

@ -5,7 +5,7 @@ import torch
from functools import lru_cache from functools import lru_cache
from colossalai.tensor import ColoTensorSpec from colossalai.tensor import ColoTensorSpec
from colossalai.tensor import distspec, ProcessGroup from colossalai.tensor import ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from typing import Optional, Set, Callable from typing import Optional, Set, Callable
@ -51,21 +51,21 @@ class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. """ Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
Args: Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor. data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(distspec.replicate()). spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
The signature of the function has to be consistent with the __new__ except for the 1st arg. The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways. The class should be initialized with a torch tensor in the following ways.
1. directly init. 1. directly init.
>>> pg = ProcessGroup() >>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate()) >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor. >>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size), >>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0], >>> dims=[0],
>>> num_partitions=[world_size]) >>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec) >>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) >>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
2. use static method from_torch_tensor 2. use static method from_torch_tensor
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate()) >>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
""" """
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
@ -85,7 +85,7 @@ class ColoTensor(torch.Tensor):
# If not set spec, use a DP process group and replicate dist spec # If not set spec, use a DP process group and replicate dist spec
if spec is None: if spec is None:
self.has_initialized = False self.has_initialized = False
self.dist_spec = distspec.replicate() self.dist_spec = ReplicaSpec()
self.compute_spec = None self.compute_spec = None
self.process_group = ProcessGroup() self.process_group = ProcessGroup()
else: else:
@ -194,13 +194,14 @@ class ColoTensor(torch.Tensor):
"""to_replicate_ """to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE an inline member function, converting dist spec of the tensor to REPLICATE
""" """
self._redistribute(dist_spec=distspec.replicate()) self._redistribute(dist_spec=ReplicaSpec())
def to_replicate(self) -> 'ColoTensor': def to_replicate(self) -> 'ColoTensor':
"""to_replicate """to_replicate
converting dist spec of the tensor to REPLICATE converting dist spec of the tensor to REPLICATE
""" """
return self.redistribute(distspec.replicate()) return self.redistribute(ReplicaSpec())
@staticmethod @staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor': def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
@ -234,7 +235,7 @@ class ColoTensor(torch.Tensor):
""" """
if self.is_replicate(): if self.is_replicate():
return super().view(*args) return super().view(*args)
replicated_t = self.redistribute(dist_spec=distspec.replicate()) replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return replicated_t.view(*args) return replicated_t.view(*args)
def size_global(self, args: Optional[int] = None): def size_global(self, args: Optional[int] = None):

View File

@ -1,6 +1,6 @@
from .utils import InsertPostInitMethodToModuleSubClasses from .utils import InsertPostInitMethodToModuleSubClasses
import torch import torch
from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup, ReplicaSpec
from colossalai.nn.parallel.layers import register_colo_module, \ from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding ColoLinear, ColoEmbedding

View File

@ -4,7 +4,7 @@ import pytest
import torch.nn as nn import torch.nn as nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor, ProcessGroup from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor import distspec from colossalai.tensor import ShardSpec
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
@ -37,13 +37,13 @@ class Conv1D(nn.Module):
def init_1d_row(weight, bias, pg: ProcessGroup): def init_1d_row(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
def init_1d_col(weight, bias, pg: ProcessGroup): def init_1d_col(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
bias.set_tensor_spec(*spec) bias.set_tensor_spec(*spec)

View File

@ -4,10 +4,9 @@ import torch.distributed as dist
import pytest import pytest
import colossalai import colossalai
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import DistSpecManager, distspec, ProcessGroup from colossalai.tensor import DistSpecManager, ProcessGroup, ShardSpec, ReplicaSpec
from functools import partial from functools import partial
@ -18,10 +17,10 @@ def run():
depth = int(math.sqrt(size)) depth = int(math.sqrt(size))
assert depth == math.sqrt(size) assert depth == math.sqrt(size)
x = torch.rand(8, 8).cuda() x = torch.rand(8, 8).cuda()
old_dist_spec = distspec.replicate() old_dist_spec = ReplicaSpec()
row_spec = distspec.shard([0], [size]) row_spec = ShardSpec([0], [size])
col_spec = distspec.shard([-1], [size]) col_spec = ShardSpec([-1], [size])
mat_spec = distspec.shard([0, 1], [depth, depth]) mat_spec = ShardSpec([0, 1], [depth, depth])
row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec, group) row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec, group)
assert torch.equal(x.chunk(size, 0)[rank], row_shard) assert torch.equal(x.chunk(size, 0)[rank], row_shard)
assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec, group)) assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec, group))
@ -40,8 +39,8 @@ def check_mem():
x = torch.rand(32, 32).cuda() x = torch.rand(32, 32).cuda()
orig_mem = x.numel() * x.element_size() orig_mem = x.numel() * x.element_size()
assert torch.cuda.memory_allocated() == orig_mem assert torch.cuda.memory_allocated() == orig_mem
old_dist_spec = distspec.replicate() old_dist_spec = ReplicaSpec()
row_spec = distspec.shard([0], [size]) row_spec = ShardSpec([0], [size])
x.data = DistSpecManager._shard_as(x, old_dist_spec, row_spec, pg) x.data = DistSpecManager._shard_as(x, old_dist_spec, row_spec, pg)
assert x.size(0) == 32 // size and x.size(1) == 32 assert x.size(0) == 32 // size and x.size(1) == 32
assert torch.cuda.memory_allocated() == orig_mem // size assert torch.cuda.memory_allocated() == orig_mem // size

View File

@ -1,5 +1,5 @@
import torch import torch
from colossalai.tensor import distspec, ColoParameter from colossalai.tensor import ShardSpec, ColoParameter
from torch.nn import functional as F from torch.nn import functional as F
from functools import partial from functools import partial
@ -14,7 +14,7 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_col(weight, pg: ProcessGroup): def init_1d_col(weight, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)

View File

@ -1,5 +1,5 @@
import torch import torch
from colossalai.tensor import ColoTensor, distspec from colossalai.tensor import ColoTensor, ShardSpec
from torch.nn import functional as F from torch.nn import functional as F
from functools import partial from functools import partial
@ -14,13 +14,13 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, pg: ProcessGroup): def init_1d_row(weight, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
def init_1d_col(weight, pg: ProcessGroup): def init_1d_col(weight, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)

View File

@ -12,7 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
@ -20,7 +20,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_spec(model, pg: ProcessGroup): def init_1d_row_spec(model, pg: ProcessGroup):
tensor_spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n: if 'weight' in n and 'ln' not in n:
@ -28,7 +28,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
def init_1d_col_spec(model, pg: ProcessGroup): def init_1d_col_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n): if 'ln' not in n and ('weight' in n or 'bias' in n):

View File

@ -1,5 +1,5 @@
import torch import torch
from colossalai.tensor import ColoTensor, distspec from colossalai.tensor import ColoTensor, ShardSpec
from functools import partial from functools import partial
@ -15,13 +15,13 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, bias, pg: ProcessGroup): def init_1d_row(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
def init_1d_col(weight, bias, pg: ProcessGroup): def init_1d_col(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
bias.set_tensor_spec(*spec) bias.set_tensor_spec(*spec)

View File

@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import distspec, ComputeSpec, ComputePattern from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern
def check_cross_entropy(): def check_cross_entropy():
@ -22,7 +22,7 @@ def check_cross_entropy():
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg)) input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
input_shard = input_t_colo.redistribute(distspec.shard([-1], [pg.tp_world_size()])) input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D)) input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
output = F.cross_entropy(input_t, target) output = F.cross_entropy(input_t, target)

View File

@ -11,7 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, ColoTensorSpec, ComputePattern, \ from colossalai.tensor import ShardSpec, ColoTensorSpec, ComputePattern, \
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
from colossalai.nn.optimizer import ColoOptimizer from colossalai.nn.optimizer import ColoOptimizer
@ -19,28 +19,28 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup): def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_process_group(pg) weight.set_process_group(pg)
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
def init_1d_col_linear(weight, pg): def init_1d_col_linear(weight, pg):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_process_group(pg) weight.set_process_group(pg)
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
def init_1d_row_embedding(weight, pg): def init_1d_row_embedding(weight, pg):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_process_group(pg) weight.set_process_group(pg)
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
def init_1d_col_embedding(weight, pg): def init_1d_col_embedding(weight, pg):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_process_group(pg) weight.set_process_group(pg)
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)

View File

@ -5,7 +5,7 @@ from functools import partial
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, ShardSpec, ReplicaSpec
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed from _utils import tensor_equal, tensor_shard_equal, set_seed
@ -13,7 +13,7 @@ import colossalai
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, ProcessGroup from colossalai.tensor import distspec, ProcessGroup, ReplicaSpec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
@ -159,7 +159,7 @@ def run_check_shared_param():
# They are all Linear, so both row is allowed. This should pass check. # They are all Linear, so both row is allowed. This should pass check.
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row') init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
# This should be detected by check because you can not set weight as row while set bias as col. # This should be detected by check because you can not set weight as row while set bias as col.
col_spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) col_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
# TODO(jiaruifang) optimize this line # TODO(jiaruifang) optimize this line
if not model.cls.predictions.bias.has_initialized: if not model.cls.predictions.bias.has_initialized:

View File

@ -4,7 +4,7 @@ import colossalai
import torch.nn.functional as F import torch.nn.functional as F
import torch.multiprocessing as mp import torch.multiprocessing as mp
from functools import partial from functools import partial
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
@ -47,7 +47,7 @@ def check_element_wise_ops():
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
t = torch.rand(2, 2) t = torch.rand(2, 2)
x = ColoTensor(t, spec=ColoTensorSpec(pg, distspec.shard([0], [pg.tp_world_size()]))) x = ColoTensor(t, spec=ColoTensorSpec(pg, ShardSpec([0], [pg.tp_world_size()])))
check_spec_eq(x, x.cuda()) check_spec_eq(x, x.cuda())
assert torch.equal(x.cuda(), t.cuda()) assert torch.equal(x.cuda(), t.cuda())

View File

@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import distspec, ColoTensor, ProcessGroup from colossalai.tensor import distspec, ColoTensor, ProcessGroup, ShardSpec, ReplicaSpec
from functools import partial from functools import partial
@ -55,7 +55,7 @@ def _run_operand(world_size):
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg)) t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
t.set_dist_spec(distspec.shard([0], [world_size])) t.set_dist_spec(ShardSpec([0], [world_size]))
t_new = torch.zeros_like(t) t_new = torch.zeros_like(t)
assert isinstance(t_new, ColoTensor) assert isinstance(t_new, ColoTensor)
assert t_new.is_sharded() assert t_new.is_sharded()
@ -69,7 +69,7 @@ def _run_view(world_size):
rank = gpc.get_global_rank() rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size) pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
t = ColoTensor.from_torch_tensor( t = ColoTensor.from_torch_tensor(
t_ref, ColoTensorSpec(pg, dist_attr=distspec.shard(dims=[0], num_partitions=[pg.tp_world_size()]))) t_ref, ColoTensorSpec(pg, dist_attr=ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])))
assert t.size_global()[0] == 4 * world_size assert t.size_global()[0] == 4 * world_size
assert t.size_global(1) == 5 assert t.size_global(1) == 5
@ -82,7 +82,7 @@ def _run_view(world_size):
def _run_tensor_shard_init(world_size): def _run_tensor_shard_init(world_size):
t_ref = torch.randn(4, 5) t_ref = torch.randn(4, 5)
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
shard_attr = distspec.shard(dims=[0], num_partitions=[pg.tp_world_size()]) shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr) tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_dist_spec(distspec.replicate()) t.set_dist_spec(distspec.replicate())

View File

@ -17,7 +17,7 @@ from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
def check_param_equal(model, torch_model, pg: ProcessGroup): def check_param_equal(model, torch_model, pg: ProcessGroup):
@ -45,7 +45,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
def init_1d_row_spec(model, pg: ProcessGroup): def init_1d_row_spec(model, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n: if 'weight' in n and 'ln' not in n:
@ -53,7 +53,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
def init_1d_col_spec(model, pg: ProcessGroup): def init_1d_col_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n): if 'ln' not in n and ('weight' in n or 'bias' in n):

View File

@ -16,7 +16,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ShardSpec, ProcessGroup
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -81,7 +81,7 @@ class MLP(nn.Module):
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup): def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n: if 'weight' in n: