[tensor] refine linear and add gather for laynorm (#893)

* refine linear and add function to ColoTensor

* add gather for layernorm

* polish

* polish
pull/897/head
Ziyue Jiang 3 years ago committed by GitHub
parent 26c49639d8
commit cb182da7c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,4 +1,4 @@
from .spec import ComputePattern, ParallelAction, TensorSpec
from .spec import ComputePattern, ParallelAction, TensorSpec, ShardPattern
from .op_wrapper import (
colo_op_impl,)
from .colo_tensor import ColoTensor
@ -7,5 +7,5 @@ from ._ops import *
__all__ = [
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
'named_params_with_colotensor'
'named_params_with_colotensor', 'ShardPattern'
]

@ -27,6 +27,8 @@ def colo_layernorm(types, args=(), kwargs=None, pg=None):
eps = kwargs['eps']
if isinstance(input_tensor, ColoTensor):
if input_tensor.is_activation() and not input_tensor.is_gathered():
input_tensor.gather()
input_tensor = input_tensor.torch_tensor()
if isinstance(weight, ColoTensor):
weight = weight.torch_tensor()

@ -6,9 +6,75 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
from colossalai.nn.layer.utils import divide
from colossalai.core import global_context as gpc
from packaging import version
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTensor) -> ColoTensor:
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
if input_tensor.is_gathered():
# Not splited yet.
assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_tensor.shape, weight.size, weight.size(-1) * gpc.tensor_parallel_size)
input_per_partition = split_forward_gather_backward(input_tensor.torch_tensor(), parallel_action.parallel_mode, dim=-1)
elif input_tensor.shard_pattern == ShardPattern.Col:
# Splited by 1Dcol
assert input_tensor.shape[-1] == weight.size(-1), \
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_tensor.shape, weight.size, weight.size(-1))
input_per_partition = input_tensor.torch_tensor()
else:
raise NotImplementedError
# Output:P
partial_output = torch.nn.functional.linear(input_per_partition, weight.torch_tensor())
# Reduce(Output)
output = reduce_input(partial_output, parallel_action.parallel_mode)
# Bias
if bias is not None:
assert not bias.has_spec(), 'Invalid bias spec for 1Drow Linear op'
output = output + bias.torch_tensor()
output = ColoTensor.init_from_torch_tensor(output)
return output
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTensor) -> ColoTensor:
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
if input_tensor.is_gathered():
# Not splited yet.
assert input_tensor.shape[-1] == weight.size(-1), \
'Invalid shapes in 1Dcol forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_tensor.shape, weight.size, weight.size(-1))
input_parallel = reduce_grad(input_tensor.torch_tensor(), parallel_action.parallel_mode)
# Bias:S[1]
if bias is not None:
assert bias.has_spec() and bias.shard_spec.num_action == 1 and \
bias.shard_pattern in [ShardPattern.Col, ShardPattern.Row], \
'Invalid bias spec for 1Dcol Linear op'
output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias.torch_tensor())
output = ColoTensor.init_from_torch_tensor(output_parallel)
out_parallel_action_list = [
ParallelAction(
priority=1, compute_pattern=ComputePattern.Activation,
parallel_mode=parallel_action.parallel_mode
)
]
output_spec = TensorSpec(out_parallel_action_list)
output.set_spec(output_spec, shard=False)
output.set_shard_pattern(ShardPattern.Col)
if parallel_action.gather_out:
# All-Gather(Output)
output.gather()
return output
@colo_op_impl(torch.nn.functional.linear)
def colo_linear(types, args, kwargs, pg):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
@ -25,110 +91,29 @@ def colo_linear(types, args, kwargs, pg):
else:
bias = kwargs.get('bias', None)
bias_spec = None
if isinstance(bias, ColoTensor):
bias_spec = bias.shard_spec
bias = bias.torch_tensor()
# Add communication logic before and after linear call.
if isinstance(weight, ColoTensor):
if weight.shard_spec == None or weight.shard_spec.num_action == 0:
assert bias_spec == None or bias_spec.num_action == 0, 'Invalid bias spec for native Linear op'
if isinstance(input_tensor, ColoTensor):
input_tensor = input_tensor.torch_tensor()
if isinstance(weight, ColoTensor):
weight = weight.torch_tensor()
return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
elif weight.shard_spec.num_action == 1:
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
compute_patterns = weight.shard_spec.compute_patterns
if ComputePattern.TP1DRow in compute_patterns:
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
input_spec = None
if isinstance(input_tensor, ColoTensor):
input_spec = input_tensor.shard_spec
input_tensor = input_tensor.torch_tensor()
if not isinstance(input_tensor, ColoTensor):
input_tensor = ColoTensor.init_from_torch_tensor(input_tensor)
if input_spec == None or input_spec.num_action == 0:
# Not splited yet.
assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_tensor.shape, weight.size, weight.size(-1) * gpc.tensor_parallel_size)
input_per_partition = split_forward_gather_backward(input_tensor, parallel_action.parallel_mode, dim=-1)
elif input_tensor.shard_spec.num_action == 1:
if ComputePattern.TP1DCol in input_spec.compute_patterns:
# Splited by 1Dcol
assert input_tensor.shape[-1] == weight.size(-1), \
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_tensor.shape, weight.size, weight.size(-1))
input_per_partition = input_tensor
else:
raise NotImplementedError
else:
raise NotImplementedError
if not isinstance(weight, ColoTensor):
weight = ColoTensor.init_from_torch_tensor(weight)
# Output:P
weight_ = weight.torch_tensor()
partial_output = torch.nn.functional.linear(input_per_partition, weight_)
# Reduce(Output)
output = reduce_input(partial_output, parallel_action.parallel_mode)
# Bias
if bias is not None:
assert bias_spec == None or bias_spec.num_action == 0, 'Invalid bias spec for 1Drow Linear op'
output = output + bias
output = ColoTensor.init_from_torch_tensor(output)
return output
elif ComputePattern.TP1DCol in compute_patterns:
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
input_spec = None
output_spec = None
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
if isinstance(input_tensor, ColoTensor):
input_spec = input_tensor.shard_spec
input_tensor = input_tensor.torch_tensor()
if input_spec == None or input_spec.num_action == 0:
# Not splited yet.
assert input_tensor.shape[-1] == weight.size(-1), \
'Invalid shapes in 1Dcol forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_tensor.shape, weight.size, weight.size(-1))
input_parallel = reduce_grad(input_tensor, parallel_action.parallel_mode)
else:
raise NotImplementedError
# Bias:S[1]
if bias is not None:
assert bias_spec is not None and bias_spec.num_action == 1 and \
ComputePattern.TP1DCol in bias_spec.compute_patterns, \
'Invalid bias spec for 1Dcol Linear op'
if bias is not None and not isinstance(bias, ColoTensor):
bias = ColoTensor.init_from_torch_tensor(bias)
weight_ = weight.torch_tensor()
output_parallel = torch.nn.functional.linear(input_parallel, weight_, bias)
if parallel_action.gather_out:
# All-Gather(Output)
output = gather_forward_split_backward(output_parallel, parallel_action.parallel_mode, dim=-1)
output = ColoTensor.init_from_torch_tensor(output)
else:
output = ColoTensor.init_from_torch_tensor(output_parallel)
out_parallel_action_list = [
ParallelAction(
priority=1, compute_pattern=ComputePattern.TP1DCol,
parallel_mode=parallel_action.parallel_mode
)
]
output_spec = TensorSpec(out_parallel_action_list)
# set ColoTensor spec
if output_spec is not None:
output.set_spec(output_spec)
return output
else:
raise NotImplementedError
# Add communication logic before and after linear call.
if not weight.has_spec(): # No Model Parallel Applied
assert not bias.has_spec(), 'Invalid bias spec for native Linear op'
input_tensor = input_tensor.torch_tensor()
weight = weight.torch_tensor()
bias = bias.torch_tensor()
return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied
compute_patterns = weight.shard_spec.compute_patterns
if ComputePattern.TP1DRow in compute_patterns:
return colo_linear_1Drow(input_tensor, weight, bias)
elif ComputePattern.TP1DCol in compute_patterns:
return colo_linear_1Dcol(input_tensor, weight, bias)
else:
raise NotImplementedError
else:
return torch.nn.functional.linear(input_tensor, weight, bias)
raise NotImplementedError

