[tensor] customized op returns ColoTensor (#875)

* [tensor] customized op returns ColoTensor

* polish

* polish code
pull/876/head
Jiarui Fang 3 years ago committed by GitHub
parent 26d4ab8b03
commit 96211c2cc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,4 @@
from .init import colo_uniform
from .linear import colo_linear
from .element_wise import colo_mean
from .element_wise import *
from .layernorm import colo_layernorm
from .loss import colo_cross_entropy

@ -29,3 +29,22 @@ def register_elementwise_op(op):
register_elementwise_op(torch.nn.functional.gelu)
register_elementwise_op(torch.nn.functional.relu)
@colo_op_impl(torch.sum)
def sum_op(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for the elementwise op such
as ``torch.sum`.
This method computes on either a normal tensor or a sharded tensor.
"""
if len(args) > 0:
input_tensor = args[0]
if kwargs is None:
kwargs = {}
if 'input' in kwargs:
input_tensor = kwargs['input']
# Validate types
if not isinstance(input_tensor, ColoTensor):
raise TypeError("input needs to be a ColoTensor")
return ColoTensor.init_from_torch_tensor(torch.sum(input_tensor.torch_tensor()))

@ -1,29 +0,0 @@
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
def validate_param(param, param_name):
if param is None:
raise ValueError(f"param: {param_name} shouldn't be None!")
@colo_op_impl(torch.nn.init.uniform_)
def colo_uniform(types, args=(), kwargs=None, pg=None):
r"""
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
sharded_tensor: tensor sharded across devices
a: the lower bound of the uniform distribution
b: the upper bound of the uniform distribution
"""
validate_param(kwargs, "kwargs")
stateful_tensor = kwargs["tensor"]
validate_param(stateful_tensor, "stateful_tensor")
a = kwargs['a']
validate_param(a, "a")
b = kwargs['b']
validate_param(b, "b")
torch.nn.init.uniform_(stateful_tensor.torch_tensor(), a=a, b=b)
return stateful_tensor

@ -6,7 +6,8 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
from colossalai.nn.layer.utils import divide
from colossalai.core import global_context as gpc
from packaging import version
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.tensor import ComputePattern
@colo_op_impl(torch.nn.functional.linear)
def colo_linear(types, args, kwargs, pg):
@ -25,6 +26,7 @@ def colo_linear(types, args, kwargs, pg):
bias = kwargs.get('bias', None)
if isinstance(bias, ColoTensor):
assert bias.shard_spec.num_action == 0, f"We currently only support bias is duplicated among processes in the linear operator"
bias = bias.torch_tensor()
# Add communication logic before and after linear call.
@ -34,7 +36,7 @@ def colo_linear(types, args, kwargs, pg):
input_tensor = input_tensor.torch_tensor()
if isinstance(weight, ColoTensor):
weight = weight.torch_tensor()
return torch.nn.functional.linear(input_tensor, weight, bias)
return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
elif weight.shard_spec.num_action == 1:
if ComputePattern.TP1DRow in weight.shard_spec.compute_patterns:
# Input:S[1] x Weight:S[0] = Output:P
@ -54,8 +56,7 @@ def colo_linear(types, args, kwargs, pg):
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
# Bias
if bias is not None:
bias_ = bias
output = output + bias_
output = output + bias
return ColoTensor.init_from_torch_tensor(output)
else:
raise NotImplementedError

@ -2,6 +2,7 @@ from enum import Enum
from typing import Tuple, List
from colossalai.context.parallel_mode import ParallelMode
class ComputePattern(Enum):
TP1DRow = 1
TP1DCol = 2
@ -10,6 +11,7 @@ class ComputePattern(Enum):
class ParallelAction(object):
def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None:
self.priority = priority
self.compute_pattern = compute_pattern
@ -24,6 +26,7 @@ class TensorSpec(object):
parallel computation pattern of the Operator (Layer).
We have to consider the hybrid parallel mode.
"""
# a list of parallel actions.
# For example: On 8 GPUs, a hybrid parallel strategy is applied using
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
@ -38,6 +41,7 @@ class TensorSpec(object):
# Before Linear Op, we gather the tensors according to ZeRO.
# We perform Linear Op according to compute pattern of TP1DRow.
# After Linear Op, we split the tensors according to ZeRO.
def __init__(self, parallel_action_list: List[ParallelAction] = []):
self._parallel_action_list = parallel_action_list
self.sort()
@ -56,8 +60,8 @@ class TensorSpec(object):
def sort(self):
if len(self._parallel_action_list) > 0:
self._parallel_action_list.sort(key=lambda parallel_action : parallel_action.priority)
self._parallel_action_list.sort(key=lambda parallel_action: parallel_action.priority)
def get_action_by_compute_pattern(self, compute_pattern: ComputePattern):
for parallel_action in self._parallel_action_list:
if parallel_action.compute_pattern == compute_pattern:

@ -1,19 +1,13 @@
from cProfile import label
from statistics import mode
from colossalai.tensor.colo_tensor import ColoTensor
from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.utils import ColoInitContext
import torch.distributed as dist
from functools import partial

@ -53,11 +53,11 @@ def test_linear():
# torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
out = fc(input_tensor)
loss = out.sum()
loss = torch.sum(out)
loss.backward()
out_ref = fc_ref(input_ref)
loss_ref = out_ref.sum()
loss_ref = torch.sum(out_ref)
loss_ref.backward()
assert (loss_ref == loss)

Loading…
Cancel
Save