mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] supported fused qkv checkpoint (#4073)
parent
0803a61412
commit
70c58cfd4f
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
import fused_mix_prec_layer_norm_cuda
|
||||
|
@ -46,6 +47,53 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
|||
return grad_input, grad_weight, grad_bias, None, None
|
||||
|
||||
|
||||
class MatmulWithAsyncCommunication(torch.autograd.Function):
|
||||
"""
|
||||
Linear layer execution with asynchronous communication in backprop.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
|
||||
total_input = input
|
||||
grad_input = grad_output.matmul(weight.T)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# all-reduce scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
|
||||
grad_weight = total_input.t().matmul(grad_output)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
handle.wait()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
"""
|
||||
Linear layer execution with asynchronous communication in backprop.
|
||||
|
@ -58,9 +106,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
ctx.process_group = process_group
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
|
||||
output = torch.matmul(input_, weight.t())
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
output = F.linear(input_, weight, bias)
|
||||
else:
|
||||
output = F.linear(input_, weight)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
|
@ -114,7 +163,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
|||
return _gather(grad_output, ctx.dim, ctx.process_group), None, None
|
||||
|
||||
|
||||
class _ReduceInput(torch.autograd.Function):
|
||||
class _ReduceForward(torch.autograd.Function):
|
||||
"""
|
||||
All-reduce the input from the model parallel region.
|
||||
|
||||
|
@ -132,6 +181,25 @@ class _ReduceInput(torch.autograd.Function):
|
|||
return grad_output, None
|
||||
|
||||
|
||||
class _ReduceBackward(torch.autograd.Function):
|
||||
"""
|
||||
All-reduce the input from the model parallel region.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
parallel_mode: parallel mode.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group):
|
||||
ctx.process_group = process_group
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _reduce(grad_output, ctx.process_group), None
|
||||
|
||||
|
||||
def _reduce(input_, process_group):
|
||||
# skip if only one rank involved
|
||||
if dist.get_world_size(process_group) == 1:
|
||||
|
@ -198,6 +266,10 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
|||
return _split(grad_output, ctx.dim, ctx.process_group), None, None
|
||||
|
||||
|
||||
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||
|
||||
|
@ -210,5 +282,9 @@ def split_forward_gather_backward(input_, dim, process_group):
|
|||
return _SplitForwardGatherBackward.apply(input_, dim, process_group)
|
||||
|
||||
|
||||
def reduce_input(input_, process_group):
|
||||
return _ReduceInput.apply(input_, process_group)
|
||||
def reduce_forward(input_, process_group):
|
||||
return _ReduceForward.apply(input_, process_group)
|
||||
|
||||
|
||||
def reduce_backward(input_, process_group):
|
||||
return _ReduceBackward.apply(input_, process_group)
|
||||
|
|
|
@ -15,7 +15,7 @@ from colossalai.nn import init as init
|
|||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||
|
||||
from ._operation import gather_forward_split_backward, reduce_input
|
||||
from ._operation import gather_forward_split_backward, reduce_forward
|
||||
from .parallel_module import ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
|
@ -276,5 +276,5 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
# Mask the output embedding.
|
||||
output_parallel[input_mask, :] = 0.
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_input(output_parallel, self.process_group)
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
return output
|
||||
|
|
|
@ -15,12 +15,11 @@ from torch.nn.parameter import Parameter
|
|||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._operation import (
|
||||
gather_forward_split_backward,
|
||||
linear_with_async_comm,
|
||||
reduce_input,
|
||||
reduce_forward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from .parallel_module import ParallelModule
|
||||
|
@ -148,9 +147,10 @@ class Linear1D_Col(ParallelModule):
|
|||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
|
||||
# Set up backprop all-reduce.
|
||||
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||
input_parallel = input_
|
||||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
|
@ -209,17 +209,14 @@ class Linear1D_Row(ParallelModule):
|
|||
self.parallel_input = parallel_input
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
|
||||
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
||||
sharded_weight = shard_colwise(weight, self.process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
@ -327,8 +324,7 @@ class Linear1D_Row(ParallelModule):
|
|||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
|
||||
output = reduce_input(output_parallel, self.process_group)
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
|
@ -336,5 +332,3 @@ class Linear1D_Row(ParallelModule):
|
|||
return output
|
||||
else:
|
||||
return output, self.bias
|
||||
return output, self.bias
|
||||
return output, self.bias
|
||||
|
|
|
@ -14,13 +14,18 @@ from torch.nn.parameter import Parameter
|
|||
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.tensor.d_tensor.api import (
|
||||
customized_distributed_tensor_to_param,
|
||||
distribute_tensor_with_customization,
|
||||
shard_rowwise,
|
||||
sharded_tensor_to_param,
|
||||
)
|
||||
|
||||
from ._operation import (
|
||||
gather_forward_split_backward,
|
||||
linear_with_async_comm,
|
||||
reduce_input,
|
||||
matmul_with_async_comm,
|
||||
reduce_backward,
|
||||
reduce_forward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from .parallel_module import ParallelModule
|
||||
|
@ -29,11 +34,69 @@ from .utils import create_randomizer_with_offset
|
|||
__all__ = ['LinearConv1D_Col', 'LinearConv1D_Row']
|
||||
|
||||
|
||||
def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup):
|
||||
"""
|
||||
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
|
||||
"""
|
||||
# get the number of slice for the fused qkv
|
||||
rank = dist.get_rank(group=process_group)
|
||||
world_size = dist.get_world_size(group=process_group)
|
||||
order = torch.arange(world_size * n_fused)
|
||||
|
||||
# split the fused qkv
|
||||
# from
|
||||
# [Q, K, V]
|
||||
# to
|
||||
# [Q1, Q2, K1, K2, V1, V2]
|
||||
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
|
||||
|
||||
# rearrange the slice into the final order
|
||||
# from
|
||||
# [Q1, Q2, K1, K2, V1, V2]
|
||||
# to
|
||||
# [Q1, K1, V1], [Q2, K2, V2]
|
||||
weight_chunks_of_current_rank = [weight_chunks[i] for i in order[rank::world_size]]
|
||||
weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1)
|
||||
return weight_of_current_rank
|
||||
|
||||
|
||||
def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup):
|
||||
"""
|
||||
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
|
||||
"""
|
||||
world_size = dist.get_world_size(group=process_group)
|
||||
|
||||
# gather the tensors
|
||||
# from
|
||||
# [Q1, K1, V1], [Q2, K2, V2]
|
||||
# to
|
||||
# [Q1, K1, V1, Q2, K2, V2]
|
||||
origin_device = qkv.device
|
||||
qkv = qkv.cuda()
|
||||
gather_list = [torch.zeros_like(qkv) for _ in range(world_size)]
|
||||
dist.all_gather(gather_list, qkv, group=process_group)
|
||||
gather_weight = torch.cat(gather_list, dim=-1)
|
||||
gather_weight = gather_weight.to(origin_device)
|
||||
qkv = qkv.to(origin_device)
|
||||
|
||||
# rearrange the tensor slices
|
||||
# from
|
||||
# [Q1, K1, V1, Q2, K2, V2]
|
||||
# to
|
||||
# [Q1, Q2, K1, K2, V1, V2]
|
||||
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
|
||||
reordered_chunk_list = []
|
||||
for i in range(n_fused):
|
||||
reordered_chunk_list.extend(weight_chunks[i::n_fused])
|
||||
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
|
||||
return reordered_gather_weight
|
||||
|
||||
|
||||
class LinearConv1D_Col(ParallelModule):
|
||||
r"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
||||
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface.
|
||||
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
|
@ -41,6 +104,7 @@ class LinearConv1D_Col(ParallelModule):
|
|||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
device (`torch.device`): The device of parameters, defaults to None.
|
||||
n_fused (int): The number items fused, defaults to 3 (QKV).
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
|
@ -63,8 +127,10 @@ class LinearConv1D_Col(ParallelModule):
|
|||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
async_communication: bool = False,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
n_fused: int = 3,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
|
@ -75,23 +141,34 @@ class LinearConv1D_Col(ParallelModule):
|
|||
self.gather_output = gather_output
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.n_fused = n_fused
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
self.async_communication = async_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
self.out_features_per_partition = divide(out_features, self.num_partitions)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
|
||||
weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)
|
||||
|
||||
def shard_fn(tensor):
|
||||
return split_fused_qkv(tensor, self.n_fused, self.process_group)
|
||||
|
||||
def gather_fn(tensor):
|
||||
return gather_fused_qkv(tensor, 3, self.process_group)
|
||||
|
||||
with torch.no_grad():
|
||||
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
|
||||
self.weight = customized_distributed_tensor_to_param(sharded_weight)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
|
||||
bias = torch.empty(self.out_features, **factory_kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn)
|
||||
self.bias = customized_distributed_tensor_to_param(sharded_bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -103,7 +180,7 @@ class LinearConv1D_Col(ParallelModule):
|
|||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
|
||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
|
||||
*args, **kwargs) -> ParallelModule:
|
||||
r"""
|
||||
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
|
||||
|
@ -135,29 +212,12 @@ class LinearConv1D_Col(ParallelModule):
|
|||
|
||||
# TODO: copy the sharded weights
|
||||
with torch.no_grad():
|
||||
# the weigh to the linear layer is a transpose
|
||||
# thus shard on row is equal to shard on column
|
||||
|
||||
# first rearange the order of weight and bias
|
||||
world_size = dist.get_world_size(group=process_group)
|
||||
order = torch.arange(world_size * n_fused)
|
||||
new_order = []
|
||||
for i in range(world_size):
|
||||
new_order.append(order[i::world_size])
|
||||
new_order = torch.cat(new_order)
|
||||
|
||||
weight_chunks = torch.chunk(module.weight.data, world_size * n_fused, dim=1)
|
||||
rearanged_weight_chunks = [weight_chunks[i] for i in new_order]
|
||||
rearanged_weight = torch.cat(rearanged_weight_chunks, dim=1)
|
||||
sharded_weight = shard_colwise(rearanged_weight, process_group)
|
||||
linear_1d.weight.data.copy_(sharded_weight.T.contiguous())
|
||||
sharded_weight = split_fused_qkv(module.weight.data, n_fused=n_fused, process_group=process_group)
|
||||
linear_1d.weight.data.copy_(sharded_weight.data)
|
||||
|
||||
if bias:
|
||||
bias_chunks = torch.chunk(module.bias.data, world_size * n_fused, dim=0)
|
||||
rearanged_bias_chunks = [bias_chunks[i] for i in new_order]
|
||||
rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0)
|
||||
sharded_bias = shard_colwise(rearanged_bias, process_group)
|
||||
linear_1d.bias.copy_(sharded_bias.contiguous())
|
||||
sharded_bias = split_fused_qkv(module.bias.data, n_fused=n_fused, process_group=process_group)
|
||||
linear_1d.bias.data.copy_(sharded_bias.data)
|
||||
|
||||
return linear_1d
|
||||
|
||||
|
@ -169,15 +229,18 @@ class LinearConv1D_Col(ParallelModule):
|
|||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
assert input_.shape[-1] == self.weight.shape[0], \
|
||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
# Set up backprop all-reduce.
|
||||
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||
input_parallel = input_
|
||||
input_parallel = reduce_backward(input_, self.process_group)
|
||||
# input_parallel = input_
|
||||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
|
||||
output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group,
|
||||
self.async_communication)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
|
@ -192,7 +255,8 @@ class LinearConv1D_Col(ParallelModule):
|
|||
|
||||
|
||||
class LinearConv1D_Row(ParallelModule):
|
||||
r""" Linear layer with row parallelism
|
||||
r""" Linear layer with row parallelism.
|
||||
This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
|
@ -243,11 +307,10 @@ class LinearConv1D_Row(ParallelModule):
|
|||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
|
||||
weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)
|
||||
sharded_weight = shard_rowwise(weight, self.process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
# TODO() work for inference only
|
||||
|
@ -295,7 +358,7 @@ class LinearConv1D_Row(ParallelModule):
|
|||
# the weigh to the linear layer is a transpose
|
||||
# thus shard on col is equal to shard on row
|
||||
sharded_weight = shard_rowwise(module.weight.data, process_group)
|
||||
linear_1d.weight.data.copy_(sharded_weight.T.contiguous())
|
||||
linear_1d.weight.data.copy_(sharded_weight.data)
|
||||
|
||||
if bias:
|
||||
linear_1d.bias.copy_(module.bias.data)
|
||||
|
@ -325,12 +388,12 @@ class LinearConv1D_Row(ParallelModule):
|
|||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
assert input_.shape[-1] == self.weight.shape[0], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
input_ = input_
|
||||
else:
|
||||
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \
|
||||
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions)
|
||||
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
|
||||
|
@ -342,7 +405,7 @@ class LinearConv1D_Row(ParallelModule):
|
|||
output_parallel_list = [None for i in range(self.stream_chunk_num)]
|
||||
handle_list = []
|
||||
for i in range(self.stream_chunk_num):
|
||||
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
|
||||
output_parallel_list[i] = torch.matmul(input_, self.weight_list[i])
|
||||
handle = torch.distributed.all_reduce(output_parallel_list[i],
|
||||
group=self.process_group,
|
||||
async_op=True)
|
||||
|
@ -352,9 +415,8 @@ class LinearConv1D_Row(ParallelModule):
|
|||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
|
||||
output = reduce_input(output_parallel, self.process_group)
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
|
|
|
@ -12,11 +12,14 @@ from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
|
|||
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
distribute_tensor_with_customization,
|
||||
get_device_mesh,
|
||||
get_sharding_spec,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
sharded_tensor_to_param,
|
||||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
|
||||
__all__ = ['ParallelModule']
|
||||
|
@ -54,9 +57,10 @@ class ParallelModule(nn.Module, ABC):
|
|||
for name, param in self._parameters.items():
|
||||
if param is not None:
|
||||
param_ = param if keep_vars else param.detach()
|
||||
|
||||
if is_distributed_tensor(param_):
|
||||
destination[prefix + name] = to_global(param_)
|
||||
elif is_customized_distributed_tensor(param_):
|
||||
destination[prefix + name] = to_global_for_customized_distributed_tensor(param_)
|
||||
else:
|
||||
destination[prefix + name] = param_
|
||||
|
||||
|
@ -124,6 +128,8 @@ class ParallelModule(nn.Module, ABC):
|
|||
sharding_spec = get_sharding_spec(param)
|
||||
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
|
||||
input_param = sharded_tensor_to_param(sharded_tensor)
|
||||
elif is_customized_distributed_tensor(param):
|
||||
input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn)
|
||||
|
||||
# This is used to avoid copying uninitialized parameters into
|
||||
# non-lazy modules, since they dont have the hook to do the checks
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
from .api import (
|
||||
compute_global_numel,
|
||||
customized_distributed_tensor_to_param,
|
||||
distribute_tensor,
|
||||
distribute_tensor_with_customization,
|
||||
get_device_mesh,
|
||||
get_global_shape,
|
||||
get_layout,
|
||||
get_sharding_spec,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
is_sharded,
|
||||
redistribute,
|
||||
|
@ -12,6 +15,7 @@ from .api import (
|
|||
shard_rowwise,
|
||||
sharded_tensor_to_param,
|
||||
to_global,
|
||||
to_global_for_customized_distributed_tensor,
|
||||
)
|
||||
from .layout import Layout
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
@ -19,6 +23,6 @@ from .sharding_spec import ShardingSpec
|
|||
__all__ = [
|
||||
'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise',
|
||||
'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh',
|
||||
'redistribute', 'get_layout'
|
||||
'Layout', 'ShardingSpec'
|
||||
'redistribute', 'get_layout', 'is_customized_distributed_tensor', 'distribute_tensor_with_customization',
|
||||
'to_global_for_customized_distributed_tensor', 'customized_distributed_tensor_to_param', 'Layout', 'ShardingSpec'
|
||||
]
|
||||
|
|
|
@ -305,3 +305,130 @@ def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
|
|||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
return dtensor.dist_layout.sharding_spec
|
||||
|
||||
|
||||
# ======================================================
|
||||
# Some sharding does not obey the SPMD style
|
||||
# e.g. Fused QKV layer in GPT2
|
||||
# we support customize sharding with the following APIs
|
||||
# ======================================================
|
||||
def is_customized_distributed_tensor(tensor: torch.Tensor):
|
||||
"""
|
||||
Check whether the given tensor is a customized distributed tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
bool: Whether the given tensor is a customized distributed tensor.
|
||||
"""
|
||||
return hasattr(tensor, 'shard_fn') and hasattr(tensor, 'gather_fn')
|
||||
|
||||
|
||||
def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be hijacked.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The hijacked tensor.
|
||||
"""
|
||||
dtensor._old_detach = dtensor.detach
|
||||
dtensor._old_clone = dtensor.clone
|
||||
|
||||
def new_detach(self):
|
||||
t_ = self._old_detach()
|
||||
t_.shard_fn = self.shard_fn
|
||||
t_.gather_fn = self.gather_fn
|
||||
return t_
|
||||
|
||||
def new_clone(self, *args, **kwargs):
|
||||
t_ = self._old_clone(*args, **kwargs)
|
||||
t_.shard_fn = self.shard_fn
|
||||
t_.gather_fn = self.gather_fn
|
||||
return t_
|
||||
|
||||
# bind the new methods to the tensor
|
||||
dtensor.detach = new_detach.__get__(dtensor)
|
||||
dtensor.clone = new_clone.__get__(dtensor)
|
||||
return dtensor
|
||||
|
||||
|
||||
def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_fn: callable):
|
||||
"""
|
||||
Distribute the given tensor with the given shard_fn and gather_fn.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
# define shard and gather functions
|
||||
def shard_fn(tensor):
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
return tensor.chunk(world_size, dim=0)[rank]
|
||||
|
||||
def gather_fn(tensor):
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(shard_list, tensor)
|
||||
return torch.cat(shard_list, dim=0)
|
||||
|
||||
# create a distributed tensor
|
||||
tensor = torch.rand(4, 4)
|
||||
dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn)
|
||||
```
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be distributed.
|
||||
shard_fn (callable): The function to shard the tensor.
|
||||
gather_fn (callable): The function to gather the tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The distributed tensor.
|
||||
"""
|
||||
assert callable(shard_fn), 'The shard_fn must be callable.'
|
||||
assert callable(gather_fn), 'The gather_fn must be callable.'
|
||||
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
|
||||
|
||||
sharded_tensor = shard_fn(tensor)
|
||||
|
||||
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
||||
sharded_tensor.shard_fn = shard_fn
|
||||
sharded_tensor.gather_fn = gather_fn
|
||||
|
||||
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
||||
_hijack_detach_and_clone_for_customized_distributed_tensor(sharded_tensor)
|
||||
|
||||
return sharded_tensor
|
||||
|
||||
|
||||
def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Gather the given tensor to the global tensor.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): The distributed tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The global tensor.
|
||||
"""
|
||||
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
|
||||
return dtensor.gather_fn(dtensor)
|
||||
|
||||
|
||||
def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
||||
"""
|
||||
Convert the given customized distributed tensor to a parameter.
|
||||
"""
|
||||
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
|
||||
|
||||
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
|
||||
|
||||
# make it distributed as well
|
||||
param.shard_fn = dtensor.shard_fn
|
||||
param.gather_fn = dtensor.gather_fn
|
||||
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
|
||||
return param
|
||||
|
|
|
@ -27,8 +27,13 @@ def check_linear_1d_col():
|
|||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_col(x)
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone())
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
out = linear(x_for_unshard)
|
||||
gather_out = linear_col(x_for_shard)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
|
@ -39,6 +44,11 @@ def check_linear_1d_col():
|
|||
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
|
||||
assert_close(target_grad, linear_col.weight.grad)
|
||||
|
||||
# check the input gradients
|
||||
assert x_for_shard.grad is not None
|
||||
assert x_for_unshard.grad is not None
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_1d_row():
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
|
@ -49,8 +59,14 @@ def check_linear_1d_row():
|
|||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_row(x)
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone())
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
# run forward
|
||||
out = linear(x_for_unshard)
|
||||
gather_out = linear_row(x_for_shard)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
|
@ -61,11 +77,49 @@ def check_linear_1d_row():
|
|||
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
|
||||
assert_close(target_grad, linear_row.weight.grad)
|
||||
|
||||
# check the input gradients
|
||||
assert x_for_shard.grad is not None
|
||||
assert x_for_unshard.grad is not None
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_col_plus_row():
|
||||
linear_1 = nn.Linear(32, 128).cuda()
|
||||
linear_2 = nn.Linear(128, 32).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
|
||||
linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone())
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
# run forward
|
||||
unshard_out = linear_2(linear_1(x_for_unshard))
|
||||
shard_out = linear_row(linear_col(x_for_shard))
|
||||
assert_close(unshard_out, shard_out)
|
||||
|
||||
# check backward correctness
|
||||
unshard_out.sum().backward()
|
||||
shard_out.sum().backward()
|
||||
|
||||
rank = dist.get_rank()
|
||||
target_1_grad = torch.chunk(linear_1.weight.grad, 2, dim=0)[rank]
|
||||
assert_close(target_1_grad, linear_col.weight.grad)
|
||||
|
||||
# check the input gradients
|
||||
assert x_for_shard.grad is not None
|
||||
assert x_for_unshard.grad is not None
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_linear_1d_col()
|
||||
# check_linear_1d_row()
|
||||
check_linear_1d_row()
|
||||
check_linear_col_plus_row()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
|
@ -5,6 +5,7 @@ from torch.testing import assert_close
|
|||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer import LinearConv1D_Col, LinearConv1D_Row
|
||||
from colossalai.shardformer.layer.linear_conv import split_fused_qkv
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
@ -53,9 +54,15 @@ def check_linear_conv_1d_col():
|
|||
linear = Conv1D(192, 48).cuda()
|
||||
linear_conv_col = LinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, n_fused=3)
|
||||
|
||||
assert linear_conv_col.weight.shape == torch.Size([96, 48])
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear.bias.shape == torch.Size([192])
|
||||
assert linear_conv_col.weight.shape == torch.Size([48, 96])
|
||||
assert linear_conv_col.bias.shape == torch.Size([96])
|
||||
|
||||
# ensure weights are reversibly loadable
|
||||
linear_conv_col.load_state_dict(linear.state_dict())
|
||||
linear.load_state_dict(linear_conv_col.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 48).cuda()
|
||||
out = linear(x)
|
||||
|
@ -66,16 +73,16 @@ def check_linear_conv_1d_col():
|
|||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
|
||||
rank = dist.get_rank()
|
||||
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
|
||||
assert_close(target_grad.transpose(0, 1).contiguous(), linear_conv_col.weight.grad)
|
||||
target_grad = split_fused_qkv(linear.weight.grad, 3, None)
|
||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||
|
||||
|
||||
def check_linear_1d_row():
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||
|
||||
assert linear_row.weight.shape == torch.Size([192, 24])
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear_row.weight.shape == torch.Size([24, 192])
|
||||
assert linear_row.bias.shape == torch.Size([192])
|
||||
|
||||
# check computation correctness
|
||||
|
@ -89,13 +96,14 @@ def check_linear_1d_row():
|
|||
gather_out.sum().backward()
|
||||
|
||||
rank = dist.get_rank()
|
||||
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
|
||||
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
|
||||
assert_close(target_grad, linear_row.weight.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_linear_conv_1d_col()
|
||||
check_linear_1d_row()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
|
@ -20,20 +20,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
|||
|
||||
# check grad equality
|
||||
if org_model.__class__.__name__ == 'GPT2Model':
|
||||
org_grad = org_model.h[0].attn.c_attn.weight.grad
|
||||
shard_grad = sharded_model.h[0].attn.c_attn.weight.grad.transpose(0, 1).contiguous()
|
||||
org_grad = org_model.h[0].mlp.c_fc.weight.grad
|
||||
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
|
||||
else:
|
||||
org_grad = org_model.transformer.h[0].mlp.c_fc.weight.grad
|
||||
shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad.transpose(0, 1).contiguous()
|
||||
shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}"
|
||||
assert torch.allclose(
|
||||
org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
def check_gpt2(rank, world_size, port):
|
||||
|
|
Loading…
Reference in New Issue