mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] integrated linear 1D with dtensor (#3996)
* [shardformer] integrated linear 1D with dtensor * polish codepull/4157/head
parent
d3bc530849
commit
015af592f8
|
@ -10,6 +10,7 @@ from colossalai.core import global_context as gpc
|
|||
|
||||
|
||||
class ParallelLayer(nn.Module):
|
||||
|
||||
global_state_dict: bool = True
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -54,10 +54,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.parallel_mode = parallel_mode
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
|
||||
output = torch.matmul(input_, weight.t())
|
||||
|
@ -74,12 +74,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
|
||||
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# all-reduce scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
|
@ -93,5 +94,123 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""
|
||||
Split the input and keep only the corresponding chuck to the rank.
|
||||
|
||||
Args:
|
||||
input_ (`torch.Tensor`): input matrix.
|
||||
dim (int): the dimension to perform split and gather
|
||||
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, process_group):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
return _split(input_, dim, process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather(grad_output, ctx.dim, ctx.process_group), None, None
|
||||
|
||||
|
||||
class _ReduceInput(torch.autograd.Function):
|
||||
"""
|
||||
All-reduce the input from the model parallel region.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
parallel_mode: parallel mode.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group):
|
||||
return _reduce(input_, process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, None
|
||||
|
||||
|
||||
def _reduce(input_, process_group):
|
||||
# skip if only one rank involved
|
||||
if dist.get_world_size(process_group) == 1:
|
||||
return input_
|
||||
else:
|
||||
dist.all_reduce(input_, group=process_group)
|
||||
return input_
|
||||
|
||||
|
||||
def _split(input_, dim=-1, process_group=None):
|
||||
# skip if only one rank involved
|
||||
world_size = dist.get_world_size(process_group)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Split along last dimension.
|
||||
dim_size = input_.size(dim)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
|
||||
f'cannot split tensor evenly'
|
||||
|
||||
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
||||
rank = dist.get_rank(process_group)
|
||||
output = tensor_list[rank].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather(input_, dim=-1, process_group=None):
|
||||
# skip if only one rank involved
|
||||
world_size = dist.get_world_size(process_group)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# all gather
|
||||
rank = dist.get_rank(process_group)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
torch.distributed.all_gather(tensor_list, input_, group=process_group)
|
||||
|
||||
# concat
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
"""Gather the input from model parallel region and concatenate.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
parallel_mode: parallel mode.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, process_group):
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
return _gather(input_, dim, process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split(grad_output, ctx.dim, ctx.process_group), None, None
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||
|
||||
|
||||
def gather_forward_split_backward(input_, dim, process_group):
|
||||
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
|
||||
|
||||
|
||||
def split_forward_gather_backward(input_, dim, process_group):
|
||||
return _SplitForwardGatherBackward.apply(input_, dim, process_group)
|
||||
|
||||
|
||||
def reduce_input(input_, process_group):
|
||||
return _ReduceInput.apply(input_, process_group)
|
||||
|
|
|
@ -1,58 +1,20 @@
|
|||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SeedManager:
|
||||
"""
|
||||
This class is a random state manager to change random state for different random seed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
original_state = torch.cuda.get_rng_state()
|
||||
# TODO: unify this seed manager with the colossalai.context.random
|
||||
seed = os.getpid()
|
||||
torch.cuda.manual_seed(int(seed))
|
||||
self.dropout_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(original_state)
|
||||
|
||||
def set_mode(self, rng_state):
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
|
||||
def get_current_mode(self):
|
||||
current_state = torch.cuda.get_rng_state()
|
||||
return current_state
|
||||
|
||||
@contextmanager
|
||||
def dropout_mode(self):
|
||||
"""
|
||||
This is a context manager to change the dropout state and recover the original state.
|
||||
|
||||
Usage:
|
||||
::
|
||||
>>> with _seed_manager.dropout_mode():
|
||||
>>> input = super().forward(input)
|
||||
"""
|
||||
try:
|
||||
current_mode = self.get_current_mode()
|
||||
yield self.set_mode(self.dropout_state)
|
||||
finally:
|
||||
self.dropout_state = self.get_current_mode()
|
||||
self.set_mode(current_mode)
|
||||
|
||||
|
||||
_seed_manager = SeedManager()
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
|
||||
class Dropout1D(nn.Dropout):
|
||||
|
||||
def __init__(self, p=0.5, inplace=False):
|
||||
def __init__(self, p=0.5, inplace=False, process_group=None):
|
||||
super().__init__(p, inplace)
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=process_group)
|
||||
|
||||
def forward(self, input):
|
||||
with _seed_manager.dropout_mode():
|
||||
with self.randomizer.fork_rng():
|
||||
input = super().forward(input)
|
||||
return input
|
||||
|
|
|
@ -2,12 +2,16 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Tuple
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.communication import broadcast
|
||||
|
@ -22,13 +26,11 @@ from colossalai.nn.layer.parallel_1d._utils import (
|
|||
gather_forward_split_backward,
|
||||
get_parallel_input,
|
||||
reduce_grad,
|
||||
reduce_input,
|
||||
set_parallel_input,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
|
||||
from colossalai.utils.checkpointing import (
|
||||
broadcast_state_dict,
|
||||
gather_tensor_parallel_state_dict,
|
||||
|
@ -36,7 +38,13 @@ from colossalai.utils.checkpointing import (
|
|||
)
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._operation import linear_with_async_comm
|
||||
from ._operation import (
|
||||
gather_forward_split_backward,
|
||||
linear_with_async_comm,
|
||||
reduce_input,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
Fast_LN = None
|
||||
try:
|
||||
|
@ -46,17 +54,172 @@ except ImportError:
|
|||
pass
|
||||
|
||||
|
||||
# @LAYERS.register_module
|
||||
class Linear1D(ColossalaiModule):
|
||||
r"""Linear layer for 1D parallelism.
|
||||
class ParallelModule(nn.Module, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def from_native_module(module: nn.Module,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule":
|
||||
"""
|
||||
Convert a native PyTorch module to a parallelized module.
|
||||
|
||||
Args:
|
||||
module (nn.Module): the module to be converted.
|
||||
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
|
||||
If this is a list, the process group at the ith index of the list will correspond to the process group
|
||||
in the ith axis of the device mesh. Defaults to None, which means the global process group.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Linear1D_Col(ParallelModule):
|
||||
r"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
||||
its second dimension as :math:`A = [A_1, ..., A_p]`.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
gather_output (bool, optional): Whether to call all-gather on output, defaults to False.
|
||||
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
device (`torch.device`): The device of parameters, defaults to None.
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (`typing.Callable`):
|
||||
The initializer of weight, defaults to kaiming uniform initializer.
|
||||
bias_initializer (`typing.Callable`):
|
||||
The initializer of bias, defaults to xavier uniform initializer.
|
||||
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.gather_output = gather_output
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
self.out_features_per_partition = divide(out_features, self.num_partitions)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
"""
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
process_group = process_group[0]
|
||||
|
||||
linear_1d = Linear1D_Col(in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# TODO: copy the sharded weights
|
||||
with torch.no_grad():
|
||||
# the weigh to the linear layer is a transpose
|
||||
# thus shard on row is equal to shard on column
|
||||
sharded_weight = shard_rowwise(module.weight.data, process_group)
|
||||
linear_1d.weight.data.copy_(sharded_weight)
|
||||
if bias:
|
||||
sharded_bias = shard_colwise(module.bias.data, process_group)
|
||||
linear_1d.bias.copy_(sharded_bias)
|
||||
|
||||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
# Set up backprop all-reduce.
|
||||
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||
input_parallel = input_
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
if self.skip_bias_add:
|
||||
return output, self.bias
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
class Linear1D_Row(ParallelModule):
|
||||
r""" Linear layer with row parallelism
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (:class:`typing.Callable`, optional):
|
||||
The initializer of weight, defaults to kaiming uniform initializer.
|
||||
|
@ -72,32 +235,149 @@ class Linear1D(ColossalaiModule):
|
|||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
gather_output: bool = False,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
parallel_input = get_parallel_input()
|
||||
if not parallel_input and not gather_output:
|
||||
layer = Linear1D_Col(in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer)
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1):
|
||||
super().__init__()
|
||||
|
||||
self.stream_chunk_num = stream_chunk_num
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.parallel_input = parallel_input
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.input_size_per_partition = divide(in_features, self.num_partitions)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
# TODO() work for inference only
|
||||
self.chunk_weight()
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
layer = Linear1D_Row(in_features,
|
||||
out_features,
|
||||
self.bias = None
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
"""
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
process_group = process_group[0]
|
||||
|
||||
linear_1d = Linear1D_Row(in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_input=parallel_input,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer)
|
||||
super().__init__(layer)
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# TODO: copy the sharded weights
|
||||
with torch.no_grad():
|
||||
# the weigh to the linear layer is a transpose
|
||||
# thus shard on col is equal to shard on row
|
||||
sharded_weight = shard_colwise(module.weight.data, process_group)
|
||||
linear_1d.weight.data.copy_(sharded_weight)
|
||||
|
||||
if bias:
|
||||
linear_1d.bias.copy_(module.bias.data)
|
||||
|
||||
return linear_1d
|
||||
|
||||
def chunk_weight(self):
|
||||
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
if self.process_group is None:
|
||||
src_rank = 0
|
||||
else:
|
||||
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
|
||||
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
input_ = input_
|
||||
else:
|
||||
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions)
|
||||
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
if self.training:
|
||||
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
|
||||
with torch.no_grad():
|
||||
output_parallel_list = [None for i in range(self.stream_chunk_num)]
|
||||
handle_list = []
|
||||
for i in range(self.stream_chunk_num):
|
||||
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
|
||||
handle = torch.distributed.all_reduce(output_parallel_list[i],
|
||||
group=self.process_group,
|
||||
async_op=True)
|
||||
handle_list.append(handle)
|
||||
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
|
||||
for handle in handle_list:
|
||||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
|
||||
output = reduce_input(output_parallel, self.process_group)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
return output
|
||||
else:
|
||||
return output, self.bias
|
||||
|
||||
|
||||
# @LAYERS.register_module
|
||||
class LayerNorm1D(ColossalaiModule):
|
||||
r"""
|
||||
Layer Normalization for colossalai
|
||||
|
@ -152,7 +432,6 @@ class LayerNorm1D(ColossalaiModule):
|
|||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
|
||||
# @LAYERS.register_module
|
||||
class Classifier1D(ParallelLayer):
|
||||
r"""RowLinear with given weight. Classifier of 1D parallelism.
|
||||
|
||||
|
@ -288,7 +567,6 @@ class Classifier1D(ParallelLayer):
|
|||
return output
|
||||
|
||||
|
||||
# @LAYERS.register_module
|
||||
class VocabParallelClassifier1D(ParallelLayer):
|
||||
r"""ColLinear with given weight. Classifier of 1D parallelism.
|
||||
|
||||
|
@ -424,317 +702,8 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|||
|
||||
|
||||
# @LAYERS.register_module
|
||||
class Linear1D_Col(ParallelLayer):
|
||||
r"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
||||
its second dimension as :math:`A = [A_1, ..., A_p]`.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (:class:`typing.Callable`, optional):
|
||||
The initializer of weight, defaults to kaiming uniform initializer.
|
||||
bias_initializer (:class:`typing.Callable`, optional):
|
||||
The initializer of bias, defaults to xavier uniform initializer.
|
||||
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.gather_output = gather_output
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size)
|
||||
self.out_features_per_partition = out_features
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
|
||||
else:
|
||||
self.bias = None
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
is_parallel_output = not self.gather_output
|
||||
set_parallel_input(is_parallel_output)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
if self.bias is not None:
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
})
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: 0,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: True
|
||||
},
|
||||
keep_vars=keep_vars)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
# Set up backprop all-reduce.
|
||||
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||
input_parallel = input_
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
# output_parallel = F.linear(input_parallel, self.weight, bias)
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
if self.skip_bias_add:
|
||||
return output, self.bias
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
# @LAYERS.register_module
|
||||
class Linear1D_Row(ParallelLayer):
|
||||
r""" Linear layer with row parallelism
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False.
|
||||
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (:class:`typing.Callable`, optional):
|
||||
The initializer of weight, defaults to kaiming uniform initializer.
|
||||
bias_initializer (:class:`typing.Callable`, optional):
|
||||
The initializer of bias, defaults to xavier uniform initializer.
|
||||
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1):
|
||||
super().__init__()
|
||||
|
||||
self.stream_chunk_num = stream_chunk_num
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.parallel_input = parallel_input
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
# self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size)
|
||||
self.input_size_per_partition = in_features
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
# TODO() work for inference only
|
||||
self.chunk_weight()
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
self.bias = None
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
set_parallel_input(False)
|
||||
|
||||
def chunk_weight(self):
|
||||
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
# weight
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is not None:
|
||||
local_state[weight_key] = weight
|
||||
# bias
|
||||
if self.bias is not None:
|
||||
bias = state_dict.pop(bias_key, None)
|
||||
if bias is not None:
|
||||
local_state[bias_key] = bias
|
||||
|
||||
local_state = partition_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
})
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
if self.bias is not None:
|
||||
local_state[bias_key] = self.bias
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
dims={
|
||||
weight_key: -1,
|
||||
bias_key: 0
|
||||
},
|
||||
partition_states={
|
||||
weight_key: True,
|
||||
bias_key: False
|
||||
},
|
||||
keep_vars=keep_vars)
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
input_ = input_
|
||||
else:
|
||||
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
|
||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
if self.training:
|
||||
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
|
||||
with torch.no_grad():
|
||||
output_parallel_list = [None for i in range(self.stream_chunk_num)]
|
||||
handle_list = []
|
||||
for i in range(self.stream_chunk_num):
|
||||
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
|
||||
handle = torch.distributed.all_reduce(output_parallel_list[i],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D),
|
||||
async_op=True)
|
||||
handle_list.append(handle)
|
||||
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
|
||||
for handle in handle_list:
|
||||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
|
||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
return output
|
||||
else:
|
||||
return output, self.bias
|
||||
|
||||
|
||||
# @LAYERS.register_module
|
||||
class Embedding1D(ParallelLayer):
|
||||
r"""Embedding for 1D parallelism.
|
||||
|
||||
|
@ -842,7 +811,6 @@ class Embedding1D(ParallelLayer):
|
|||
return output
|
||||
|
||||
|
||||
# @LAYERS.register_module
|
||||
class VocabParallelEmbedding1D(ParallelLayer):
|
||||
r"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
|
@ -960,7 +928,6 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||
return output
|
||||
|
||||
|
||||
# @LAYERS.register_module
|
||||
class Dropout1D(ParallelLayer):
|
||||
"""Dropout layer of 1D parallelism.
|
||||
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class Randomizer:
|
||||
"""
|
||||
Randomizer enables the program to be executed under a different seed within the context.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
randomizer = Randomizer(seed=1024)
|
||||
|
||||
with randomizer.fork():
|
||||
# do something here with seed 1024
|
||||
do_something()
|
||||
```
|
||||
|
||||
Args:
|
||||
seed (int): The random seed to set.
|
||||
enable_cpu (bool): fork the CPU RNG state as well.
|
||||
with_index (bool): whether to use the index of the randomizer.
|
||||
"""
|
||||
|
||||
_INDEX = 0
|
||||
|
||||
def __init__(self, seed: int):
|
||||
# TODO: remove colossalai.context.random
|
||||
|
||||
self.seed = seed
|
||||
|
||||
# Handle CUDA rng state
|
||||
# 1. get the current rng state
|
||||
# 2. set the seed and store the rng state
|
||||
# 3. recover the original rng state
|
||||
cuda_original_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.manual_seed(seed)
|
||||
self.cuda_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(cuda_original_rng_state)
|
||||
|
||||
# to the same for cpu rng state
|
||||
cpu_original_rng_state = torch.get_rng_state()
|
||||
torch.manual_seed(seed)
|
||||
self.cpu_rng_state = torch.get_rng_state()
|
||||
torch.set_rng_state(cpu_original_rng_state)
|
||||
|
||||
def _set_cuda_rng_state(self, rng_state):
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
|
||||
def _get_cuda_rng_state(self):
|
||||
current_state = torch.cuda.get_rng_state()
|
||||
return current_state
|
||||
|
||||
def _set_cpu_rng_state(self, rng_state):
|
||||
torch.set_rng_state(rng_state)
|
||||
|
||||
def _get_cpu_rng_state(self):
|
||||
current_state = torch.get_rng_state()
|
||||
return current_state
|
||||
|
||||
@contextmanager
|
||||
def fork_rng(self, enable_cpu: bool = False):
|
||||
"""
|
||||
This is a context manager to change the dropout state and recover the original state.
|
||||
|
||||
Usage:
|
||||
::
|
||||
>>> with _seed_manager.dropout_mode():
|
||||
>>> input = super().forward(input)
|
||||
"""
|
||||
try:
|
||||
current_cuda_rng_state = self._get_cuda_rng_state()
|
||||
self._set_cuda_rng_state(self.cuda_rng_state)
|
||||
|
||||
if enable_cpu:
|
||||
current_cpu_rng_state = self._get_cpu_rng_state()
|
||||
self._set_cpu_rng_state(self.cpu_rng_state)
|
||||
yield
|
||||
finally:
|
||||
self.cuda_rng_state = self._get_cuda_rng_state()
|
||||
self._set_cuda_rng_state(current_cuda_rng_state)
|
||||
|
||||
if enable_cpu:
|
||||
self.cpu_rng_state = self._get_cpu_rng_state()
|
||||
self._set_cpu_rng_state(current_cpu_rng_state)
|
||||
|
||||
@staticmethod
|
||||
def index():
|
||||
"""
|
||||
Return the index of the randomizer. The index is useful when the user wants
|
||||
to introduce some randomness in the program.
|
||||
|
||||
Note:
|
||||
The index will increment by one each time this method is called.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
# assume we need a randomizer to init the weight of different layers
|
||||
# we can use the index of the randomizer to do so that
|
||||
# each layer has its own randomizer with a different seed
|
||||
base_seed = torch.random.initial_seed()
|
||||
seed = base_seed + Randomizer.index()
|
||||
randomizer = Randomizer(seed)
|
||||
|
||||
with randomizer.fork():
|
||||
init_weights()
|
||||
```
|
||||
|
||||
"""
|
||||
idx = Randomizer._INDEX
|
||||
Randomizer._INDEX += 1
|
||||
return idx
|
||||
|
||||
|
||||
def create_randomizer_with_offset(seed: int, process_group: ProcessGroup = None):
|
||||
"""
|
||||
Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer.
|
||||
|
||||
Args:
|
||||
seed (int): The base random seed to set.
|
||||
enable_cpu (bool): fork the CPU RNG state as well.
|
||||
process_group (ProcessGroup): the process group to get the rank from.
|
||||
|
||||
Returns:
|
||||
Randomizer: the randomizer with offset.
|
||||
"""
|
||||
offset = Randomizer.index()
|
||||
|
||||
if dist.is_initialized():
|
||||
rank = dist.get_rank(process_group)
|
||||
offset += rank
|
||||
|
||||
seed += offset
|
||||
return Randomizer(seed=seed)
|
|
@ -0,0 +1,44 @@
|
|||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from .d_tensor import DTensor
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
|
||||
"""
|
||||
Shard the first dim of the given tensor
|
||||
"""
|
||||
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
||||
if group_or_device_mesh is None:
|
||||
group_or_device_mesh = dist.GroupMember.WORLD
|
||||
|
||||
if isinstance(group_or_device_mesh, ProcessGroup):
|
||||
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
|
||||
else:
|
||||
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||
device_mesh = group_or_device_mesh
|
||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
|
||||
return DTensor(tensor, device_mesh, sharding_spec)
|
||||
|
||||
|
||||
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
|
||||
"""
|
||||
Shard the first dim of the given tensor
|
||||
"""
|
||||
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
||||
if group_or_device_mesh is None:
|
||||
group_or_device_mesh = dist.GroupMember.WORLD
|
||||
|
||||
if isinstance(group_or_device_mesh, ProcessGroup):
|
||||
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
|
||||
else:
|
||||
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||
device_mesh = group_or_device_mesh
|
||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
|
||||
return DTensor(tensor, device_mesh, sharding_spec)
|
|
@ -34,7 +34,7 @@ class Layout:
|
|||
def get_sharded_shape_per_device(self):
|
||||
sharded_shape = list(self.entire_shape)
|
||||
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
|
||||
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
|
||||
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
|
||||
shard_partitions = reduce(operator.mul, mesh_list, 1)
|
||||
assert sharded_shape[
|
||||
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
|
||||
|
@ -45,14 +45,15 @@ class Layout:
|
|||
sharding_spec = self.sharding_spec
|
||||
|
||||
# make sure all axes in logical device mesh only be used once
|
||||
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
|
||||
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
||||
for element in shard_list:
|
||||
if element in dim_check_list:
|
||||
dim_check_list.remove(element)
|
||||
else:
|
||||
raise DuplicatedShardingDimensionError(
|
||||
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
|
||||
if self.device_mesh.logical_mesh_id is not None:
|
||||
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
|
||||
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
||||
for element in shard_list:
|
||||
if element in dim_check_list:
|
||||
dim_check_list.remove(element)
|
||||
else:
|
||||
raise DuplicatedShardingDimensionError(
|
||||
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
|
||||
|
||||
# make sure that the sharding for a dimension is divisible by the number of devices
|
||||
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
||||
|
@ -60,7 +61,7 @@ class Layout:
|
|||
num_devices = 1
|
||||
|
||||
for element in shard_list:
|
||||
num_devices *= self.device_mesh.mesh_shape[element]
|
||||
num_devices *= self.device_mesh.shape[element]
|
||||
|
||||
if tensor_dim_size % num_devices != 0:
|
||||
raise ShardingNotDivisibleError(
|
||||
|
|
|
@ -304,7 +304,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
|
||||
# legal sharding dims means the mesh_id is still available to use.
|
||||
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))]
|
||||
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))]
|
||||
for dim, shard_list in source_spec.dim_partition_dict.items():
|
||||
for element in shard_list:
|
||||
legal_sharding_dims.remove(element)
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_linear_1d_col():
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
|
||||
|
||||
assert linear_col.weight.shape == torch.Size([64, 32])
|
||||
assert linear_col.bias.shape == torch.Size([64])
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_col(x)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
|
||||
rank = dist.get_rank()
|
||||
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
|
||||
assert_close(target_grad, linear_col.weight.grad)
|
||||
|
||||
|
||||
def check_linear_1d_row():
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||
|
||||
assert linear_row.weight.shape == torch.Size([128, 16])
|
||||
assert linear_row.bias.shape == torch.Size([128])
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_row(x)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
|
||||
rank = dist.get_rank()
|
||||
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
|
||||
assert_close(target_grad, linear_row.weight.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_linear_1d_col()
|
||||
check_linear_1d_row()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear():
|
||||
spawn(run_dist, nprocs=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_linear()
|
Loading…
Reference in New Issue