@ -1,4 +1,3 @@
from colossalai.context import parallel_mode
from .op_wrapper import _COLOSSAL_OPS
import torch
@ -6,8 +5,8 @@ from typing import Tuple, Optional, Callable
from numpy import product
from colossalai.core import global_context as gpc
from colossalai.nn.layer.utils import divide
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.tensor import TensorSpec, ComputePattern, ShardPattern
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, gather_forward_split_backward
class ColoTensor(object):
""" Data Structure for Tensor in Colossal-AI
@ -37,6 +36,7 @@ class ColoTensor(object):
self._device = device
self._torch_tensor = torch_tensor
self._shard_spec = shard_spec
self._shard_pattern = ShardPattern.NA
def __getitem__(self, key):
return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key])
@ -45,6 +45,10 @@ class ColoTensor(object):
def shard_spec(self) -> TensorSpec:
return self._shard_spec
@property
def shard_pattern(self):
return self._shard_pattern
@property
def data(self):
return self._torch_tensor.data
@ -112,22 +116,51 @@ class ColoTensor(object):
device=self._device)
return self._torch_tensor
def set_spec(self, spec: TensorSpec, lazy_shard: bool = False) -> None:
def set_spec(self, spec: TensorSpec, shard: bool = True) -> None:
self._shard_spec = spec
if lazy_shard == False:
self._shard()
if shard == True:
self.shard()
def set_shard_pattern(self, shard_pattern: ShardPattern):
self._shard_pattern = shard_pattern
def _shard(self):
def shard(self):
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
if self._shard_spec.num_action == 1:
if ComputePattern.TP1DRow in self._shard_spec.compute_patterns:
parallel_action = self._shard_spec.get_action_by_compute_pattern(
ComputePattern.TP1DRow)
self._shard_1d(parallel_action=parallel_action, dim=-1)
elif ComputePattern.TP1DCol in self._shard_spec.compute_patterns:
parallel_action = self._shard_spec.get_action_by_compute_pattern(
ComputePattern.TP1DCol)
self._shard_1d(parallel_action=parallel_action, dim=0)
if self._shard_pattern is not ShardPattern.NA: # reshard
self.gather()
# Model Parameters
if ComputePattern.TP1DRow in self._shard_spec.compute_patterns:
parallel_action = self._shard_spec.get_action_by_compute_pattern(
ComputePattern.TP1DRow)
self._shard_1d(parallel_action=parallel_action, dim=-1)
self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear().
elif ComputePattern.TP1DCol in self._shard_spec.compute_patterns:
parallel_action = self._shard_spec.get_action_by_compute_pattern(
ComputePattern.TP1DCol)
self._shard_1d(parallel_action=parallel_action, dim=0)
self._shard_pattern = ShardPattern.Row
def gather(self):
assert self.is_activation(), 'Currently we only support gather Activation ColoTensor.'
assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.'
parallel_action = self._shard_spec.get_action_by_compute_pattern(
ComputePattern.Activation)
if self._shard_pattern == ShardPattern.Row:
dim = 0
elif self._shard_pattern == ShardPattern.Col:
dim = -1
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
self._shard_pattern = ShardPattern.NA
def is_gathered(self) -> bool:
return self._shard_pattern == ShardPattern.NA
def has_spec(self) -> bool:
return self._shard_spec is not None and self._shard_spec.num_action > 0
def is_activation(self) -> bool:
return self._shard_spec is not None and self._shard_spec.num_action == 1 \
and ComputePattern.Activation in self._shard_spec.compute_patterns
def _shard_1d(self, parallel_action, dim=-1):
num_partition = gpc.get_world_size(parallel_action.parallel_mode)

@ -4,11 +4,16 @@ from colossalai.context.parallel_mode import ParallelMode
class ComputePattern(Enum):
Activation = 0 # TODO(jzy) A tmp place to store Activation info. Find a better place in future.
TP1DRow = 1
TP1DCol = 2
ZeRO = 3
DP = 4
class ShardPattern(Enum):
NA = 0
Row = 1
Col = 2
class ParallelAction(object):
@ -18,6 +23,7 @@ class ParallelAction(object):
self.parallel_mode = parallel_mode
self.gather_out = gather_out
class TensorSpec(object):
"""
It contains two aspects of information:
@ -42,8 +48,9 @@ class TensorSpec(object):
# We perform Linear Op according to compute pattern of TP1DRow.
# After Linear Op, we split the tensors according to ZeRO.
def __init__(self, parallel_action_list: List[ParallelAction] = []):
def __init__(self, parallel_action_list: List[ParallelAction] = [], shard_pattern: ShardPattern = ShardPattern.NA):
self._parallel_action_list = parallel_action_list
self._shard_pattern = shard_pattern
self.sort()
@property
@ -57,6 +64,10 @@ class TensorSpec(object):
@property
def compute_patterns(self):
return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list]
@property
def shard_pattern(self):
return self._shard_pattern
def sort(self):
if len(self._parallel_action_list) > 0:

