[shardformer] supported fused qkv checkpoint (#4073)

pull/4157/head
Frank Lee 2023-06-23 16:07:09 +08:00
parent 0803a61412
commit 70c58cfd4f
10 changed files with 420 additions and 88 deletions

View File

@ -1,5 +1,6 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
try:
import fused_mix_prec_layer_norm_cuda
@ -46,6 +47,53 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None
class MatmulWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input_, weight)
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
total_input = input
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
class LinearWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
@ -58,9 +106,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input_, weight.t())
if bias is not None:
output = output + bias
output = F.linear(input_, weight, bias)
else:
output = F.linear(input_, weight)
return output
@staticmethod
@ -114,7 +163,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
return _gather(grad_output, ctx.dim, ctx.process_group), None, None
class _ReduceInput(torch.autograd.Function):
class _ReduceForward(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.
@ -132,6 +181,25 @@ class _ReduceInput(torch.autograd.Function):
return grad_output, None
class _ReduceBackward(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.
Args:
input_: input matrix.
parallel_mode: parallel mode.
"""
@staticmethod
def forward(ctx, input_, process_group):
ctx.process_group = process_group
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, ctx.process_group), None
def _reduce(input_, process_group):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
@ -198,6 +266,10 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
return _split(grad_output, ctx.dim, ctx.process_group), None, None
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
@ -210,5 +282,9 @@ def split_forward_gather_backward(input_, dim, process_group):
return _SplitForwardGatherBackward.apply(input_, dim, process_group)
def reduce_input(input_, process_group):
return _ReduceInput.apply(input_, process_group)
def reduce_forward(input_, process_group):
return _ReduceForward.apply(input_, process_group)
def reduce_backward(input_, process_group):
return _ReduceBackward.apply(input_, process_group)

View File

@ -15,7 +15,7 @@ from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
from ._operation import gather_forward_split_backward, reduce_input
from ._operation import gather_forward_split_backward, reduce_forward
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset
@ -276,5 +276,5 @@ class VocabParallelEmbedding1D(ParallelModule):
# Mask the output embedding.
output_parallel[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(output_parallel, self.process_group)
output = reduce_forward(output_parallel, self.process_group)
return output

View File

@ -15,12 +15,11 @@ from torch.nn.parameter import Parameter
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
from colossalai.utils.cuda import get_current_device
from ._operation import (
gather_forward_split_backward,
linear_with_async_comm,
reduce_input,
reduce_forward,
split_forward_gather_backward,
)
from .parallel_module import ParallelModule
@ -148,9 +147,10 @@ class Linear1D_Col(ParallelModule):
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
@ -209,17 +209,14 @@ class Linear1D_Row(ParallelModule):
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_colwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
@ -327,8 +324,7 @@ class Linear1D_Row(ParallelModule):
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, self.process_group)
output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:
@ -336,5 +332,3 @@ class Linear1D_Row(ParallelModule):
return output
else:
return output, self.bias
return output, self.bias
return output, self.bias

View File

@ -14,13 +14,18 @@ from torch.nn.parameter import Parameter
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
from colossalai.utils.cuda import get_current_device
from colossalai.tensor.d_tensor.api import (
customized_distributed_tensor_to_param,
distribute_tensor_with_customization,
shard_rowwise,
sharded_tensor_to_param,
)
from ._operation import (
gather_forward_split_backward,
linear_with_async_comm,
reduce_input,
matmul_with_async_comm,
reduce_backward,
reduce_forward,
split_forward_gather_backward,
)
from .parallel_module import ParallelModule
@ -29,11 +34,69 @@ from .utils import create_randomizer_with_offset
__all__ = ['LinearConv1D_Col', 'LinearConv1D_Row']
def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup):
"""
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
"""
# get the number of slice for the fused qkv
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
order = torch.arange(world_size * n_fused)
# split the fused qkv
# from
# [Q, K, V]
# to
# [Q1, Q2, K1, K2, V1, V2]
weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1)
# rearrange the slice into the final order
# from
# [Q1, Q2, K1, K2, V1, V2]
# to
# [Q1, K1, V1], [Q2, K2, V2]
weight_chunks_of_current_rank = [weight_chunks[i] for i in order[rank::world_size]]
weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1)
return weight_of_current_rank
def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup):
"""
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
"""
world_size = dist.get_world_size(group=process_group)
# gather the tensors
# from
# [Q1, K1, V1], [Q2, K2, V2]
# to
# [Q1, K1, V1, Q2, K2, V2]
origin_device = qkv.device
qkv = qkv.cuda()
gather_list = [torch.zeros_like(qkv) for _ in range(world_size)]
dist.all_gather(gather_list, qkv, group=process_group)
gather_weight = torch.cat(gather_list, dim=-1)
gather_weight = gather_weight.to(origin_device)
qkv = qkv.to(origin_device)
# rearrange the tensor slices
# from
# [Q1, K1, V1, Q2, K2, V2]
# to
# [Q1, Q2, K1, K2, V1, V2]
weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1)
reordered_chunk_list = []
for i in range(n_fused):
reordered_chunk_list.extend(weight_chunks[i::n_fused])
reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1)
return reordered_gather_weight
class LinearConv1D_Col(ParallelModule):
r"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface.
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface.
Args:
in_features (int): size of each input sample.
@ -41,6 +104,7 @@ class LinearConv1D_Col(ParallelModule):
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
@ -63,8 +127,10 @@ class LinearConv1D_Col(ParallelModule):
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False,
skip_bias_add: bool = False,
n_fused: int = 3,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
@ -75,23 +141,34 @@ class LinearConv1D_Col(ParallelModule):
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
self.device = device
self.n_fused = n_fused
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
self.async_communication = async_communication
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
self.out_features_per_partition = divide(out_features, self.num_partitions)
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)
def shard_fn(tensor):
return split_fused_qkv(tensor, self.n_fused, self.process_group)
def gather_fn(tensor):
return gather_fused_qkv(tensor, 3, self.process_group)
with torch.no_grad():
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
self.weight = customized_distributed_tensor_to_param(sharded_weight)
if bias:
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
bias = torch.empty(self.out_features, **factory_kwargs)
with torch.no_grad():
sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn)
self.bias = customized_distributed_tensor_to_param(sharded_bias)
else:
self.bias = None
@ -103,7 +180,7 @@ class LinearConv1D_Col(ParallelModule):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
*args, **kwargs) -> ParallelModule:
r"""
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
@ -135,29 +212,12 @@ class LinearConv1D_Col(ParallelModule):
# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on row is equal to shard on column
# first rearange the order of weight and bias
world_size = dist.get_world_size(group=process_group)
order = torch.arange(world_size * n_fused)
new_order = []
for i in range(world_size):
new_order.append(order[i::world_size])
new_order = torch.cat(new_order)
weight_chunks = torch.chunk(module.weight.data, world_size * n_fused, dim=1)
rearanged_weight_chunks = [weight_chunks[i] for i in new_order]
rearanged_weight = torch.cat(rearanged_weight_chunks, dim=1)
sharded_weight = shard_colwise(rearanged_weight, process_group)
linear_1d.weight.data.copy_(sharded_weight.T.contiguous())
sharded_weight = split_fused_qkv(module.weight.data, n_fused=n_fused, process_group=process_group)
linear_1d.weight.data.copy_(sharded_weight.data)
if bias:
bias_chunks = torch.chunk(module.bias.data, world_size * n_fused, dim=0)
rearanged_bias_chunks = [bias_chunks[i] for i in new_order]
rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0)
sharded_bias = shard_colwise(rearanged_bias, process_group)
linear_1d.bias.copy_(sharded_bias.contiguous())
sharded_bias = split_fused_qkv(module.bias.data, n_fused=n_fused, process_group=process_group)
linear_1d.bias.data.copy_(sharded_bias.data)
return linear_1d
@ -169,15 +229,18 @@ class LinearConv1D_Col(ParallelModule):
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[-1], \
assert input_.shape[-1] == self.weight.shape[0], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel = input_
input_parallel = reduce_backward(input_, self.process_group)
# input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group,
self.async_communication)
if self.gather_output:
# All-gather across the partitions.
@ -192,7 +255,8 @@ class LinearConv1D_Col(ParallelModule):
class LinearConv1D_Row(ParallelModule):
r""" Linear layer with row parallelism
r""" Linear layer with row parallelism.
This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface.
Args:
in_features (int): size of each input sample.
@ -243,11 +307,10 @@ class LinearConv1D_Row(ParallelModule):
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)
sharded_weight = shard_rowwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
if self.stream_chunk_num > 1:
# TODO() work for inference only
@ -295,7 +358,7 @@ class LinearConv1D_Row(ParallelModule):
# the weigh to the linear layer is a transpose
# thus shard on col is equal to shard on row
sharded_weight = shard_rowwise(module.weight.data, process_group)
linear_1d.weight.data.copy_(sharded_weight.T.contiguous())
linear_1d.weight.data.copy_(sharded_weight.data)
if bias:
linear_1d.bias.copy_(module.bias.data)
@ -325,12 +388,12 @@ class LinearConv1D_Row(ParallelModule):
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[-1], \
assert input_.shape[-1] == self.weight.shape[0], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
input_ = input_
else:
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions)
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
@ -342,7 +405,7 @@ class LinearConv1D_Row(ParallelModule):
output_parallel_list = [None for i in range(self.stream_chunk_num)]
handle_list = []
for i in range(self.stream_chunk_num):
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
output_parallel_list[i] = torch.matmul(input_, self.weight_list[i])
handle = torch.distributed.all_reduce(output_parallel_list[i],
group=self.process_group,
async_op=True)
@ -352,9 +415,8 @@ class LinearConv1D_Row(ParallelModule):
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, self.process_group)
output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:

