mirror of https://github.com/hpcaitech/ColossalAI
[shardformer/sequence parallel] Cherry pick commit to new branch (#4450)
* [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384) * [sequence parallel] add sequence parallel linear col/row support (#4336) * add sequence parallel linear col/row support * add annotation * add annotation * add support for gpt2 fused qkv linear layer * support sequence parallel in GPT2 * add docstring and note * add requirments * remove unused flash-attb * modify flash attn test * modify flash attn setting * modify flash attn code * add assert before divide, rename forward function * [shardformer/test] fix gpt2 test with seq-parallel * [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401) * overlap gather input / grad computing during col backward * modify test for overlap * simplify code * fix code and modify cuda stream synchronize * [shardformer/sequence parallel] polish codepull/4446/head
parent
d20dceb9a3
commit
424629fea0
|
@ -152,6 +152,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
enable_fused_normalization: bool = False,
|
enable_fused_normalization: bool = False,
|
||||||
enable_flash_attention: bool = False,
|
enable_flash_attention: bool = False,
|
||||||
enable_jit_fused: bool = False,
|
enable_jit_fused: bool = False,
|
||||||
|
enable_sequence_parallelism: bool = False,
|
||||||
num_microbatches: Optional[int] = None,
|
num_microbatches: Optional[int] = None,
|
||||||
initial_scale: float = 2**16,
|
initial_scale: float = 2**16,
|
||||||
min_scale: float = 1,
|
min_scale: float = 1,
|
||||||
|
@ -178,6 +179,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.enable_fused_normalization = enable_fused_normalization
|
self.enable_fused_normalization = enable_fused_normalization
|
||||||
self.enable_flash_attention = enable_flash_attention
|
self.enable_flash_attention = enable_flash_attention
|
||||||
self.enable_jit_fused = enable_jit_fused
|
self.enable_jit_fused = enable_jit_fused
|
||||||
|
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
|
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
|
||||||
self.stage_manager = None
|
self.stage_manager = None
|
||||||
self.schedule = None
|
self.schedule = None
|
||||||
|
@ -195,7 +197,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
enable_all_optimization=self.enable_all_optimization,
|
enable_all_optimization=self.enable_all_optimization,
|
||||||
enable_fused_normalization=self.enable_fused_normalization,
|
enable_fused_normalization=self.enable_fused_normalization,
|
||||||
enable_flash_attention=self.enable_flash_attention,
|
enable_flash_attention=self.enable_flash_attention,
|
||||||
enable_jit_fused=self.enable_jit_fused)
|
enable_jit_fused=self.enable_jit_fused,
|
||||||
|
enable_sequence_parallelism=enable_sequence_parallelism)
|
||||||
self.amp_config = dict(
|
self.amp_config = dict(
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
growth_factor=growth_factor,
|
growth_factor=growth_factor,
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -141,6 +143,215 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
|
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
|
||||||
|
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
|
||||||
|
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
|
||||||
|
ctx.save_for_backward(input_, weight)
|
||||||
|
ctx.use_bias = bias is not None
|
||||||
|
ctx.process_group = process_group
|
||||||
|
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
||||||
|
ctx.dim = dim
|
||||||
|
ctx.overlap = overlap
|
||||||
|
|
||||||
|
input_parallel = _gather(input_, dim, process_group)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output = F.linear(input_parallel, weight, bias)
|
||||||
|
else:
|
||||||
|
output = F.linear(input_parallel, weight)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input_, weight = ctx.saved_tensors
|
||||||
|
use_bias = ctx.use_bias
|
||||||
|
dim = ctx.dim
|
||||||
|
process_group = ctx.process_group
|
||||||
|
overlap = ctx.overlap
|
||||||
|
|
||||||
|
if not overlap:
|
||||||
|
# TODO: overlap SP input with gradient computation
|
||||||
|
input_parallel = _gather(input_, dim, process_group)
|
||||||
|
|
||||||
|
total_input = input_parallel
|
||||||
|
grad_input = grad_output.matmul(weight)
|
||||||
|
grad_output = grad_output.contiguous()
|
||||||
|
# Convert the tensor shapes to 2D for execution compatibility
|
||||||
|
if len(grad_output.shape) > 2:
|
||||||
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||||
|
total_input = total_input.view(-1, total_input.shape[-1])
|
||||||
|
|
||||||
|
# TODO: overlap SP input with gradient computation
|
||||||
|
if ctx.async_grad_reduce_scatter:
|
||||||
|
# Asynchronous reduce-scatter
|
||||||
|
input_list = [
|
||||||
|
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||||
|
]
|
||||||
|
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
|
||||||
|
device=input_parallel.device).contiguous()
|
||||||
|
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||||
|
# Delay the start of weight gradient computation shortly (3us) to have
|
||||||
|
# reduce-scatter scheduled first and have GPU resources allocated
|
||||||
|
_ = torch.empty(1, device=grad_output.device) + 1
|
||||||
|
|
||||||
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
|
||||||
|
if ctx.async_grad_reduce_scatter:
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
|
else:
|
||||||
|
# create new stream for calculate the gradient
|
||||||
|
calculate_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
# do all gather in default stream
|
||||||
|
input_ = input_.contiguous()
|
||||||
|
world_size = dist.get_world_size(process_group)
|
||||||
|
rank = dist.get_rank(process_group)
|
||||||
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||||
|
tensor_list[rank] = input_
|
||||||
|
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
|
||||||
|
|
||||||
|
# calculate gradient in calculate_stream
|
||||||
|
with torch.cuda.stream(calculate_stream):
|
||||||
|
# calculate
|
||||||
|
grad_input = grad_output.matmul(weight)
|
||||||
|
grad_output = grad_output.contiguous()
|
||||||
|
# Convert the tensor shapes to 2D for execution compatibility
|
||||||
|
if len(grad_output.shape) > 2:
|
||||||
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||||
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
|
||||||
|
# prepare data
|
||||||
|
input_list = [
|
||||||
|
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||||
|
]
|
||||||
|
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
||||||
|
|
||||||
|
torch.cuda.current_stream().wait_stream(calculate_stream)
|
||||||
|
|
||||||
|
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||||
|
with torch.cuda.stream(calculate_stream):
|
||||||
|
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
||||||
|
if len(input_parallel.shape) > 2:
|
||||||
|
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
||||||
|
print(grad_output.shape, input_parallel.shape)
|
||||||
|
grad_weight = grad_output.t().matmul(input_parallel)
|
||||||
|
|
||||||
|
torch.cuda.current_stream().wait_stream(calculate_stream)
|
||||||
|
|
||||||
|
return output, grad_weight, grad_bias, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||||
|
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
|
||||||
|
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_, process_group, dim):
|
||||||
|
ctx.dim = dim
|
||||||
|
ctx.process_group = process_group
|
||||||
|
|
||||||
|
# do reduce-scatter
|
||||||
|
new_shape = list(input_.shape)
|
||||||
|
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
|
||||||
|
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
|
||||||
|
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
|
||||||
|
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
|
||||||
|
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
|
||||||
|
dist.reduce_scatter(output, input_list, group=process_group)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
dim = ctx.dim
|
||||||
|
process_group = ctx.process_group
|
||||||
|
|
||||||
|
return _gather(grad_output, dim, process_group), None, None
|
||||||
|
|
||||||
|
|
||||||
|
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
This class is designed for matmul operation with gather forward and reduce-scatter backward.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ (`torch.Tensor`): input matrix.
|
||||||
|
dim (int): the dimension to perform split and gather
|
||||||
|
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
|
||||||
|
ctx.save_for_backward(input_, weight)
|
||||||
|
ctx.use_bias = bias is not None
|
||||||
|
ctx.process_group = process_group
|
||||||
|
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
||||||
|
ctx.dim = dim
|
||||||
|
|
||||||
|
input_parallel = _gather(input_, dim, process_group)
|
||||||
|
|
||||||
|
output = torch.matmul(input_parallel, weight)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input_, weight = ctx.saved_tensors
|
||||||
|
use_bias = ctx.use_bias
|
||||||
|
dim = ctx.dim
|
||||||
|
process_group = ctx.process_group
|
||||||
|
|
||||||
|
# TODO: overlap SP input with gradient computation
|
||||||
|
input_parallel = _gather(input_, dim, process_group)
|
||||||
|
|
||||||
|
total_input = input_parallel
|
||||||
|
grad_input = grad_output.matmul(weight.T)
|
||||||
|
grad_output = grad_output.contiguous()
|
||||||
|
# Convert the tensor shapes to 2D for execution compatibility
|
||||||
|
if len(grad_output.shape) > 2:
|
||||||
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||||
|
total_input = total_input.view(-1, total_input.shape[-1])
|
||||||
|
|
||||||
|
# TODO: overlap SP input with gradient computation
|
||||||
|
if ctx.async_grad_reduce_scatter:
|
||||||
|
# Asynchronous reduce-scatter
|
||||||
|
input_list = [
|
||||||
|
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||||
|
]
|
||||||
|
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
|
||||||
|
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||||
|
# Delay the start of weight gradient computation shortly (3us) to have
|
||||||
|
# reduce-scatter scheduled first and have GPU resources allocated
|
||||||
|
_ = torch.empty(1, device=grad_output.device) + 1
|
||||||
|
|
||||||
|
grad_weight = total_input.t().matmul(grad_output)
|
||||||
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
|
||||||
|
if ctx.async_grad_reduce_scatter:
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
|
return output, grad_weight, grad_bias, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
Split the input and keep only the corresponding chuck to the rank.
|
Split the input and keep only the corresponding chuck to the rank.
|
||||||
|
@ -200,6 +411,26 @@ class _ReduceBackward(torch.autograd.Function):
|
||||||
return _reduce(grad_output, ctx.process_group), None
|
return _reduce(grad_output, ctx.process_group), None
|
||||||
|
|
||||||
|
|
||||||
|
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||||
|
"""Gather the input from model parallel region and concatenate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_: input matrix.
|
||||||
|
parallel_mode: parallel mode.
|
||||||
|
dim: dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_, dim, process_group):
|
||||||
|
ctx.process_group = process_group
|
||||||
|
ctx.dim = dim
|
||||||
|
return _gather(input_, dim, process_group)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
return _split(grad_output, ctx.dim, ctx.process_group), None, None
|
||||||
|
|
||||||
|
|
||||||
def _reduce(input_, process_group):
|
def _reduce(input_, process_group):
|
||||||
# skip if only one rank involved
|
# skip if only one rank involved
|
||||||
if dist.get_world_size(process_group) == 1:
|
if dist.get_world_size(process_group) == 1:
|
||||||
|
@ -235,6 +466,7 @@ def _gather(input_, dim=-1, process_group=None):
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
# all gather
|
# all gather
|
||||||
|
input_ = input_.contiguous()
|
||||||
rank = dist.get_rank(process_group)
|
rank = dist.get_rank(process_group)
|
||||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||||
tensor_list[rank] = input_
|
tensor_list[rank] = input_
|
||||||
|
@ -246,24 +478,27 @@ def _gather(input_, dim=-1, process_group=None):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
def _reduce_scatter(input_, dim=1, process_group=None):
|
||||||
"""Gather the input from model parallel region and concatenate.
|
""" Do reduce-scatter operation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_: input matrix.
|
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
|
||||||
parallel_mode: parallel mode.
|
dim (int): The dimension to perform reduce-scatter.
|
||||||
dim: dimension
|
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
|
||||||
"""
|
"""
|
||||||
|
world_size = dist.get_world_size(process_group)
|
||||||
|
if world_size == 1:
|
||||||
|
return input_
|
||||||
|
|
||||||
@staticmethod
|
# reduce-scatter
|
||||||
def forward(ctx, input_, dim, process_group):
|
new_shape = list(input_.shape)
|
||||||
ctx.process_group = process_group
|
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
|
||||||
ctx.dim = dim
|
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
|
||||||
return _gather(input_, dim, process_group)
|
new_shape[dim] = new_shape[dim] // world_size
|
||||||
|
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
|
||||||
|
dist.reduce_scatter(output, input_, group=process_group)
|
||||||
|
|
||||||
@staticmethod
|
return output
|
||||||
def backward(ctx, grad_output):
|
|
||||||
return _split(grad_output, ctx.dim, ctx.process_group), None, None
|
|
||||||
|
|
||||||
|
|
||||||
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||||
|
@ -274,6 +509,21 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
|
||||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
|
||||||
|
overlap):
|
||||||
|
return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
|
||||||
|
async_grad_reduce_scatter, dim, overlap)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
|
||||||
|
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
|
||||||
|
return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
|
||||||
|
async_grad_reduce_scatter, dim)
|
||||||
|
|
||||||
|
|
||||||
def gather_forward_split_backward(input_, dim, process_group):
|
def gather_forward_split_backward(input_, dim, process_group):
|
||||||
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
|
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,8 @@ from colossalai.tensor.d_tensor.api import (
|
||||||
|
|
||||||
from ._operation import (
|
from ._operation import (
|
||||||
gather_forward_split_backward,
|
gather_forward_split_backward,
|
||||||
|
linear_gather_forward_reducescatter_backward,
|
||||||
|
linear_reducescatter_forward_gather_backward,
|
||||||
linear_with_async_comm,
|
linear_with_async_comm,
|
||||||
reduce_forward,
|
reduce_forward,
|
||||||
split_forward_gather_backward,
|
split_forward_gather_backward,
|
||||||
|
@ -50,6 +52,8 @@ class Linear1D_Col(ParallelModule):
|
||||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||||
to all GPUs, otherwise, every GPU will have its output
|
to all GPUs, otherwise, every GPU will have its output
|
||||||
which is :math:`Y_i = XA_i`, defaults to False
|
which is :math:`Y_i = XA_i`, defaults to False
|
||||||
|
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||||
|
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
|
||||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||||
which is preserved for kernel fusion, defaults to False
|
which is preserved for kernel fusion, defaults to False
|
||||||
weight_initializer (`typing.Callable`):
|
weight_initializer (`typing.Callable`):
|
||||||
|
@ -69,6 +73,8 @@ class Linear1D_Col(ParallelModule):
|
||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
gather_output: bool = False,
|
gather_output: bool = False,
|
||||||
|
seq_parallel: bool = False,
|
||||||
|
overlap: bool = False,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
weight: Optional[Parameter] = None,
|
weight: Optional[Parameter] = None,
|
||||||
bias_: Optional[Parameter] = None,
|
bias_: Optional[Parameter] = None,
|
||||||
|
@ -80,6 +86,8 @@ class Linear1D_Col(ParallelModule):
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
|
self.seq_parallel = seq_parallel
|
||||||
|
self.overlap = overlap
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.device = device
|
self.device = device
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
@ -180,7 +188,11 @@ class Linear1D_Col(ParallelModule):
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
if self.seq_parallel:
|
||||||
|
output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
|
||||||
|
self.process_group, True, 1, self.overlap)
|
||||||
|
else:
|
||||||
|
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||||
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
|
@ -203,6 +215,8 @@ class Linear1D_Row(ParallelModule):
|
||||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||||
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
||||||
|
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||||
|
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||||
which is preserved for kernel fusion, defaults to False
|
which is preserved for kernel fusion, defaults to False
|
||||||
weight_initializer (:class:`typing.Callable`, optional):
|
weight_initializer (:class:`typing.Callable`, optional):
|
||||||
|
@ -221,6 +235,7 @@ class Linear1D_Row(ParallelModule):
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
|
seq_parallel: bool = False,
|
||||||
parallel_input: bool = True,
|
parallel_input: bool = True,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
weight: Optional[Parameter] = None,
|
weight: Optional[Parameter] = None,
|
||||||
|
@ -238,6 +253,7 @@ class Linear1D_Row(ParallelModule):
|
||||||
self.parallel_input = parallel_input
|
self.parallel_input = parallel_input
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
self.seq_parallel = seq_parallel
|
||||||
self.num_partitions = dist.get_world_size(self.process_group)
|
self.num_partitions = dist.get_world_size(self.process_group)
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
|
@ -373,7 +389,10 @@ class Linear1D_Row(ParallelModule):
|
||||||
output = torch.cat(output_parallel_list, dim=-1)
|
output = torch.cat(output_parallel_list, dim=-1)
|
||||||
else:
|
else:
|
||||||
output_parallel = F.linear(input_, self.weight)
|
output_parallel = F.linear(input_, self.weight)
|
||||||
output = reduce_forward(output_parallel, self.process_group)
|
if self.seq_parallel:
|
||||||
|
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
|
||||||
|
else:
|
||||||
|
output = reduce_forward(output_parallel, self.process_group)
|
||||||
|
|
||||||
if not self.skip_bias_add:
|
if not self.skip_bias_add:
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
|
|
|
@ -25,7 +25,9 @@ from colossalai.tensor.d_tensor.api import (
|
||||||
|
|
||||||
from ._operation import (
|
from ._operation import (
|
||||||
gather_forward_split_backward,
|
gather_forward_split_backward,
|
||||||
|
linear_reducescatter_forward_gather_backward,
|
||||||
linear_with_async_comm,
|
linear_with_async_comm,
|
||||||
|
matmul_gather_forward_reducescatter_backward,
|
||||||
matmul_with_async_comm,
|
matmul_with_async_comm,
|
||||||
reduce_backward,
|
reduce_backward,
|
||||||
reduce_forward,
|
reduce_forward,
|
||||||
|
@ -150,6 +152,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
device (`torch.device`): The device of parameters, defaults to None.
|
device (`torch.device`): The device of parameters, defaults to None.
|
||||||
n_fused (int): The number items fused, defaults to 3 (QKV).
|
n_fused (int): The number items fused, defaults to 3 (QKV).
|
||||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||||
|
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||||
to all GPUs, otherwise, every GPU will have its output
|
to all GPUs, otherwise, every GPU will have its output
|
||||||
which is :math:`Y_i = XA_i`, defaults to False
|
which is :math:`Y_i = XA_i`, defaults to False
|
||||||
|
@ -173,6 +176,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
async_communication: bool = False,
|
async_communication: bool = False,
|
||||||
gather_output: bool = False,
|
gather_output: bool = False,
|
||||||
|
seq_parallel: bool = False,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
n_fused: int = 3,
|
n_fused: int = 3,
|
||||||
weight: Optional[Parameter] = None,
|
weight: Optional[Parameter] = None,
|
||||||
|
@ -185,6 +189,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
|
self.seq_parallel = seq_parallel
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.device = device
|
self.device = device
|
||||||
self.n_fused = n_fused
|
self.n_fused = n_fused
|
||||||
|
@ -296,15 +301,19 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
assert input_.shape[-1] == self.weight.shape[0], \
|
assert input_.shape[-1] == self.weight.shape[0], \
|
||||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||||
# Set up backprop all-reduce.
|
|
||||||
input_parallel = reduce_backward(input_, self.process_group)
|
|
||||||
# input_parallel = input_
|
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group,
|
if self.seq_parallel:
|
||||||
self.async_communication)
|
input_parallel = input_
|
||||||
|
output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
|
||||||
|
self.process_group, True, 1)
|
||||||
|
else:
|
||||||
|
# Set up backprop all-reduce.
|
||||||
|
input_parallel = reduce_backward(input_, self.process_group)
|
||||||
|
output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group,
|
||||||
|
self.async_communication)
|
||||||
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
|
@ -329,6 +338,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||||
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
||||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||||
|
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||||
which is preserved for kernel fusion, defaults to False
|
which is preserved for kernel fusion, defaults to False
|
||||||
weight_initializer (:class:`typing.Callable`, optional):
|
weight_initializer (:class:`typing.Callable`, optional):
|
||||||
The initializer of weight, defaults to kaiming uniform initializer.
|
The initializer of weight, defaults to kaiming uniform initializer.
|
||||||
|
@ -346,6 +356,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
|
seq_parallel: bool = False,
|
||||||
parallel_input: bool = True,
|
parallel_input: bool = True,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
weight: Optional[Parameter] = None,
|
weight: Optional[Parameter] = None,
|
||||||
|
@ -363,6 +374,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||||
self.parallel_input = parallel_input
|
self.parallel_input = parallel_input
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
self.seq_parallel = seq_parallel
|
||||||
self.num_partitions = dist.get_world_size(self.process_group)
|
self.num_partitions = dist.get_world_size(self.process_group)
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
|
@ -499,7 +511,10 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||||
output = torch.cat(output_parallel_list, dim=-1)
|
output = torch.cat(output_parallel_list, dim=-1)
|
||||||
else:
|
else:
|
||||||
output_parallel = torch.matmul(input_, self.weight)
|
output_parallel = torch.matmul(input_, self.weight)
|
||||||
output = reduce_forward(output_parallel, self.process_group)
|
if self.seq_parallel:
|
||||||
|
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
|
||||||
|
else:
|
||||||
|
output = reduce_forward(output_parallel, self.process_group)
|
||||||
|
|
||||||
if not self.skip_bias_add:
|
if not self.skip_bias_add:
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
|
|
|
@ -0,0 +1,222 @@
|
||||||
|
# this code is modified from transformers.models.gpt2.modeling_gpt2
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L670
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||||
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: put all contents in `gpt2.py` and make it compatible with pipeline
|
||||||
|
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (output_hidden_states
|
||||||
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||||
|
if position_ids is not None:
|
||||||
|
position_ids = position_ids.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
if past_key_values is None:
|
||||||
|
past_length = 0
|
||||||
|
past_key_values = tuple([None] * len(self.h))
|
||||||
|
else:
|
||||||
|
past_length = past_key_values[0][0].size(-2)
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
# GPT2Attention mask.
|
||||||
|
if attention_mask is not None:
|
||||||
|
if batch_size <= 0:
|
||||||
|
raise ValueError("batch_size has to be defined and > 0")
|
||||||
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
|
attention_mask = attention_mask[:, None, None, :]
|
||||||
|
|
||||||
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
|
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||||
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
|
# effectively the same as removing these entirely.
|
||||||
|
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||||
|
|
||||||
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||||
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
|
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
else:
|
||||||
|
encoder_attention_mask = None
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
|
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
position_embeds = self.wpe(position_ids)
|
||||||
|
hidden_states = inputs_embeds + position_embeds
|
||||||
|
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_embeds = self.wte(token_type_ids)
|
||||||
|
hidden_states = hidden_states + token_type_embeds
|
||||||
|
|
||||||
|
hidden_states = self.drop(hidden_states)
|
||||||
|
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
presents = () if use_cache else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
# split the input tensor along sequence dimension
|
||||||
|
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||||
|
hidden_states = split_forward_gather_backward(hidden_states,
|
||||||
|
dim=1,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group)
|
||||||
|
|
||||||
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
|
# Model parallel
|
||||||
|
if self.model_parallel:
|
||||||
|
torch.cuda.set_device(hidden_states.device)
|
||||||
|
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||||
|
if layer_past is not None:
|
||||||
|
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||||
|
# Ensure that attention_mask is always on the same device as hidden_states
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(hidden_states.device)
|
||||||
|
if isinstance(head_mask, torch.Tensor):
|
||||||
|
head_mask = head_mask.to(hidden_states.device)
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs, use_cache, output_attentions)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states,
|
||||||
|
None,
|
||||||
|
attention_mask,
|
||||||
|
head_mask[i],
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outputs = block(
|
||||||
|
hidden_states,
|
||||||
|
layer_past=layer_past,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask[i],
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
if use_cache is True:
|
||||||
|
presents = presents + (outputs[1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
|
if self.config.add_cross_attention:
|
||||||
|
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||||
|
|
||||||
|
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||||
|
if self.model_parallel:
|
||||||
|
for k, v in self.device_map.items():
|
||||||
|
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||||
|
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||||
|
|
||||||
|
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||||
|
hidden_states = gather_forward_split_backward(hidden_states,
|
||||||
|
dim=1,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group)
|
||||||
|
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
# Add last hidden state
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||||
|
if v is not None)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=presents,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
|
@ -11,17 +11,12 @@ from torch.nn import Module
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
|
from ..layer.parallel_module import ParallelModule
|
||||||
from ..shard.shard_config import ShardConfig
|
from ..shard.shard_config import ShardConfig
|
||||||
|
|
||||||
__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"]
|
__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"]
|
||||||
|
|
||||||
|
|
||||||
class ParallelModule():
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SubModuleReplacementDescription:
|
class SubModuleReplacementDescription:
|
||||||
r"""
|
r"""
|
||||||
|
@ -231,3 +226,22 @@ class Policy(ABC):
|
||||||
end_idx = num_layers_per_stage_accumulated[stage + 1]
|
end_idx = num_layers_per_stage_accumulated[stage + 1]
|
||||||
|
|
||||||
return [start_idx, end_idx]
|
return [start_idx, end_idx]
|
||||||
|
|
||||||
|
def append_seq_parallel_to_policy(
|
||||||
|
self,
|
||||||
|
suffix_list: List[str],
|
||||||
|
module_policy_description: ModulePolicyDescription,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Append the sequence parallel policy to the policy for the given key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
suffix_list (List[str]): the suffix list of the module to be parallelized
|
||||||
|
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
|
||||||
|
"""
|
||||||
|
|
||||||
|
for sub_description in module_policy_description.sub_module_replacement:
|
||||||
|
if (sub_description.suffix in suffix_list):
|
||||||
|
if sub_description.kwargs is None:
|
||||||
|
sub_description.kwargs = {}
|
||||||
|
sub_description.kwargs["seq_parallel"] = True
|
||||||
|
|
|
@ -7,6 +7,7 @@ import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward
|
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward
|
||||||
|
from ..modeling.gpt2_seq import gpt2_sequence_parallel_forward_fn
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -49,6 +50,9 @@ class GPT2Policy(Policy):
|
||||||
target_module=col_nn.DropoutForParallelInput,
|
target_module=col_nn.DropoutForParallelInput,
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
||||||
|
|
||||||
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
|
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
|
||||||
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||||
|
@ -120,6 +124,11 @@ class GPT2Policy(Policy):
|
||||||
policy[GPT2Attention] = ModulePolicyDescription(method_replacement={
|
policy[GPT2Attention] = ModulePolicyDescription(method_replacement={
|
||||||
'forward': get_gpt2_flash_attention_forward(),
|
'forward': get_gpt2_flash_attention_forward(),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"]
|
||||||
|
self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block])
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
|
|
|
@ -28,6 +28,7 @@ class ShardConfig:
|
||||||
enable_all_optimization: bool = False
|
enable_all_optimization: bool = False
|
||||||
enable_flash_attention: bool = False
|
enable_flash_attention: bool = False
|
||||||
enable_jit_fused: bool = False
|
enable_jit_fused: bool = False
|
||||||
|
enable_sequence_parallelism: bool = False
|
||||||
|
|
||||||
# pipeline_parallel_size: int
|
# pipeline_parallel_size: int
|
||||||
# data_parallel_size: int
|
# data_parallel_size: int
|
||||||
|
|
|
@ -53,8 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int):
|
||||||
return rearanged_tensor
|
return rearanged_tensor
|
||||||
|
|
||||||
|
|
||||||
@parameterize('lazy_init', [False, True])
|
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool):
|
||||||
def check_linear_conv_1d_col(lazy_init: bool):
|
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = Conv1D(192, 48).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
|
@ -62,6 +61,7 @@ def check_linear_conv_1d_col(lazy_init: bool):
|
||||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy,
|
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy,
|
||||||
process_group=None,
|
process_group=None,
|
||||||
gather_output=True,
|
gather_output=True,
|
||||||
|
seq_parallel=seq_parallel,
|
||||||
n_fused=3)
|
n_fused=3)
|
||||||
|
|
||||||
assert linear.weight.shape == torch.Size([48, 192])
|
assert linear.weight.shape == torch.Size([48, 192])
|
||||||
|
@ -76,10 +76,11 @@ def check_linear_conv_1d_col(lazy_init: bool):
|
||||||
linear.load_state_dict(linear_conv_col.state_dict())
|
linear.load_state_dict(linear_conv_col.state_dict())
|
||||||
|
|
||||||
# check computation correctness
|
# check computation correctness
|
||||||
x = torch.rand(4, 48).cuda()
|
x = torch.rand(1, 4, 48).cuda()
|
||||||
out = linear(x)
|
out = linear(x)
|
||||||
gather_out = linear_conv_col(x)
|
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
assert_close(rearrange(out, 1), gather_out)
|
gather_out = linear_conv_col(x_for_shard)
|
||||||
|
assert_close(rearrange(out, -1), gather_out)
|
||||||
|
|
||||||
# check backward correctness
|
# check backward correctness
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
|
@ -89,14 +90,16 @@ def check_linear_conv_1d_col(lazy_init: bool):
|
||||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
@parameterize('lazy_init', [False, True])
|
def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||||
def check_linear_conv_1d_row(lazy_init: bool):
|
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = Conv1D(192, 48).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
linear_copy = Conv1D(192, 48).cuda()
|
linear_copy = Conv1D(192, 48).cuda()
|
||||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
|
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy,
|
||||||
|
process_group=None,
|
||||||
|
parallel_input=False,
|
||||||
|
seq_parallel=seq_parallel)
|
||||||
|
|
||||||
assert linear.weight.shape == torch.Size([48, 192])
|
assert linear.weight.shape == torch.Size([48, 192])
|
||||||
assert linear_row.weight.shape == torch.Size([24, 192])
|
assert linear_row.weight.shape == torch.Size([24, 192])
|
||||||
|
@ -109,10 +112,11 @@ def check_linear_conv_1d_row(lazy_init: bool):
|
||||||
linear.load_state_dict(linear_row.state_dict())
|
linear.load_state_dict(linear_row.state_dict())
|
||||||
|
|
||||||
# check computation correctness
|
# check computation correctness
|
||||||
x = torch.rand(4, 48).cuda()
|
x = torch.rand(1, 4, 48).cuda()
|
||||||
out = linear(x)
|
out = linear(x)
|
||||||
gather_out = linear_row(x)
|
gather_out = linear_row(x)
|
||||||
assert_close(out, gather_out)
|
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
|
assert_close(target_out, gather_out)
|
||||||
|
|
||||||
# check backward correctness
|
# check backward correctness
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
|
@ -123,12 +127,18 @@ def check_linear_conv_1d_row(lazy_init: bool):
|
||||||
assert_close(target_grad, linear_row.weight.grad)
|
assert_close(target_grad, linear_row.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('lazy_init', [False, True])
|
||||||
|
@parameterize('seq_parallel', [False, True])
|
||||||
|
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool):
|
||||||
|
check_linear_conv_1d_col(lazy_init, seq_parallel)
|
||||||
|
check_linear_conv_1d_row(lazy_init, seq_parallel)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
# test for linear conv
|
# test for linear conv
|
||||||
check_linear_conv_1d_col()
|
check_gpt2_qkv_fused_linear_1d()
|
||||||
check_linear_conv_1d_row()
|
|
||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
|
|
@ -12,13 +12,16 @@ from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
@parameterize('lazy_init', [False, True])
|
def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||||
def check_linear_1d_col(lazy_init: bool):
|
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
linear = nn.Linear(32, 128).cuda()
|
linear = nn.Linear(32, 128).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
linear_copy = nn.Linear(32, 128).cuda()
|
linear_copy = nn.Linear(32, 128).cuda()
|
||||||
linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True)
|
linear_col = Linear1D_Col.from_native_module(linear_copy,
|
||||||
|
process_group=None,
|
||||||
|
gather_output=True,
|
||||||
|
seq_parallel=seq_parallel,
|
||||||
|
overlap=overlap)
|
||||||
|
|
||||||
# ensure that the parameters are distributed
|
# ensure that the parameters are distributed
|
||||||
assert is_distributed_tensor(linear_col.weight)
|
assert is_distributed_tensor(linear_col.weight)
|
||||||
|
@ -35,10 +38,11 @@ def check_linear_1d_col(lazy_init: bool):
|
||||||
linear_col.load_state_dict(linear.state_dict())
|
linear_col.load_state_dict(linear.state_dict())
|
||||||
|
|
||||||
# check computation correctness
|
# check computation correctness
|
||||||
x = torch.rand(4, 32).cuda()
|
# [batch_size, seq_len, hidden_size]
|
||||||
|
x = torch.rand(2, 4, 32).cuda()
|
||||||
x_for_unshard = x.expand_as(x.clone())
|
x_for_unshard = x.expand_as(x.clone())
|
||||||
x_for_unshard.requires_grad_(True)
|
x_for_unshard.requires_grad_(True)
|
||||||
x_for_shard = x.expand_as(x.clone())
|
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
x_for_shard.requires_grad_(True)
|
x_for_shard.requires_grad_(True)
|
||||||
|
|
||||||
out = linear(x_for_unshard)
|
out = linear(x_for_unshard)
|
||||||
|
@ -56,17 +60,21 @@ def check_linear_1d_col(lazy_init: bool):
|
||||||
# check the input gradients
|
# check the input gradients
|
||||||
assert x_for_shard.grad is not None
|
assert x_for_shard.grad is not None
|
||||||
assert x_for_unshard.grad is not None
|
assert x_for_unshard.grad is not None
|
||||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk(
|
||||||
|
x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
|
assert_close(target_unshard_gard, x_for_shard.grad)
|
||||||
|
|
||||||
|
|
||||||
@parameterize('lazy_init', [False, True])
|
def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||||
def check_linear_1d_row(lazy_init: bool):
|
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
linear = nn.Linear(32, 128).cuda()
|
linear = nn.Linear(32, 128).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
linear_copy = nn.Linear(32, 128).cuda()
|
linear_copy = nn.Linear(32, 128).cuda()
|
||||||
linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
|
linear_row = Linear1D_Row.from_native_module(linear_copy,
|
||||||
|
process_group=None,
|
||||||
|
parallel_input=False,
|
||||||
|
seq_parallel=seq_parallel)
|
||||||
|
|
||||||
assert linear_row.weight.shape == torch.Size([128, 16])
|
assert linear_row.weight.shape == torch.Size([128, 16])
|
||||||
assert linear_row.bias.shape == torch.Size([128])
|
assert linear_row.bias.shape == torch.Size([128])
|
||||||
|
@ -77,7 +85,8 @@ def check_linear_1d_row(lazy_init: bool):
|
||||||
linear_row.load_state_dict(linear.state_dict())
|
linear_row.load_state_dict(linear.state_dict())
|
||||||
|
|
||||||
# check computation correctness
|
# check computation correctness
|
||||||
x = torch.rand(4, 32).cuda()
|
# [batch_size, seq_len, hidden_size]
|
||||||
|
x = torch.rand(2, 4, 32).cuda()
|
||||||
x_for_unshard = x.expand_as(x.clone())
|
x_for_unshard = x.expand_as(x.clone())
|
||||||
x_for_unshard.requires_grad_(True)
|
x_for_unshard.requires_grad_(True)
|
||||||
x_for_shard = x.expand_as(x.clone())
|
x_for_shard = x.expand_as(x.clone())
|
||||||
|
@ -86,7 +95,8 @@ def check_linear_1d_row(lazy_init: bool):
|
||||||
# run forward
|
# run forward
|
||||||
out = linear(x_for_unshard)
|
out = linear(x_for_unshard)
|
||||||
gather_out = linear_row(x_for_shard)
|
gather_out = linear_row(x_for_shard)
|
||||||
assert_close(out, gather_out)
|
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
|
assert_close(target_out, gather_out)
|
||||||
|
|
||||||
# check backward correctness
|
# check backward correctness
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
|
@ -102,8 +112,7 @@ def check_linear_1d_row(lazy_init: bool):
|
||||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||||
|
|
||||||
|
|
||||||
@parameterize('lazy_init', [False, True])
|
def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||||
def check_linear_col_plus_row(lazy_init: bool):
|
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
linear_1 = nn.Linear(32, 128).cuda()
|
linear_1 = nn.Linear(32, 128).cuda()
|
||||||
|
@ -112,8 +121,15 @@ def check_linear_col_plus_row(lazy_init: bool):
|
||||||
with ctx:
|
with ctx:
|
||||||
linear_1_copy = nn.Linear(32, 128).cuda()
|
linear_1_copy = nn.Linear(32, 128).cuda()
|
||||||
linear_2_copy = nn.Linear(128, 32).cuda()
|
linear_2_copy = nn.Linear(128, 32).cuda()
|
||||||
linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False)
|
linear_col = Linear1D_Col.from_native_module(linear_1_copy,
|
||||||
linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True)
|
process_group=None,
|
||||||
|
gather_output=False,
|
||||||
|
seq_parallel=seq_parallel,
|
||||||
|
overlap=overlap)
|
||||||
|
linear_row = Linear1D_Row.from_native_module(linear_2_copy,
|
||||||
|
process_group=None,
|
||||||
|
parallel_input=True,
|
||||||
|
seq_parallel=seq_parallel)
|
||||||
|
|
||||||
linear_1.load_state_dict(linear_col.state_dict())
|
linear_1.load_state_dict(linear_col.state_dict())
|
||||||
linear_col.load_state_dict(linear_1.state_dict())
|
linear_col.load_state_dict(linear_1.state_dict())
|
||||||
|
@ -121,16 +137,18 @@ def check_linear_col_plus_row(lazy_init: bool):
|
||||||
linear_row.load_state_dict(linear_2.state_dict())
|
linear_row.load_state_dict(linear_2.state_dict())
|
||||||
|
|
||||||
# check computation correctness
|
# check computation correctness
|
||||||
x = torch.rand(4, 32).cuda()
|
# [batch_size, seq_len, hidden_size]
|
||||||
|
x = torch.rand(2, 4, 32).cuda()
|
||||||
x_for_unshard = x.expand_as(x.clone())
|
x_for_unshard = x.expand_as(x.clone())
|
||||||
x_for_unshard.requires_grad_(True)
|
x_for_unshard.requires_grad_(True)
|
||||||
x_for_shard = x.expand_as(x.clone())
|
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
x_for_shard.requires_grad_(True)
|
x_for_shard.requires_grad_(True)
|
||||||
|
|
||||||
# run forward
|
# run forward
|
||||||
unshard_out = linear_2(linear_1(x_for_unshard))
|
unshard_out = linear_2(linear_1(x_for_unshard))
|
||||||
shard_out = linear_row(linear_col(x_for_shard))
|
shard_out = linear_row(linear_col(x_for_shard))
|
||||||
assert_close(unshard_out, shard_out)
|
target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
|
assert_close(target_out, shard_out)
|
||||||
|
|
||||||
# check backward correctness
|
# check backward correctness
|
||||||
unshard_out.sum().backward()
|
unshard_out.sum().backward()
|
||||||
|
@ -143,19 +161,28 @@ def check_linear_col_plus_row(lazy_init: bool):
|
||||||
# check the input gradients
|
# check the input gradients
|
||||||
assert x_for_shard.grad is not None
|
assert x_for_shard.grad is not None
|
||||||
assert x_for_unshard.grad is not None
|
assert x_for_unshard.grad is not None
|
||||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk(
|
||||||
|
x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
|
assert_close(target_unshard_gard, x_for_shard.grad)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
@parameterize('lazy_init', [False, True])
|
||||||
|
@parameterize('seq_parallel', [False, True])
|
||||||
|
@parameterize('overlap', [False, True])
|
||||||
|
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
|
||||||
|
check_linear_1d_col(lazy_init, seq_parallel, overlap)
|
||||||
|
check_linear_1d_row(lazy_init, seq_parallel)
|
||||||
|
check_linear_col_plus_row(lazy_init, seq_parallel, overlap)
|
||||||
|
|
||||||
|
|
||||||
|
def check_dist_linear(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
check_linear_1d_col()
|
run_dist_linear_test()
|
||||||
check_linear_1d_row()
|
|
||||||
check_linear_col_plus_row()
|
|
||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_linear():
|
def test_linear():
|
||||||
spawn(run_dist, nprocs=2)
|
spawn(check_dist_linear, nprocs=2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import copy
|
import copy
|
||||||
|
import math
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
@ -25,6 +26,7 @@ def build_model(model_fn,
|
||||||
enable_tensor_parallelism=True,
|
enable_tensor_parallelism=True,
|
||||||
enable_flash_attention=False,
|
enable_flash_attention=False,
|
||||||
enable_jit_fused=False,
|
enable_jit_fused=False,
|
||||||
|
enable_sequence_parallelism=False,
|
||||||
use_lazy_init: bool = False):
|
use_lazy_init: bool = False):
|
||||||
# create new model
|
# create new model
|
||||||
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
||||||
|
@ -38,7 +40,8 @@ def build_model(model_fn,
|
||||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||||
enable_flash_attention=enable_flash_attention,
|
enable_flash_attention=enable_flash_attention,
|
||||||
enable_jit_fused=enable_jit_fused)
|
enable_jit_fused=enable_jit_fused,
|
||||||
|
enable_sequence_parallelism=enable_sequence_parallelism)
|
||||||
model_copy = copy.deepcopy(org_model)
|
model_copy = copy.deepcopy(org_model)
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
sharded_model, shared_params = shard_former.optimize(model_copy)
|
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||||
|
@ -135,6 +138,16 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
data = data_gen_fn()
|
data = data_gen_fn()
|
||||||
|
|
||||||
|
if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
|
||||||
|
seq_len = data['input_ids'].shape[1]
|
||||||
|
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
|
||||||
|
times = lcm // seq_len
|
||||||
|
input_shape = data['input_ids'].shape
|
||||||
|
for k, v in data.items():
|
||||||
|
if v.shape == input_shape:
|
||||||
|
data[k] = v.repeat(1, times)
|
||||||
|
|
||||||
sharded_model.train()
|
sharded_model.train()
|
||||||
if booster.plugin.stage_manager is not None:
|
if booster.plugin.stage_manager is not None:
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
|
|
|
@ -106,6 +106,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
'enable_all_optimization': True,
|
'enable_all_optimization': True,
|
||||||
'use_lazy_init': False,
|
'use_lazy_init': False,
|
||||||
'precision': 'fp32',
|
'precision': 'fp32',
|
||||||
|
}, {
|
||||||
|
'tp_size': 4,
|
||||||
|
'pp_size': 1,
|
||||||
|
'enable_all_optimization': False,
|
||||||
|
'use_lazy_init': True,
|
||||||
|
'enable_sequence_parallelism': True,
|
||||||
|
'precision': 'fp32',
|
||||||
}])
|
}])
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def run_gpt2_test(test_config):
|
def run_gpt2_test(test_config):
|
||||||
|
|
Loading…
Reference in New Issue