[Tensor] Add function to spec and update linear 1Drow and unit tests (#869)

pull/874/head^2
Ziyue Jiang 2022-04-26 10:15:26 +08:00 committed by GitHub
parent 11f54c7b6b
commit 26d4ab8b03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 85 additions and 58 deletions

View File

@ -1,7 +1,9 @@
from .spec import ComputePattern, ParallelAction, TensorSpec
from .op_wrapper import (
colo_op_impl,)
from .colo_tensor import ColoTensor
from .utils import convert_parameter
from ._ops import *
__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl']
__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern',
'TensorSpec', 'ParallelAction']

View File

@ -2,4 +2,4 @@ from .init import colo_uniform
from .linear import colo_linear
from .element_wise import colo_mean
from .layernorm import colo_layernorm
from .loss import colo_cross_entropy
from .loss import colo_cross_entropy

View File

@ -6,8 +6,7 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
from colossalai.nn.layer.utils import divide
from colossalai.core import global_context as gpc
from packaging import version
from colossalai.utils.cuda import get_current_device
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
@colo_op_impl(torch.nn.functional.linear)
def colo_linear(types, args, kwargs, pg):
@ -30,32 +29,36 @@ def colo_linear(types, args, kwargs, pg):
# Add communication logic before and after linear call.
if isinstance(weight, ColoTensor):
if weight.shard_spec == None:
if weight.shard_spec == None or weight.shard_spec.num_action == 0:
if isinstance(input_tensor, ColoTensor):
input_tensor = input_tensor.torch_tensor()
if isinstance(weight, ColoTensor):
weight = weight.torch_tensor()
return torch.nn.functional.linear(input_tensor, weight, bias)
elif weight.shard_spec == '1Drow':
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_tensor.shape, weight.size, weight.size[-1] * gpc.tensor_parallel_size)
# Input:S[1]
input_per_partition = split_forward_gather_backward(input_tensor, ParallelMode.PARALLEL_1D, dim=-1)
# Output:P
device = get_current_device() # TODO where to put to(deivce)?
weight_ = weight.torch_tensor().to(device)
partial_output = torch.nn.functional.linear(input_per_partition, weight_)
# Reduce(Output)
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
# Bias
if bias is not None:
bias_ = bias.to(device)
output = output + bias_
return output
elif weight.shard_spec.num_action == 1:
if ComputePattern.TP1DRow in weight.shard_spec.compute_patterns:
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_tensor.shape, weight.size, weight.size(-1) * gpc.tensor_parallel_size)
# Input:S[1]
if isinstance(input_tensor, ColoTensor):
input_tensor = input_tensor.torch_tensor()
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
input_per_partition = split_forward_gather_backward(input_tensor, parallel_action.parallel_mode, dim=-1)
# Output:P
weight_ = weight.torch_tensor()
partial_output = torch.nn.functional.linear(input_per_partition, weight_)
# Reduce(Output)
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
# Bias
if bias is not None:
bias_ = bias
output = output + bias_
return ColoTensor.init_from_torch_tensor(output)
else:
raise NotImplementedError
else:
raise NotImplementedError
else:

View File