View File

@ -12,11 +12,14 @@ from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
get_device_mesh,
get_sharding_spec,
is_customized_distributed_tensor,
is_distributed_tensor,
sharded_tensor_to_param,
to_global,
to_global_for_customized_distributed_tensor,
)
__all__ = ['ParallelModule']
@ -54,9 +57,10 @@ class ParallelModule(nn.Module, ABC):
for name, param in self._parameters.items():
if param is not None:
param_ = param if keep_vars else param.detach()
if is_distributed_tensor(param_):
destination[prefix + name] = to_global(param_)
elif is_customized_distributed_tensor(param_):
destination[prefix + name] = to_global_for_customized_distributed_tensor(param_)
else:
destination[prefix + name] = param_
@ -124,6 +128,8 @@ class ParallelModule(nn.Module, ABC):
sharding_spec = get_sharding_spec(param)
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
input_param = sharded_tensor_to_param(sharded_tensor)
elif is_customized_distributed_tensor(param):
input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn)
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks

View File

@ -1,10 +1,13 @@
from .api import (
compute_global_numel,
customized_distributed_tensor_to_param,
distribute_tensor,
distribute_tensor_with_customization,
get_device_mesh,
get_global_shape,
get_layout,
get_sharding_spec,
is_customized_distributed_tensor,
is_distributed_tensor,
is_sharded,
redistribute,
@ -12,6 +15,7 @@ from .api import (
shard_rowwise,
sharded_tensor_to_param,
to_global,
to_global_for_customized_distributed_tensor,
)
from .layout import Layout
from .sharding_spec import ShardingSpec
@ -19,6 +23,6 @@ from .sharding_spec import ShardingSpec
__all__ = [
'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise',
'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh',
'redistribute', 'get_layout'
'Layout', 'ShardingSpec'
'redistribute', 'get_layout', 'is_customized_distributed_tensor', 'distribute_tensor_with_customization',
'to_global_for_customized_distributed_tensor', 'customized_distributed_tensor_to_param', 'Layout', 'ShardingSpec'
]

View File

@ -305,3 +305,130 @@ def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return dtensor.dist_layout.sharding_spec
# ======================================================
# Some sharding does not obey the SPMD style
# e.g. Fused QKV layer in GPT2
# we support customize sharding with the following APIs
# ======================================================
def is_customized_distributed_tensor(tensor: torch.Tensor):
"""
Check whether the given tensor is a customized distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a customized distributed tensor.
"""
return hasattr(tensor, 'shard_fn') and hasattr(tensor, 'gather_fn')
def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
dtensor._old_detach = dtensor.detach
dtensor._old_clone = dtensor.clone
def new_detach(self):
t_ = self._old_detach()
t_.shard_fn = self.shard_fn
t_.gather_fn = self.gather_fn
return t_
def new_clone(self, *args, **kwargs):
t_ = self._old_clone(*args, **kwargs)
t_.shard_fn = self.shard_fn
t_.gather_fn = self.gather_fn
return t_
# bind the new methods to the tensor
dtensor.detach = new_detach.__get__(dtensor)
dtensor.clone = new_clone.__get__(dtensor)
return dtensor
def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_fn: callable):
"""
Distribute the given tensor with the given shard_fn and gather_fn.
Example:
```python
# define shard and gather functions
def shard_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
return tensor.chunk(world_size, dim=0)[rank]
def gather_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(shard_list, tensor)
return torch.cat(shard_list, dim=0)
# create a distributed tensor
tensor = torch.rand(4, 4)
dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn)
```
Args:
tensor (torch.Tensor): The tensor to be distributed.
shard_fn (callable): The function to shard the tensor.
gather_fn (callable): The function to gather the tensor.
Returns:
torch.Tensor: The distributed tensor.
"""
assert callable(shard_fn), 'The shard_fn must be callable.'
assert callable(gather_fn), 'The gather_fn must be callable.'
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
sharded_tensor = shard_fn(tensor)
# set the shard_fn and gather_fn as attributes of the distributed tensor
sharded_tensor.shard_fn = shard_fn
sharded_tensor.gather_fn = gather_fn
# set the shard_fn and gather_fn as attributes of the distributed tensor
_hijack_detach_and_clone_for_customized_distributed_tensor(sharded_tensor)
return sharded_tensor
def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
"""
Gather the given tensor to the global tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
torch.Tensor: The global tensor.
"""
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
return dtensor.gather_fn(dtensor)
def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
"""
Convert the given customized distributed tensor to a parameter.
"""
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
# make it distributed as well
param.shard_fn = dtensor.shard_fn
param.gather_fn = dtensor.gather_fn
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
return param

