mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] integrated linear 1D with dtensor (#3996)
* [shardformer] integrated linear 1D with dtensor * polish codepull/4157/head
parent
d3bc530849
commit
015af592f8
|
@ -10,6 +10,7 @@ from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
|
||||||
class ParallelLayer(nn.Module):
|
class ParallelLayer(nn.Module):
|
||||||
|
|
||||||
global_state_dict: bool = True
|
global_state_dict: bool = True
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -54,10 +54,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
|
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
|
||||||
ctx.save_for_backward(input_, weight)
|
ctx.save_for_backward(input_, weight)
|
||||||
ctx.use_bias = bias is not None
|
ctx.use_bias = bias is not None
|
||||||
ctx.parallel_mode = parallel_mode
|
ctx.process_group = process_group
|
||||||
ctx.async_grad_allreduce = async_grad_allreduce
|
ctx.async_grad_allreduce = async_grad_allreduce
|
||||||
|
|
||||||
output = torch.matmul(input_, weight.t())
|
output = torch.matmul(input_, weight.t())
|
||||||
|
@ -74,12 +74,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
grad_input = grad_output.matmul(weight)
|
grad_input = grad_output.matmul(weight)
|
||||||
grad_output = grad_output.contiguous()
|
grad_output = grad_output.contiguous()
|
||||||
# Convert the tensor shapes to 2D for execution compatibility
|
# Convert the tensor shapes to 2D for execution compatibility
|
||||||
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
|
if len(grad_output.shape) > 2:
|
||||||
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||||
|
total_input = total_input.view(-1, total_input.shape[-1])
|
||||||
|
|
||||||
if ctx.async_grad_allreduce:
|
if ctx.async_grad_allreduce:
|
||||||
# Asynchronous all-reduce
|
# Asynchronous all-reduce
|
||||||
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||||
# Delay the start of weight gradient computation shortly (3us) to have
|
# Delay the start of weight gradient computation shortly (3us) to have
|
||||||
# all-reduce scheduled first and have GPU resources allocated
|
# all-reduce scheduled first and have GPU resources allocated
|
||||||
_ = torch.empty(1, device=grad_output.device) + 1
|
_ = torch.empty(1, device=grad_output.device) + 1
|
||||||
|
@ -93,5 +94,123 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
|
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
|
"""
|
||||||
|
Split the input and keep only the corresponding chuck to the rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ (`torch.Tensor`): input matrix.
|
||||||
|
dim (int): the dimension to perform split and gather
|
||||||
|
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_, dim, process_group):
|
||||||
|
ctx.process_group = process_group
|
||||||
|
ctx.dim = dim
|
||||||
|
return _split(input_, dim, process_group)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
return _gather(grad_output, ctx.dim, ctx.process_group), None, None
|
||||||
|
|
||||||
|
|
||||||
|
class _ReduceInput(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
All-reduce the input from the model parallel region.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_: input matrix.
|
||||||
|
parallel_mode: parallel mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_, process_group):
|
||||||
|
return _reduce(input_, process_group)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
return grad_output, None
|
||||||
|
|
||||||
|
|
||||||
|
def _reduce(input_, process_group):
|
||||||
|
# skip if only one rank involved
|
||||||
|
if dist.get_world_size(process_group) == 1:
|
||||||
|
return input_
|
||||||
|
else:
|
||||||
|
dist.all_reduce(input_, group=process_group)
|
||||||
|
return input_
|
||||||
|
|
||||||
|
|
||||||
|
def _split(input_, dim=-1, process_group=None):
|
||||||
|
# skip if only one rank involved
|
||||||
|
world_size = dist.get_world_size(process_group)
|
||||||
|
if world_size == 1:
|
||||||
|
return input_
|
||||||
|
|
||||||
|
# Split along last dimension.
|
||||||
|
dim_size = input_.size(dim)
|
||||||
|
assert dim_size % world_size == 0, \
|
||||||
|
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
|
||||||
|
f'cannot split tensor evenly'
|
||||||
|
|
||||||
|
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
||||||
|
rank = dist.get_rank(process_group)
|
||||||
|
output = tensor_list[rank].contiguous()
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _gather(input_, dim=-1, process_group=None):
|
||||||
|
# skip if only one rank involved
|
||||||
|
world_size = dist.get_world_size(process_group)
|
||||||
|
if world_size == 1:
|
||||||
|
return input_
|
||||||
|
|
||||||
|
# all gather
|
||||||
|
rank = dist.get_rank(process_group)
|
||||||
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||||
|
tensor_list[rank] = input_
|
||||||
|
torch.distributed.all_gather(tensor_list, input_, group=process_group)
|
||||||
|
|
||||||
|
# concat
|
||||||
|
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||||
|
"""Gather the input from model parallel region and concatenate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_: input matrix.
|
||||||
|
parallel_mode: parallel mode.
|
||||||
|
dim: dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_, dim, process_group):
|
||||||
|
ctx.process_group = process_group
|
||||||
|
ctx.dim = dim
|
||||||
|
return _gather(input_, dim, process_group)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
return _split(grad_output, ctx.dim, ctx.process_group), None, None
|
||||||
|
|
||||||
|
|
||||||
|
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||||
|
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||||
|
|
||||||
|
|
||||||
|
def gather_forward_split_backward(input_, dim, process_group):
|
||||||
|
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
|
||||||
|
|
||||||
|
|
||||||
|
def split_forward_gather_backward(input_, dim, process_group):
|
||||||
|
return _SplitForwardGatherBackward.apply(input_, dim, process_group)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_input(input_, process_group):
|
||||||
|
return _ReduceInput.apply(input_, process_group)
|
||||||
|
|
|
@ -1,58 +1,20 @@
|
||||||
import os
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .utils import create_randomizer_with_offset
|
||||||
class SeedManager:
|
|
||||||
"""
|
|
||||||
This class is a random state manager to change random state for different random seed.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
original_state = torch.cuda.get_rng_state()
|
|
||||||
# TODO: unify this seed manager with the colossalai.context.random
|
|
||||||
seed = os.getpid()
|
|
||||||
torch.cuda.manual_seed(int(seed))
|
|
||||||
self.dropout_state = torch.cuda.get_rng_state()
|
|
||||||
torch.cuda.set_rng_state(original_state)
|
|
||||||
|
|
||||||
def set_mode(self, rng_state):
|
|
||||||
torch.cuda.set_rng_state(rng_state)
|
|
||||||
|
|
||||||
def get_current_mode(self):
|
|
||||||
current_state = torch.cuda.get_rng_state()
|
|
||||||
return current_state
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def dropout_mode(self):
|
|
||||||
"""
|
|
||||||
This is a context manager to change the dropout state and recover the original state.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
::
|
|
||||||
>>> with _seed_manager.dropout_mode():
|
|
||||||
>>> input = super().forward(input)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
current_mode = self.get_current_mode()
|
|
||||||
yield self.set_mode(self.dropout_state)
|
|
||||||
finally:
|
|
||||||
self.dropout_state = self.get_current_mode()
|
|
||||||
self.set_mode(current_mode)
|
|
||||||
|
|
||||||
|
|
||||||
_seed_manager = SeedManager()
|
|
||||||
|
|
||||||
|
|
||||||
class Dropout1D(nn.Dropout):
|
class Dropout1D(nn.Dropout):
|
||||||
|
|
||||||
def __init__(self, p=0.5, inplace=False):
|
def __init__(self, p=0.5, inplace=False, process_group=None):
|
||||||
super().__init__(p, inplace)
|
super().__init__(p, inplace)
|
||||||
|
|
||||||
|
# offset the seed with randomizer index and rank
|
||||||
|
seed = torch.random.initial_seed()
|
||||||
|
self.randomizer = create_randomizer_with_offset(seed, process_group=process_group)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
with _seed_manager.dropout_mode():
|
with self.randomizer.fork_rng():
|
||||||
input = super().forward(input)
|
input = super().forward(input)
|
||||||
return input
|
return input
|
||||||
|
|
|
@ -2,12 +2,16 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from colossalai.communication import broadcast
|
from colossalai.communication import broadcast
|
||||||
|
@ -22,13 +26,11 @@ from colossalai.nn.layer.parallel_1d._utils import (
|
||||||
gather_forward_split_backward,
|
gather_forward_split_backward,
|
||||||
get_parallel_input,
|
get_parallel_input,
|
||||||
reduce_grad,
|
reduce_grad,
|
||||||
reduce_input,
|
|
||||||
set_parallel_input,
|
set_parallel_input,
|
||||||
split_forward_gather_backward,
|
|
||||||
)
|
)
|
||||||
from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition
|
from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition
|
||||||
from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding
|
from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding
|
||||||
from colossalai.registry import LAYERS
|
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
|
||||||
from colossalai.utils.checkpointing import (
|
from colossalai.utils.checkpointing import (
|
||||||
broadcast_state_dict,
|
broadcast_state_dict,
|
||||||
gather_tensor_parallel_state_dict,
|
gather_tensor_parallel_state_dict,
|
||||||
|
@ -36,7 +38,13 @@ from colossalai.utils.checkpointing import (
|
||||||
)
|
)
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
|
||||||
from ._operation import linear_with_async_comm
|
from ._operation import (
|
||||||
|
gather_forward_split_backward,
|
||||||
|
linear_with_async_comm,
|
||||||
|
reduce_input,
|
||||||
|
split_forward_gather_backward,
|
||||||
|
)
|
||||||
|
from .utils import create_randomizer_with_offset
|
||||||
|
|
||||||
Fast_LN = None
|
Fast_LN = None
|
||||||
try:
|
try:
|
||||||
|
@ -46,17 +54,172 @@ except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# @LAYERS.register_module
|
class ParallelModule(nn.Module, ABC):
|
||||||
class Linear1D(ColossalaiModule):
|
|
||||||
r"""Linear layer for 1D parallelism.
|
@abstractmethod
|
||||||
|
def from_native_module(module: nn.Module,
|
||||||
|
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule":
|
||||||
|
"""
|
||||||
|
Convert a native PyTorch module to a parallelized module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): the module to be converted.
|
||||||
|
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
|
||||||
|
If this is a list, the process group at the ith index of the list will correspond to the process group
|
||||||
|
in the ith axis of the device mesh. Defaults to None, which means the global process group.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Linear1D_Col(ParallelModule):
|
||||||
|
r"""Linear layer with column parallelism.
|
||||||
|
|
||||||
|
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
||||||
|
its second dimension as :math:`A = [A_1, ..., A_p]`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_features (int): size of each input sample.
|
in_features (int): size of each input sample.
|
||||||
out_features (int): size of each output sample.
|
out_features (int): size of each output sample.
|
||||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||||
gather_output (bool, optional): Whether to call all-gather on output, defaults to False.
|
device (`torch.device`): The device of parameters, defaults to None.
|
||||||
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
|
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||||
|
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||||
|
to all GPUs, otherwise, every GPU will have its output
|
||||||
|
which is :math:`Y_i = XA_i`, defaults to False
|
||||||
|
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||||
|
which is preserved for kernel fusion, defaults to False
|
||||||
|
weight_initializer (`typing.Callable`):
|
||||||
|
The initializer of weight, defaults to kaiming uniform initializer.
|
||||||
|
bias_initializer (`typing.Callable`):
|
||||||
|
The initializer of bias, defaults to xavier uniform initializer.
|
||||||
|
|
||||||
|
More details about ``initializer`` please refer to
|
||||||
|
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
dtype: torch.dtype = None,
|
||||||
|
device: torch.device = None,
|
||||||
|
process_group: ProcessGroup = None,
|
||||||
|
gather_output: bool = False,
|
||||||
|
skip_bias_add: bool = False,
|
||||||
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Keep input parameters
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.gather_output = gather_output
|
||||||
|
self.skip_bias_add = skip_bias_add
|
||||||
|
self.device = device
|
||||||
|
self.process_group = process_group
|
||||||
|
self.num_partitions = dist.get_world_size(self.process_group)
|
||||||
|
|
||||||
|
if skip_bias_add and not bias:
|
||||||
|
raise ValueError('cannot skip bias addition if bias is None')
|
||||||
|
|
||||||
|
self.out_features_per_partition = divide(out_features, self.num_partitions)
|
||||||
|
|
||||||
|
# Parameters.
|
||||||
|
# Initialize weight.
|
||||||
|
if device is None:
|
||||||
|
device = get_current_device()
|
||||||
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
# offset the seed with randomizer index and rank
|
||||||
|
seed = torch.random.initial_seed()
|
||||||
|
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||||
|
|
||||||
|
with self.randomizer.fork_rng(enable_cpu=True):
|
||||||
|
self.reset_parameters(weight_initializer, bias_initializer)
|
||||||
|
|
||||||
|
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||||
|
**kwargs) -> ParallelModule:
|
||||||
|
r"""
|
||||||
|
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||||
|
"""
|
||||||
|
# get the attributes
|
||||||
|
in_features = module.in_features
|
||||||
|
out_features = module.out_features
|
||||||
|
bias = module.bias is not None
|
||||||
|
device = module.weight.device
|
||||||
|
|
||||||
|
# ensure only one process group is passed
|
||||||
|
if isinstance(process_group, (list, tuple)):
|
||||||
|
assert len(process_group) == 1, \
|
||||||
|
f'Expected only one process group, got {len(process_group)}.'
|
||||||
|
process_group = process_group[0]
|
||||||
|
|
||||||
|
linear_1d = Linear1D_Col(in_features=in_features,
|
||||||
|
out_features=out_features,
|
||||||
|
bias=bias,
|
||||||
|
device=device,
|
||||||
|
process_group=process_group,
|
||||||
|
*args,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
# TODO: copy the sharded weights
|
||||||
|
with torch.no_grad():
|
||||||
|
# the weigh to the linear layer is a transpose
|
||||||
|
# thus shard on row is equal to shard on column
|
||||||
|
sharded_weight = shard_rowwise(module.weight.data, process_group)
|
||||||
|
linear_1d.weight.data.copy_(sharded_weight)
|
||||||
|
if bias:
|
||||||
|
sharded_bias = shard_colwise(module.bias.data, process_group)
|
||||||
|
linear_1d.bias.copy_(sharded_bias)
|
||||||
|
|
||||||
|
return linear_1d
|
||||||
|
|
||||||
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
|
fan_in, fan_out = self.in_features, self.out_features
|
||||||
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||||
|
if self.bias is not None:
|
||||||
|
bias_initializer(self.bias, fan_in=fan_in)
|
||||||
|
|
||||||
|
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
|
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||||
|
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||||
|
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||||
|
# Set up backprop all-reduce.
|
||||||
|
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||||
|
input_parallel = input_
|
||||||
|
# Matrix multiply.
|
||||||
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||||
|
|
||||||
|
if self.gather_output:
|
||||||
|
# All-gather across the partitions.
|
||||||
|
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||||
|
else:
|
||||||
|
output = output_parallel
|
||||||
|
|
||||||
|
if self.skip_bias_add:
|
||||||
|
return output, self.bias
|
||||||
|
else:
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Linear1D_Row(ParallelModule):
|
||||||
|
r""" Linear layer with row parallelism
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features (int): size of each input sample.
|
||||||
|
out_features (int): size of each output sample.
|
||||||
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||||
|
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||||
|
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
||||||
|
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||||
which is preserved for kernel fusion, defaults to False
|
which is preserved for kernel fusion, defaults to False
|
||||||
weight_initializer (:class:`typing.Callable`, optional):
|
weight_initializer (:class:`typing.Callable`, optional):
|
||||||
The initializer of weight, defaults to kaiming uniform initializer.
|
The initializer of weight, defaults to kaiming uniform initializer.
|
||||||
|
@ -72,32 +235,149 @@ class Linear1D(ColossalaiModule):
|
||||||
out_features: int,
|
out_features: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
gather_output: bool = False,
|
device: torch.device = None,
|
||||||
|
process_group: ProcessGroup = None,
|
||||||
|
parallel_input: bool = True,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
parallel_input = get_parallel_input()
|
stream_chunk_num: int = 1):
|
||||||
if not parallel_input and not gather_output:
|
super().__init__()
|
||||||
layer = Linear1D_Col(in_features,
|
|
||||||
out_features,
|
self.stream_chunk_num = stream_chunk_num
|
||||||
bias=bias,
|
|
||||||
dtype=dtype,
|
# Keep input parameters
|
||||||
skip_bias_add=skip_bias_add,
|
self.in_features = in_features
|
||||||
weight_initializer=weight_initializer,
|
self.out_features = out_features
|
||||||
bias_initializer=bias_initializer)
|
self.parallel_input = parallel_input
|
||||||
|
self.skip_bias_add = skip_bias_add
|
||||||
|
self.process_group = process_group
|
||||||
|
self.num_partitions = dist.get_world_size(self.process_group)
|
||||||
|
|
||||||
|
if skip_bias_add and not bias:
|
||||||
|
raise ValueError('cannot skip bias addition if bias is None')
|
||||||
|
|
||||||
|
# Divide the weight matrix along the last dimension.
|
||||||
|
self.input_size_per_partition = divide(in_features, self.num_partitions)
|
||||||
|
|
||||||
|
# Parameters.
|
||||||
|
# Initialize weight.
|
||||||
|
if device is None:
|
||||||
|
device = get_current_device()
|
||||||
|
|
||||||
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
|
||||||
|
|
||||||
|
if self.stream_chunk_num > 1:
|
||||||
|
# TODO() work for inference only
|
||||||
|
self.chunk_weight()
|
||||||
|
if bias:
|
||||||
|
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||||
else:
|
else:
|
||||||
layer = Linear1D_Row(in_features,
|
self.bias = None
|
||||||
out_features,
|
|
||||||
|
# offset the seed with randomizer index and rank
|
||||||
|
seed = torch.random.initial_seed()
|
||||||
|
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||||
|
|
||||||
|
with self.randomizer.fork_rng(enable_cpu=True):
|
||||||
|
self.reset_parameters(weight_initializer, bias_initializer)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||||
|
**kwargs) -> ParallelModule:
|
||||||
|
r"""
|
||||||
|
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||||
|
"""
|
||||||
|
# get the attributes
|
||||||
|
in_features = module.in_features
|
||||||
|
out_features = module.out_features
|
||||||
|
bias = module.bias is not None
|
||||||
|
device = module.weight.device
|
||||||
|
|
||||||
|
# ensure only one process group is passed
|
||||||
|
if isinstance(process_group, (list, tuple)):
|
||||||
|
assert len(process_group) == 1, \
|
||||||
|
f'Expected only one process group, got {len(process_group)}.'
|
||||||
|
process_group = process_group[0]
|
||||||
|
|
||||||
|
linear_1d = Linear1D_Row(in_features=in_features,
|
||||||
|
out_features=out_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
dtype=dtype,
|
device=device,
|
||||||
parallel_input=parallel_input,
|
process_group=process_group,
|
||||||
skip_bias_add=skip_bias_add,
|
*args,
|
||||||
weight_initializer=weight_initializer,
|
**kwargs)
|
||||||
bias_initializer=bias_initializer)
|
|
||||||
super().__init__(layer)
|
# TODO: copy the sharded weights
|
||||||
|
with torch.no_grad():
|
||||||
|
# the weigh to the linear layer is a transpose
|
||||||
|
# thus shard on col is equal to shard on row
|
||||||
|
sharded_weight = shard_colwise(module.weight.data, process_group)
|
||||||
|
linear_1d.weight.data.copy_(sharded_weight)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
linear_1d.bias.copy_(module.bias.data)
|
||||||
|
|
||||||
|
return linear_1d
|
||||||
|
|
||||||
|
def chunk_weight(self):
|
||||||
|
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
|
||||||
|
|
||||||
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
|
fan_in, fan_out = self.in_features, self.out_features
|
||||||
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
bias_initializer(self.bias, fan_in=fan_in)
|
||||||
|
if self.process_group is None:
|
||||||
|
src_rank = 0
|
||||||
|
else:
|
||||||
|
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
|
||||||
|
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
|
||||||
|
|
||||||
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
|
# Set up backprop all-reduce.
|
||||||
|
if self.parallel_input:
|
||||||
|
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||||
|
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||||
|
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||||
|
input_ = input_
|
||||||
|
else:
|
||||||
|
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \
|
||||||
|
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||||
|
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions)
|
||||||
|
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
|
||||||
|
|
||||||
|
if self.stream_chunk_num > 1:
|
||||||
|
if self.training:
|
||||||
|
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
|
||||||
|
with torch.no_grad():
|
||||||
|
output_parallel_list = [None for i in range(self.stream_chunk_num)]
|
||||||
|
handle_list = []
|
||||||
|
for i in range(self.stream_chunk_num):
|
||||||
|
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
|
||||||
|
handle = torch.distributed.all_reduce(output_parallel_list[i],
|
||||||
|
group=self.process_group,
|
||||||
|
async_op=True)
|
||||||
|
handle_list.append(handle)
|
||||||
|
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
|
||||||
|
for handle in handle_list:
|
||||||
|
handle.wait()
|
||||||
|
output = torch.cat(output_parallel_list, dim=-1)
|
||||||
|
else:
|
||||||
|
output_parallel = F.linear(input_, self.weight)
|
||||||
|
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
|
||||||
|
output = reduce_input(output_parallel, self.process_group)
|
||||||
|
|
||||||
|
if not self.skip_bias_add:
|
||||||
|
if self.bias is not None:
|
||||||
|
output = output + self.bias
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
return output, self.bias
|
||||||
|
|
||||||
|
|
||||||
# @LAYERS.register_module
|
|
||||||
class LayerNorm1D(ColossalaiModule):
|
class LayerNorm1D(ColossalaiModule):
|
||||||
r"""
|
r"""
|
||||||
Layer Normalization for colossalai
|
Layer Normalization for colossalai
|
||||||
|
@ -152,7 +432,6 @@ class LayerNorm1D(ColossalaiModule):
|
||||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||||
|
|
||||||
|
|
||||||
# @LAYERS.register_module
|
|
||||||
class Classifier1D(ParallelLayer):
|
class Classifier1D(ParallelLayer):
|
||||||
r"""RowLinear with given weight. Classifier of 1D parallelism.
|
r"""RowLinear with given weight. Classifier of 1D parallelism.
|
||||||
|
|
||||||
|
@ -288,7 +567,6 @@ class Classifier1D(ParallelLayer):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
# @LAYERS.register_module
|
|
||||||
class VocabParallelClassifier1D(ParallelLayer):
|
class VocabParallelClassifier1D(ParallelLayer):
|
||||||
r"""ColLinear with given weight. Classifier of 1D parallelism.
|
r"""ColLinear with given weight. Classifier of 1D parallelism.
|
||||||
|
|
||||||
|
@ -424,317 +702,8 @@ class VocabParallelClassifier1D(ParallelLayer):
|
||||||
|
|
||||||
|
|
||||||
# @LAYERS.register_module
|
# @LAYERS.register_module
|
||||||
class Linear1D_Col(ParallelLayer):
|
|
||||||
r"""Linear layer with column parallelism.
|
|
||||||
|
|
||||||
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
|
||||||
its second dimension as :math:`A = [A_1, ..., A_p]`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_features (int): size of each input sample.
|
|
||||||
out_features (int): size of each output sample.
|
|
||||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
|
||||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
|
||||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
|
||||||
to all GPUs, otherwise, every GPU will have its output
|
|
||||||
which is :math:`Y_i = XA_i`, defaults to False
|
|
||||||
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
|
|
||||||
which is preserved for kernel fusion, defaults to False
|
|
||||||
weight_initializer (:class:`typing.Callable`, optional):
|
|
||||||
The initializer of weight, defaults to kaiming uniform initializer.
|
|
||||||
bias_initializer (:class:`typing.Callable`, optional):
|
|
||||||
The initializer of bias, defaults to xavier uniform initializer.
|
|
||||||
|
|
||||||
More details about ``initializer`` please refer to
|
|
||||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
in_features: int,
|
|
||||||
out_features: int,
|
|
||||||
bias: bool = True,
|
|
||||||
dtype: torch.dtype = None,
|
|
||||||
gather_output: bool = False,
|
|
||||||
skip_bias_add: bool = False,
|
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# Keep input parameters
|
|
||||||
self.in_features = in_features
|
|
||||||
self.out_features = out_features
|
|
||||||
self.gather_output = gather_output
|
|
||||||
self.skip_bias_add = skip_bias_add
|
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
|
||||||
raise ValueError('cannot skip bias addition if bias is None')
|
|
||||||
|
|
||||||
# self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size)
|
|
||||||
self.out_features_per_partition = out_features
|
|
||||||
|
|
||||||
# Parameters.
|
|
||||||
# Initialize weight.
|
|
||||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
|
||||||
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
|
|
||||||
|
|
||||||
if bias:
|
|
||||||
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
with seed(ParallelMode.TENSOR):
|
|
||||||
self.reset_parameters(weight_initializer, bias_initializer)
|
|
||||||
self._set_tensor_parallel_attributes()
|
|
||||||
is_parallel_output = not self.gather_output
|
|
||||||
set_parallel_input(is_parallel_output)
|
|
||||||
|
|
||||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
|
||||||
fan_in, fan_out = self.in_features, self.out_features
|
|
||||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
|
||||||
if self.bias is not None:
|
|
||||||
bias_initializer(self.bias, fan_in=fan_in)
|
|
||||||
|
|
||||||
def _set_tensor_parallel_attributes(self):
|
|
||||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
|
||||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
|
||||||
if self.bias is not None:
|
|
||||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
|
||||||
|
|
||||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
|
||||||
local_state = OrderedDict()
|
|
||||||
weight_key = prefix + 'weight'
|
|
||||||
bias_key = prefix + 'bias'
|
|
||||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
||||||
# weight
|
|
||||||
weight = state_dict.pop(weight_key, None)
|
|
||||||
if weight is not None:
|
|
||||||
local_state[weight_key] = weight
|
|
||||||
# bias
|
|
||||||
if self.bias is not None:
|
|
||||||
bias = state_dict.pop(bias_key, None)
|
|
||||||
if bias is not None:
|
|
||||||
local_state[bias_key] = bias
|
|
||||||
|
|
||||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
|
||||||
ParallelMode.PARALLEL_1D,
|
|
||||||
dims={
|
|
||||||
weight_key: 0,
|
|
||||||
bias_key: 0
|
|
||||||
},
|
|
||||||
partition_states={
|
|
||||||
weight_key: True,
|
|
||||||
bias_key: True
|
|
||||||
})
|
|
||||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
|
||||||
|
|
||||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
|
||||||
weight_key = prefix + 'weight'
|
|
||||||
bias_key = prefix + 'bias'
|
|
||||||
local_state = OrderedDict({weight_key: self.weight})
|
|
||||||
if self.bias is not None:
|
|
||||||
local_state[bias_key] = self.bias
|
|
||||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
|
||||||
ParallelMode.PARALLEL_1D,
|
|
||||||
dims={
|
|
||||||
weight_key: 0,
|
|
||||||
bias_key: 0
|
|
||||||
},
|
|
||||||
partition_states={
|
|
||||||
weight_key: True,
|
|
||||||
bias_key: True
|
|
||||||
},
|
|
||||||
keep_vars=keep_vars)
|
|
||||||
destination.update(local_state)
|
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
|
||||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
|
||||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
|
||||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
|
||||||
# Set up backprop all-reduce.
|
|
||||||
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
|
||||||
input_parallel = input_
|
|
||||||
# Matrix multiply.
|
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
|
||||||
# output_parallel = F.linear(input_parallel, self.weight, bias)
|
|
||||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True)
|
|
||||||
if self.gather_output:
|
|
||||||
# All-gather across the partitions.
|
|
||||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
|
||||||
else:
|
|
||||||
output = output_parallel
|
|
||||||
|
|
||||||
if self.skip_bias_add:
|
|
||||||
return output, self.bias
|
|
||||||
else:
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
# @LAYERS.register_module
|
|
||||||
class Linear1D_Row(ParallelLayer):
|
|
||||||
r""" Linear layer with row parallelism
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_features (int): size of each input sample.
|
|
||||||
out_features (int): size of each output sample.
|
|
||||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
|
||||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
|
||||||
parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False.
|
|
||||||
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
|
|
||||||
which is preserved for kernel fusion, defaults to False
|
|
||||||
weight_initializer (:class:`typing.Callable`, optional):
|
|
||||||
The initializer of weight, defaults to kaiming uniform initializer.
|
|
||||||
bias_initializer (:class:`typing.Callable`, optional):
|
|
||||||
The initializer of bias, defaults to xavier uniform initializer.
|
|
||||||
|
|
||||||
More details about ``initializer`` please refer to
|
|
||||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
in_features: int,
|
|
||||||
out_features: int,
|
|
||||||
bias: bool = True,
|
|
||||||
dtype: torch.dtype = None,
|
|
||||||
parallel_input: bool = True,
|
|
||||||
skip_bias_add: bool = False,
|
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
|
||||||
stream_chunk_num: int = 1):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.stream_chunk_num = stream_chunk_num
|
|
||||||
|
|
||||||
# Keep input parameters
|
|
||||||
self.in_features = in_features
|
|
||||||
self.out_features = out_features
|
|
||||||
self.parallel_input = parallel_input
|
|
||||||
self.skip_bias_add = skip_bias_add
|
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
|
||||||
raise ValueError('cannot skip bias addition if bias is None')
|
|
||||||
|
|
||||||
# Divide the weight matrix along the last dimension.
|
|
||||||
# self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size)
|
|
||||||
self.input_size_per_partition = in_features
|
|
||||||
|
|
||||||
# Parameters.
|
|
||||||
# Initialize weight.
|
|
||||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
|
||||||
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
|
|
||||||
|
|
||||||
if self.stream_chunk_num > 1:
|
|
||||||
# TODO() work for inference only
|
|
||||||
self.chunk_weight()
|
|
||||||
if bias:
|
|
||||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
with seed(ParallelMode.TENSOR):
|
|
||||||
self.reset_parameters(weight_initializer, bias_initializer)
|
|
||||||
self._set_tensor_parallel_attributes()
|
|
||||||
set_parallel_input(False)
|
|
||||||
|
|
||||||
def chunk_weight(self):
|
|
||||||
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
|
|
||||||
|
|
||||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
|
||||||
fan_in, fan_out = self.in_features, self.out_features
|
|
||||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
|
||||||
if self.bias is not None:
|
|
||||||
bias_initializer(self.bias, fan_in=fan_in)
|
|
||||||
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
|
|
||||||
|
|
||||||
def _set_tensor_parallel_attributes(self):
|
|
||||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
|
||||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
|
||||||
|
|
||||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
|
||||||
local_state = OrderedDict()
|
|
||||||
weight_key = prefix + 'weight'
|
|
||||||
bias_key = prefix + 'bias'
|
|
||||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
|
||||||
# weight
|
|
||||||
weight = state_dict.pop(weight_key, None)
|
|
||||||
if weight is not None:
|
|
||||||
local_state[weight_key] = weight
|
|
||||||
# bias
|
|
||||||
if self.bias is not None:
|
|
||||||
bias = state_dict.pop(bias_key, None)
|
|
||||||
if bias is not None:
|
|
||||||
local_state[bias_key] = bias
|
|
||||||
|
|
||||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
|
||||||
ParallelMode.PARALLEL_1D,
|
|
||||||
dims={
|
|
||||||
weight_key: -1,
|
|
||||||
bias_key: 0
|
|
||||||
},
|
|
||||||
partition_states={
|
|
||||||
weight_key: True,
|
|
||||||
bias_key: False
|
|
||||||
})
|
|
||||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
|
||||||
|
|
||||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
|
||||||
weight_key = prefix + 'weight'
|
|
||||||
bias_key = prefix + 'bias'
|
|
||||||
local_state = OrderedDict({weight_key: self.weight})
|
|
||||||
if self.bias is not None:
|
|
||||||
local_state[bias_key] = self.bias
|
|
||||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
|
||||||
ParallelMode.PARALLEL_1D,
|
|
||||||
dims={
|
|
||||||
weight_key: -1,
|
|
||||||
bias_key: 0
|
|
||||||
},
|
|
||||||
partition_states={
|
|
||||||
weight_key: True,
|
|
||||||
bias_key: False
|
|
||||||
},
|
|
||||||
keep_vars=keep_vars)
|
|
||||||
destination.update(local_state)
|
|
||||||
|
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
|
||||||
# Set up backprop all-reduce.
|
|
||||||
if self.parallel_input:
|
|
||||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
|
||||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
|
||||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
|
||||||
input_ = input_
|
|
||||||
else:
|
|
||||||
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
|
|
||||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
|
||||||
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
|
|
||||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
|
||||||
|
|
||||||
if self.stream_chunk_num > 1:
|
|
||||||
if self.training:
|
|
||||||
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
|
|
||||||
with torch.no_grad():
|
|
||||||
output_parallel_list = [None for i in range(self.stream_chunk_num)]
|
|
||||||
handle_list = []
|
|
||||||
for i in range(self.stream_chunk_num):
|
|
||||||
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
|
|
||||||
handle = torch.distributed.all_reduce(output_parallel_list[i],
|
|
||||||
group=gpc.get_group(ParallelMode.PARALLEL_1D),
|
|
||||||
async_op=True)
|
|
||||||
handle_list.append(handle)
|
|
||||||
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
|
|
||||||
for handle in handle_list:
|
|
||||||
handle.wait()
|
|
||||||
output = torch.cat(output_parallel_list, dim=-1)
|
|
||||||
else:
|
|
||||||
output_parallel = F.linear(input_, self.weight)
|
|
||||||
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
|
|
||||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
|
||||||
if not self.skip_bias_add:
|
|
||||||
if self.bias is not None:
|
|
||||||
output = output + self.bias
|
|
||||||
return output
|
|
||||||
else:
|
|
||||||
return output, self.bias
|
|
||||||
|
|
||||||
|
|
||||||
# @LAYERS.register_module
|
|
||||||
class Embedding1D(ParallelLayer):
|
class Embedding1D(ParallelLayer):
|
||||||
r"""Embedding for 1D parallelism.
|
r"""Embedding for 1D parallelism.
|
||||||
|
|
||||||
|
@ -842,7 +811,6 @@ class Embedding1D(ParallelLayer):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
# @LAYERS.register_module
|
|
||||||
class VocabParallelEmbedding1D(ParallelLayer):
|
class VocabParallelEmbedding1D(ParallelLayer):
|
||||||
r"""Embedding parallelized in the vocabulary dimension.
|
r"""Embedding parallelized in the vocabulary dimension.
|
||||||
|
|
||||||
|
@ -960,7 +928,6 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
# @LAYERS.register_module
|
|
||||||
class Dropout1D(ParallelLayer):
|
class Dropout1D(ParallelLayer):
|
||||||
"""Dropout layer of 1D parallelism.
|
"""Dropout layer of 1D parallelism.
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,138 @@
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
|
||||||
|
class Randomizer:
|
||||||
|
"""
|
||||||
|
Randomizer enables the program to be executed under a different seed within the context.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
randomizer = Randomizer(seed=1024)
|
||||||
|
|
||||||
|
with randomizer.fork():
|
||||||
|
# do something here with seed 1024
|
||||||
|
do_something()
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int): The random seed to set.
|
||||||
|
enable_cpu (bool): fork the CPU RNG state as well.
|
||||||
|
with_index (bool): whether to use the index of the randomizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_INDEX = 0
|
||||||
|
|
||||||
|
def __init__(self, seed: int):
|
||||||
|
# TODO: remove colossalai.context.random
|
||||||
|
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
# Handle CUDA rng state
|
||||||
|
# 1. get the current rng state
|
||||||
|
# 2. set the seed and store the rng state
|
||||||
|
# 3. recover the original rng state
|
||||||
|
cuda_original_rng_state = torch.cuda.get_rng_state()
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
self.cuda_rng_state = torch.cuda.get_rng_state()
|
||||||
|
torch.cuda.set_rng_state(cuda_original_rng_state)
|
||||||
|
|
||||||
|
# to the same for cpu rng state
|
||||||
|
cpu_original_rng_state = torch.get_rng_state()
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
self.cpu_rng_state = torch.get_rng_state()
|
||||||
|
torch.set_rng_state(cpu_original_rng_state)
|
||||||
|
|
||||||
|
def _set_cuda_rng_state(self, rng_state):
|
||||||
|
torch.cuda.set_rng_state(rng_state)
|
||||||
|
|
||||||
|
def _get_cuda_rng_state(self):
|
||||||
|
current_state = torch.cuda.get_rng_state()
|
||||||
|
return current_state
|
||||||
|
|
||||||
|
def _set_cpu_rng_state(self, rng_state):
|
||||||
|
torch.set_rng_state(rng_state)
|
||||||
|
|
||||||
|
def _get_cpu_rng_state(self):
|
||||||
|
current_state = torch.get_rng_state()
|
||||||
|
return current_state
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def fork_rng(self, enable_cpu: bool = False):
|
||||||
|
"""
|
||||||
|
This is a context manager to change the dropout state and recover the original state.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
::
|
||||||
|
>>> with _seed_manager.dropout_mode():
|
||||||
|
>>> input = super().forward(input)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
current_cuda_rng_state = self._get_cuda_rng_state()
|
||||||
|
self._set_cuda_rng_state(self.cuda_rng_state)
|
||||||
|
|
||||||
|
if enable_cpu:
|
||||||
|
current_cpu_rng_state = self._get_cpu_rng_state()
|
||||||
|
self._set_cpu_rng_state(self.cpu_rng_state)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self.cuda_rng_state = self._get_cuda_rng_state()
|
||||||
|
self._set_cuda_rng_state(current_cuda_rng_state)
|
||||||
|
|
||||||
|
if enable_cpu:
|
||||||
|
self.cpu_rng_state = self._get_cpu_rng_state()
|
||||||
|
self._set_cpu_rng_state(current_cpu_rng_state)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def index():
|
||||||
|
"""
|
||||||
|
Return the index of the randomizer. The index is useful when the user wants
|
||||||
|
to introduce some randomness in the program.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The index will increment by one each time this method is called.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# assume we need a randomizer to init the weight of different layers
|
||||||
|
# we can use the index of the randomizer to do so that
|
||||||
|
# each layer has its own randomizer with a different seed
|
||||||
|
base_seed = torch.random.initial_seed()
|
||||||
|
seed = base_seed + Randomizer.index()
|
||||||
|
randomizer = Randomizer(seed)
|
||||||
|
|
||||||
|
with randomizer.fork():
|
||||||
|
init_weights()
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
idx = Randomizer._INDEX
|
||||||
|
Randomizer._INDEX += 1
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
def create_randomizer_with_offset(seed: int, process_group: ProcessGroup = None):
|
||||||
|
"""
|
||||||
|
Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int): The base random seed to set.
|
||||||
|
enable_cpu (bool): fork the CPU RNG state as well.
|
||||||
|
process_group (ProcessGroup): the process group to get the rank from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Randomizer: the randomizer with offset.
|
||||||
|
"""
|
||||||
|
offset = Randomizer.index()
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
rank = dist.get_rank(process_group)
|
||||||
|
offset += rank
|
||||||
|
|
||||||
|
seed += offset
|
||||||
|
return Randomizer(seed=seed)
|
|
@ -0,0 +1,44 @@
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
from .d_tensor import DTensor
|
||||||
|
from .sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
|
||||||
|
def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
|
||||||
|
"""
|
||||||
|
Shard the first dim of the given tensor
|
||||||
|
"""
|
||||||
|
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
||||||
|
if group_or_device_mesh is None:
|
||||||
|
group_or_device_mesh = dist.GroupMember.WORLD
|
||||||
|
|
||||||
|
if isinstance(group_or_device_mesh, ProcessGroup):
|
||||||
|
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
|
||||||
|
else:
|
||||||
|
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||||
|
device_mesh = group_or_device_mesh
|
||||||
|
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
|
||||||
|
return DTensor(tensor, device_mesh, sharding_spec)
|
||||||
|
|
||||||
|
|
||||||
|
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
|
||||||
|
"""
|
||||||
|
Shard the first dim of the given tensor
|
||||||
|
"""
|
||||||
|
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
||||||
|
if group_or_device_mesh is None:
|
||||||
|
group_or_device_mesh = dist.GroupMember.WORLD
|
||||||
|
|
||||||
|
if isinstance(group_or_device_mesh, ProcessGroup):
|
||||||
|
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
|
||||||
|
else:
|
||||||
|
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||||
|
device_mesh = group_or_device_mesh
|
||||||
|
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
|
||||||
|
return DTensor(tensor, device_mesh, sharding_spec)
|
|
@ -34,7 +34,7 @@ class Layout:
|
||||||
def get_sharded_shape_per_device(self):
|
def get_sharded_shape_per_device(self):
|
||||||
sharded_shape = list(self.entire_shape)
|
sharded_shape = list(self.entire_shape)
|
||||||
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
|
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
|
||||||
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
|
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
|
||||||
shard_partitions = reduce(operator.mul, mesh_list, 1)
|
shard_partitions = reduce(operator.mul, mesh_list, 1)
|
||||||
assert sharded_shape[
|
assert sharded_shape[
|
||||||
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
|
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
|
||||||
|
@ -45,14 +45,15 @@ class Layout:
|
||||||
sharding_spec = self.sharding_spec
|
sharding_spec = self.sharding_spec
|
||||||
|
|
||||||
# make sure all axes in logical device mesh only be used once
|
# make sure all axes in logical device mesh only be used once
|
||||||
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
|
if self.device_mesh.logical_mesh_id is not None:
|
||||||
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
|
||||||
for element in shard_list:
|
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
||||||
if element in dim_check_list:
|
for element in shard_list:
|
||||||
dim_check_list.remove(element)
|
if element in dim_check_list:
|
||||||
else:
|
dim_check_list.remove(element)
|
||||||
raise DuplicatedShardingDimensionError(
|
else:
|
||||||
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
|
raise DuplicatedShardingDimensionError(
|
||||||
|
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
|
||||||
|
|
||||||
# make sure that the sharding for a dimension is divisible by the number of devices
|
# make sure that the sharding for a dimension is divisible by the number of devices
|
||||||
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
||||||
|
@ -60,7 +61,7 @@ class Layout:
|
||||||
num_devices = 1
|
num_devices = 1
|
||||||
|
|
||||||
for element in shard_list:
|
for element in shard_list:
|
||||||
num_devices *= self.device_mesh.mesh_shape[element]
|
num_devices *= self.device_mesh.shape[element]
|
||||||
|
|
||||||
if tensor_dim_size % num_devices != 0:
|
if tensor_dim_size % num_devices != 0:
|
||||||
raise ShardingNotDivisibleError(
|
raise ShardingNotDivisibleError(
|
||||||
|
|
|
@ -304,7 +304,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||||
|
|
||||||
# legal sharding dims means the mesh_id is still available to use.
|
# legal sharding dims means the mesh_id is still available to use.
|
||||||
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))]
|
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))]
|
||||||
for dim, shard_list in source_spec.dim_partition_dict.items():
|
for dim, shard_list in source_spec.dim_partition_dict.items():
|
||||||
for element in shard_list:
|
for element in shard_list:
|
||||||
legal_sharding_dims.remove(element)
|
legal_sharding_dims.remove(element)
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
def check_linear_1d_col():
|
||||||
|
linear = nn.Linear(32, 128).cuda()
|
||||||
|
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
|
||||||
|
|
||||||
|
assert linear_col.weight.shape == torch.Size([64, 32])
|
||||||
|
assert linear_col.bias.shape == torch.Size([64])
|
||||||
|
|
||||||
|
# check computation correctness
|
||||||
|
x = torch.rand(4, 32).cuda()
|
||||||
|
out = linear(x)
|
||||||
|
gather_out = linear_col(x)
|
||||||
|
assert_close(out, gather_out)
|
||||||
|
|
||||||
|
# check backward correctness
|
||||||
|
out.sum().backward()
|
||||||
|
gather_out.sum().backward()
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
|
||||||
|
assert_close(target_grad, linear_col.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def check_linear_1d_row():
|
||||||
|
linear = nn.Linear(32, 128).cuda()
|
||||||
|
linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||||
|
|
||||||
|
assert linear_row.weight.shape == torch.Size([128, 16])
|
||||||
|
assert linear_row.bias.shape == torch.Size([128])
|
||||||
|
|
||||||
|
# check computation correctness
|
||||||
|
x = torch.rand(4, 32).cuda()
|
||||||
|
out = linear(x)
|
||||||
|
gather_out = linear_row(x)
|
||||||
|
assert_close(out, gather_out)
|
||||||
|
|
||||||
|
# check backward correctness
|
||||||
|
out.sum().backward()
|
||||||
|
gather_out.sum().backward()
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
|
||||||
|
assert_close(target_grad, linear_row.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
check_linear_1d_col()
|
||||||
|
check_linear_1d_row()
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_linear():
|
||||||
|
spawn(run_dist, nprocs=2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_linear()
|
Loading…
Reference in New Issue