@ -1,13 +1,12 @@
from colossalai.context import parallel_mode
from .op_wrapper import _COLOSSAL_OPS
import torch
from typing import Tuple, Optional
from numpy import product
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.nn.layer.utils import divide
from colossalai.utils.cuda import get_current_device
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
class ColoTensor(object):
""" Data Structure for Tensor in Colossal-AI
@ -28,7 +27,7 @@ class ColoTensor(object):
pin_memory=False,
device=None,
torch_tensor=torch.empty(0),
shard_spec: str = None,
shard_spec: TensorSpec = TensorSpec(),
):
self._size = size
self._dtype = dtype
@ -39,7 +38,7 @@ class ColoTensor(object):
self._shard_spec = shard_spec
@property
def shard_spec(self) -> Optional[str]:
def shard_spec(self) -> TensorSpec:
return self._shard_spec
@property
@ -109,27 +108,27 @@ class ColoTensor(object):
device=self._device)
return self._torch_tensor
def set_spec(self, spec: str, lazy_shard: bool = False) -> None:
def set_spec(self, spec: TensorSpec, lazy_shard: bool = False) -> None:
self._shard_spec = spec
if lazy_shard == False:
self._shard()
def _shard(self):
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
if self._shard_spec == "1Drow": # TODO It actually represents the sharding layout for Linear-1Drow-weight, but we make it simpler now.
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
dim = -1
chunk_size = divide(self._size[dim], num_partition)
device = get_current_device()
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
self._torch_tensor.requires_grad = self._requires_grad
self._size = self._torch_tensor.size()
self._device = device # TODO A `fake` device now because torch_tensor.device always = cpu
if self._shard_spec.num_action == 1:
if ComputePattern.TP1DRow in self._shard_spec.compute_patterns:
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
num_partition = gpc.get_world_size(parallel_action.parallel_mode)
local_rank = gpc.get_local_rank(parallel_action.parallel_mode)
dim = -1
chunk_size = divide(self._size[dim], num_partition)
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
self._torch_tensor.requires_grad = self._requires_grad
self._size = self._torch_tensor.size()
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
@ -151,5 +150,5 @@ class ColoTensor(object):
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return func(*args, **kwargs)
def backward(self, retain_graph: bool = False):
self._torch_tensor.backward(retain_graph=retain_graph)
def backward(self, gradient: Optional[torch.Tensor] = None , retain_graph: bool = False):
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)

View File

@ -1,8 +1,6 @@
from enum import Enum
from typing import Tuple, List
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
class ComputePattern(Enum):
TP1DRow = 1
@ -12,17 +10,13 @@ class ComputePattern(Enum):
class ParallelAction(object):
priority = 0
compute_pattern = ComputePattern.DP
process_group = gpc.get_group(ParallelMode.DATA)
def __init__(self, priority, compute_pattern, process_group) -> None:
def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None:
self.priority = priority
self.compute_pattern = compute_pattern
self.process_group = process_group
self.parallel_mode = parallel_mode
class TensorSpec(Enum):
class TensorSpec(object):
"""
It contains two aspects of information:
First, How are tensors distributed in Heterougenous memory space.
@ -44,4 +38,28 @@ class TensorSpec(Enum):
# Before Linear Op, we gather the tensors according to ZeRO.
# We perform Linear Op according to compute pattern of TP1DRow.
# After Linear Op, we split the tensors according to ZeRO.
parallel_action_list: List[ParallelAction] = []
def __init__(self, parallel_action_list: List[ParallelAction] = []):
self._parallel_action_list = parallel_action_list
self.sort()
@property
def parallel_action_list(self):
return self._parallel_action_list
@property
def num_action(self):
return len(self._parallel_action_list)
@property
def compute_patterns(self):
return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list]
def sort(self):
if len(self._parallel_action_list) > 0:
self._parallel_action_list.sort(key=lambda parallel_action : parallel_action.priority)
def get_action_by_compute_pattern(self, compute_pattern: ComputePattern):
for parallel_action in self._parallel_action_list:
if parallel_action.compute_pattern == compute_pattern:
return parallel_action
return None

View File

@ -12,6 +12,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from _utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk
@ -45,7 +46,11 @@ def run_linear_tp1d_row_test():
# replace the torch nn.Parameters with ColoTensor
sharded_weight = ColoTensor.init_from_torch_tensor(W)
sharded_weight.set_spec(spec="1Drow") # reshard
parallel_action_list = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)
]
spec = TensorSpec(parallel_action_list)
sharded_weight.set_spec(spec=spec) # reshard
sharded_bias = ColoTensor.init_from_torch_tensor(B)
replace_parameter_add_grad(layer, sharded_weight, sharded_bias)
out = layer(A)