@ -145,7 +145,7 @@ def run_linear_tp1d_row_test():
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_linear_tp1d_row_test()
#run_linear_tp1d_row_test()
run_linear_tp1d_col_test()
@pytest.mark.dist

@ -26,6 +26,77 @@ def set_seed(seed):
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def run_1d_col_tp():
# A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
set_seed(1)
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
parallel_action_list_row = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)
]
spec_row = TensorSpec(parallel_action_list_row)
parallel_action_list_col = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)
]
spec_col = TensorSpec(parallel_action_list_col)
set_seed(1)
if rank == 0:
model_torch = model_builder(checkpoint=True)
model_torch = model_torch.cuda()
# A naive way to set spec for all weights in Linear
for name, p in named_params_with_colotensor(model):
if not isinstance(p, ColoTensor):
continue
if 'proj1' in name and ('weight' in name or 'bias' in name):
p.set_spec(spec_col)
if 'proj2' in name and 'weight' in name:
p.set_spec(spec_row)
model = model.cuda()
for i, (data, label) in enumerate(train_dataloader):
data = data.to(get_current_device())
label = label.to(get_current_device())
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Bcast rank0 data to all processes
if criterion:
output = model(data)
loss = criterion(output, label)
else:
output = model(data, label)
loss = output
# For reference
if rank == 0:
if criterion:
output_torch = model_torch(data)
loss_torch = criterion(output_torch, label)
else:
output_torch = model_torch(data, label)
loss_torch = output_torch
if rank == 0:
# print(loss.torch_tensor().item())
# print('loss torch', loss_torch.item())
assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2)
loss.backward()
if rank == 0:
loss_torch.backward()
if i > 5:
break
def run_1d_row_tp():
# A simple net with two stacked nn.Linear

Loading…
Cancel
Save