mirror of https://github.com/hpcaitech/ColossalAI
[tensor] a shorter shard and replicate spec (#1245)
parent
2699dfbbfd
commit
9bcd2fd4af
|
@ -1,6 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.node import map_arg
|
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern, ShardSpec
|
||||||
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern
|
|
||||||
|
|
||||||
|
|
||||||
def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter:
|
def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter:
|
||||||
|
@ -25,7 +24,7 @@ def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
|
|
||||||
spec = ColoTensorSpec(pg, distspec.shard([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = ColoTensorSpec(pg, ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
|
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
|
||||||
setattr(weight, "fx_attr", spec)
|
setattr(weight, "fx_attr", spec)
|
||||||
return weight
|
return weight
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor
|
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor
|
||||||
from colossalai.tensor import distspec, ColoTensorSpec
|
from colossalai.tensor import distspec, ColoTensorSpec, ShardSpec, ReplicaSpec
|
||||||
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
|
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
|
||||||
from ._utils import reduce_input, reduce_grad
|
from ._utils import reduce_input, reduce_grad
|
||||||
|
|
||||||
|
@ -11,7 +11,8 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
||||||
# mat1:S[1] x mat2:S[0] = Output:P
|
# mat1:S[1] x mat2:S[0] = Output:P
|
||||||
# beta * input + alpha * All-Reduce(Output) = res
|
# beta * input + alpha * All-Reduce(Output) = res
|
||||||
|
|
||||||
mat1 = mat1.redistribute(distspec.shard([-1], [mat2.get_tp_world_size()]))
|
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
|
||||||
|
|
||||||
|
|
||||||
# Output:P
|
# Output:P
|
||||||
partial_output = torch.mm(mat1, mat2)
|
partial_output = torch.mm(mat1, mat2)
|
||||||
|
@ -20,7 +21,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
||||||
# input
|
# input
|
||||||
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
|
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
|
||||||
output = beta * input_tensor + alpha * output
|
output = beta * input_tensor + alpha * output
|
||||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(distspec.replicate()))
|
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(ReplicaSpec()))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,11 +29,11 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
||||||
alpha: Number) -> ColoTensor:
|
alpha: Number) -> ColoTensor:
|
||||||
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
||||||
compute_spec = mat2.compute_spec
|
compute_spec = mat2.compute_spec
|
||||||
mat1 = mat1.redistribute(distspec.replicate())
|
mat1 = mat1.redistribute(ReplicaSpec())
|
||||||
mat1 = reduce_grad(mat1, mat1.get_process_group())
|
mat1 = reduce_grad(mat1, mat1.get_process_group())
|
||||||
|
|
||||||
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
|
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
|
||||||
output_spec = ColoTensorSpec(input_tensor.get_process_group(), distspec.shard([-1], [mat2.get_tp_world_size()]),
|
output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]),
|
||||||
ComputeSpec(ComputePattern.TP1D))
|
ComputeSpec(ComputePattern.TP1D))
|
||||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
|
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec
|
||||||
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
|
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +14,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||||
sparse: bool = False) -> ColoTensor:
|
sparse: bool = False) -> ColoTensor:
|
||||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||||
# Gather splitted lookup table
|
# Gather splitted lookup table
|
||||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
|
||||||
|
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||||
|
|
||||||
output_parallel = F.embedding(input_tensor,
|
output_parallel = F.embedding(input_tensor,
|
||||||
weight,
|
weight,
|
||||||
|
@ -23,7 +24,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||||
norm_type=norm_type,
|
norm_type=norm_type,
|
||||||
scale_grad_by_freq=scale_grad_by_freq,
|
scale_grad_by_freq=scale_grad_by_freq,
|
||||||
sparse=sparse)
|
sparse=sparse)
|
||||||
output_spec = ColoTensorSpec(weight.get_process_group(), distspec.shard([-1], [weight.get_tp_world_size()]),
|
output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]),
|
||||||
ComputeSpec(ComputePattern.TP1D))
|
ComputeSpec(ComputePattern.TP1D))
|
||||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||||
|
|
||||||
|
@ -46,7 +47,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
||||||
# Find index in this shard and mask those not here
|
# Find index in this shard and mask those not here
|
||||||
# Reduce all
|
# Reduce all
|
||||||
pg = weight.get_process_group()
|
pg = weight.get_process_group()
|
||||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
|
||||||
|
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||||
|
|
||||||
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||||
tensor_parallel_rank = weight.get_process_group().tp_local_rank()
|
tensor_parallel_rank = weight.get_process_group().tp_local_rank()
|
||||||
|
@ -74,7 +76,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
||||||
partial_output[input_mask, :] = 0.
|
partial_output[input_mask, :] = 0.
|
||||||
# Reduce across all the model parallel GPUs.
|
# Reduce across all the model parallel GPUs.
|
||||||
output = reduce_input(partial_output, weight.get_process_group())
|
output = reduce_input(partial_output, weight.get_process_group())
|
||||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), distspec.replicate()))
|
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec()))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch.nn.functional as F
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec
|
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, ShardSpec, ReplicaSpec
|
||||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
|
||||||
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||||
# Gather splitted lookup table
|
# Gather splitted lookup table
|
||||||
pg = weight.get_process_group()
|
pg = weight.get_process_group()
|
||||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||||
|
|
||||||
output_parallel = F.embedding_bag(input_tensor,
|
output_parallel = F.embedding_bag(input_tensor,
|
||||||
weight,
|
weight,
|
||||||
|
@ -33,8 +33,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
|
||||||
per_sample_weights=per_sample_weights,
|
per_sample_weights=per_sample_weights,
|
||||||
include_last_offset=include_last_offset,
|
include_last_offset=include_last_offset,
|
||||||
padding_idx=padding_idx)
|
padding_idx=padding_idx)
|
||||||
output_spec = ColoTensorSpec(pg, distspec.shard([-1], [weight.get_tp_world_size()]),
|
output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
ComputeSpec(ComputePattern.TP1D))
|
|
||||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||||
|
|
||||||
if weight.compute_spec.output_replicate:
|
if weight.compute_spec.output_replicate:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec
|
from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec, ReplicaSpec
|
||||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ def colo_layernorm(
|
||||||
assert isinstance(weight, ColoTensor)
|
assert isinstance(weight, ColoTensor)
|
||||||
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
|
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
|
||||||
bias = convert_to_colo_tensor(bias, weight.get_process_group())
|
bias = convert_to_colo_tensor(bias, weight.get_process_group())
|
||||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||||
|
|
||||||
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
|
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
|
||||||
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))
|
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))
|
||||||
|
|
|
@ -3,8 +3,7 @@ from typing import Optional
|
||||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
from ._utils import reduce_input, reduce_grad
|
from ._utils import reduce_input, reduce_grad
|
||||||
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec
|
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec
|
||||||
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
|
|
||||||
|
|
||||||
|
|
||||||
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||||
|
@ -12,7 +11,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
||||||
# All-Reduce(Output) + bias = res
|
# All-Reduce(Output) + bias = res
|
||||||
# Input:S[1]
|
# Input:S[1]
|
||||||
pg = weight.get_process_group()
|
pg = weight.get_process_group()
|
||||||
input_tensor = input_tensor.redistribute(distspec.shard([-1], [weight.get_tp_world_size()]))
|
input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]))
|
||||||
|
|
||||||
# Output:P
|
# Output:P
|
||||||
partial_output = F.linear(input_tensor, weight)
|
partial_output = F.linear(input_tensor, weight)
|
||||||
|
@ -24,7 +23,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
||||||
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
|
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
|
||||||
output = output + bias
|
output = output + bias
|
||||||
|
|
||||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, distspec.replicate()))
|
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec()))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,13 +32,15 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
||||||
# All-Gather(Output)
|
# All-Gather(Output)
|
||||||
# Input:B
|
# Input:B
|
||||||
compute_spec = weight.compute_spec
|
compute_spec = weight.compute_spec
|
||||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
|
||||||
|
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||||
|
|
||||||
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
|
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
|
||||||
|
|
||||||
output_parallel = F.linear(input_parallel, weight, bias)
|
output_parallel = F.linear(input_parallel, weight, bias)
|
||||||
output = ColoTensor.from_torch_tensor(output_parallel,
|
output = ColoTensor.from_torch_tensor(output_parallel,
|
||||||
spec=ColoTensorSpec(weight.get_process_group(),
|
spec=ColoTensorSpec(weight.get_process_group(),
|
||||||
distspec.shard([-1], [weight.get_tp_world_size()]),
|
ShardSpec([-1], [weight.get_tp_world_size()]),
|
||||||
ComputeSpec(ComputePattern.TP1D)))
|
ComputeSpec(ComputePattern.TP1D)))
|
||||||
if compute_spec.output_replicate:
|
if compute_spec.output_replicate:
|
||||||
return output.to_replicate()
|
return output.to_replicate()
|
||||||
|
|
|
@ -7,16 +7,6 @@ class ColoModule(object):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._shard_params: List[str] = []
|
self._shard_params: List[str] = []
|
||||||
# Example:
|
|
||||||
# {ComputePattern.TP1D:
|
|
||||||
# 'default':
|
|
||||||
# 'weight':
|
|
||||||
# distspec.shard(xxxxx)
|
|
||||||
# 'bias':
|
|
||||||
# distspec.shard(xxxxx)
|
|
||||||
# 'row': ...
|
|
||||||
# 'col': ...
|
|
||||||
# }
|
|
||||||
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
|
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
|
||||||
|
|
||||||
def _register_shard_params(self, params: List[str]):
|
def _register_shard_params(self, params: List[str]):
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from .colo_module import ColoModule
|
from .colo_module import ColoModule
|
||||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
|
from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
|
|
||||||
|
|
||||||
class ColoEmbedding(ColoModule):
|
class ColoEmbedding(ColoModule):
|
||||||
|
@ -21,7 +19,7 @@ class ColoEmbedding(ColoModule):
|
||||||
self._register_allowed_patterns(
|
self._register_allowed_patterns(
|
||||||
compute_pattern=_compute_pattern,
|
compute_pattern=_compute_pattern,
|
||||||
dist_specs={
|
dist_specs={
|
||||||
'weight': distspec.shard([0], [pg.tp_world_size()]),
|
'weight': ShardSpec([0], [pg.tp_world_size()]),
|
||||||
},
|
},
|
||||||
mode='row',
|
mode='row',
|
||||||
)
|
)
|
||||||
|
@ -30,7 +28,7 @@ class ColoEmbedding(ColoModule):
|
||||||
self._register_allowed_patterns(
|
self._register_allowed_patterns(
|
||||||
compute_pattern=_compute_pattern,
|
compute_pattern=_compute_pattern,
|
||||||
dist_specs={
|
dist_specs={
|
||||||
'weight': distspec.shard([-1], [pg.tp_world_size()]),
|
'weight': ShardSpec([-1], [pg.tp_world_size()]),
|
||||||
},
|
},
|
||||||
mode='col',
|
mode='col',
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from .colo_module import ColoModule
|
from .colo_module import ColoModule
|
||||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
|
from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
|
||||||
|
|
||||||
|
|
||||||
class ColoLinear(ColoModule):
|
class ColoLinear(ColoModule):
|
||||||
|
@ -19,7 +19,7 @@ class ColoLinear(ColoModule):
|
||||||
self._register_allowed_patterns(
|
self._register_allowed_patterns(
|
||||||
compute_pattern=_compute_pattern,
|
compute_pattern=_compute_pattern,
|
||||||
dist_specs={
|
dist_specs={
|
||||||
'weight': distspec.shard([-1], [pg.tp_world_size()]),
|
'weight': ShardSpec([-1], [pg.tp_world_size()]),
|
||||||
'bias': None
|
'bias': None
|
||||||
},
|
},
|
||||||
mode='row',
|
mode='row',
|
||||||
|
@ -29,8 +29,8 @@ class ColoLinear(ColoModule):
|
||||||
self._register_allowed_patterns(
|
self._register_allowed_patterns(
|
||||||
compute_pattern=_compute_pattern,
|
compute_pattern=_compute_pattern,
|
||||||
dist_specs={
|
dist_specs={
|
||||||
'weight': distspec.shard([0], [pg.tp_world_size()]),
|
'weight': ShardSpec([0], [pg.tp_world_size()]),
|
||||||
'bias': distspec.shard([0], [pg.tp_world_size()])
|
'bias': ShardSpec([0], [pg.tp_world_size()])
|
||||||
},
|
},
|
||||||
mode='col',
|
mode='col',
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
from .process_group import ProcessGroup
|
from .process_group import ProcessGroup
|
||||||
from .tensor_spec import ColoTensorSpec
|
from .tensor_spec import ColoTensorSpec
|
||||||
|
from .distspec import shard as ShardSpec
|
||||||
|
from .distspec import replicate as ReplicaSpec
|
||||||
|
|
||||||
from .compute_spec import ComputeSpec, ComputePattern
|
from .compute_spec import ComputeSpec, ComputePattern
|
||||||
from .colo_tensor import ColoTensor
|
from .colo_tensor import ColoTensor
|
||||||
from .colo_parameter import ColoParameter
|
from .colo_parameter import ColoParameter
|
||||||
|
@ -11,5 +14,5 @@ from . import distspec
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
||||||
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup',
|
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup',
|
||||||
'ColoTensorSpec', 'TensorSpec'
|
'ColoTensorSpec', 'TensorSpec', 'ShardSpec', 'ReplicaSpec'
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
from colossalai.tensor import ColoTensorSpec
|
from colossalai.tensor import ColoTensorSpec
|
||||||
from colossalai.tensor import distspec, ProcessGroup
|
from colossalai.tensor import ProcessGroup, ReplicaSpec
|
||||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||||
from typing import Optional, Set, Callable
|
from typing import Optional, Set, Callable
|
||||||
|
@ -51,21 +51,21 @@ class ColoTensor(torch.Tensor):
|
||||||
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
||||||
Args:
|
Args:
|
||||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(distspec.replicate()).
|
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
||||||
|
|
||||||
The signature of the function has to be consistent with the __new__ except for the 1st arg.
|
The signature of the function has to be consistent with the __new__ except for the 1st arg.
|
||||||
The class should be initialized with a torch tensor in the following ways.
|
The class should be initialized with a torch tensor in the following ways.
|
||||||
1. directly init.
|
1. directly init.
|
||||||
>>> pg = ProcessGroup()
|
>>> pg = ProcessGroup()
|
||||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
|
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||||
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
|
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
|
||||||
>>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size),
|
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
||||||
>>> dims=[0],
|
>>> dims=[0],
|
||||||
>>> num_partitions=[world_size])
|
>>> num_partitions=[world_size])
|
||||||
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
||||||
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||||
2. use static method from_torch_tensor
|
2. use static method from_torch_tensor
|
||||||
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
|
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
||||||
|
@ -85,7 +85,7 @@ class ColoTensor(torch.Tensor):
|
||||||
# If not set spec, use a DP process group and replicate dist spec
|
# If not set spec, use a DP process group and replicate dist spec
|
||||||
if spec is None:
|
if spec is None:
|
||||||
self.has_initialized = False
|
self.has_initialized = False
|
||||||
self.dist_spec = distspec.replicate()
|
self.dist_spec = ReplicaSpec()
|
||||||
self.compute_spec = None
|
self.compute_spec = None
|
||||||
self.process_group = ProcessGroup()
|
self.process_group = ProcessGroup()
|
||||||
else:
|
else:
|
||||||
|
@ -194,13 +194,14 @@ class ColoTensor(torch.Tensor):
|
||||||
"""to_replicate_
|
"""to_replicate_
|
||||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||||
"""
|
"""
|
||||||
self._redistribute(dist_spec=distspec.replicate())
|
self._redistribute(dist_spec=ReplicaSpec())
|
||||||
|
|
||||||
def to_replicate(self) -> 'ColoTensor':
|
def to_replicate(self) -> 'ColoTensor':
|
||||||
"""to_replicate
|
"""to_replicate
|
||||||
converting dist spec of the tensor to REPLICATE
|
converting dist spec of the tensor to REPLICATE
|
||||||
"""
|
"""
|
||||||
return self.redistribute(distspec.replicate())
|
return self.redistribute(ReplicaSpec())
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
||||||
|
@ -234,7 +235,7 @@ class ColoTensor(torch.Tensor):
|
||||||
"""
|
"""
|
||||||
if self.is_replicate():
|
if self.is_replicate():
|
||||||
return super().view(*args)
|
return super().view(*args)
|
||||||
replicated_t = self.redistribute(dist_spec=distspec.replicate())
|
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
|
||||||
return replicated_t.view(*args)
|
return replicated_t.view(*args)
|
||||||
|
|
||||||
def size_global(self, args: Optional[int] = None):
|
def size_global(self, args: Optional[int] = None):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from .utils import InsertPostInitMethodToModuleSubClasses
|
from .utils import InsertPostInitMethodToModuleSubClasses
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup
|
from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup, ReplicaSpec
|
||||||
|
|
||||||
from colossalai.nn.parallel.layers import register_colo_module, \
|
from colossalai.nn.parallel.layers import register_colo_module, \
|
||||||
ColoLinear, ColoEmbedding
|
ColoLinear, ColoEmbedding
|
||||||
|
|
|
@ -4,7 +4,7 @@ import pytest
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.tensor import ColoTensor, ProcessGroup
|
from colossalai.tensor import ColoTensor, ProcessGroup
|
||||||
from colossalai.tensor import distspec
|
from colossalai.tensor import ShardSpec
|
||||||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
|
@ -37,13 +37,13 @@ class Conv1D(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row(weight, bias, pg: ProcessGroup):
|
def init_1d_row(weight, bias, pg: ProcessGroup):
|
||||||
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_tensor_spec(*spec)
|
weight.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col(weight, bias, pg: ProcessGroup):
|
def init_1d_col(weight, bias, pg: ProcessGroup):
|
||||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_tensor_spec(*spec)
|
weight.set_tensor_spec(*spec)
|
||||||
bias.set_tensor_spec(*spec)
|
bias.set_tensor_spec(*spec)
|
||||||
|
|
|
@ -4,10 +4,9 @@ import torch.distributed as dist
|
||||||
import pytest
|
import pytest
|
||||||
import colossalai
|
import colossalai
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from torch.distributed.distributed_c10d import _get_default_group
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.tensor import DistSpecManager, distspec, ProcessGroup
|
from colossalai.tensor import DistSpecManager, ProcessGroup, ShardSpec, ReplicaSpec
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,10 +17,10 @@ def run():
|
||||||
depth = int(math.sqrt(size))
|
depth = int(math.sqrt(size))
|
||||||
assert depth == math.sqrt(size)
|
assert depth == math.sqrt(size)
|
||||||
x = torch.rand(8, 8).cuda()
|
x = torch.rand(8, 8).cuda()
|
||||||
old_dist_spec = distspec.replicate()
|
old_dist_spec = ReplicaSpec()
|
||||||
row_spec = distspec.shard([0], [size])
|
row_spec = ShardSpec([0], [size])
|
||||||
col_spec = distspec.shard([-1], [size])
|
col_spec = ShardSpec([-1], [size])
|
||||||
mat_spec = distspec.shard([0, 1], [depth, depth])
|
mat_spec = ShardSpec([0, 1], [depth, depth])
|
||||||
row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec, group)
|
row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec, group)
|
||||||
assert torch.equal(x.chunk(size, 0)[rank], row_shard)
|
assert torch.equal(x.chunk(size, 0)[rank], row_shard)
|
||||||
assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec, group))
|
assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec, group))
|
||||||
|
@ -40,8 +39,8 @@ def check_mem():
|
||||||
x = torch.rand(32, 32).cuda()
|
x = torch.rand(32, 32).cuda()
|
||||||
orig_mem = x.numel() * x.element_size()
|
orig_mem = x.numel() * x.element_size()
|
||||||
assert torch.cuda.memory_allocated() == orig_mem
|
assert torch.cuda.memory_allocated() == orig_mem
|
||||||
old_dist_spec = distspec.replicate()
|
old_dist_spec = ReplicaSpec()
|
||||||
row_spec = distspec.shard([0], [size])
|
row_spec = ShardSpec([0], [size])
|
||||||
x.data = DistSpecManager._shard_as(x, old_dist_spec, row_spec, pg)
|
x.data = DistSpecManager._shard_as(x, old_dist_spec, row_spec, pg)
|
||||||
assert x.size(0) == 32 // size and x.size(1) == 32
|
assert x.size(0) == 32 // size and x.size(1) == 32
|
||||||
assert torch.cuda.memory_allocated() == orig_mem // size
|
assert torch.cuda.memory_allocated() == orig_mem // size
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor import distspec, ColoParameter
|
from colossalai.tensor import ShardSpec, ColoParameter
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from _utils import tensor_equal, tensor_shard_equal
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col(weight, pg: ProcessGroup):
|
def init_1d_col(weight, pg: ProcessGroup):
|
||||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_tensor_spec(*spec)
|
weight.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor import ColoTensor, distspec
|
from colossalai.tensor import ColoTensor, ShardSpec
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
@ -14,13 +14,13 @@ from _utils import tensor_equal, tensor_shard_equal
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row(weight, pg: ProcessGroup):
|
def init_1d_row(weight, pg: ProcessGroup):
|
||||||
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_tensor_spec(*spec)
|
weight.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col(weight, pg: ProcessGroup):
|
def init_1d_col(weight, pg: ProcessGroup):
|
||||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_tensor_spec(*spec)
|
weight.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
|
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
|
||||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
@ -20,7 +20,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_spec(model, pg: ProcessGroup):
|
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||||
tensor_spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
for n, p in model.named_parameters():
|
for n, p in model.named_parameters():
|
||||||
if 'weight' in n and 'ln' not in n:
|
if 'weight' in n and 'ln' not in n:
|
||||||
|
@ -28,7 +28,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col_spec(model, pg: ProcessGroup):
|
def init_1d_col_spec(model, pg: ProcessGroup):
|
||||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
for n, p in model.named_parameters():
|
for n, p in model.named_parameters():
|
||||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor import ColoTensor, distspec
|
from colossalai.tensor import ColoTensor, ShardSpec
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
@ -15,13 +15,13 @@ from _utils import tensor_equal, tensor_shard_equal
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row(weight, bias, pg: ProcessGroup):
|
def init_1d_row(weight, bias, pg: ProcessGroup):
|
||||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_tensor_spec(*spec)
|
weight.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col(weight, bias, pg: ProcessGroup):
|
def init_1d_col(weight, bias, pg: ProcessGroup):
|
||||||
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_tensor_spec(*spec)
|
weight.set_tensor_spec(*spec)
|
||||||
bias.set_tensor_spec(*spec)
|
bias.set_tensor_spec(*spec)
|
||||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.tensor import distspec, ComputeSpec, ComputePattern
|
from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern
|
||||||
|
|
||||||
|
|
||||||
def check_cross_entropy():
|
def check_cross_entropy():
|
||||||
|
@ -22,7 +22,7 @@ def check_cross_entropy():
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
|
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
|
||||||
input_shard = input_t_colo.redistribute(distspec.shard([-1], [pg.tp_world_size()]))
|
input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
|
||||||
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
|
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
|
||||||
|
|
||||||
output = F.cross_entropy(input_t, target)
|
output = F.cross_entropy(input_t, target)
|
||||||
|
|
|
@ -11,7 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.tensor import distspec, ColoTensorSpec, ComputePattern, \
|
from colossalai.tensor import ShardSpec, ColoTensorSpec, ComputePattern, \
|
||||||
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
|
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
|
||||||
from colossalai.nn.optimizer import ColoOptimizer
|
from colossalai.nn.optimizer import ColoOptimizer
|
||||||
|
|
||||||
|
@ -19,28 +19,28 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
|
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
|
||||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_process_group(pg)
|
weight.set_process_group(pg)
|
||||||
weight.set_tensor_spec(*spec)
|
weight.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col_linear(weight, pg):
|
def init_1d_col_linear(weight, pg):
|
||||||
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_process_group(pg)
|
weight.set_process_group(pg)
|
||||||
weight.set_tensor_spec(*spec)
|
weight.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_embedding(weight, pg):
|
def init_1d_row_embedding(weight, pg):
|
||||||
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_process_group(pg)
|
weight.set_process_group(pg)
|
||||||
weight.set_tensor_spec(*spec)
|
weight.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col_embedding(weight, pg):
|
def init_1d_col_embedding(weight, pg):
|
||||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_process_group(pg)
|
weight.set_process_group(pg)
|
||||||
weight.set_tensor_spec(*spec)
|
weight.set_tensor_spec(*spec)
|
||||||
|
|
|
@ -5,7 +5,7 @@ from functools import partial
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec
|
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, ShardSpec, ReplicaSpec
|
||||||
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
|
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
|
||||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ import colossalai
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
|
||||||
from colossalai.tensor import distspec, ProcessGroup
|
from colossalai.tensor import distspec, ProcessGroup, ReplicaSpec
|
||||||
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
|
@ -159,7 +159,7 @@ def run_check_shared_param():
|
||||||
# They are all Linear, so both row is allowed. This should pass check.
|
# They are all Linear, so both row is allowed. This should pass check.
|
||||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
|
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
|
||||||
# This should be detected by check because you can not set weight as row while set bias as col.
|
# This should be detected by check because you can not set weight as row while set bias as col.
|
||||||
col_spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
col_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
|
|
||||||
# TODO(jiaruifang) optimize this line
|
# TODO(jiaruifang) optimize this line
|
||||||
if not model.cls.predictions.bias.has_initialized:
|
if not model.cls.predictions.bias.has_initialized:
|
||||||
|
|
|
@ -4,7 +4,7 @@ import colossalai
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
|
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec, ShardSpec
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
@ -47,7 +47,7 @@ def check_element_wise_ops():
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
t = torch.rand(2, 2)
|
t = torch.rand(2, 2)
|
||||||
x = ColoTensor(t, spec=ColoTensorSpec(pg, distspec.shard([0], [pg.tp_world_size()])))
|
x = ColoTensor(t, spec=ColoTensorSpec(pg, ShardSpec([0], [pg.tp_world_size()])))
|
||||||
|
|
||||||
check_spec_eq(x, x.cuda())
|
check_spec_eq(x, x.cuda())
|
||||||
assert torch.equal(x.cuda(), t.cuda())
|
assert torch.equal(x.cuda(), t.cuda())
|
||||||
|
|
|
@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.tensor import distspec, ColoTensor, ProcessGroup
|
from colossalai.tensor import distspec, ColoTensor, ProcessGroup, ShardSpec, ReplicaSpec
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ def _run_operand(world_size):
|
||||||
|
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
|
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
|
||||||
t.set_dist_spec(distspec.shard([0], [world_size]))
|
t.set_dist_spec(ShardSpec([0], [world_size]))
|
||||||
t_new = torch.zeros_like(t)
|
t_new = torch.zeros_like(t)
|
||||||
assert isinstance(t_new, ColoTensor)
|
assert isinstance(t_new, ColoTensor)
|
||||||
assert t_new.is_sharded()
|
assert t_new.is_sharded()
|
||||||
|
@ -69,7 +69,7 @@ def _run_view(world_size):
|
||||||
rank = gpc.get_global_rank()
|
rank = gpc.get_global_rank()
|
||||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||||
t = ColoTensor.from_torch_tensor(
|
t = ColoTensor.from_torch_tensor(
|
||||||
t_ref, ColoTensorSpec(pg, dist_attr=distspec.shard(dims=[0], num_partitions=[pg.tp_world_size()])))
|
t_ref, ColoTensorSpec(pg, dist_attr=ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])))
|
||||||
|
|
||||||
assert t.size_global()[0] == 4 * world_size
|
assert t.size_global()[0] == 4 * world_size
|
||||||
assert t.size_global(1) == 5
|
assert t.size_global(1) == 5
|
||||||
|
@ -82,7 +82,7 @@ def _run_view(world_size):
|
||||||
def _run_tensor_shard_init(world_size):
|
def _run_tensor_shard_init(world_size):
|
||||||
t_ref = torch.randn(4, 5)
|
t_ref = torch.randn(4, 5)
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
shard_attr = distspec.shard(dims=[0], num_partitions=[pg.tp_world_size()])
|
shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
|
||||||
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
|
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
|
||||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||||
t.set_dist_spec(distspec.replicate())
|
t.set_dist_spec(distspec.replicate())
|
||||||
|
|
|
@ -17,7 +17,7 @@ from colossalai.zero import ZeroOptimizer
|
||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize
|
||||||
from colossalai.amp import convert_to_apex_amp
|
from colossalai.amp import convert_to_apex_amp
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
|
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
|
||||||
|
|
||||||
|
|
||||||
def check_param_equal(model, torch_model, pg: ProcessGroup):
|
def check_param_equal(model, torch_model, pg: ProcessGroup):
|
||||||
|
@ -45,7 +45,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_spec(model, pg: ProcessGroup):
|
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||||
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
for n, p in model.named_parameters():
|
for n, p in model.named_parameters():
|
||||||
if 'weight' in n and 'ln' not in n:
|
if 'weight' in n and 'ln' not in n:
|
||||||
|
@ -53,7 +53,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col_spec(model, pg: ProcessGroup):
|
def init_1d_col_spec(model, pg: ProcessGroup):
|
||||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
for n, p in model.named_parameters():
|
for n, p in model.named_parameters():
|
||||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||||
|
|
|
@ -16,7 +16,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
|
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ShardSpec, ProcessGroup
|
||||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||||
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
|
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
|
@ -81,7 +81,7 @@ class MLP(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
|
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
|
||||||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
for n, p in model.named_parameters():
|
for n, p in model.named_parameters():
|
||||||
if 'weight' in n:
|
if 'weight' in n:
|
||||||
|
|
Loading…
Reference in New Issue