mirror of https://github.com/hpcaitech/ColossalAI
[feat] support no tensor parallel Linear in shardformer; Add test for use weightGradStore and not use WeightGradStore
parent
982e4ee1f8
commit
d2e05a99b3
|
@ -2,7 +2,7 @@ from ._operation import all_to_all_comm
|
|||
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
|
||||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
||||
from .linear import Linear1D, Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
||||
from .loss import cross_entropy_1d, dist_cross_entropy
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
|
@ -11,6 +11,7 @@ from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2
|
|||
__all__ = [
|
||||
"Embedding1D",
|
||||
"VocabParallelEmbedding1D",
|
||||
"Linear1D",
|
||||
"Linear1D_Col",
|
||||
"Linear1D_Row",
|
||||
"GPT2FusedLinearConv1D_Col",
|
||||
|
|
|
@ -154,7 +154,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
|
||||
|
||||
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
|
||||
# _grad_output_.t().matmul(_input_)
|
||||
return wgrad_gemm_func(_grad_output_.t(), _input_)
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
||||
|
@ -236,6 +235,107 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
class LinearBase(torch.autograd.Function):
|
||||
"""
|
||||
Linear layer baseline (no tensor parallel version).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
ctx.fp8_communication = fp8_communication
|
||||
ctx.use_zbv = use_zbv
|
||||
if bias is not None:
|
||||
output = F.linear(input_, weight, bias)
|
||||
else:
|
||||
output = F.linear(input_, weight)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight, bias = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
ctx.fp8_communication
|
||||
use_zbv = ctx.use_zbv
|
||||
|
||||
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
|
||||
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
|
||||
|
||||
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
|
||||
return wgrad_gemm_func(_grad_output_.t(), _input_)
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
||||
if use_bias:
|
||||
bias.view(bias.shape)
|
||||
|
||||
total_input = input.contiguous()
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if use_zbv:
|
||||
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
|
||||
if grad.dtype == torch.float32:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
elif grad.dtype in (torch.float16, torch.bfloat16):
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
|
||||
else:
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
else:
|
||||
if use_zbv:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass,
|
||||
wgrad_gemm_func=torch.matmul,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
|
||||
# currently only support one single tensor as output
|
||||
group_size = dist.get_world_size(process_group)
|
||||
|
@ -1101,6 +1201,10 @@ def linear_with_async_comm(
|
|||
)
|
||||
|
||||
|
||||
def linear_base(input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False):
|
||||
return LinearBase.apply(input_, weight, bias, async_grad_allreduce, fp8_communication, use_zbv)
|
||||
|
||||
|
||||
def linear_gather_forward_reducescatter_backward(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
|
||||
):
|
||||
|
|
|
@ -25,6 +25,7 @@ from colossalai.tensor.d_tensor.api import (
|
|||
from ._operation import (
|
||||
gather_forward_reducescatter_backward,
|
||||
gather_forward_split_backward,
|
||||
linear_base,
|
||||
linear_gather_forward_reducescatter_backward,
|
||||
linear_reducescatter_forward_gather_backward,
|
||||
linear_with_async_comm,
|
||||
|
@ -35,7 +36,159 @@ from ._operation import (
|
|||
from .parallel_module import PaddingParallelModule, ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
__all__ = ["Linear1D_Col", "Linear1D_Row"]
|
||||
__all__ = ["Linear1D", "Linear1D_Col", "Linear1D_Row"]
|
||||
|
||||
|
||||
class Linear1D(ParallelModule):
|
||||
r"""Linear layer with no parallelism.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
device (`torch.device`): The device of parameters, defaults to None.
|
||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (`typing.Callable`):
|
||||
The initializer of weight, defaults to kaiming uniform initializer.
|
||||
bias_initializer (`typing.Callable`):
|
||||
The initializer of bias, defaults to xavier uniform initializer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
gather_output: bool = False,
|
||||
seq_parallel_mode: str = None,
|
||||
seq_parallel_dim: int = 1,
|
||||
overlap: torch.cuda.Stream = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
use_zbv: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(weight=weight, bias_=bias_, **kwargs)
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.gather_output = gather_output
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.overlap = overlap
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.fp8_communication = fp8_communication
|
||||
self.use_zbv = use_zbv
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=None)
|
||||
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
|
||||
else:
|
||||
assert bias_ is None, "bias_ must be None if weight is None"
|
||||
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
|
||||
if bias:
|
||||
if bias_ is None:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||
self.bias = bias_
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
if weight is None:
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, **kwargs) -> ParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
|
||||
linear_1d = Linear1D(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert (
|
||||
input_.shape[-1] == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = input_
|
||||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output_parallel = linear_base(
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
False,
|
||||
fp8_communication=self.fp8_communication,
|
||||
use_zbv=self.use_zbv,
|
||||
)
|
||||
|
||||
output = output_parallel
|
||||
|
||||
if self.skip_bias_add:
|
||||
return output, self.bias
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
class Linear1D_Col(ParallelModule):
|
||||
|
|
|
@ -8,7 +8,8 @@ from torch.testing import assert_close
|
|||
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
||||
from colossalai.shardformer.layer import Linear1D, Linear1D_Col, Linear1D_Row
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
@ -117,6 +118,93 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool):
|
|||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_base = Linear1D.from_native_module(
|
||||
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
|
||||
)
|
||||
assert linear_base.weight.shape == torch.Size([128, 32])
|
||||
assert linear_base.bias.shape == torch.Size([128])
|
||||
assert linear_copy.weight is linear_base.weight
|
||||
assert linear_copy.bias is linear_base.bias
|
||||
|
||||
linear.load_state_dict(linear_base.state_dict())
|
||||
linear_base.load_state_dict(linear.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
# [batch_size, seq_len, hidden_size]
|
||||
x = torch.rand(2, 4, 32).cuda()
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone())
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
# run forward
|
||||
out = linear(x_for_unshard)
|
||||
gather_out = linear_base(x_for_shard)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
assert_close(linear.weight.grad, linear_base.weight.grad)
|
||||
# check the input gradients
|
||||
assert x_for_shard.grad is not None
|
||||
assert x_for_unshard.grad is not None
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_base = Linear1D.from_native_module(
|
||||
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
|
||||
)
|
||||
assert linear_base.weight.shape == torch.Size([128, 32])
|
||||
assert linear_base.bias.shape == torch.Size([128])
|
||||
assert linear_copy.weight is linear_base.weight
|
||||
assert linear_copy.bias is linear_base.bias
|
||||
|
||||
linear.load_state_dict(linear_base.state_dict())
|
||||
linear_base.load_state_dict(linear.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
# [batch_size, seq_len, hidden_size]
|
||||
x = torch.rand(2, 4, 32).cuda()
|
||||
x_for_unshard = x.expand_as(x.clone())
|
||||
x_for_unshard.requires_grad_(True)
|
||||
x_for_shard = x.expand_as(x.clone())
|
||||
x_for_shard.requires_grad_(True)
|
||||
|
||||
# run forward
|
||||
out = linear(x_for_unshard)
|
||||
gather_out = linear_base(x_for_shard)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
|
||||
# Weight grad is None before we do WeightGradStore pop
|
||||
assert linear_base.weight.grad is None
|
||||
# after WeightGradStore pop (dw computation complete), we assert weight grad
|
||||
WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue
|
||||
WeightGradStore.pop(chunk=0)
|
||||
assert_close(linear.weight.grad, linear_base.weight.grad)
|
||||
|
||||
# check the input gradients
|
||||
assert x_for_shard.grad is not None
|
||||
assert x_for_unshard.grad is not None
|
||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||
|
||||
|
||||
def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
|
@ -182,6 +270,8 @@ def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap):
|
|||
check_linear_1d_col(lazy_init, seq_parallel_mode, overlap)
|
||||
check_linear_1d_row(lazy_init, seq_parallel_mode)
|
||||
check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap)
|
||||
check_linear_without_weight_grad_store(lazy_init, seq_parallel_mode)
|
||||
check_linear_with_weight_grad_store(lazy_init, seq_parallel_mode)
|
||||
|
||||
|
||||
def check_dist_linear(rank, world_size, port):
|
||||
|
|
Loading…
Reference in New Issue