[Tensor] activation is an attr of ColoTensor (#897)

pull/900/head
Jiarui Fang 3 years ago committed by GitHub
parent e76f76c08b
commit 676f191532
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,17 +9,19 @@ from packaging import version
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTensor) -> ColoTensor:
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
# Input:S[1]
if input_tensor.is_gathered():
# Not splited yet.
assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_tensor.shape, weight.size, weight.size(-1) * gpc.tensor_parallel_size)
input_per_partition = split_forward_gather_backward(input_tensor.torch_tensor(), parallel_action.parallel_mode, dim=-1)
input_per_partition = split_forward_gather_backward(input_tensor.torch_tensor(),
parallel_action.parallel_mode,
dim=-1)
elif input_tensor.shard_pattern == ShardPattern.Col:
# Splited by 1Dcol
assert input_tensor.shape[-1] == weight.size(-1), \
@ -40,7 +42,8 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTen
output = ColoTensor.init_from_torch_tensor(output)
return output
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTensor) -> ColoTensor:
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
@ -59,14 +62,9 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTen
'Invalid bias spec for 1Dcol Linear op'
output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias.torch_tensor())
output = ColoTensor.init_from_torch_tensor(output_parallel)
out_parallel_action_list = [
ParallelAction(
priority=1, compute_pattern=ComputePattern.Activation,
parallel_mode=parallel_action.parallel_mode
)
]
out_parallel_action_list = [ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]
output_spec = TensorSpec(out_parallel_action_list)
output.set_spec(output_spec, shard=False)
output.set_shard_pattern(ShardPattern.Col)
@ -75,6 +73,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTen
output.gather()
return output
@colo_op_impl(torch.nn.functional.linear)
def colo_linear(types, args, kwargs, pg):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
@ -99,15 +98,15 @@ def colo_linear(types, args, kwargs, pg):
if bias is not None and not isinstance(bias, ColoTensor):
bias = ColoTensor.init_from_torch_tensor(bias)
# Add communication logic before and after linear call.
if not weight.has_spec(): # No Model Parallel Applied
if not weight.has_spec(): # No Model Parallel Applied
assert not bias.has_spec(), 'Invalid bias spec for native Linear op'
input_tensor = input_tensor.torch_tensor()
weight = weight.torch_tensor()
bias = bias.torch_tensor()
return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied
elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied
compute_patterns = weight.shard_spec.compute_patterns
if ComputePattern.TP1DRow in compute_patterns:
return colo_linear_1Drow(input_tensor, weight, bias)

@ -7,6 +7,13 @@ from colossalai.core import global_context as gpc
from colossalai.nn.layer.utils import divide
from colossalai.tensor import TensorSpec, ComputePattern, ShardPattern
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, gather_forward_split_backward
from enum import Enum
class TensorType(Enum):
MODEL = 0
NONMODEL = 1 # mainly activations
class ColoTensor(object):
""" Data Structure for Tensor in Colossal-AI
@ -28,6 +35,7 @@ class ColoTensor(object):
device=None,
torch_tensor=torch.empty(0),
shard_spec: TensorSpec = TensorSpec(),
is_model_data: bool = False,
):
self._size = size
self._dtype = dtype
@ -37,6 +45,10 @@ class ColoTensor(object):
self._torch_tensor = torch_tensor
self._shard_spec = shard_spec
self._shard_pattern = ShardPattern.NA
if is_model_data:
self._type = TensorType.MODEL
else:
self._type = TensorType.NONMODEL
def __getitem__(self, key):
return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key])
@ -85,13 +97,14 @@ class ColoTensor(object):
return product(self._size)
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True, is_model_data=False) -> 'ColoTensor':
colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.is_pinned(),
device=tensor.device,
torch_tensor=tensor if save_payload else torch.empty(0))
torch_tensor=tensor if save_payload else torch.empty(0),
is_model_data=is_model_data)
return colo_t
def del_torch_tensor(self, save_shape=False) -> None:
@ -120,31 +133,28 @@ class ColoTensor(object):
self._shard_spec = spec
if shard == True:
self.shard()
def set_shard_pattern(self, shard_pattern: ShardPattern):
self._shard_pattern = shard_pattern
def shard(self):
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
if self._shard_pattern is not ShardPattern.NA: # reshard
if self._shard_pattern is not ShardPattern.NA: # reshard
self.gather()
# Model Parameters
if ComputePattern.TP1DRow in self._shard_spec.compute_patterns:
parallel_action = self._shard_spec.get_action_by_compute_pattern(
ComputePattern.TP1DRow)
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
self._shard_1d(parallel_action=parallel_action, dim=-1)
self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear().
self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear().
elif ComputePattern.TP1DCol in self._shard_spec.compute_patterns:
parallel_action = self._shard_spec.get_action_by_compute_pattern(
ComputePattern.TP1DCol)
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
self._shard_1d(parallel_action=parallel_action, dim=0)
self._shard_pattern = ShardPattern.Row
def gather(self):
assert self.is_activation(), 'Currently we only support gather Activation ColoTensor.'
assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.'
parallel_action = self._shard_spec.get_action_by_compute_pattern(
ComputePattern.Activation)
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
if self._shard_pattern == ShardPattern.Row:
dim = 0
elif self._shard_pattern == ShardPattern.Col:
@ -159,9 +169,8 @@ class ColoTensor(object):
return self._shard_spec is not None and self._shard_spec.num_action > 0
def is_activation(self) -> bool:
return self._shard_spec is not None and self._shard_spec.num_action == 1 \
and ComputePattern.Activation in self._shard_spec.compute_patterns
return self._type == TensorType.NONMODEL
def _shard_1d(self, parallel_action, dim=-1):
num_partition = gpc.get_world_size(parallel_action.parallel_mode)
local_rank = gpc.get_local_rank(parallel_action.parallel_mode)
@ -169,8 +178,8 @@ class ColoTensor(object):
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach().contiguous(
) # TODO Shall we clone() here since detach() will point to the old tensor?
self._torch_tensor.requires_grad = self._requires_grad
self._size = self._torch_tensor.size()

@ -4,20 +4,25 @@ from colossalai.context.parallel_mode import ParallelMode
class ComputePattern(Enum):
Activation = 0 # TODO(jzy) A tmp place to store Activation info. Find a better place in future.
TP1DRow = 1
TP1DCol = 2
ZeRO = 3
DP = 4
class ShardPattern(Enum):
NA = 0
Row = 1
Col = 2
class ParallelAction(object):
def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA, gather_out=True) -> None:
def __init__(self,
priority=0,
compute_pattern=ComputePattern.DP,
parallel_mode=ParallelMode.DATA,
gather_out=True) -> None:
self.priority = priority
self.compute_pattern = compute_pattern
self.parallel_mode = parallel_mode
@ -64,7 +69,7 @@ class TensorSpec(object):
@property
def compute_patterns(self):
return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list]
@property
def shard_pattern(self):
return self._shard_pattern

@ -94,7 +94,10 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
save_torch_payload = True if not self._lazy_memory_allocate else False
for name, param in name_list:
delattr(module, name)
setattr(module, name,
ColoTensor.init_from_torch_tensor(tensor=param.to(self._device), save_payload=save_torch_payload))
setattr(
module, name,
ColoTensor.init_from_torch_tensor(tensor=param.to(self._device),
save_payload=save_torch_payload,
is_model_data=True))
ColoModulize(module)

Loading…
Cancel
Save