mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] Add function to spec and update linear 1Drow and unit tests (#869)
parent
11f54c7b6b
commit
26d4ab8b03
|
@ -1,7 +1,9 @@
|
|||
from .spec import ComputePattern, ParallelAction, TensorSpec
|
||||
from .op_wrapper import (
|
||||
colo_op_impl,)
|
||||
from .colo_tensor import ColoTensor
|
||||
from .utils import convert_parameter
|
||||
from ._ops import *
|
||||
|
||||
__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl']
|
||||
__all__ = ['ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern',
|
||||
'TensorSpec', 'ParallelAction']
|
||||
|
|
|
@ -2,4 +2,4 @@ from .init import colo_uniform
|
|||
from .linear import colo_linear
|
||||
from .element_wise import colo_mean
|
||||
from .layernorm import colo_layernorm
|
||||
from .loss import colo_cross_entropy
|
||||
from .loss import colo_cross_entropy
|
||||
|
|
|
@ -6,8 +6,7 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
|
|||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.core import global_context as gpc
|
||||
from packaging import version
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||
|
||||
@colo_op_impl(torch.nn.functional.linear)
|
||||
def colo_linear(types, args, kwargs, pg):
|
||||
|
@ -30,32 +29,36 @@ def colo_linear(types, args, kwargs, pg):
|
|||
|
||||
# Add communication logic before and after linear call.
|
||||
if isinstance(weight, ColoTensor):
|
||||
if weight.shard_spec == None:
|
||||
if weight.shard_spec == None or weight.shard_spec.num_action == 0:
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
if isinstance(weight, ColoTensor):
|
||||
weight = weight.torch_tensor()
|
||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||
elif weight.shard_spec == '1Drow':
|
||||
# Input:S[1] x Weight:S[0] = Output:P
|
||||
# All-Reduce(Output) + bias = res
|
||||
assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \
|
||||
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_tensor.shape, weight.size, weight.size[-1] * gpc.tensor_parallel_size)
|
||||
# Input:S[1]
|
||||
input_per_partition = split_forward_gather_backward(input_tensor, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
# Output:P
|
||||
device = get_current_device() # TODO where to put to(deivce)?
|
||||
weight_ = weight.torch_tensor().to(device)
|
||||
partial_output = torch.nn.functional.linear(input_per_partition, weight_)
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
# Bias
|
||||
if bias is not None:
|
||||
bias_ = bias.to(device)
|
||||
output = output + bias_
|
||||
return output
|
||||
|
||||
elif weight.shard_spec.num_action == 1:
|
||||
if ComputePattern.TP1DRow in weight.shard_spec.compute_patterns:
|
||||
# Input:S[1] x Weight:S[0] = Output:P
|
||||
# All-Reduce(Output) + bias = res
|
||||
assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \
|
||||
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_tensor.shape, weight.size, weight.size(-1) * gpc.tensor_parallel_size)
|
||||
# Input:S[1]
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
||||
input_per_partition = split_forward_gather_backward(input_tensor, parallel_action.parallel_mode, dim=-1)
|
||||
# Output:P
|
||||
weight_ = weight.torch_tensor()
|
||||
partial_output = torch.nn.functional.linear(input_per_partition, weight_)
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
# Bias
|
||||
if bias is not None:
|
||||
bias_ = bias
|
||||
output = output + bias_
|
||||
return ColoTensor.init_from_torch_tensor(output)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
from colossalai.context import parallel_mode
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
|
||||
import torch
|
||||
from typing import Tuple, Optional
|
||||
from numpy import product
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||
|
||||
class ColoTensor(object):
|
||||
""" Data Structure for Tensor in Colossal-AI
|
||||
|
@ -28,7 +27,7 @@ class ColoTensor(object):
|
|||
pin_memory=False,
|
||||
device=None,
|
||||
torch_tensor=torch.empty(0),
|
||||
shard_spec: str = None,
|
||||
shard_spec: TensorSpec = TensorSpec(),
|
||||
):
|
||||
self._size = size
|
||||
self._dtype = dtype
|
||||
|
@ -39,7 +38,7 @@ class ColoTensor(object):
|
|||
self._shard_spec = shard_spec
|
||||
|
||||
@property
|
||||
def shard_spec(self) -> Optional[str]:
|
||||
def shard_spec(self) -> TensorSpec:
|
||||
return self._shard_spec
|
||||
|
||||
@property
|
||||
|
@ -109,27 +108,27 @@ class ColoTensor(object):
|
|||
device=self._device)
|
||||
return self._torch_tensor
|
||||
|
||||
def set_spec(self, spec: str, lazy_shard: bool = False) -> None:
|
||||
def set_spec(self, spec: TensorSpec, lazy_shard: bool = False) -> None:
|
||||
self._shard_spec = spec
|
||||
if lazy_shard == False:
|
||||
self._shard()
|
||||
|
||||
def _shard(self):
|
||||
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
|
||||
if self._shard_spec == "1Drow": # TODO It actually represents the sharding layout for Linear-1Drow-weight, but we make it simpler now.
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
dim = -1
|
||||
chunk_size = divide(self._size[dim], num_partition)
|
||||
device = get_current_device()
|
||||
# Reshape to get shard for this rank and we don't want autograd
|
||||
# recording here for the narrow op and 'local_shard' should be a
|
||||
# leaf variable in the autograd graph.
|
||||
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
|
||||
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
|
||||
self._torch_tensor.requires_grad = self._requires_grad
|
||||
self._size = self._torch_tensor.size()
|
||||
self._device = device # TODO A `fake` device now because torch_tensor.device always = cpu
|
||||
if self._shard_spec.num_action == 1:
|
||||
if ComputePattern.TP1DRow in self._shard_spec.compute_patterns:
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
||||
num_partition = gpc.get_world_size(parallel_action.parallel_mode)
|
||||
local_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||
dim = -1
|
||||
chunk_size = divide(self._size[dim], num_partition)
|
||||
# Reshape to get shard for this rank and we don't want autograd
|
||||
# recording here for the narrow op and 'local_shard' should be a
|
||||
# leaf variable in the autograd graph.
|
||||
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
|
||||
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
|
||||
self._torch_tensor.requires_grad = self._requires_grad
|
||||
self._size = self._torch_tensor.size()
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
|
@ -151,5 +150,5 @@ class ColoTensor(object):
|
|||
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def backward(self, retain_graph: bool = False):
|
||||
self._torch_tensor.backward(retain_graph=retain_graph)
|
||||
def backward(self, gradient: Optional[torch.Tensor] = None , retain_graph: bool = False):
|
||||
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
from enum import Enum
|
||||
from typing import Tuple, List
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
class ComputePattern(Enum):
|
||||
TP1DRow = 1
|
||||
|
@ -12,17 +10,13 @@ class ComputePattern(Enum):
|
|||
|
||||
|
||||
class ParallelAction(object):
|
||||
priority = 0
|
||||
compute_pattern = ComputePattern.DP
|
||||
process_group = gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
def __init__(self, priority, compute_pattern, process_group) -> None:
|
||||
def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None:
|
||||
self.priority = priority
|
||||
self.compute_pattern = compute_pattern
|
||||
self.process_group = process_group
|
||||
self.parallel_mode = parallel_mode
|
||||
|
||||
|
||||
class TensorSpec(Enum):
|
||||
class TensorSpec(object):
|
||||
"""
|
||||
It contains two aspects of information:
|
||||
First, How are tensors distributed in Heterougenous memory space.
|
||||
|
@ -44,4 +38,28 @@ class TensorSpec(Enum):
|
|||
# Before Linear Op, we gather the tensors according to ZeRO.
|
||||
# We perform Linear Op according to compute pattern of TP1DRow.
|
||||
# After Linear Op, we split the tensors according to ZeRO.
|
||||
parallel_action_list: List[ParallelAction] = []
|
||||
def __init__(self, parallel_action_list: List[ParallelAction] = []):
|
||||
self._parallel_action_list = parallel_action_list
|
||||
self.sort()
|
||||
|
||||
@property
|
||||
def parallel_action_list(self):
|
||||
return self._parallel_action_list
|
||||
|
||||
@property
|
||||
def num_action(self):
|
||||
return len(self._parallel_action_list)
|
||||
|
||||
@property
|
||||
def compute_patterns(self):
|
||||
return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list]
|
||||
|
||||
def sort(self):
|
||||
if len(self._parallel_action_list) > 0:
|
||||
self._parallel_action_list.sort(key=lambda parallel_action : parallel_action.priority)
|
||||
|
||||
def get_action_by_compute_pattern(self, compute_pattern: ComputePattern):
|
||||
for parallel_action in self._parallel_action_list:
|
||||
if parallel_action.compute_pattern == compute_pattern:
|
||||
return parallel_action
|
||||
return None
|
||||
|
|
|
@ -12,6 +12,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
|||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||
|
||||
from _utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk
|
||||
|
||||
|
@ -45,7 +46,11 @@ def run_linear_tp1d_row_test():
|
|||
|
||||
# replace the torch nn.Parameters with ColoTensor
|
||||
sharded_weight = ColoTensor.init_from_torch_tensor(W)
|
||||
sharded_weight.set_spec(spec="1Drow") # reshard
|
||||
parallel_action_list = [
|
||||
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)
|
||||
]
|
||||
spec = TensorSpec(parallel_action_list)
|
||||
sharded_weight.set_spec(spec=spec) # reshard
|
||||
sharded_bias = ColoTensor.init_from_torch_tensor(B)
|
||||
replace_parameter_add_grad(layer, sharded_weight, sharded_bias)
|
||||
out = layer(A)
|
||||
|
|
Loading…
Reference in New Issue