mirror of https://github.com/hpcaitech/ColossalAI
[tensor] customized op returns ColoTensor (#875)
* [tensor] customized op returns ColoTensor * polish * polish codepull/876/head
parent
26d4ab8b03
commit
96211c2cc8
|
@ -1,5 +1,4 @@
|
||||||
from .init import colo_uniform
|
|
||||||
from .linear import colo_linear
|
from .linear import colo_linear
|
||||||
from .element_wise import colo_mean
|
from .element_wise import *
|
||||||
from .layernorm import colo_layernorm
|
from .layernorm import colo_layernorm
|
||||||
from .loss import colo_cross_entropy
|
from .loss import colo_cross_entropy
|
||||||
|
|
|
@ -29,3 +29,22 @@ def register_elementwise_op(op):
|
||||||
|
|
||||||
register_elementwise_op(torch.nn.functional.gelu)
|
register_elementwise_op(torch.nn.functional.gelu)
|
||||||
register_elementwise_op(torch.nn.functional.relu)
|
register_elementwise_op(torch.nn.functional.relu)
|
||||||
|
|
||||||
|
|
||||||
|
@colo_op_impl(torch.sum)
|
||||||
|
def sum_op(types, args=(), kwargs=None, pg=None):
|
||||||
|
"""
|
||||||
|
Handles ``__torch_function__`` dispatch for the elementwise op such
|
||||||
|
as ``torch.sum`.
|
||||||
|
This method computes on either a normal tensor or a sharded tensor.
|
||||||
|
"""
|
||||||
|
if len(args) > 0:
|
||||||
|
input_tensor = args[0]
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
if 'input' in kwargs:
|
||||||
|
input_tensor = kwargs['input']
|
||||||
|
# Validate types
|
||||||
|
if not isinstance(input_tensor, ColoTensor):
|
||||||
|
raise TypeError("input needs to be a ColoTensor")
|
||||||
|
return ColoTensor.init_from_torch_tensor(torch.sum(input_tensor.torch_tensor()))
|
||||||
|
|
|
@ -1,29 +0,0 @@
|
||||||
import torch
|
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
||||||
|
|
||||||
|
|
||||||
def validate_param(param, param_name):
|
|
||||||
if param is None:
|
|
||||||
raise ValueError(f"param: {param_name} shouldn't be None!")
|
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(torch.nn.init.uniform_)
|
|
||||||
def colo_uniform(types, args=(), kwargs=None, pg=None):
|
|
||||||
r"""
|
|
||||||
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
|
|
||||||
distribution :math:`\mathcal{U}(a, b)`.
|
|
||||||
Args:
|
|
||||||
sharded_tensor: tensor sharded across devices
|
|
||||||
a: the lower bound of the uniform distribution
|
|
||||||
b: the upper bound of the uniform distribution
|
|
||||||
"""
|
|
||||||
validate_param(kwargs, "kwargs")
|
|
||||||
stateful_tensor = kwargs["tensor"]
|
|
||||||
validate_param(stateful_tensor, "stateful_tensor")
|
|
||||||
a = kwargs['a']
|
|
||||||
validate_param(a, "a")
|
|
||||||
b = kwargs['b']
|
|
||||||
validate_param(b, "b")
|
|
||||||
|
|
||||||
torch.nn.init.uniform_(stateful_tensor.torch_tensor(), a=a, b=b)
|
|
||||||
return stateful_tensor
|
|
|
@ -6,7 +6,8 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
|
||||||
from colossalai.nn.layer.utils import divide
|
from colossalai.nn.layer.utils import divide
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
from colossalai.tensor import ComputePattern
|
||||||
|
|
||||||
|
|
||||||
@colo_op_impl(torch.nn.functional.linear)
|
@colo_op_impl(torch.nn.functional.linear)
|
||||||
def colo_linear(types, args, kwargs, pg):
|
def colo_linear(types, args, kwargs, pg):
|
||||||
|
@ -25,6 +26,7 @@ def colo_linear(types, args, kwargs, pg):
|
||||||
bias = kwargs.get('bias', None)
|
bias = kwargs.get('bias', None)
|
||||||
|
|
||||||
if isinstance(bias, ColoTensor):
|
if isinstance(bias, ColoTensor):
|
||||||
|
assert bias.shard_spec.num_action == 0, f"We currently only support bias is duplicated among processes in the linear operator"
|
||||||
bias = bias.torch_tensor()
|
bias = bias.torch_tensor()
|
||||||
|
|
||||||
# Add communication logic before and after linear call.
|
# Add communication logic before and after linear call.
|
||||||
|
@ -34,7 +36,7 @@ def colo_linear(types, args, kwargs, pg):
|
||||||
input_tensor = input_tensor.torch_tensor()
|
input_tensor = input_tensor.torch_tensor()
|
||||||
if isinstance(weight, ColoTensor):
|
if isinstance(weight, ColoTensor):
|
||||||
weight = weight.torch_tensor()
|
weight = weight.torch_tensor()
|
||||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
|
||||||
elif weight.shard_spec.num_action == 1:
|
elif weight.shard_spec.num_action == 1:
|
||||||
if ComputePattern.TP1DRow in weight.shard_spec.compute_patterns:
|
if ComputePattern.TP1DRow in weight.shard_spec.compute_patterns:
|
||||||
# Input:S[1] x Weight:S[0] = Output:P
|
# Input:S[1] x Weight:S[0] = Output:P
|
||||||
|
@ -54,8 +56,7 @@ def colo_linear(types, args, kwargs, pg):
|
||||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||||
# Bias
|
# Bias
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
bias_ = bias
|
output = output + bias
|
||||||
output = output + bias_
|
|
||||||
return ColoTensor.init_from_torch_tensor(output)
|
return ColoTensor.init_from_torch_tensor(output)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -2,6 +2,7 @@ from enum import Enum
|
||||||
from typing import Tuple, List
|
from typing import Tuple, List
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
|
||||||
|
|
||||||
class ComputePattern(Enum):
|
class ComputePattern(Enum):
|
||||||
TP1DRow = 1
|
TP1DRow = 1
|
||||||
TP1DCol = 2
|
TP1DCol = 2
|
||||||
|
@ -10,6 +11,7 @@ class ComputePattern(Enum):
|
||||||
|
|
||||||
|
|
||||||
class ParallelAction(object):
|
class ParallelAction(object):
|
||||||
|
|
||||||
def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None:
|
def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None:
|
||||||
self.priority = priority
|
self.priority = priority
|
||||||
self.compute_pattern = compute_pattern
|
self.compute_pattern = compute_pattern
|
||||||
|
@ -24,6 +26,7 @@ class TensorSpec(object):
|
||||||
parallel computation pattern of the Operator (Layer).
|
parallel computation pattern of the Operator (Layer).
|
||||||
We have to consider the hybrid parallel mode.
|
We have to consider the hybrid parallel mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# a list of parallel actions.
|
# a list of parallel actions.
|
||||||
# For example: On 8 GPUs, a hybrid parallel strategy is applied using
|
# For example: On 8 GPUs, a hybrid parallel strategy is applied using
|
||||||
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
|
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
|
||||||
|
@ -38,6 +41,7 @@ class TensorSpec(object):
|
||||||
# Before Linear Op, we gather the tensors according to ZeRO.
|
# Before Linear Op, we gather the tensors according to ZeRO.
|
||||||
# We perform Linear Op according to compute pattern of TP1DRow.
|
# We perform Linear Op according to compute pattern of TP1DRow.
|
||||||
# After Linear Op, we split the tensors according to ZeRO.
|
# After Linear Op, we split the tensors according to ZeRO.
|
||||||
|
|
||||||
def __init__(self, parallel_action_list: List[ParallelAction] = []):
|
def __init__(self, parallel_action_list: List[ParallelAction] = []):
|
||||||
self._parallel_action_list = parallel_action_list
|
self._parallel_action_list = parallel_action_list
|
||||||
self.sort()
|
self.sort()
|
||||||
|
@ -56,8 +60,8 @@ class TensorSpec(object):
|
||||||
|
|
||||||
def sort(self):
|
def sort(self):
|
||||||
if len(self._parallel_action_list) > 0:
|
if len(self._parallel_action_list) > 0:
|
||||||
self._parallel_action_list.sort(key=lambda parallel_action : parallel_action.priority)
|
self._parallel_action_list.sort(key=lambda parallel_action: parallel_action.priority)
|
||||||
|
|
||||||
def get_action_by_compute_pattern(self, compute_pattern: ComputePattern):
|
def get_action_by_compute_pattern(self, compute_pattern: ComputePattern):
|
||||||
for parallel_action in self._parallel_action_list:
|
for parallel_action in self._parallel_action_list:
|
||||||
if parallel_action.compute_pattern == compute_pattern:
|
if parallel_action.compute_pattern == compute_pattern:
|
||||||
|
|
|
@ -1,19 +1,13 @@
|
||||||
from cProfile import label
|
|
||||||
from statistics import mode
|
|
||||||
from colossalai.tensor.colo_tensor import ColoTensor
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.utils import ColoInitContext
|
from colossalai.utils import ColoInitContext
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -53,11 +53,11 @@ def test_linear():
|
||||||
|
|
||||||
# torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
|
# torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
|
||||||
out = fc(input_tensor)
|
out = fc(input_tensor)
|
||||||
loss = out.sum()
|
loss = torch.sum(out)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
out_ref = fc_ref(input_ref)
|
out_ref = fc_ref(input_ref)
|
||||||
loss_ref = out_ref.sum()
|
loss_ref = torch.sum(out_ref)
|
||||||
loss_ref.backward()
|
loss_ref.backward()
|
||||||
|
|
||||||
assert (loss_ref == loss)
|
assert (loss_ref == loss)
|
||||||
|
|
Loading…
Reference in New Issue