mirror of https://github.com/hpcaitech/ColossalAI
[tensor] add ColoTensor 1Dcol (#888)
parent
a0e5971692
commit
1d0aba4153
|
@ -1,12 +1,12 @@
|
|||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input
|
||||
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input, \
|
||||
gather_forward_split_backward, reduce_grad
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.core import global_context as gpc
|
||||
from packaging import version
|
||||
from colossalai.tensor import ComputePattern
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor
|
||||
|
||||
|
||||
@colo_op_impl(torch.nn.functional.linear)
|
||||
|
@ -25,39 +25,107 @@ def colo_linear(types, args, kwargs, pg):
|
|||
else:
|
||||
bias = kwargs.get('bias', None)
|
||||
|
||||
bias_spec = None
|
||||
if isinstance(bias, ColoTensor):
|
||||
assert bias.shard_spec.num_action == 0, f"We currently only support bias is duplicated among processes in the linear operator"
|
||||
bias_spec = bias.shard_spec
|
||||
bias = bias.torch_tensor()
|
||||
|
||||
|
||||
# Add communication logic before and after linear call.
|
||||
if isinstance(weight, ColoTensor):
|
||||
if weight.shard_spec == None or weight.shard_spec.num_action == 0:
|
||||
assert bias_spec == None or bias_spec.num_action == 0, 'Invalid bias spec for native Linear op'
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
if isinstance(weight, ColoTensor):
|
||||
weight = weight.torch_tensor()
|
||||
return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
|
||||
elif weight.shard_spec.num_action == 1:
|
||||
if ComputePattern.TP1DRow in weight.shard_spec.compute_patterns:
|
||||
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
||||
compute_patterns = weight.shard_spec.compute_patterns
|
||||
if ComputePattern.TP1DRow in compute_patterns:
|
||||
# Input:S[1] x Weight:S[0] = Output:P
|
||||
# All-Reduce(Output) + bias = res
|
||||
assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \
|
||||
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_tensor.shape, weight.size, weight.size(-1) * gpc.tensor_parallel_size)
|
||||
# Input:S[1]
|
||||
input_spec = None
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
input_spec = input_tensor.shard_spec
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
||||
input_per_partition = split_forward_gather_backward(input_tensor, parallel_action.parallel_mode, dim=-1)
|
||||
|
||||
if input_spec == None or input_spec.num_action == 0:
|
||||
# Not splited yet.
|
||||
assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \
|
||||
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_tensor.shape, weight.size, weight.size(-1) * gpc.tensor_parallel_size)
|
||||
input_per_partition = split_forward_gather_backward(input_tensor, parallel_action.parallel_mode, dim=-1)
|
||||
elif input_tensor.shard_spec.num_action == 1:
|
||||
if ComputePattern.TP1DCol in input_spec.compute_patterns:
|
||||
# Splited by 1Dcol
|
||||
assert input_tensor.shape[-1] == weight.size(-1), \
|
||||
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_tensor.shape, weight.size, weight.size(-1))
|
||||
input_per_partition = input_tensor
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Output:P
|
||||
weight_ = weight.torch_tensor()
|
||||
partial_output = torch.nn.functional.linear(input_per_partition, weight_)
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
output = reduce_input(partial_output, parallel_action.parallel_mode)
|
||||
# Bias
|
||||
if bias is not None:
|
||||
assert bias_spec == None or bias_spec.num_action == 0, 'Invalid bias spec for 1Drow Linear op'
|
||||
output = output + bias
|
||||
return ColoTensor.init_from_torch_tensor(output)
|
||||
output = ColoTensor.init_from_torch_tensor(output)
|
||||
return output
|
||||
elif ComputePattern.TP1DCol in compute_patterns:
|
||||
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
||||
# All-Gather(Output)
|
||||
# Input:B
|
||||
input_spec = None
|
||||
output_spec = None
|
||||
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
input_spec = input_tensor.shard_spec
|
||||
input_tensor = input_tensor.torch_tensor()
|
||||
|
||||
if input_spec == None or input_spec.num_action == 0:
|
||||
# Not splited yet.
|
||||
assert input_tensor.shape[-1] == weight.size(-1), \
|
||||
'Invalid shapes in 1Dcol forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_tensor.shape, weight.size, weight.size(-1))
|
||||
input_parallel = reduce_grad(input_tensor, parallel_action.parallel_mode)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
# Bias:S[1]
|
||||
if bias is not None:
|
||||
assert bias_spec is not None and bias_spec.num_action == 1 and \
|
||||
ComputePattern.TP1DCol in bias_spec.compute_patterns, \
|
||||
'Invalid bias spec for 1Dcol Linear op'
|
||||
|
||||
weight_ = weight.torch_tensor()
|
||||
output_parallel = torch.nn.functional.linear(input_parallel, weight_, bias)
|
||||
|
||||
if parallel_action.gather_out:
|
||||
# All-Gather(Output)
|
||||
output = gather_forward_split_backward(output_parallel, parallel_action.parallel_mode, dim=-1)
|
||||
output = ColoTensor.init_from_torch_tensor(output)
|
||||
else:
|
||||
output = ColoTensor.init_from_torch_tensor(output_parallel)
|
||||
out_parallel_action_list = [
|
||||
ParallelAction(
|
||||
priority=1, compute_pattern=ComputePattern.TP1DCol,
|
||||
parallel_mode=parallel_action.parallel_mode
|
||||
)
|
||||
]
|
||||
output_spec = TensorSpec(out_parallel_action_list)
|
||||
# set ColoTensor spec
|
||||
if output_spec is not None:
|
||||
output.set_spec(output_spec)
|
||||
return output
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
|
|
|
@ -121,18 +121,25 @@ class ColoTensor(object):
|
|||
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
|
||||
if self._shard_spec.num_action == 1:
|
||||
if ComputePattern.TP1DRow in self._shard_spec.compute_patterns:
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
||||
num_partition = gpc.get_world_size(parallel_action.parallel_mode)
|
||||
local_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||
dim = -1
|
||||
chunk_size = divide(self._size[dim], num_partition)
|
||||
# Reshape to get shard for this rank and we don't want autograd
|
||||
# recording here for the narrow op and 'local_shard' should be a
|
||||
# leaf variable in the autograd graph.
|
||||
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
|
||||
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
|
||||
self._torch_tensor.requires_grad = self._requires_grad
|
||||
self._size = self._torch_tensor.size()
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(
|
||||
ComputePattern.TP1DRow)
|
||||
self._shard_1d(parallel_action=parallel_action, dim=-1)
|
||||
elif ComputePattern.TP1DCol in self._shard_spec.compute_patterns:
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(
|
||||
ComputePattern.TP1DCol)
|
||||
self._shard_1d(parallel_action=parallel_action, dim=0)
|
||||
|
||||
def _shard_1d(self, parallel_action, dim=-1):
|
||||
num_partition = gpc.get_world_size(parallel_action.parallel_mode)
|
||||
local_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||
chunk_size = divide(self._size[dim], num_partition)
|
||||
# Reshape to get shard for this rank and we don't want autograd
|
||||
# recording here for the narrow op and 'local_shard' should be a
|
||||
# leaf variable in the autograd graph.
|
||||
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
|
||||
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
|
||||
self._torch_tensor.requires_grad = self._requires_grad
|
||||
self._size = self._torch_tensor.size()
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
|
|
|
@ -12,11 +12,11 @@ class ComputePattern(Enum):
|
|||
|
||||
class ParallelAction(object):
|
||||
|
||||
def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None:
|
||||
def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA, gather_out=True) -> None:
|
||||
self.priority = priority
|
||||
self.compute_pattern = compute_pattern
|
||||
self.parallel_mode = parallel_mode
|
||||
|
||||
self.gather_out = gather_out
|
||||
|
||||
class TensorSpec(object):
|
||||
"""
|
||||
|
|
|
@ -16,6 +16,69 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
|||
|
||||
from _utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk
|
||||
|
||||
def run_linear_tp1d_col_test():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
DEPTH = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
in_features = 4
|
||||
out_features = 8
|
||||
|
||||
local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
layer_master = torch.nn.Linear(in_features, out_features)
|
||||
layer = torch.nn.Linear(in_features, out_features)
|
||||
|
||||
A_shape = (2, in_features)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
A = broadcast_tensor_chunk(A_master, chunk_size=1)
|
||||
A.requires_grad = True
|
||||
|
||||
W_shape = (out_features, in_features)
|
||||
W_master = torch.randn(W_shape, dtype=dtype, device=device)
|
||||
W = broadcast_tensor_chunk(W_master, chunk_size=1)
|
||||
W.requires_grad = True
|
||||
|
||||
B_shape = (out_features)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=device)
|
||||
B = broadcast_tensor_chunk(B_master, chunk_size=1)
|
||||
B.requires_grad = True
|
||||
|
||||
# replace the torch nn.Parameters with ColoTensor
|
||||
sharded_weight = ColoTensor.init_from_torch_tensor(W)
|
||||
sharded_bias = ColoTensor.init_from_torch_tensor(B)
|
||||
parallel_action_list = [
|
||||
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)
|
||||
]
|
||||
spec = TensorSpec(parallel_action_list)
|
||||
sharded_weight.set_spec(spec) # reshard
|
||||
sharded_bias.set_spec(spec)
|
||||
|
||||
replace_parameter_add_grad(layer, sharded_weight, sharded_bias)
|
||||
out = layer(A)
|
||||
|
||||
replace_parameter_add_grad(layer_master, W_master, B_master)
|
||||
A_master.requires_grad = True
|
||||
#C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
|
||||
C_master = layer_master(A_master)
|
||||
C = C_master.clone()
|
||||
|
||||
check_equal(out, C)
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad = broadcast_tensor_chunk(grad_master, chunk_size=1)
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
W_grad = W_master.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[local_rank]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[local_rank]
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
def run_linear_tp1d_row_test():
|
||||
device = get_current_device()
|
||||
|
@ -83,7 +146,7 @@ def run_dist(rank, world_size, port):
|
|||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_linear_tp1d_row_test()
|
||||
|
||||
run_linear_tp1d_col_test()
|
||||
|
||||
@pytest.mark.dist
|
||||
@parameterize('world_size', [1, 4])
|
||||
|
|
Loading…
Reference in New Issue