View File

@ -27,8 +27,13 @@ def check_linear_1d_col():
# check computation correctness
x = torch.rand(4, 32).cuda()
out = linear(x)
gather_out = linear_col(x)
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
out = linear(x_for_unshard)
gather_out = linear_col(x_for_shard)
assert_close(out, gather_out)
# check backward correctness
@ -39,6 +44,11 @@ def check_linear_1d_col():
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
assert_close(target_grad, linear_col.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_1d_row():
linear = nn.Linear(32, 128).cuda()
@ -49,8 +59,14 @@ def check_linear_1d_row():
# check computation correctness
x = torch.rand(4, 32).cuda()
out = linear(x)
gather_out = linear_row(x)
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
# run forward
out = linear(x_for_unshard)
gather_out = linear_row(x_for_shard)
assert_close(out, gather_out)
# check backward correctness
@ -61,11 +77,49 @@ def check_linear_1d_row():
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
assert_close(target_grad, linear_row.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_col_plus_row():
linear_1 = nn.Linear(32, 128).cuda()
linear_2 = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)
# check computation correctness
x = torch.rand(4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
# run forward
unshard_out = linear_2(linear_1(x_for_unshard))
shard_out = linear_row(linear_col(x_for_shard))
assert_close(unshard_out, shard_out)
# check backward correctness
unshard_out.sum().backward()
shard_out.sum().backward()
rank = dist.get_rank()
target_1_grad = torch.chunk(linear_1.weight.grad, 2, dim=0)[rank]
assert_close(target_1_grad, linear_col.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_1d_col()
# check_linear_1d_row()
check_linear_1d_row()
check_linear_col_plus_row()
@rerun_if_address_is_in_use()

View File

@ -5,6 +5,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer import LinearConv1D_Col, LinearConv1D_Row
from colossalai.shardformer.layer.linear_conv import split_fused_qkv
from colossalai.testing import rerun_if_address_is_in_use, spawn
@ -53,9 +54,15 @@ def check_linear_conv_1d_col():
linear = Conv1D(192, 48).cuda()
linear_conv_col = LinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, n_fused=3)
assert linear_conv_col.weight.shape == torch.Size([96, 48])
assert linear.weight.shape == torch.Size([48, 192])
assert linear.bias.shape == torch.Size([192])
assert linear_conv_col.weight.shape == torch.Size([48, 96])
assert linear_conv_col.bias.shape == torch.Size([96])
# ensure weights are reversibly loadable
linear_conv_col.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_conv_col.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()
out = linear(x)
@ -66,16 +73,16 @@ def check_linear_conv_1d_col():
out.sum().backward()
gather_out.sum().backward()
rank = dist.get_rank()
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
assert_close(target_grad.transpose(0, 1).contiguous(), linear_conv_col.weight.grad)
target_grad = split_fused_qkv(linear.weight.grad, 3, None)
assert_close(target_grad, linear_conv_col.weight.grad)
def check_linear_1d_row():
linear = Conv1D(192, 48).cuda()
linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
assert linear_row.weight.shape == torch.Size([192, 24])
assert linear.weight.shape == torch.Size([48, 192])
assert linear_row.weight.shape == torch.Size([24, 192])
assert linear_row.bias.shape == torch.Size([192])
# check computation correctness
@ -89,13 +96,14 @@ def check_linear_1d_row():
gather_out.sum().backward()
rank = dist.get_rank()
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
assert_close(target_grad, linear_row.weight.grad)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_conv_1d_col()
check_linear_1d_row()
@rerun_if_address_is_in_use()

View File

@ -20,20 +20,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check grad equality
if org_model.__class__.__name__ == 'GPT2Model':
org_grad = org_model.h[0].attn.c_attn.weight.grad
shard_grad = sharded_model.h[0].attn.c_attn.weight.grad.transpose(0, 1).contiguous()
org_grad = org_model.h[0].mlp.c_fc.weight.grad
shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad
else:
org_grad = org_model.transformer.h[0].mlp.c_fc.weight.grad
shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad.transpose(0, 1).contiguous()
shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=1)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}"
assert torch.allclose(
org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
def check_gpt2(rank, world_size, port):