[tensor] customized op returns ColoTensor (#875)

* [tensor] customized op returns ColoTensor

* polish

* polish code
pull/876/head
Jiarui Fang 2022-04-26 13:23:59 +08:00 committed by GitHub
parent 26d4ab8b03
commit 96211c2cc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 33 additions and 45 deletions

View File

@ -1,5 +1,4 @@
from .init import colo_uniform
from .linear import colo_linear from .linear import colo_linear
from .element_wise import colo_mean from .element_wise import *
from .layernorm import colo_layernorm from .layernorm import colo_layernorm
from .loss import colo_cross_entropy from .loss import colo_cross_entropy

View File

@ -29,3 +29,22 @@ def register_elementwise_op(op):
register_elementwise_op(torch.nn.functional.gelu) register_elementwise_op(torch.nn.functional.gelu)
register_elementwise_op(torch.nn.functional.relu) register_elementwise_op(torch.nn.functional.relu)
@colo_op_impl(torch.sum)
def sum_op(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for the elementwise op such
as ``torch.sum`.
This method computes on either a normal tensor or a sharded tensor.
"""
if len(args) > 0:
input_tensor = args[0]
if kwargs is None:
kwargs = {}
if 'input' in kwargs:
input_tensor = kwargs['input']
# Validate types
if not isinstance(input_tensor, ColoTensor):
raise TypeError("input needs to be a ColoTensor")
return ColoTensor.init_from_torch_tensor(torch.sum(input_tensor.torch_tensor()))

View File

@ -1,29 +0,0 @@
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
def validate_param(param, param_name):
if param is None:
raise ValueError(f"param: {param_name} shouldn't be None!")
@colo_op_impl(torch.nn.init.uniform_)
def colo_uniform(types, args=(), kwargs=None, pg=None):
r"""
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
sharded_tensor: tensor sharded across devices
a: the lower bound of the uniform distribution
b: the upper bound of the uniform distribution
"""
validate_param(kwargs, "kwargs")
stateful_tensor = kwargs["tensor"]
validate_param(stateful_tensor, "stateful_tensor")
a = kwargs['a']
validate_param(a, "a")
b = kwargs['b']
validate_param(b, "b")
torch.nn.init.uniform_(stateful_tensor.torch_tensor(), a=a, b=b)
return stateful_tensor

View File

@ -6,7 +6,8 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
from colossalai.nn.layer.utils import divide from colossalai.nn.layer.utils import divide
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from packaging import version from packaging import version
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction from colossalai.tensor import ComputePattern
@colo_op_impl(torch.nn.functional.linear) @colo_op_impl(torch.nn.functional.linear)
def colo_linear(types, args, kwargs, pg): def colo_linear(types, args, kwargs, pg):
@ -25,6 +26,7 @@ def colo_linear(types, args, kwargs, pg):
bias = kwargs.get('bias', None) bias = kwargs.get('bias', None)
if isinstance(bias, ColoTensor): if isinstance(bias, ColoTensor):
assert bias.shard_spec.num_action == 0, f"We currently only support bias is duplicated among processes in the linear operator"
bias = bias.torch_tensor() bias = bias.torch_tensor()
# Add communication logic before and after linear call. # Add communication logic before and after linear call.
@ -34,7 +36,7 @@ def colo_linear(types, args, kwargs, pg):
input_tensor = input_tensor.torch_tensor() input_tensor = input_tensor.torch_tensor()
if isinstance(weight, ColoTensor): if isinstance(weight, ColoTensor):
weight = weight.torch_tensor() weight = weight.torch_tensor()
return torch.nn.functional.linear(input_tensor, weight, bias) return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
elif weight.shard_spec.num_action == 1: elif weight.shard_spec.num_action == 1:
if ComputePattern.TP1DRow in weight.shard_spec.compute_patterns: if ComputePattern.TP1DRow in weight.shard_spec.compute_patterns:
# Input:S[1] x Weight:S[0] = Output:P # Input:S[1] x Weight:S[0] = Output:P
@ -54,8 +56,7 @@ def colo_linear(types, args, kwargs, pg):
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
# Bias # Bias
if bias is not None: if bias is not None:
bias_ = bias output = output + bias
output = output + bias_
return ColoTensor.init_from_torch_tensor(output) return ColoTensor.init_from_torch_tensor(output)
else: else:
raise NotImplementedError raise NotImplementedError

View File

@ -2,6 +2,7 @@ from enum import Enum
from typing import Tuple, List from typing import Tuple, List
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
class ComputePattern(Enum): class ComputePattern(Enum):
TP1DRow = 1 TP1DRow = 1
TP1DCol = 2 TP1DCol = 2
@ -10,6 +11,7 @@ class ComputePattern(Enum):
class ParallelAction(object): class ParallelAction(object):
def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None: def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None:
self.priority = priority self.priority = priority
self.compute_pattern = compute_pattern self.compute_pattern = compute_pattern
@ -24,6 +26,7 @@ class TensorSpec(object):
parallel computation pattern of the Operator (Layer). parallel computation pattern of the Operator (Layer).
We have to consider the hybrid parallel mode. We have to consider the hybrid parallel mode.
""" """
# a list of parallel actions. # a list of parallel actions.
# For example: On 8 GPUs, a hybrid parallel strategy is applied using # For example: On 8 GPUs, a hybrid parallel strategy is applied using
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2. # using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
@ -38,6 +41,7 @@ class TensorSpec(object):
# Before Linear Op, we gather the tensors according to ZeRO. # Before Linear Op, we gather the tensors according to ZeRO.
# We perform Linear Op according to compute pattern of TP1DRow. # We perform Linear Op according to compute pattern of TP1DRow.
# After Linear Op, we split the tensors according to ZeRO. # After Linear Op, we split the tensors according to ZeRO.
def __init__(self, parallel_action_list: List[ParallelAction] = []): def __init__(self, parallel_action_list: List[ParallelAction] = []):
self._parallel_action_list = parallel_action_list self._parallel_action_list = parallel_action_list
self.sort() self.sort()
@ -56,8 +60,8 @@ class TensorSpec(object):
def sort(self): def sort(self):
if len(self._parallel_action_list) > 0: if len(self._parallel_action_list) > 0:
self._parallel_action_list.sort(key=lambda parallel_action : parallel_action.priority) self._parallel_action_list.sort(key=lambda parallel_action: parallel_action.priority)
def get_action_by_compute_pattern(self, compute_pattern: ComputePattern): def get_action_by_compute_pattern(self, compute_pattern: ComputePattern):
for parallel_action in self._parallel_action_list: for parallel_action in self._parallel_action_list:
if parallel_action.compute_pattern == compute_pattern: if parallel_action.compute_pattern == compute_pattern:

View File

@ -1,19 +1,13 @@
from cProfile import label
from statistics import mode
from colossalai.tensor.colo_tensor import ColoTensor
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai import colossalai
import pytest import pytest
import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.utils import ColoInitContext from colossalai.utils import ColoInitContext
import torch.distributed as dist
from functools import partial from functools import partial

View File

@ -53,11 +53,11 @@ def test_linear():
# torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias) # torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
out = fc(input_tensor) out = fc(input_tensor)
loss = out.sum() loss = torch.sum(out)
loss.backward() loss.backward()
out_ref = fc_ref(input_ref) out_ref = fc_ref(input_ref)
loss_ref = out_ref.sum() loss_ref = torch.sum(out_ref)
loss_ref.backward() loss_ref.backward()
assert (loss_ref == loss) assert (loss_ref == loss)