[tensor] refactor colo-tensor (#992)

* refactor colo-tensor and update linear op

* polish code

* polish code

* update ops and unit tests

* update unit tests

* polish code

* rename dist_spec module

* polish code

* polish code

* remove unneeded import

* fix pipelinable
pull/1003/head
ver217 2022-05-19 12:44:59 +08:00 committed by GitHub
parent 1467d83edf
commit ad536e308e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 657 additions and 616 deletions

View File

@ -6,10 +6,10 @@ from .colo_parameter import ColoParameter
from .utils import convert_parameter, named_params_with_colotensor
from ._ops import *
from .optim.colo_optimizer import ColoOptimizer
from . import dist_spec
from . import distspec
from .dist_spec_mgr import DistSpecManager
__all__ = [
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'dist_spec', 'DistSpecManager'
'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager'
]

View File

@ -0,0 +1,12 @@
import torch
from typing import Union, Optional
from colossalai.tensor import ColoTensor
GeneralTensor = Union[ColoTensor, torch.Tensor]
Number = Union[int, float]
def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTensor]:
if tensor is not None and not isinstance(tensor, ColoTensor):
tensor = ColoTensor.from_torch_tensor(tensor)
return tensor

View File

@ -1,64 +1,66 @@
import torch
from typing import Union
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor
from colossalai.tensor import dist_spec
from colossalai.tensor import distspec
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
alpha: Union[int, float]) -> ColoTensor:
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor:
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
mat1.to_dist_spec(dist_spec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group().size()]))
mat1 = mat1.convert_to_dist_spec(
distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]))
# Output:P
partial_output = torch.mm(mat1.torch_tensor(), mat2.torch_tensor())
partial_output = torch.mm(mat1, mat2)
# Reduce(Output)
output = reduce_input(partial_output, parallel_action.parallel_mode)
# input
assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor.torch_tensor() + alpha * output
output = ColoTensor.init_from_torch_tensor(output,
spec=TensorSpec(dist_spec.replicate(mat2.spec.get_process_group())))
output = beta * input_tensor + alpha * output
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(mat2.spec.get_process_group())))
return output
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
alpha: Union[int, float]) -> ColoTensor:
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
mat1.to_dist_spec(dist_spec.replicate(mat2.spec.get_process_group()))
mat1_torch_tensor = reduce_grad(mat1.torch_tensor(), parallel_action.parallel_mode)
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
mat1 = reduce_grad(mat1, parallel_action.parallel_mode)
output_parallel = torch.addmm(input_tensor.torch_tensor(),
mat1_torch_tensor,
mat2.torch_tensor(),
beta=beta,
alpha=alpha)
output_spec = TensorSpec(
dist_spec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group().size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
output = ColoTensor.init_from_torch_tensor(output_parallel, spec=output_spec)
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if parallel_action.gather_out:
# All-Gather(Output)
output.to_dist_spec(dist_spec.replicate(mat2.spec.get_process_group()))
output = output.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
return output
def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor:
assert mode in ('row', 'col')
funcs = {'row': colo_addmm_1Drow, 'col': colo_addmm_1Dcol}
return funcs[mode](input_tensor, mat1, mat2, beta, alpha)
@colo_op_impl(torch.addmm)
def colo_addmm(types, args, kwargs, pg):
def colo_addmm(input_tensor: GeneralTensor,
mat1: GeneralTensor,
mat2: GeneralTensor,
*args,
beta: Number = 1,
alpha: Number = 1) -> ColoTensor:
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
input_tensor, mat1, mat2 = args[:3]
to_colo_tensor = lambda t: t if isinstance(t, ColoTensor) else ColoTensor.init_from_torch_tensor(t)
input_tensor = to_colo_tensor(input_tensor)
mat2 = to_colo_tensor(mat2)
beta = kwargs.get('beta', 1) if kwargs else 1
alpha = kwargs.get('alpha', 1) if kwargs else 1
input_tensor, mat1, mat2 = tuple(map(convert_to_colo_tensor, (input_tensor, mat1, mat2)))
# building the computing graph, inputs -> op
# if GraphGlobalEnv().graph_building:
@ -70,17 +72,15 @@ def colo_addmm(types, args, kwargs, pg):
if not mat2.has_spec(): # No Model Parallel Applied
assert mat2.spec.is_gathered(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.spec.is_gathered(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.init_from_torch_tensor(
torch.addmm(input_tensor.torch_tensor(), mat1, mat2.torch_tensor(), beta=beta, alpha=alpha))
ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha))
elif mat2.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
spec = TensorSpec(dist_spec.replicate(mat2.spec.get_process_group()))
mat1 = args[1] if isinstance(args[1], ColoTensor) else ColoTensor.init_from_torch_tensor(args[1], spec=spec)
if mat2.spec.is_1D_row() and input_tensor.spec.is_gathered():
ret_tensor = colo_addmm_1Drow(input_tensor, mat1, mat2, beta, alpha)
mode = 'row'
elif mat2.spec.is_1D_col() and (input_tensor.spec.is_1D_col() or input_tensor.spec.is_1D_row()):
ret_tensor = colo_addmm_1Dcol(input_tensor, mat1, mat2, beta, alpha)
mode = 'col'
else:
raise NotImplementedError
ret_tensor = colo_addmm_1d(mode, input_tensor, mat1, mat2, beta, alpha)
else:
raise NotImplementedError

View File

@ -1,64 +1,28 @@
from copy import copy
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor
@colo_op_impl(torch.allclose)
def colo_mean(types, args=(), kwargs=None, pg=None):
a = args[0]
b = args[1]
if isinstance(a, ColoTensor):
a = a.torch_tensor()
elif isinstance(b, ColoTensor):
b = b.torch_tensor()
if kwargs is None:
kwargs = {}
return torch.allclose(a, b, **kwargs)
@colo_op_impl(torch.mean)
def colo_mean(types, args=(), kwargs=None, pg=None):
input_t = args[0]
if isinstance(input_t, ColoTensor):
input_t = input_t.torch_tensor()
return ColoTensor.init_from_torch_tensor(torch.mean(input_t))
from ._utils import GeneralTensor
def register_elementwise_op(op):
@colo_op_impl(op)
def elementwise_op(types, args=(), kwargs=None, pg=None):
def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs):
"""
Handles ``__torch_function__`` dispatch for the elementwise op such
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
This method computes on either a normal tensor or a sharded tensor.
"""
input_tensor = args[0]
# Validate types
if not isinstance(input_tensor, ColoTensor):
raise TypeError("input needs to be a ColoTensor")
return ColoTensor.init_from_torch_tensor(op(input_tensor.torch_tensor()))
output = op(input_tensor, *args, **kwargs)
if isinstance(input_tensor, ColoTensor):
spec = copy(input_tensor.spec)
return ColoTensor.from_torch_tensor(output, spec=spec)
return ColoTensor.from_torch_tensor(output)
register_elementwise_op(torch.nn.functional.gelu)
register_elementwise_op(torch.nn.functional.relu)
@colo_op_impl(torch.sum)
def sum_op(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for the elementwise op such
as ``torch.sum`.
This method computes on either a normal tensor or a sharded tensor.
"""
if len(args) > 0:
input_tensor = args[0]
if kwargs is None:
kwargs = {}
if 'input' in kwargs:
input_tensor = kwargs['input']
# Validate types
if not isinstance(input_tensor, ColoTensor):
raise TypeError("input needs to be a ColoTensor")
return ColoTensor.init_from_torch_tensor(torch.sum(input_tensor.torch_tensor()))
register_elementwise_op(torch.clone)
register_elementwise_op(torch.Tensor.clone)
register_elementwise_op(torch.Tensor.detach)

View File

@ -1,31 +1,52 @@
import torch
import torch.nn.functional as F
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input
from colossalai.core import global_context as gpc
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, dist_spec
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
from ._utils import GeneralTensor, convert_to_colo_tensor
def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
def colo_embedding_1Dcol(input_tensor: ColoTensor,
weight: ColoTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
output_parallel = torch.nn.functional.embedding(input_tensor.torch_tensor(), weight.torch_tensor(), *args, **kwargs)
output_parallel = F.embedding(input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
output_spec = TensorSpec(
dist_spec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group().size()]),
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
output = ColoTensor.init_from_torch_tensor(output_parallel, spec=output_spec)
output.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
return output
def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
def colo_embedding_1Drow(input_tensor: ColoTensor,
weight: ColoTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> ColoTensor:
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
num_embeddings_per_partition = weight.size(0)
@ -33,53 +54,87 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwa
vocab_end_index = vocab_start_index + num_embeddings_per_partition
# Build the mask.
input_mask = (input_tensor.torch_tensor() < vocab_start_index) | \
(input_tensor.torch_tensor() >= vocab_end_index)
input_mask = (input_tensor < vocab_start_index) | \
(input_tensor >= vocab_end_index)
# Mask the input.
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
masked_input = input_tensor.torch_tensor().clone() - vocab_start_index
masked_input = input_tensor.clone() - vocab_start_index
masked_input[input_mask] = 0
partial_output = torch.nn.functional.embedding(masked_input, weight.torch_tensor(), *args, **kwargs)
partial_output = F.embedding(masked_input,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
# Mask the output embedding.
partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, parallel_action.parallel_mode)
output = ColoTensor.init_from_torch_tensor(output,
spec=TensorSpec(dist_spec.replicate(weight.spec.get_process_group())))
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
return output
@colo_op_impl(torch.nn.functional.embedding)
def colo_embedding(types, args, kwargs, pg):
def colo_embedding_1d(mode: str,
input_tensor: ColoTensor,
weight: ColoTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> ColoTensor:
assert mode in ('row', 'col')
funcs = {'row': colo_embedding_1Drow, 'col': colo_embedding_1Dcol}
return funcs[mode](input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
@colo_op_impl(F.embedding)
def colo_embedding(input_tensor: GeneralTensor,
weight: GeneralTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
This method looks up an embedding table.
"""
input_tensor = args[0]
weight = args[1]
args = args[2:]
if not isinstance(input_tensor, ColoTensor):
input_tensor = ColoTensor.init_from_torch_tensor(input_tensor)
if not isinstance(weight, ColoTensor):
weight = ColoTensor.init_from_torch_tensor(weight)
input_tensor, weight = tuple(map(convert_to_colo_tensor, (input_tensor, weight)))
# Handle differen parallel actions.
if not weight.has_spec(): # No Model Parallel Applied
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op'
input_tensor = input_tensor.torch_tensor()
weight = weight.torch_tensor()
output = torch.nn.functional.embedding(input_tensor, weight, *args, **kwargs)
return ColoTensor.init_from_torch_tensor(output)
return ColoTensor.from_torch_tensor(
F.embedding(input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse))
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_row():
return colo_embedding_1Drow(input_tensor, weight, args, kwargs)
mode = 'row'
elif weight.spec.is_1D_col():
return colo_embedding_1Dcol(input_tensor, weight, args, kwargs)
mode = 'col'
else:
raise NotImplementedError
return colo_embedding_1d(mode,
input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
else:
raise NotImplementedError

View File

@ -1,39 +1,24 @@
import torch
import torch.nn.functional as F
from typing import List, Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, dist_spec
from colossalai.tensor import ColoTensor, distspec
from ._utils import GeneralTensor, convert_to_colo_tensor
@colo_op_impl(torch.nn.functional.layer_norm)
def colo_layernorm(types, args=(), kwargs=None, pg=None):
arg_num = len(args)
if arg_num > 0:
input_tensor = args[0]
if arg_num > 1:
normalized_shape = args[1]
if arg_num > 2:
weight = args[3]
if arg_num > 3:
bias = args[4]
if arg_num > 4:
eps = args[5]
@colo_op_impl(F.layer_norm)
def colo_layernorm(
input_tensor: GeneralTensor,
normalized_shape: List[int],
weight: Optional[GeneralTensor] = None,
bias: Optional[GeneralTensor] = None,
eps: float = 1e-5,
):
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
if 'input' in kwargs:
input_tensor = kwargs['input']
if 'weight' in kwargs:
weight = kwargs['weight']
if 'bias' in kwargs:
bias = kwargs['bias']
if 'eps' in kwargs:
eps = kwargs['eps']
# TODO (ver217): check dist spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.spec.get_process_group()))
if isinstance(input_tensor, ColoTensor):
# TODO (ver217): check input dist spec
input_tensor.to_dist_spec(dist_spec.replicate(input_tensor.spec.get_process_group()))
input_tensor = input_tensor.torch_tensor()
if isinstance(weight, ColoTensor):
weight = weight.torch_tensor()
if isinstance(bias, ColoTensor):
bias = bias.torch_tensor()
return ColoTensor.init_from_torch_tensor(
torch.nn.functional.layer_norm(input_tensor, normalized_shape, weight, bias, eps))
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, input_tensor.spec)
return output

View File

@ -1,108 +1,89 @@
import torch
import torch.nn.functional as F
import torch.distributed as dist
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input, reduce_grad
from colossalai.nn.layer.utils import divide
from colossalai.core import global_context as gpc
from packaging import version
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, dist_spec
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
from ._utils import GeneralTensor, convert_to_colo_tensor
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
input_tensor.to_dist_spec(
dist_spec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group().size()]))
input_tensor = input_tensor.convert_to_dist_spec(
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]))
# Output:P
partial_output = torch.nn.functional.linear(input_tensor.torch_tensor(), weight.torch_tensor())
partial_output = F.linear(input_tensor, weight)
# Reduce(Output)
output = reduce_input(partial_output, parallel_action.parallel_mode)
# Bias
if bias is not None:
assert not bias.has_spec(), 'Invalid bias spec for 1Drow Linear op'
output = output + bias.torch_tensor()
output = ColoTensor.init_from_torch_tensor(output,
spec=TensorSpec(dist_spec.replicate(weight.spec.get_process_group())))
output = output + bias
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
return output
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
input_parallel = reduce_grad(input_tensor.torch_tensor(), parallel_action.parallel_mode)
if bias is not None:
bias = bias.torch_tensor()
output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias)
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
input_parallel = reduce_grad(input_tensor, parallel_action.parallel_mode)
output = ColoTensor.init_from_torch_tensor(
output_parallel = F.linear(input_parallel, weight, bias)
output = ColoTensor.from_torch_tensor(
output_parallel,
spec=TensorSpec(
dist_spec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group().size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]))
spec=TensorSpec(distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]))
if parallel_action.gather_out:
# All-Gather(Output)
output.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
return output
@colo_op_impl(torch.nn.functional.linear)
def colo_linear(types, args, kwargs, pg):
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
assert mode in ('row', 'col')
funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol}
return funcs[mode](input_tensor, weight, bias)
@colo_op_impl(F.linear)
def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
input_tensor = args[0]
weight = args[1]
if version.parse(torch.__version__) > version.parse("1.11.0"):
if len(args) == 3:
bias = args[2]
else:
bias = None
else:
bias = kwargs.get('bias', None)
if not isinstance(input_tensor, ColoTensor):
input_tensor = ColoTensor.init_from_torch_tensor(input_tensor)
if not isinstance(weight, ColoTensor):
weight = ColoTensor.init_from_torch_tensor(weight)
if bias is not None and not isinstance(bias, ColoTensor):
bias = ColoTensor.init_from_torch_tensor(bias)
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
# building the computing graph, inputs -> op
if GraphGlobalEnv().graph_building:
cur_op_node = GraphOpNode('linear', [weight, bias])
cur_op_node.add_prev_tensor(input_tensor)
# Add communication logic before and after linear call.
ret_tensor = None
if not weight.has_spec(): # No Model Parallel Applied
assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
input_tensor = input_tensor.torch_tensor()
weight = weight.torch_tensor()
if bias is not None:
bias = bias.torch_tensor()
ret_tensor = ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
assert weight.spec.is_gathered(), 'Invalid weight spec for native Linear op'
assert bias is None or bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias))
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_col() and (bias is None or bias.spec.is_gathered()):
ret_tensor = colo_linear_1Drow(input_tensor, weight, bias)
mode = 'row'
elif weight.spec.is_1D_row() and (bias is None or bias.spec.is_1D_row() or bias.spec.is_1D_col()):
ret_tensor = colo_linear_1Dcol(input_tensor, weight, bias)
mode = 'col'
else:
raise NotImplementedError
ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias)
else:
raise NotImplementedError
# building the computing graph, op -> output
if GraphGlobalEnv().graph_building:
cur_op_node.add_post_tensor(ret_tensor)
return ret_tensor

View File

@ -1,40 +1,37 @@
from colossalai.tensor.dist_spec import DistPlacementPattern
import torch
import torch.nn.functional as F
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor
from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
from ._utils import GeneralTensor, convert_to_colo_tensor
@colo_op_impl(torch.nn.functional.cross_entropy)
def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
arg_num = len(args)
if arg_num > 0:
input_tensor = args[0]
if arg_num > 1:
target = args[1]
if arg_num > 2:
weight = args[2]
if 'input' in kwargs:
input_tensor = kwargs.pop('input')
if 'target' in kwargs:
target = kwargs.pop('target')
if 'weight' in kwargs:
weight = kwargs.pop('weight')
if not isinstance(input_tensor, ColoTensor):
input_tensor = ColoTensor.init_from_torch_tensor(input_tensor)
if isinstance(target, ColoTensor):
target = target.torch_tensor()
@colo_op_impl(F.cross_entropy)
def colo_cross_entropy(input_tensor: GeneralTensor,
target: GeneralTensor,
weight: Optional[GeneralTensor] = None,
size_average: Optional[bool] = None,
ignore_index: int = -100,
reduce: Optional[bool] = None,
reduction: str = "mean",
label_smoothing: float = 0.0):
input_tensor, target, weight = tuple(map(convert_to_colo_tensor, (input_tensor, target, weight)))
if input_tensor.spec.is_gathered(): # Input is gathered
return ColoTensor.init_from_torch_tensor(
torch.nn.functional.cross_entropy(input_tensor.torch_tensor(), target, weight))
output = F.cross_entropy(input_tensor,
target,
weight=weight,
size_average=size_average,
ignore_index=ignore_index,
reduce=reduce,
reduction=reduction,
label_smoothing=label_smoothing)
return ColoTensor.from_torch_tensor(output)
elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied
if input_tensor.spec.is_1D_col():
return ColoTensor.init_from_torch_tensor(VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(),
target))
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
return ColoTensor.from_torch_tensor(output)
else:
raise NotImplementedError
else:

View File

@ -1,6 +1,8 @@
from .colo_tensor import ColoTensor
from .const import TensorType
import torch
from colossalai.tensor import TensorSpec, distspec
from copy import copy
class ColoParameter(ColoTensor):
@ -8,21 +10,26 @@ class ColoParameter(ColoTensor):
"""
def __init__(self, *args, **kargs):
super().__init__(*args, **kargs)
self._type = TensorType.MODEL
def __new__(cls,
data: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __new__(cls, *args, **kwargs):
t = super(ColoParameter, cls).__new__(cls)
t._type = TensorType.MODEL
return t
def __init__(self,
data: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec)
self._type = TensorType.MODEL
self._graph_node = None
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoParameter':
colo_p = ColoParameter(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.is_pinned(),
device=tensor.device,
torch_tensor=tensor if save_payload else torch.empty(0))
return colo_p
def from_torch_tensor(tensor: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor

View File

@ -1,16 +1,23 @@
from .op_wrapper import _COLOSSAL_OPS
from copy import copy
import torch
from typing import Tuple, Optional, Callable, Union
from numpy import product
from colossalai.tensor import TensorSpec
from .const import TensorType
from colossalai.tensor import dist_spec
from colossalai.tensor import distspec
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.dist_spec import _DistSpec
from colossalai.tensor.distspec import _DistSpec
from torch.overrides import get_default_nowrap_functions
class ColoTensor(object):
def _convert_output(output):
if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
output = ColoTensor.from_torch_tensor(output)
elif isinstance(output, (list, tuple)):
output = type(output)(_convert_output(o) for o in output)
return output
class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute.
2. It supports lazy init the tensor's payload.
@ -18,120 +25,23 @@ class ColoTensor(object):
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
"""
def __new__(cls, *args, **kwargs):
return super(ColoTensor, cls).__new__(cls)
def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
def __init__(self,
*size: Tuple[int],
dtype=None,
requires_grad=False,
pin_memory=False,
device=None,
torch_tensor=torch.empty(0),
spec: TensorSpec = TensorSpec(dist_spec.replicate())):
self._size = size
self._dtype = dtype
self._requires_grad = requires_grad
self._pin_memory = pin_memory
self._device = device
self._torch_tensor = torch_tensor
def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec)
self._type = TensorType.NONMODEL
self._graph_node = None
def __getitem__(self, key):
return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key])
@property
def spec(self) -> TensorSpec:
return self._spec
@property
def shard_pattern(self):
return self._shard_pattern
@property
def data(self):
return self._torch_tensor.data
@data.setter
def data(self, tensor: Union[torch.Tensor, "ColoTensor"]):
if isinstance(tensor, ColoTensor):
self._torch_tensor.data = tensor.data
elif isinstance(tensor, torch.Tensor):
self._torch_tensor.data = tensor
else:
raise NotImplementedError
@property
def grad(self):
return self._torch_tensor.grad
@property
def size(self):
return self._size
@property
def shape(self):
return torch.Size(self._size)
@property
def device(self):
return self._torch_tensor.device
def size(self, dim=None):
if dim is None:
return self.shape
return self._size[dim]
def dim(self):
return len(self._size)
def normal_(self, mean=0., std=1.):
torch_tensor = self.torch_tensor()
return torch_tensor.normal_(mean=mean, std=std)
def numel(self):
return product(self._size)
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor,
save_payload=True,
spec: TensorSpec = TensorSpec(dist_spec.replicate())) -> 'ColoTensor':
colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.is_pinned(),
device=tensor.device,
torch_tensor=tensor if save_payload else torch.empty(0),
spec=spec)
return colo_t
def del_torch_tensor(self, save_shape=False) -> None:
"""
delete the payload of the torch tensor.
Args:
save_shape (bool, optional): if saving the shape of the torch_tensor.
If saving the shape, the size of self._torch_tensor is inconsist with the self._size.
Defaults to False.
"""
if not save_shape:
self._size = (0,)
self._torch_tensor = torch.empty((0,), device=self._device, dtype=self._dtype)
def torch_tensor(self) -> torch.Tensor:
if self._torch_tensor.numel() == 0:
self._torch_tensor = torch.empty(*self._size,
dtype=self._dtype,
pin_memory=self._pin_memory,
requires_grad=self._requires_grad,
device=self._device)
return self._torch_tensor
def set_spec(self, spec: TensorSpec) -> None:
spec = copy(spec)
self.to_dist_spec(spec.dist_spec)
self.convert_to_dist_spec_(spec.dist_spec)
self._spec = spec
def has_spec(self) -> bool:
@ -142,89 +52,51 @@ class ColoTensor(object):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if not all(issubclass(cls, t) for t in types):
return NotImplemented
global _COLOSSAL_OPS
if func in _COLOSSAL_OPS:
for arg in args:
if isinstance(arg, ColoTensor):
return _COLOSSAL_OPS[func](types, args, kwargs, None)
func = _COLOSSAL_OPS[func]
for kwarg in kwargs.values():
if isinstance(kwarg, ColoTensor):
return _COLOSSAL_OPS[func](types, args, kwargs, None)
else:
# If we have not hijact the function, convert the ColoTensors in args and kwargs to torch tensors.
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
if kwargs is None:
kwargs = {}
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return ret
else:
return _convert_output(ret)
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return cls._filter_outputs_with_colo(func(*args, **kwargs))
def __repr__(self):
return f'ColoTensor: {super().__repr__()}'
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
def __add__(self, o) -> "ColoTensor":
if isinstance(o, ColoTensor):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
elif isinstance(o, (torch.Tensor, int, float)):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o)
else:
raise TypeError(f'{type(o)} is not supported in ColoTensor __add__')
__radd__ = __add__
def __truediv__(self, o) -> "ColoTensor":
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
def __getattr__(self, name):
def replace_tensor_with_colo(func):
def execute_func(*args, **kwargs):
# transform the ColoTensor args to torch Tensor.
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
if kwargs is None:
kwargs = {}
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return self._filter_outputs_with_colo(func(*args, **kwargs))
return execute_func
if hasattr(self._torch_tensor, name) == False:
raise AttributeError
attr = getattr(self._torch_tensor, name)
if isinstance(attr, Callable):
return replace_tensor_with_colo(attr)
else:
return attr
@classmethod
def _filter_outputs_with_colo(cls, outputs):
if outputs is None: # return None
return None
elif type(outputs) is not tuple: # num of return val = 1
return ColoTensor.init_from_torch_tensor(outputs) if type(outputs) is torch.Tensor else outputs
else: # num of return val > 1
return tuple([
ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output
for output in outputs
])
def __mul__(self, other) -> "ColoTensor":
if isinstance(other, ColoTensor):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other.torch_tensor())
elif isinstance(other, (torch.Tensor, int, float)):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other)
else:
raise TypeError(f'{type(other)} is not supported in ColoTensor __mul__')
__rmul__ = __mul__
def to_dist_spec(self, dist_spec: _DistSpec) -> None:
self._torch_tensor = DistSpecManager.handle_trans_spec(self.torch_tensor(), self.spec.dist_spec, dist_spec)
if self._torch_tensor.is_leaf:
self._torch_tensor.requires_grad = self._requires_grad
self._size = self._torch_tensor.size()
def convert_to_dist_spec_(self, dist_spec: _DistSpec) -> None:
with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
self._spec.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
spec = copy(self._spec)
spec.dist_spec = dist_spec
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
return ColoTensor.from_torch_tensor(ret, spec)
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
tensor = tensor.as_subclass(ColoTensor)
tensor.__init__(tensor, spec=spec)
return tensor
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoTensor(data, spec=copy(self.spec))
memo[id(self)] = tensor
return tensor

View File

@ -1,4 +1,4 @@
from colossalai.tensor.dist_spec import _DistSpec
from colossalai.tensor.distspec import _DistSpec
from colossalai.nn.layer.utils import divide
from numpy import prod
from contextlib import contextmanager
@ -53,7 +53,7 @@ class DistSpecManager:
@staticmethod
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \
and dist_spec.process_group is not None:
and dist_spec.process_group is not None:
raise NotImplementedError
return tensor
@ -66,7 +66,7 @@ class DistSpecManager:
@staticmethod
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
if old_dist_spec.process_group != dist_spec.process_group \
and dist_spec.process_group is not None:
and dist_spec.process_group is not None:
raise NotImplementedError
return DistSpecManager._gather(tensor, old_dist_spec)

View File

@ -9,11 +9,6 @@ _COLOSSAL_OPS: Dict[str, Callable] = {}
def _register_colo_op(op, func):
from inspect import signature
if len(signature(func).parameters) != 4:
raise TypeError(f'Custom stateful op function expects signature: '
f'(types, args, kwargs, process_group), but received '
f'signature: {signature(func)}')
global _COLOSSAL_OPS
_COLOSSAL_OPS[op] = func

View File

@ -1,7 +1,8 @@
import torch.distributed as dist
from enum import Enum
from typing import List
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor.dist_spec import _DistSpec, DistPlacementPattern
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
class ComputePattern(Enum):
@ -77,6 +78,9 @@ class TensorSpec(object):
def get_process_group(self):
return self.dist_spec.process_group
def get_process_group_size(self):
return dist.get_world_size(self.dist_spec.process_group)
def get_placement(self):
return self.dist_spec.placement

View File

@ -7,11 +7,13 @@ from torch import nn
from typing import Iterator, Tuple, Union, Optional
# find named_params includes replica
def _named_params_with_replica(
module: nn.Module,
prefix: str = '',
recurse: bool = True,
) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]:
module: nn.Module,
prefix: str = '',
recurse: bool = True,
) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]:
modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)]
for mod_prefix, mod in modules:
@ -21,11 +23,13 @@ def _named_params_with_replica(
name = mod_prefix + ('.' if mod_prefix else '') + name
yield name, val
# Adapted from torch.nn.module.Module.register_param
def _register_parameter_with_colotensor(self, name: str, param):
if '_parameters' not in self.__dict__:
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
raise AttributeError("cannot assign parameter before Module.__init__() call")
if not isinstance(name, torch._six.string_classes):
raise TypeError("parameter name should be a string. "
@ -41,19 +45,21 @@ def _register_parameter_with_colotensor(self, name: str, param):
self._parameters[name] = None
elif not isinstance(param, (torch.nn.Parameter, ColoParameter)):
raise TypeError("cannot assign '{}' object to parameter '{}' "
"(torch.nn.Parameter or ColoParameter or None required)"
.format(torch.typename(param), name))
"(torch.nn.Parameter or ColoParameter or None required)".format(torch.typename(param), name))
elif param.grad_fn:
raise ValueError(
"Cannot assign non-leaf Tensor to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another Tensor, compute the value in "
"the forward() method.".format(name))
raise ValueError("Cannot assign non-leaf Tensor to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another Tensor, compute the value in "
"the forward() method.".format(name))
else:
self._parameters[name] = param
# Adapted from torch.nn.module.Module.__setattr__
def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.nn.Module, ColoTensor]):
def remove_from(*dicts_or_sets):
for d in dicts_or_sets:
if name in d:
@ -65,70 +71,45 @@ def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.n
params = self.__dict__.get('_parameters')
if isinstance(value, (ColoTensor, torch.nn.Parameter)):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
raise AttributeError("cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
"(torch.nn.Parameter or None expected)".format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, torch.nn.Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
raise AttributeError("cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
"(torch.nn.Module or None expected)".format(torch.typename(value), name))
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not isinstance(value, torch.Tensor):
raise TypeError("cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)"
.format(torch.typename(value), name))
"(torch.Tensor or None expected)".format(torch.typename(value), name))
buffers[name] = value
else:
object.__setattr__(self, name, value)
def ColoModulize(module):
"""
Replacing the parameters() and named_parameters() with our customized ones
"""
def fake_parameters(self, *args, **kargs):
for p in module.old_parameters(*args, **kargs):
if isinstance(p, ColoTensor):
yield p.torch_tensor()
elif isinstance(p, torch.Tensor):
yield p
def fake_named_parameters(self, *args, **kargs):
for name, p in module.old_named_parameters(*args, **kargs):
if isinstance(p, ColoTensor):
yield name, p.torch_tensor()
elif isinstance(p, torch.Tensor):
yield name, p
module.old_named_parameters = module.named_parameters
module.old_parameters = module.parameters
funcType = types.MethodType
module.parameters = funcType(fake_parameters, module)
module.named_parameters = funcType(fake_named_parameters, module)
module.colo_parameters = module.old_parameters
module.colo_named_parameters = module.old_named_parameters
module._colo_visited = True
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
@ -159,15 +140,16 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
continue
split = name.rfind('.')
if split >= 0: # param in submodule
if split >= 0: # param in submodule
module_name = name[:split]
param_name = name[split+1:]
param_name = name[split + 1:]
else:
module_name = '' # param in current module
module_name = '' # param in current module
param_name = name
name_list.append((module_name, param_name))
replaced_tensors = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
replaced_tensors = dict(
) # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
for module_name, param_name in name_list:
submodule = module.get_submodule(module_name)
param = submodule.get_parameter(param_name)
@ -177,13 +159,11 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
save_torch_payload = True if not self._lazy_memory_allocate else False
# detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad
tensor_detached = param.to(self._device).detach()
tensor_detached.requires_grad = requires_grad
colo_param = ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload)
colo_param = ColoParameter(param.to(self._device), requires_grad=requires_grad)
# add mapping record
replaced_tensors[param] = colo_param
delattr(submodule, param_name)
setattr(submodule, param_name, colo_param)
ColoModulize(module)
ColoModulize(module)

View File

@ -83,7 +83,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
for name, param in name_list:
delattr(module, name)
setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param, save_payload=False))
setattr(module, name, ColoTensor.from_torch_tensor(param))
def to_layer_list(self, exec_seq=None):
"""
@ -91,7 +91,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
If exec_seq is None, we will take the module initizing order as execution order.
"""
if exec_seq is None:
#if user do not provide the model executing sequence, we use the initialization order as the executing order.
# if user do not provide the model executing sequence, we use the initialization order as the executing order.
children_name = []
for child in self._root_children:
layer_spec = self._layer_spec_dict[id(child)]

View File

@ -1,9 +1,11 @@
import torch
import torch.distributed as dist
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True
def replace_parameter_add_grad(layer, weight=None, bias=None):
if weight is not None:
delattr(layer, 'weight')
@ -14,7 +16,12 @@ def replace_parameter_add_grad(layer, weight=None, bias=None):
setattr(layer, 'bias', bias)
layer.bias.requires_grad = True
def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0):
dist.broadcast(tensor, src=0)
tensor_chunk = torch.chunk(tensor, chunk_size, dim=-1)[local_rank]
return tensor_chunk.clone()
return tensor_chunk.clone()
def tensor_equal(A, B):
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)

View File

@ -4,7 +4,7 @@ import pytest
import torch.nn as nn
import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor
from colossalai.tensor import dist_spec
from colossalai.tensor import distspec
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
from colossalai.context import ParallelMode
from colossalai.testing import rerun_if_address_is_in_use
@ -39,7 +39,7 @@ class Conv1D(nn.Module):
def init_1d_row(weight, bias):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -54,7 +54,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight, bias):
def init_1d_col(weight, bias):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -70,8 +70,8 @@ def check_grad_1d_col(model: torch.nn.Module, weight, bias):
def run_with_spec(spec_init_func, check_grad_func):
model = Conv1D(4, 16).cuda()
weight = ColoTensor.init_from_torch_tensor(torch.nn.Parameter(model.weight.detach()))
bias = ColoTensor.init_from_torch_tensor(torch.nn.Parameter(model.bias.detach()))
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
spec_init_func(weight, bias)
x = torch.rand(2, 16).cuda()
out = model(x)

View File

@ -1,3 +1,4 @@
import pytest
from colossalai.utils import ColoInitContext
from numpy import allclose, require
@ -8,6 +9,8 @@ from copy import deepcopy
from colossalai.utils.cuda import get_current_device
@pytest.mark.skip
# FIXME(ver217): support lazy init
def test_lazy_init():
in_dim = 4
out_dim = 5
@ -22,6 +25,7 @@ def test_lazy_init():
assert fc.weight._torch_tensor.numel() == in_dim * out_dim
@pytest.mark.skip
def test_device():
in_dim = 4
out_dim = 5

View File

@ -7,7 +7,7 @@ import torch.multiprocessing as mp
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import dist_spec, DistSpecManager
from colossalai.tensor import DistSpecManager, distspec
from functools import partial
@ -18,10 +18,10 @@ def run():
depth = int(math.sqrt(size))
assert depth == math.sqrt(size)
x = torch.rand(8, 8).cuda()
old_dist_spec = dist_spec.replicate()
row_spec = dist_spec.shard(group, [0], [size])
col_spec = dist_spec.shard(group, [-1], [size])
mat_spec = dist_spec.shard(group, [0, 1], [depth, depth])
old_dist_spec = distspec.replicate()
row_spec = distspec.shard(group, [0], [size])
col_spec = distspec.shard(group, [-1], [size])
mat_spec = distspec.shard(group, [0, 1], [depth, depth])
row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec)
assert torch.equal(x.chunk(size, 0)[rank], row_shard)
assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec))

View File

@ -1,6 +1,6 @@
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import ColoTensor
from colossalai.tensor import ColoTensor, distspec
from torch.nn import functional as F
from functools import partial
@ -11,12 +11,12 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, dist_spec, DistSpecManager
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
def init_1d_row(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -30,7 +30,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight):
def init_1d_col(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -44,7 +44,7 @@ def check_grad_1d_col(model: torch.nn.Module, weight):
def run_with_spec(spec_init_func, check_grad_func):
model = torch.nn.Embedding(12, 32).cuda()
weight = ColoTensor.init_from_torch_tensor(torch.nn.Parameter(model.weight.detach()))
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
spec_init_func(weight)
x = torch.tensor((0, 3, 6, 9)).cuda()
out = model(x)

View File

@ -0,0 +1,240 @@
import pytest
import colossalai
import os
import random
import numpy as np
import torch
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from transformers import GPT2Config, GPT2LMHeadModel
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager, distspec
from colossalai.core import global_context as gpc
from functools import partial
# Hack huggingface Bert ModelOutput
# Make it available to our ColoTensor
from transformers.file_utils import ModelOutput
from dataclasses import fields
from tests.test_tensor._utils import tensor_equal
def _post_init_colotensor(self):
class_fields = fields(self)
# Safety and consistency checks
if len(class_fields) == 0:
raise ValueError(f"{self.__class__.__name__} has no fields.")
if not all(field.default is None for field in class_fields[1:]):
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
def is_tensor_with_colo(x):
"""
Tests if `x` is a `ColoTensor` or `torch.Tensor`.
"""
if isinstance(x, torch.Tensor):
return True
return isinstance(x, ColoTensor)
if other_fields_are_none and not is_tensor_with_colo(first_field):
if isinstance(first_field, dict):
iterator = first_field.items()
first_field_iterator = True
else:
try:
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False
# if we provided an iterator as first field and the iterator is a (key, value) iterator
# set the associated fields
if first_field_iterator:
for element in iterator:
if (not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str)):
break
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
elif first_field is not None:
self[class_fields[0].name] = first_field
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
ModelOutput.__post_init__ = _post_init_colotensor
class GPTLMModel(nn.Module):
def __init__(self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_seq_len=1024,
vocab_size=50304,
checkpoint=False):
super().__init__()
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(
GPT2Config(n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size,
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0))
if checkpoint:
self.model.gradient_checkpointing_enable()
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
def gpt2_s(checkpoint=True):
return GPTLMModel(checkpoint=checkpoint)
def gpt2_m(checkpoint=True):
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
def init_1d_row_spec(model):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
p.set_spec(spec)
def init_1d_col_spec(model):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_spec(spec)
def check_tensor_equal_1d(tensor: torch.Tensor, shard: ColoTensor):
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
assert len(shard.spec.dist_spec.dims) == 1
dim = shard.spec.dist_spec.dims[0]
assert torch.equal(tensor.chunk(world_size, dim)[rank], shard.torch_tensor())
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
assert tensor.ndim == shard.ndim
if tensor.shape == shard.shape:
return tensor_equal(tensor, shard)
else:
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
if dims_not_eq.numel() == 1:
# 1D shard
dim = dims_not_eq.item()
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
else:
raise NotImplementedError
def check_param_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert tensor_shard_equal(torch_p, p)
def check_grad_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert tensor_shard_equal(torch_p.grad, p.grad)
def run_gpt(init_spec_func):
BATCH_SIZE = 4
SEQ_LEN = 1024
VOCAB_SIZE = 50304
NUM_STEPS = 1
criterion = GPTLMLoss()
with ColoInitContext(device=get_current_device()):
model = gpt2_s()
model = model.cuda()
torch_model = gpt2_s().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p)
init_spec_func(model)
check_param_equal(model, torch_model)
model.train()
torch_model.train()
for i in range(NUM_STEPS):
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
logits = model(input_ids, attn_mask)
torch_logits = torch_model(input_ids, attn_mask)
assert tensor_equal(torch_logits, logits)
loss = criterion(logits, input_ids)
torch_loss = criterion(torch_logits, input_ids)
loss.backward()
torch_loss.backward()
check_grad_equal(model, torch_model)
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt(init_1d_row_spec)
run_gpt(init_1d_col_spec)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_gpt(1)

View File

@ -1,3 +1,4 @@
import pytest
from torch import nn
import torch
from colossalai.tensor import ColoTensor
@ -55,7 +56,7 @@ def count_tensors(use_colossal):
model.eval()
with torch.no_grad():
if use_colossal:
colo_input = ColoTensor.init_from_torch_tensor(torch.randn(4))
colo_input = ColoTensor.from_torch_tensor(torch.randn(4))
graph_ctx = GraphContext()
with graph_ctx:
output = model(colo_input)
@ -73,6 +74,8 @@ def count_tensors(use_colossal):
return _count_tensors()
@pytest.mark.skip
# FIXME(ver217)
def test_check_activation_tensors():
assert count_tensors(False) == count_tensors(True)

View File

@ -1,6 +1,6 @@
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import ColoTensor
from colossalai.tensor import ColoTensor, distspec
from functools import partial
@ -12,12 +12,12 @@ import torch.nn.functional as F
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, dist_spec, DistSpecManager
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager
def init_1d_row(weight, bias):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -32,7 +32,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight, bias):
def init_1d_col(weight, bias):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -48,8 +48,8 @@ def check_grad_1d_col(model: torch.nn.Module, weight, bias):
def run_with_spec(spec_init_func, check_grad_func):
model = torch.nn.Linear(4, 8).cuda()
weight = ColoTensor.init_from_torch_tensor(torch.nn.Parameter(model.weight.detach()))
bias = ColoTensor.init_from_torch_tensor(torch.nn.Parameter(model.bias.detach()))
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
spec_init_func(weight, bias)
x = torch.rand(2, 4).cuda()
out = model(x)

View File

@ -9,8 +9,8 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, \
ParallelAction, ColoTensor, ColoOptimizer, dist_spec, DistSpecManager
from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec, ComputePattern, \
ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
@ -89,7 +89,7 @@ def set_seed(seed):
def init_1d_row_linear(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -97,7 +97,7 @@ def init_1d_row_linear(weight):
def init_1d_col_linear(weight, gather_out=True):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), [
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), [
ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1D,
parallel_mode=ParallelMode.PARALLEL_1D,
@ -109,7 +109,7 @@ def init_1d_col_linear(weight, gather_out=True):
def init_1d_row_embedding(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -117,7 +117,7 @@ def init_1d_row_embedding(weight):
def init_1d_col_embedding(weight):
spec = TensorSpec(
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
with DistSpecManager.no_grad():
weight.set_spec(spec)
@ -143,7 +143,7 @@ def run_1d_hybrid_tp(model_name):
p2.data.copy_(p1.data)
if 'bert' == model_name:
for name, p in model.colo_named_parameters():
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
continue
# print(name)
@ -161,7 +161,7 @@ def run_1d_hybrid_tp(model_name):
init_1d_col_embedding(p)
elif "simple_net" == model_name:
# A naive way to set spec for all weights in Linear
for name, p in model.colo_named_parameters():
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
continue
if 'embed' in name and 'weight' in name:
@ -187,7 +187,6 @@ def run_1d_hybrid_tp(model_name):
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Bcast rank0 data to all processes
if criterion:
output = model(data)
@ -206,10 +205,8 @@ def run_1d_hybrid_tp(model_name):
loss_torch = output_torch
if rank == 0:
# print(loss.torch_tensor().item())
# print('loss torch', loss_torch.item())
with torch.no_grad():
assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2)
assert torch.allclose(loss, loss_torch, rtol=1e-2)
loss.backward()
colo_optimizer.step()
@ -257,7 +254,7 @@ def test_model_parameters():
param_cnt += 1
assert param_cnt == 5
for name, colo_p in model.colo_named_parameters():
for name, colo_p in model.named_parameters():
assert colo_p.is_model_data()
param_cnt = 0
@ -314,7 +311,7 @@ def run_1d_row_tp(model_name: str):
model_torch = model_builder(checkpoint=True)
model_torch = model_torch.cuda()
# A naive way to set spec for all weights in Linear
for name, p in model.colo_named_parameters():
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
continue
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
@ -349,9 +346,7 @@ def run_1d_row_tp(model_name: str):
loss_torch = output_torch
if rank == 0:
# print(loss.torch_tensor().item())
# print('loss torch', loss_torch.item())
assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2)
assert torch.allclose(loss, loss_torch, rtol=1e-2)
loss.backward()
@ -380,7 +375,7 @@ def _run_pretrain_load():
c_ref += 1
c1 = 0
c2 = 0
for name, param in model.colo_named_parameters():
for name, param in model.named_parameters():
if isinstance(param, ColoParameter):
c1 += 1
else:

View File

@ -1,96 +1,33 @@
from numpy import allclose
import torch
from colossalai.tensor import ColoTensor, ColoParameter
from copy import deepcopy
from colossalai.utils import get_current_device
from torch.nn import Parameter
import torch.nn.functional as F
def test_layernorm():
ln_op = torch.nn.LayerNorm(2, 3, device=get_current_device())
ln_op_colo = deepcopy(ln_op)
input_t = torch.randn(3, 2, device=get_current_device())
input_t_colo = ColoTensor.init_from_torch_tensor(tensor=input_t.clone().detach())
input_t_colo = ColoTensor.from_torch_tensor(input_t.clone().detach())
# prepare colossalai LN
delattr(ln_op_colo, 'weight')
weight_clone = ln_op.weight.clone().detach()
weight_clone.requires_grad = True
setattr(ln_op_colo, 'weight', ColoParameter.init_from_torch_tensor(tensor=weight_clone))
weight = ColoTensor(Parameter(ln_op.weight.detach()))
bias = ColoTensor(Parameter(ln_op.bias.detach()))
output = ln_op(input_t)
output_colo = ln_op_colo(input_t_colo)
output_colo = F.layer_norm(input_t_colo, ln_op.normalized_shape, weight, bias, ln_op.eps)
assert allclose(output_colo.torch_tensor().detach().cpu(), output.detach().cpu())
assert torch.allclose(output_colo, output)
torch.mean(output).backward()
torch.mean(output_colo).backward()
assert allclose(ln_op.weight.grad.cpu(), ln_op_colo.weight.torch_tensor().grad.cpu())
def test_linear():
in_dim = 4
out_dim = 5
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
fc_ref = deepcopy(fc)
input_ref = torch.randn(1, in_dim)
input_tensor = input_ref.clone()
sharded_weight = ColoParameter.init_from_torch_tensor(fc_ref.weight)
sharded_bias = ColoParameter.init_from_torch_tensor(fc_ref.bias)
# replace the torch nn.Parameters with ShardedTensor
delattr(fc, 'weight')
setattr(fc, 'weight', sharded_weight)
delattr(fc, 'bias')
setattr(fc, 'bias', sharded_bias)
fc.weight.requires_grad = True
fc.bias.requires_grad = True
# torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
out = fc(input_tensor)
loss = torch.sum(out)
loss.backward()
out_ref = fc_ref(input_ref)
loss_ref = torch.sum(out_ref)
loss_ref.backward()
assert (loss_ref == loss)
assert allclose(fc_ref.weight.grad, fc.weight.torch_tensor().grad)
# The test case failed
# def test_uniform():
# t = ColoTensor(torch.zeros(3, 5))
# torch.nn.init.uniform_(t)
# print(t)
def test_element_wise():
t_ref = torch.randn(3, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
assert torch.mean(t) == torch.mean(t_ref)
assert allclose(torch.nn.functional.gelu(t).torch_tensor(), torch.nn.functional.gelu(t_ref))
assert allclose(torch.nn.functional.relu(t).torch_tensor(), torch.nn.functional.relu(t_ref))
# Test a function not wrapped by
def test_no_wrap_op():
t_ref = torch.randn(3, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
assert torch.sum(t) == torch.sum(t_ref)
assert torch.sum(input=t) == torch.sum(input=t_ref)
assert torch.allclose(ln_op.weight.grad, weight.grad)
def check_all():
test_linear()
test_element_wise()
test_no_wrap_op()
test_layernorm()
if __name__ == '__main__':

View File

@ -1,14 +1,17 @@
import torch
import pytest
from colossalai.tensor import ColoTensor
from numpy import allclose
def test_tensor_indexing():
torch_t = torch.randn(2, 3)
colo_t = ColoTensor.init_from_torch_tensor(torch_t)
assert allclose(torch_t[:, 1], colo_t[:, 1].torch_tensor())
colo_t = ColoTensor(torch_t)
assert allclose(torch_t[:, 1], colo_t[:, 1])
@pytest.mark.skip
# FIXME(ver217): support lazy init
def test_lazy_init_tensor():
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor.numel() == 0
@ -17,7 +20,7 @@ def test_lazy_init_tensor():
def test_wrapped_tensor_func():
t_ref = torch.randn(4, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
t = ColoTensor.from_torch_tensor(t_ref.clone())
# non-func attr
assert t.is_cuda == t_ref.is_cuda
@ -26,7 +29,7 @@ def test_wrapped_tensor_func():
# return 1 torch.Tensor
t_abs = t.abs()
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs.torch_tensor(), t_ref.abs())
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs, t_ref.abs())
# return 1 non-torch.Tensor
assert t.dim() == t_ref.dim()
@ -38,7 +41,7 @@ def test_wrapped_tensor_func():
def test_operand():
t_ref = torch.randn(4, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
t = ColoTensor.from_torch_tensor(t_ref.clone())
t_ref_res = t_ref + t_ref
t_res = t + t