class ParallelLayer(nn.Module):
global_state_dict: bool = True
def __init__(self):
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.parallel_mode = parallel_mode
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input_, weight.t())
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
return grad_input, grad_weight, grad_bias, None, None, None
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
Split the input and keep only the corresponding chuck to the rank.
input_ (`torch.Tensor`): input matrix.
dim (int): the dimension to perform split and gather
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
def forward(ctx, input_, dim, process_group):
ctx.process_group = process_group
ctx.dim = dim
return _split(input_, dim, process_group)
def backward(ctx, grad_output):
return _gather(grad_output, ctx.dim, ctx.process_group), None, None
class _ReduceInput(torch.autograd.Function):
All-reduce the input from the model parallel region.
input_: input matrix.
parallel_mode: parallel mode.
def forward(ctx, input_, process_group):
return _reduce(input_, process_group)
def backward(ctx, grad_output):
return grad_output, None
def _reduce(input_, process_group):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
return input_
dist.all_reduce(input_, group=process_group)
return input_
def _split(input_, dim=-1, process_group=None):
# skip if only one rank involved
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
# Split along last dimension.
dim_size = input_.size(dim)
assert dim_size % world_size == 0, \
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
f'cannot split tensor evenly'
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = dist.get_rank(process_group)
output = tensor_list[rank].contiguous()
return output
def _gather(input_, dim=-1, process_group=None):
# skip if only one rank involved
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
# all gather
rank = dist.get_rank(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=process_group)
# concat
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
def forward(ctx, input_, dim, process_group):
ctx.process_group = process_group
ctx.dim = dim
return _gather(input_, dim, process_group)
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
def gather_forward_split_backward(input_, dim, process_group):
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
def split_forward_gather_backward(input_, dim, process_group):
return _SplitForwardGatherBackward.apply(input_, dim, process_group)
def reduce_input(input_, process_group):
return _ReduceInput.apply(input_, process_group)
import os
from contextlib import contextmanager
import torch
import torch.distributed as dist
import torch.nn as nn
class SeedManager:
This class is a random state manager to change random state for different random seed.
def __init__(self):
original_state = torch.cuda.get_rng_state()
# TODO: unify this seed manager with the colossalai.context.random
seed = os.getpid()
self.dropout_state = torch.cuda.get_rng_state()
def set_mode(self, rng_state):
def get_current_mode(self):
current_state = torch.cuda.get_rng_state()
return current_state
def dropout_mode(self):
This is a context manager to change the dropout state and recover the original state.
>>> with _seed_manager.dropout_mode():
>>> input = super().forward(input)
current_mode = self.get_current_mode()
yield self.set_mode(self.dropout_state)
self.dropout_state = self.get_current_mode()
_seed_manager = SeedManager()
from .utils import create_randomizer_with_offset
class Dropout1D(nn.Dropout):
def __init__(self, p=0.5, inplace=False):
def __init__(self, p=0.5, inplace=False, process_group=None):
super().__init__(p, inplace)
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=process_group)
def forward(self, input):
with _seed_manager.dropout_mode():
with self.randomizer.fork_rng():
input = super().forward(input)
return input
# -*- encoding: utf-8 -*-
import math
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Callable, Tuple
from typing import Callable, List, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.communication import broadcast
from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition
from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding
from colossalai.registry import LAYERS
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
@ -36,7 +38,13 @@ from colossalai.utils.checkpointing import (
from colossalai.utils.cuda import get_current_device
from ._operation import linear_with_async_comm
from ._operation import (
from .utils import create_randomizer_with_offset
Fast_LN = None
@ -46,17 +54,172 @@ except ImportError:
# @LAYERS.register_module
class Linear1D(ColossalaiModule):
r"""Linear layer for 1D parallelism.
class ParallelModule(nn.Module, ABC):
def from_native_module(module: nn.Module,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule":
Convert a native PyTorch module to a parallelized module.
module (nn.Module): the module to be converted.
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
If this is a list, the process group at the ith index of the list will correspond to the process group
in the ith axis of the device mesh. Defaults to None, which means the global process group.
class Linear1D_Col(ParallelModule):
r"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
gather_output (bool, optional): Whether to call all-gather on output, defaults to False.
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = False,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
self.out_features_per_partition = divide(out_features, self.num_partitions)
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
Convert a native PyTorch linear layer to a parallelized linear layer.
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
linear_1d = Linear1D_Col(in_features=in_features,
# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on row is equal to shard on column
sharded_weight = shard_rowwise(module.weight.data, process_group)
if bias:
sharded_bias = shard_colwise(module.bias.data, process_group)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
output = output_parallel
if self.skip_bias_add:
return output, self.bias
return output
class Linear1D_Row(ParallelModule):
r""" Linear layer with row parallelism
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False,
device: torch.device = None,
process_group: ProcessGroup = None,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
parallel_input = get_parallel_input()
if not parallel_input and not gather_output:
layer = Linear1D_Col(in_features,
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1):
self.stream_chunk_num = stream_chunk_num
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, self.num_partitions)
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
if self.stream_chunk_num > 1:
# TODO() work for inference only
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
layer = Linear1D_Row(in_features,
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
Convert a native PyTorch linear layer to a parallelized linear layer.
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, \
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
linear_1d = Linear1D_Row(in_features=in_features,
# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on col is equal to shard on row
sharded_weight = shard_colwise(module.weight.data, process_group)
if bias:
return linear_1d
def chunk_weight(self):
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
if self.process_group is None:
src_rank = 0
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
input_ = input_
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions)
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
if self.stream_chunk_num > 1:
if self.training:
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
with torch.no_grad():
output_parallel_list = [None for i in range(self.stream_chunk_num)]
handle_list = []
for i in range(self.stream_chunk_num):
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
handle = torch.distributed.all_reduce(output_parallel_list[i],
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
for handle in handle_list:
output = torch.cat(output_parallel_list, dim=-1)
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
return output, self.bias
# @LAYERS.register_module
class LayerNorm1D(ColossalaiModule):
Layer Normalization for colossalai
@ -152,7 +432,6 @@ class LayerNorm1D(ColossalaiModule):
super()._save_to_state_dict(destination, prefix, keep_vars)
# @LAYERS.register_module
class Classifier1D(ParallelLayer):
r"""RowLinear with given weight. Classifier of 1D parallelism.
@ -288,7 +567,6 @@ class Classifier1D(ParallelLayer):
return output
# @LAYERS.register_module
class VocabParallelClassifier1D(ParallelLayer):
r"""ColLinear with given weight. Classifier of 1D parallelism.
@ -424,317 +702,8 @@ class VocabParallelClassifier1D(ParallelLayer):
class Linear1D_Col(ParallelLayer):
r"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size)
self.out_features_per_partition = out_features
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
is_parallel_output = not self.gather_output
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = partition_tensor_parallel_state_dict(local_state,
weight_key: 0,
bias_key: 0
weight_key: True,
bias_key: True
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
local_state = gather_tensor_parallel_state_dict(local_state,
weight_key: 0,
bias_key: 0
weight_key: True,
bias_key: True
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
# output_parallel = F.linear(input_parallel, self.weight, bias)
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
output = output_parallel
if self.skip_bias_add:
return output, self.bias
return output
# @LAYERS.register_module
class Linear1D_Row(ParallelLayer):
r""" Linear layer with row parallelism
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1):
self.stream_chunk_num = stream_chunk_num
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
# self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size)
self.input_size_per_partition = in_features
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
if self.stream_chunk_num > 1:
# TODO() work for inference only
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
def chunk_weight(self):
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = partition_tensor_parallel_state_dict(local_state,
weight_key: -1,
bias_key: 0
weight_key: True,
bias_key: False
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
local_state = gather_tensor_parallel_state_dict(local_state,
weight_key: -1,
bias_key: 0
weight_key: True,
bias_key: False
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
input_ = input_
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
if self.stream_chunk_num > 1:
if self.training:
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
with torch.no_grad():
output_parallel_list = [None for i in range(self.stream_chunk_num)]
handle_list = []
for i in range(self.stream_chunk_num):
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
handle = torch.distributed.all_reduce(output_parallel_list[i],
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
for handle in handle_list:
output = torch.cat(output_parallel_list, dim=-1)
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
return output, self.bias
# @LAYERS.register_module
class Embedding1D(ParallelLayer):
r"""Embedding for 1D parallelism.
@ -842,7 +811,6 @@ class Embedding1D(ParallelLayer):
return output
# @LAYERS.register_module
class VocabParallelEmbedding1D(ParallelLayer):
r"""Embedding parallelized in the vocabulary dimension.
@ -960,7 +928,6 @@ class VocabParallelEmbedding1D(ParallelLayer):
return output
# @LAYERS.register_module
class Dropout1D(ParallelLayer):
"""Dropout layer of 1D parallelism.
from contextlib import contextmanager
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
class Randomizer:
Randomizer enables the program to be executed under a different seed within the context.
randomizer = Randomizer(seed=1024)
with randomizer.fork():
# do something here with seed 1024
seed (int): The random seed to set.
enable_cpu (bool): fork the CPU RNG state as well.
with_index (bool): whether to use the index of the randomizer.
_INDEX = 0
def __init__(self, seed: int):
# TODO: remove colossalai.context.random
self.seed = seed
# Handle CUDA rng state
# 1. get the current rng state
# 2. set the seed and store the rng state
# 3. recover the original rng state
cuda_original_rng_state = torch.cuda.get_rng_state()
self.cuda_rng_state = torch.cuda.get_rng_state()
# to the same for cpu rng state
cpu_original_rng_state = torch.get_rng_state()
self.cpu_rng_state = torch.get_rng_state()
def _set_cuda_rng_state(self, rng_state):
def _get_cuda_rng_state(self):
current_state = torch.cuda.get_rng_state()
return current_state
def _set_cpu_rng_state(self, rng_state):
def _get_cpu_rng_state(self):
current_state = torch.get_rng_state()
return current_state
def fork_rng(self, enable_cpu: bool = False):
This is a context manager to change the dropout state and recover the original state.
>>> with _seed_manager.dropout_mode():
>>> input = super().forward(input)
current_cuda_rng_state = self._get_cuda_rng_state()
if enable_cpu:
current_cpu_rng_state = self._get_cpu_rng_state()
self.cuda_rng_state = self._get_cuda_rng_state()
if enable_cpu:
self.cpu_rng_state = self._get_cpu_rng_state()
def index():
Return the index of the randomizer. The index is useful when the user wants
to introduce some randomness in the program.
The index will increment by one each time this method is called.
# assume we need a randomizer to init the weight of different layers
# we can use the index of the randomizer to do so that
# each layer has its own randomizer with a different seed
base_seed = torch.random.initial_seed()
seed = base_seed + Randomizer.index()
randomizer = Randomizer(seed)
with randomizer.fork():
idx = Randomizer._INDEX
Randomizer._INDEX += 1
return idx
def create_randomizer_with_offset(seed: int, process_group: ProcessGroup = None):
Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer.
seed (int): The base random seed to set.
enable_cpu (bool): fork the CPU RNG state as well.
process_group (ProcessGroup): the process group to get the rank from.
Randomizer: the randomizer with offset.
offset = Randomizer.index()
if dist.is_initialized():
rank = dist.get_rank(process_group)
offset += rank
seed += offset
return Randomizer(seed=seed)
from typing import Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.device.device_mesh import DeviceMesh
from .d_tensor import DTensor
from .sharding_spec import ShardingSpec
def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
Shard the first dim of the given tensor
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
group_or_device_mesh = dist.GroupMember.WORLD
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
return DTensor(tensor, device_mesh, sharding_spec)
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
Shard the first dim of the given tensor
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
group_or_device_mesh = dist.GroupMember.WORLD
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
return DTensor(tensor, device_mesh, sharding_spec)
def get_sharded_shape_per_device(self):
sharded_shape = list(self.entire_shape)
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
assert sharded_shape[
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
sharding_spec = self.sharding_spec
# make sure all axes in logical device mesh only be used once
if self.device_mesh.logical_mesh_id is not None:
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
for dim, shard_list in sharding_spec.dim_partition_dict.items():
for element in shard_list:
@ -60,7 +61,7 @@ class Layout:
for element in shard_list:
num_devices *= self.device_mesh.mesh_shape[element]
num_devices *= self.device_mesh.shape[element]
if tensor_dim_size % num_devices != 0:
raise ShardingNotDivisibleError(
@ -304,7 +304,7 @@ class LayoutConverter(metaclass=SingletonMeta):
process_groups_dict = source_layout.device_mesh.process_groups_dict
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))]
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))]
for dim, shard_list in source_spec.dim_partition_dict.items():
for element in shard_list:
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_linear_1d_col():
linear = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
assert linear_col.weight.shape == torch.Size([64, 32])
assert linear_col.bias.shape == torch.Size([64])
# check computation correctness
x = torch.rand(4, 32).cuda()
out = linear(x)
gather_out = linear_col(x)
assert_close(out, gather_out)
# check backward correctness
rank = dist.get_rank()
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
assert_close(target_grad, linear_col.weight.grad)
def check_linear_1d_row():
linear = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
assert linear_row.weight.shape == torch.Size([128, 16])
assert linear_row.bias.shape == torch.Size([128])
# check computation correctness
x = torch.rand(4, 32).cuda()
out = linear(x)
gather_out = linear_row(x)
assert_close(out, gather_out)
# check backward correctness
rank = dist.get_rank()
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
assert_close(target_grad, linear_row.weight.grad)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
def test_linear():
spawn(run_dist, nprocs=2)
if __name__ == '__main__':
Reference in New Issue