[Sharderformer] Support zbv in Sharderformer Policy (#6150)

* [feat] Sharderformer support zbv

* [feat] support chatglm2, command, deepseek for zbv

* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper

* [feat] support GPT2FusedLinearConv1D

* [feat] support GPT2FusedLinear (without tp)

* [fix] debug FusedConvLinear

* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.

* [Shardformer] support FusedLinear1D base for zbv

* [shardformer] support zbv in FusedLinear1D base, Col, Row

* [shardformer] support zbv in blip2 and sam policy

* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;

* [fix] fix incorrect number of gradients ;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [Shardformer] add en doc for zbv;

* [fix] fix typo in Model compatibility table

* [fix] fix API Reference typo

* [Shardformer] add zh-Han doc for zbv

* [fix] fix Linear name; update en & zh doc

* [fix] fix shardformer doc import err

* [fix] fix shardconfig import in doc

* [fix] fix shardformer doc

* [fix] fix shardconfig doc

* [fix] fix config

* [fix] remove shardconfig

* [fix] fix doc

* [feat] add zbv doc string

* [fix] rm doc

* [fix] fix doc

* [fix] empty zbv doc

* [fix] ifx torch version

* [fix] fix torch version

* [fix] fix torch versions

* [fix] fix torch versions

* [fix] fix pyramid versions

* [fix] fix pyramid, zope version

* [fix] try fix workflow

* [fix] try import ShardConfig in yml

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix ci

* [fix] fix zbv doc

* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;

* [fix] fix policy use fused_linear

* [fix] fix weight grad none, err caused by  weight ptr change

* [fix] fix comm in WeightGradStore

* [fix] fix WeightGradStore pop param

* [fix] remove useless param in doc; fix gpt2 qkv test;

* [shardformer] simplify execute_w_pass_grad_accum;

* [fix] rm useless comments

* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass

* [shardformer] Run meaningful doc test

* [shadformer] fix doc test cmd;

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6178/head
duanjunwen 2025-01-02 10:22:26 +08:00 committed by GitHub
parent af06d162cf
commit a9bedc7a43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 3511 additions and 316 deletions

View File

@ -58,6 +58,7 @@ jobs:
# there is no main branch, so it's safe to checkout the main branch from the merged branch
# docer will rebase the remote main branch to the merged branch, so we have to config user
- name: Make the merged branch main
run: |
cd ColossalAI
git checkout -b main

View File

@ -38,6 +38,19 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
class ZeroBubbleVPipeScheduler(PipelineSchedule):
r"""
ZeroBubbleVPipeScheduler
Args:
stage_manager (PipelineStageManager): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.
schedule (List[ScheduledNode]): Schedule for ZeroBubbleVPipe.
num_model_chunks (int) : The number of model chunk in a device.
num_microbatch (Optional[int]): The number of microbatch.
microbatch_size (Optional[int]): The size per microbatch.
enable_metadata_cache (bool): whether to enable metadata cache to acclerate communication.
overlap_p2p (bool): whether to use overlap_p2p.
"""
def __init__(
self,
stage_manager: PipelineStageManager,

View File

@ -8,7 +8,6 @@ class WeightGradStore:
@classmethod
def put(cls, total_input, grad_output, weight, func):
# func(total_input, grad_output, weight.main_grad)
cls.cache.append((total_input, grad_output, weight, func))
@classmethod
@ -18,15 +17,26 @@ class WeightGradStore:
@classmethod
def pop(cls, chunk=0):
# print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}")
if cls.weight_grad_queue[chunk].qsize() > 0:
stored_grads = cls.weight_grad_queue[chunk].get()
for total_input, grad_output, weight, func in stored_grads:
if weight.grad is not None:
func(total_input, grad_output, weight.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
if isinstance(weight, tuple):
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
# View will lead to weight ptr change
# weight_cal & weight_origin in tuple, weight_cal use to cal dw, weight_origin use to update
_, weight_origin = weight
if weight_origin.grad is not None:
func(total_input, grad_output, weight_origin.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight_origin.grad = grad_weight
else:
grad_weight = func(total_input, grad_output)
weight.grad = grad_weight
if weight.grad is not None:
func(total_input, grad_output, weight.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight.grad = grad_weight
else:
raise Exception("Pop empty queue.")

View File

@ -6,7 +6,14 @@ from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHe
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from .qkv_fused_linear import (
FusedLinear,
FusedLinear1D_Col,
FusedLinear1D_Row,
GPT2FusedLinearConv,
GPT2FusedLinearConv1D_Col,
GPT2FusedLinearConv1D_Row,
)
__all__ = [
"Embedding1D",
@ -14,8 +21,9 @@ __all__ = [
"LinearWithGradAccum",
"Linear1D_Col",
"Linear1D_Row",
"GPT2FusedLinearConv1D_Col",
"GPT2FusedLinearConv",
"GPT2FusedLinearConv1D_Row",
"GPT2FusedLinearConv1D_Col",
"DropoutForParallelInput",
"DropoutForReplicatedInput",
"cross_entropy_1d",
@ -26,6 +34,7 @@ __all__ = [
"FusedLayerNorm",
"FusedRMSNorm",
"FusedLinear1D_Col",
"FusedLinear",
"ParallelModule",
"PaddingEmbedding",
"PaddingLMHead",

View File

@ -6,7 +6,13 @@ import torch.nn.functional as F
from colossalai.pipeline.weight_grad_store import WeightGradStore
from .utils import is_share_sp_tp
from .utils import (
execute_conv1d_w_pass,
execute_conv1d_w_pass_grad_accum,
execute_w_pass,
execute_w_pass_grad_accum,
is_share_sp_tp,
)
try:
import fused_mix_prec_layer_norm_cuda
@ -73,12 +79,13 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
ctx.use_zbv = use_zbv
output = torch.matmul(input_, weight)
@ -92,8 +99,10 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
fp8_communication = ctx.fp8_communication
use_zbv = ctx.use_zbv
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
weight_origin = weight
weight = weight.view(weight.shape)
if bias is not None:
bias = bias.view(bias.shape)
@ -114,7 +123,42 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
grad_weight = total_input.t().matmul(grad_output)
# split dx & dw
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
(weight, weight_origin),
functools.partial(
execute_conv1d_w_pass_grad_accum,
),
)
grad_weight = None
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = total_input.t().matmul(grad_output)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
(weight, weight_origin),
functools.partial(
execute_conv1d_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce and not fp8_communication:
@ -123,6 +167,87 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None, None
class MatmulWithGradAccum(torch.autograd.Function):
"""
Linear layer execution with grad accum in backprop. (no tp version)
"""
@staticmethod
def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.async_grad_allreduce = async_grad_allreduce
ctx.use_zbv = use_zbv
output = torch.matmul(input_, weight)
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
use_zbv = ctx.use_zbv
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
weight_origin = weight
weight = weight.view(weight.shape)
if bias is not None:
bias = bias.view(bias.shape)
total_input = input
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
# split dx & dw
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
(weight, weight_origin),
functools.partial(
execute_conv1d_w_pass_grad_accum,
),
)
grad_weight = None
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = total_input.t().matmul(grad_output)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
(weight, weight_origin),
functools.partial(
execute_conv1d_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
return grad_input, grad_weight, grad_bias, None, None, None, None
class LinearWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
@ -150,12 +275,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
fp8_communication = ctx.fp8_communication
use_zbv = ctx.use_zbv
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)
@ -179,31 +298,15 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
),
)
grad_weight = None
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
@ -259,12 +362,6 @@ class LinearWithGradAccum(torch.autograd.Function):
use_bias = ctx.use_bias
use_zbv = ctx.use_zbv
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)
@ -280,31 +377,15 @@ class LinearWithGradAccum(torch.autograd.Function):
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
),
)
grad_weight = None
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
@ -454,12 +535,13 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.use_zbv = use_zbv
if ring is True:
input_to_gather = {"input": input_}
@ -491,6 +573,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
use_zbv = ctx.use_zbv
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
@ -518,23 +601,46 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
),
)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
return output, grad_weight, grad_bias, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None
def _ring_as_reducescatter(
@ -606,11 +712,12 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, dim, ring):
def forward(ctx, input_, weight, bias, process_group, dim, ring, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.dim = dim
ctx.use_zbv = use_zbv
if ring is True:
input_to_reducescatter = {"input": input_}
@ -651,7 +758,7 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
use_zbv = ctx.use_zbv
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)
@ -666,10 +773,47 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.reshape(-1, total_input.shape[-1])
grad_weight = grad_output.t().matmul(total_input)
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
),
)
grad_weight = None
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
# grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
return grad_input, grad_weight, grad_bias, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None
class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
@ -723,13 +867,16 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
def forward(
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv=False
):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.fp8_communication = fp8_communication
ctx.use_zbv = use_zbv
if ring is True:
input_to_gather = {"input": input_}
@ -759,8 +906,10 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
use_zbv = ctx.use_zbv
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight_origin = weight
weight = weight.view(weight.shape)
if use_bias:
bias = bias.view(bias.shape)
@ -785,13 +934,49 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated
grad_weight = total_input.t().matmul(grad_output)
# split dx & dw
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
(weight, weight_origin),
functools.partial(
execute_conv1d_w_pass_grad_accum,
),
)
grad_weight = None
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = total_input.t().matmul(grad_output)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
(weight, weight_origin),
functools.partial(
execute_conv1d_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None
return output, grad_weight, grad_bias, None, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
@ -1108,12 +1293,18 @@ def _all_to_all_single(
).contiguous()
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
def matmul_with_async_comm(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
):
return MatmulWithAsyncCommunication.apply(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv
)
def matmul_with_grad_comm(input_, weight, bias, async_grad_allreduce, use_zbv=False):
return MatmulWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)
def linear_with_async_comm(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
):
@ -1127,10 +1318,10 @@ def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=F
def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, use_zbv=False
):
return _LinearWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, use_zbv
)
@ -1142,15 +1333,25 @@ def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_commun
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim, fp8_communication)
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring)
def linear_reducescatter_forward_gather_backward(
input_, weight, bias=None, process_group=None, dim=1, ring=False, use_zbv=False
):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring, use_zbv)
def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False
input_,
weight,
bias,
process_group,
async_grad_reduce_scatter,
dim,
ring=False,
fp8_communication=False,
use_zbv=False,
):
return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv
)

View File

@ -350,6 +350,7 @@ class Linear1D_Col(ParallelModule):
True,
self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
use_zbv=self.use_zbv,
)
else:
output_parallel = linear_with_async_comm(
@ -580,6 +581,7 @@ class Linear1D_Row(ParallelModule):
process_group=self.process_group,
dim=self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
use_zbv=self.use_zbv,
)
else:
output_parallel = F.linear(input_, self.weight)

View File

@ -7,7 +7,6 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
@ -28,8 +27,10 @@ from ._operation import (
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
linear_with_grad_accum,
matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm,
matmul_with_grad_comm,
reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward,
@ -37,7 +38,14 @@ from ._operation import (
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset, is_share_sp_tp
__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"]
__all__ = [
"FusedLinear1D_Col",
"FusedLinear1D_Row",
"FusedLinear",
"GPT2FusedLinearConv1D_Col",
"GPT2FusedLinearConv1D_Row",
"GPT2FusedLinearConv",
]
# ====================================
# For GPT Only
@ -228,6 +236,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
use_zbv: bool = False,
):
super().__init__()
@ -241,6 +250,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.split_sizes = split_sizes
self.process_group = process_group
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
assert (
sum(split_sizes) == out_features
@ -375,6 +385,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
1,
ring=self.seq_parallel_mode == "ring",
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
)
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
# Set up backprop all-reduce.
@ -386,6 +397,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.process_group,
True,
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
)
else:
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
@ -441,6 +453,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1,
fp8_communication: bool = False,
use_zbv: bool = False,
):
super().__init__()
@ -455,6 +468,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
self.seq_parallel_mode = seq_parallel_mode
self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -620,6 +634,152 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
return output, self.bias
class GPT2FusedLinearConv(ParallelModule):
r"""Linear layer without parallelism.
This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
use_zbv: bool = False,
):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.skip_bias_add = skip_bias_add
self.device = device
self.use_zbv = use_zbv
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, None)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
# Initialize weight.
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
if weight is None:
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(
module: nn.Module,
*args,
**kwargs,
) -> ParallelModule:
r"""
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
Args:
module (`nn.Linear`): The module to be converted.
split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.weight.shape[0]
out_features = module.weight.shape[1]
bias = module.bias is not None
device = module.weight.device
linear_1d = GPT2FusedLinearConv(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs,
)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
# Set up backprop all-reduce.
input_parallel = input_
output_parallel = matmul_with_grad_comm(
input_parallel,
self.weight,
bias,
False,
self.use_zbv,
)
else:
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output
# ====================================
# For Fused torch.nn.Linear
# ====================================
@ -671,6 +831,7 @@ class FusedLinear1D_Col(ParallelModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
use_zbv: bool = False,
):
super().__init__()
# Keep input parameters
@ -684,6 +845,7 @@ class FusedLinear1D_Col(ParallelModule):
self.split_sizes = split_sizes
self.process_group = process_group
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
assert (
sum(split_sizes) == out_features
@ -811,10 +973,17 @@ class FusedLinear1D_Col(ParallelModule):
True,
self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
use_zbv=self.use_zbv,
)
else:
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
input_parallel,
self.weight,
bias,
self.process_group,
True,
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
)
if self.gather_output:
@ -870,6 +1039,7 @@ class FusedLinear1D_Row(ParallelModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
use_zbv: bool = False,
):
super().__init__()
# Keep input parameters
@ -883,6 +1053,7 @@ class FusedLinear1D_Row(ParallelModule):
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
assert (
sum(split_sizes) == in_features
@ -1009,9 +1180,18 @@ class FusedLinear1D_Row(ParallelModule):
process_group=self.process_group,
dim=self.seq_parallel_dim,
ring=self.seq_parallel_mode == "ring",
use_zbv=self.use_zbv,
)
else:
output_parallel = F.linear(input_, self.weight)
# output_parallel = F.linear(input_, self.weight) # Replace to LinearWithGradAccum
output_parallel = linear_with_grad_accum(
input_,
self.weight,
None,
False,
use_zbv=self.use_zbv,
)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
if not self.skip_bias_add:
@ -1020,3 +1200,156 @@ class FusedLinear1D_Row(ParallelModule):
return output
else:
return output, self.bias
class FusedLinear(ParallelModule):
r"""Fused Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `torch.nn.Linear` layer (Fused QKV) in normal torch layer of huggingface, like SAM.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
split_sizes (List[int]): The sizes of the split tensor.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
seq_parallel_mode: str = None,
seq_parallel_dim: int = 1,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
use_zbv: bool = False,
):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.skip_bias_add = skip_bias_add
self.device = device
self.use_zbv = use_zbv
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=None)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
# Initialize weight.
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
if weight is None:
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(
module: nn.Module,
*args,
**kwargs,
) -> ParallelModule:
r"""
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
Args:
module (`nn.Linear`): The module to be converted.
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
linear_1d = FusedLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs,
)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)
# Set up backprop all-reduce.
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_grad_accum(input_parallel, self.weight, bias, True, use_zbv=self.use_zbv)
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output

View File

@ -9,6 +9,43 @@ from torch.distributed import ProcessGroup, get_world_size
from colossalai.accelerator import get_accelerator
try:
import fused_weight_gradient_mlp_cuda
_grad_accum_fusion_available = True
except ImportError:
_grad_accum_fusion_available = False
# execute_w_pass_grad_accum & execute_conv1d_w_pass for GPT2FusedLinearConv1D
def execute_conv1d_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_):
if _input_.dtype == torch.float32:
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32
elif _input_.dtype in (torch.float16, torch.bfloat16):
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_)
def execute_conv1d_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_input_.t(), _grad_output_)
# execute_w_pass_grad_accum & execute_w_pass for Linear (except GPT2FusedLinearConv1D)
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_):
if _input_.dtype == torch.float32:
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32
elif _input_.dtype in (torch.float16, torch.bfloat16):
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)
class SeqParallelUtils:
@staticmethod

View File

@ -51,6 +51,8 @@ class BlipPolicy(Policy):
else:
norm_cls = col_nn.LayerNorm
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -73,6 +75,7 @@ class BlipPolicy(Policy):
kwargs={
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -80,6 +83,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -88,6 +92,7 @@ class BlipPolicy(Policy):
kwargs={
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -95,6 +100,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
@ -126,6 +132,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -133,6 +140,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -140,6 +148,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -151,6 +160,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -162,6 +172,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -169,6 +180,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -176,6 +188,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -187,6 +200,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -198,6 +212,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -205,6 +220,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -227,6 +243,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -234,6 +251,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -241,6 +259,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -248,6 +267,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -255,6 +275,7 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -262,6 +283,226 @@ class BlipPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_blip2_mlp_forward(),
},
policy=policy,
target_key=Blip2MLP,
)
elif use_zbv:
policy[Blip2EncoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="self_attn.qkv",
target_module=col_nn.FusedLinear,
kwargs={
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.projection",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc1",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc2",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
policy[Blip2QFormerModel] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
),
]
)
policy[Blip2QFormerLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="crossattention.attention.query",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.key",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.value",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="crossattention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="crossattention.output.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="crossattention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate_query.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="output_query.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="output_query.dropout",
target_module=col_nn.DropoutForParallelInput,
),
],
)
policy[OPTDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],

View File

@ -59,6 +59,8 @@ class BloomPolicy(Policy):
sp_partial_derived = sp_mode == "split_gather"
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.n_head % self.shard_config.tensor_parallel_size == 0
@ -78,6 +80,7 @@ class BloomPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -86,6 +89,7 @@ class BloomPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -98,6 +102,7 @@ class BloomPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -106,6 +111,7 @@ class BloomPolicy(Policy):
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
@ -120,6 +126,52 @@ class BloomPolicy(Policy):
},
)
if use_zbv:
policy[BloomBlock] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=[
@ -247,14 +299,27 @@ class BloomPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
if stage_manager.is_interleave:
layers_per_stage = stage_manager.distribute_layers(len(module.h))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.h[start_idx:end_idx])
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.ln_f)
else:
layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
return held_layers
@ -328,8 +393,14 @@ class BloomForCausalLMPolicy(BloomPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.lm_head)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -351,6 +422,7 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification
policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
@ -363,6 +435,18 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
policy=policy,
target_key=BloomForSequenceClassification,
)
elif use_zbv:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="score",
target_module=col_nn.LinearWithGradAccum,
kwargs=dict(
gather_output=True, fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv
),
),
policy=policy,
target_key=BloomForSequenceClassification,
)
if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=BloomForSequenceClassification,
@ -375,8 +459,14 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.score)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.score)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -389,6 +479,7 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
from transformers.models.bloom.modeling_bloom import BloomForTokenClassification
policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
@ -407,6 +498,24 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
policy=policy,
target_key=BloomForTokenClassification,
)
elif use_zbv:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="classifier",
target_module=col_nn.LinearWithGradAccum,
kwargs=dict(
gather_output=True, fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv
),
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
],
policy=policy,
target_key=BloomForTokenClassification,
)
if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=BloomForTokenClassification,
@ -420,9 +529,16 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -448,8 +564,14 @@ class BloomForQuestionAnsweringPolicy(BloomPolicy):
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.qa_outputs)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -83,6 +83,8 @@ class ChatGLMPolicy(Policy):
attribute_replacement=decoder_attribute_replacement,
)
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -145,6 +147,35 @@ class ChatGLMPolicy(Policy):
),
],
)
elif use_zbv:
policy["GLMBlock"] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
"seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attention.core_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
@ -261,17 +292,30 @@ class ChatGLMPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
if stage_manager.is_first_stage():
held_layers.append(module.embedding)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
if module.encoder.post_layer_norm:
held_layers.append(module.encoder.final_layernorm)
if stage_manager.is_interleave:
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
if module.encoder.post_layer_norm:
held_layers.append(module.encoder.final_layernorm)
else:
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
if stage_manager.is_first_stage():
held_layers.append(module.embedding)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
if module.encoder.post_layer_norm:
held_layers.append(module.encoder.final_layernorm)
# rotary_pos_emb is needed for all stages
held_layers.append(module.rotary_pos_emb)
# rotary_pos_emb is needed for all stages
held_layers.append(module.rotary_pos_emb)
return held_layers
@ -335,8 +379,15 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.transformer.output_layer)
stage_manager = self.pipeline_stage_manager
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.transformer.output_layer)
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.transformer.output_layer)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -10,6 +10,7 @@ from colossalai.shardformer.layer import (
LayerNorm,
Linear1D_Col,
Linear1D_Row,
LinearWithGradAccum,
PaddingEmbedding,
PaddingLMHead,
VocabParallelEmbedding1D,
@ -107,6 +108,8 @@ class CommandPolicy(Policy):
target_key=CohereModel,
)
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
assert (
num_q_heads % tp_size == 0
@ -128,41 +131,137 @@ class CommandPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
],
)
elif use_zbv:
policy[CohereDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
@ -258,7 +357,9 @@ class CommandPolicy(Policy):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.norm)
else:
@ -351,8 +452,14 @@ class CommandForCausalLMPolicy(CommandPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.lm_head)
else:
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -6,7 +6,7 @@ from torch import Tensor
from torch.nn import Module
from transformers.utils import is_flash_attn_greater_or_equal_2_10
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, LinearWithGradAccum
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
from colossalai.shardformer.layer.linear import Linear1D_Row
from colossalai.shardformer.modeling.deepseek import (
@ -107,6 +107,8 @@ class DeepseekPolicy(Policy):
if self.tie_weight:
embedding_cls = PaddingEmbedding
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
# tensor parallelism for non-moe params
assert (
@ -133,22 +135,58 @@ class DeepseekPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
),
SubModuleReplacementDescription(
suffix="mlp.gate",
target_module=DeepseekMoEGate_Col,
kwargs={
"gather_output": True,
"fp8_communication": self.shard_config.fp8_communication,
"config": self.model.config,
},
ignore_if_not_exist=True,
),
],
)
elif use_zbv:
policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=LinearWithGradAccum,
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=LinearWithGradAccum,
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=LinearWithGradAccum,
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=LinearWithGradAccum,
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
),
SubModuleReplacementDescription(
suffix="mlp.gate",
@ -162,7 +200,6 @@ class DeepseekPolicy(Policy):
),
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
@ -291,13 +328,26 @@ class DeepseekPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.norm)
else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
@ -330,6 +380,7 @@ class DeepseekModelPolicy(DeepseekPolicy):
class DeepseekForCausalLMPolicy(DeepseekPolicy):
def module_policy(self):
policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
@ -339,7 +390,29 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy):
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]
)
}
policy.update(new_item)
elif use_zbv:
# add a new item for casual lm
new_item = {
"DeepseekForCausalLM": ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=LinearWithGradAccum,
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]
)
@ -360,8 +433,14 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.lm_head)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -51,6 +51,8 @@ class FalconPolicy(Policy):
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -73,10 +75,16 @@ class FalconPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
kwargs=dict(
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
kwargs=dict(
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
@ -85,8 +93,17 @@ class FalconPolicy(Policy):
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
kwargs=dict(
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
kwargs=dict(
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row),
],
)
@ -98,6 +115,44 @@ class FalconPolicy(Policy):
"build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group)
},
)
elif use_zbv:
policy[FalconDecoderLayer] = ModulePolicyDescription(
method_replacement={"forward": get_tp_falcon_decoder_layer_forward()},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.LinearWithGradAccum,
kwargs=dict(
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs=dict(
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.LinearWithGradAccum,
kwargs=dict(
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.LinearWithGradAccum,
kwargs=dict(
use_zbv=use_zbv,
),
),
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
@ -191,13 +246,26 @@ class FalconPolicy(Policy):
module = self.model.transformer
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.h))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.word_embeddings)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.h[start_idx:end_idx])
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.ln_f)
else:
layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
return held_layers
@ -281,8 +349,14 @@ class FalconForCausalLMPolicy(FalconPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.lm_head)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -308,11 +382,23 @@ class FalconForSequenceClassificationPolicy(FalconPolicy):
policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True, use_zbv=use_zbv)
),
policy=policy,
target_key=FalconForSequenceClassification,
)
elif use_zbv:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="score",
target_module=col_nn.LinearWithGradAccum,
kwargs=dict(gather_output=True, use_zbv=use_zbv),
),
policy=policy,
target_key=FalconForSequenceClassification,
@ -330,8 +416,14 @@ class FalconForSequenceClassificationPolicy(FalconPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.score)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.score)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -348,12 +440,32 @@ class FalconForTokenClassificationPolicy(FalconPolicy):
policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
suffix="classifier",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True, use_zbv=use_zbv),
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
],
policy=policy,
target_key=FalconForTokenClassification,
)
elif use_zbv:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="classifier",
target_module=col_nn.LinearWithGradAccum,
kwargs=dict(gather_output=True, use_zbv=use_zbv),
),
SubModuleReplacementDescription(
suffix="dropout",
@ -375,9 +487,16 @@ class FalconForTokenClassificationPolicy(FalconPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -394,11 +513,25 @@ class FalconForQuestionAnsweringPolicy(FalconPolicy):
policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="qa_outputs", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
suffix="qa_outputs",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True, use_zbv=use_zbv),
),
policy=policy,
target_key=FalconForQuestionAnswering,
)
elif use_zbv:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="qa_outputs",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True, use_zbv=use_zbv),
),
policy=policy,
target_key=FalconForQuestionAnswering,
@ -415,8 +548,14 @@ class FalconForQuestionAnsweringPolicy(FalconPolicy):
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.qa_outputs)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -67,6 +67,8 @@ class GPT2Policy(Policy):
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
sp_partial_derived = sp_mode in ["split_gather", "ring"]
use_flash_attention = self.shard_config.enable_flash_attention
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -94,12 +96,17 @@ class GPT2Policy(Policy):
"split_sizes": [self.model.config.hidden_size] * 3,
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
@ -109,12 +116,17 @@ class GPT2Policy(Policy):
"seq_parallel_mode": sp_mode,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication},
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
@ -138,6 +150,78 @@ class GPT2Policy(Policy):
policy=policy,
target_key=GPT2MLP,
)
elif use_zbv:
policy[GPT2Model] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="drop",
target_module=col_nn.DropoutForParallelInput,
),
]
)
policy[GPT2Block] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv,
kwargs={
"seq_parallel_mode": sp_mode,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
),
],
)
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_gpt2_mlp_forward(),
},
policy=policy,
target_key=GPT2MLP,
)
if embedding_cls is not None:
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
self.append_or_create_submodule_replacement(
@ -352,8 +436,17 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
stage_manager = self.pipeline_stage_manager
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.lm_head)
else:
if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
# if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True):
# held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -420,13 +513,24 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
multiple_choice_head = self.model.multiple_choice_head
held_layers.append(self.model.lm_head)
held_layers.append(multiple_choice_head.summary)
held_layers.append(multiple_choice_head.activation)
held_layers.append(multiple_choice_head.first_dropout)
held_layers.append(multiple_choice_head.last_dropout)
stage_manager = self.pipeline_stage_manager
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.lm_head)
held_layers.append(multiple_choice_head.summary)
held_layers.append(multiple_choice_head.activation)
held_layers.append(multiple_choice_head.first_dropout)
held_layers.append(multiple_choice_head.last_dropout)
else:
if self.pipeline_stage_manager.is_last_stage():
multiple_choice_head = self.model.multiple_choice_head
held_layers.append(self.model.lm_head)
held_layers.append(multiple_choice_head.summary)
held_layers.append(multiple_choice_head.activation)
held_layers.append(multiple_choice_head.first_dropout)
held_layers.append(multiple_choice_head.last_dropout)
return held_layers
@ -464,8 +568,17 @@ class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
stage_manager = self.pipeline_stage_manager
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.qa_outputs)
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
# if self.pipeline_stage_manager.is_last_stage():
# held_layers.append(self.model.qa_outputs)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -503,9 +616,20 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
stage_manager = self.pipeline_stage_manager
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
# if self.pipeline_stage_manager.is_last_stage():
# held_layers.append(self.model.dropout)
# held_layers.append(self.model.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -530,8 +654,18 @@ class GPT2ForSequenceClassificationPolicy(GPT2Policy):
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.score)
stage_manager = self.pipeline_stage_manager
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.score)
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.score)
# if self.pipeline_stage_manager.is_last_stage():
# held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -51,6 +51,8 @@ class GPTJPolicy(Policy):
self.shard_config.enable_sequence_parallelism = False
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -76,6 +78,7 @@ class GPTJPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -83,6 +86,7 @@ class GPTJPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -90,6 +94,7 @@ class GPTJPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -97,6 +102,7 @@ class GPTJPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -104,6 +110,7 @@ class GPTJPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -111,6 +118,72 @@ class GPTJPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
),
],
)
elif use_zbv:
policy[GPTJBlock] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.k_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.q_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.v_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.out_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc_in",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.fc_out",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -127,7 +200,6 @@ class GPTJPolicy(Policy):
),
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
@ -200,13 +272,25 @@ class GPTJPolicy(Policy):
held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.wte)
held_layers.append(module.drop)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
if stage_manager.is_interleave:
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.wte)
held_layers.append(module.drop)
stage_indices = stage_manager.get_stage_index(layers_per_stage)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.h[start_idx:end_idx])
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.ln_f)
else:
if stage_manager.is_first_stage():
held_layers.append(module.wte)
held_layers.append(module.drop)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
@ -309,8 +393,15 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
stage_manager = self.pipeline_stage_manager
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.lm_head)
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -349,8 +440,15 @@ class GPTJForSequenceClassificationPolicy(GPTJPolicy):
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.score)
stage_manager = self.pipeline_stage_manager
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.score)
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -378,8 +476,15 @@ class GPTJForQuestionAnsweringPolicy(GPTJPolicy):
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
stage_manager = self.pipeline_stage_manager
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.qa_outputs)
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -324,9 +324,10 @@ class MistralPolicy(Policy):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.norm)
else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
@ -419,8 +420,14 @@ class MistralForCausalLMPolicy(MistralPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.lm_head)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -475,8 +482,14 @@ class MistralForSequenceClassificationPolicy(MistralPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.score)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -10,6 +10,7 @@ from colossalai.shardformer.layer import (
LayerNorm,
Linear1D_Col,
Linear1D_Row,
LinearWithGradAccum,
PaddingEmbedding,
PaddingLMHead,
VocabParallelEmbedding1D,
@ -76,6 +77,8 @@ class OPTPolicy(Policy):
self.shard_config.enable_sequence_parallelism = False
warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -85,10 +88,16 @@ class OPTPolicy(Policy):
SubModuleReplacementDescription(
suffix="fc1",
target_module=Linear1D_Col,
kwargs=dict(
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=Linear1D_Row,
kwargs=dict(
use_zbv=use_zbv,
),
),
]
)
@ -104,6 +113,7 @@ class OPTPolicy(Policy):
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -111,6 +121,7 @@ class OPTPolicy(Policy):
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -118,6 +129,7 @@ class OPTPolicy(Policy):
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -125,11 +137,67 @@ class OPTPolicy(Policy):
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
elif use_zbv:
policy[OPTDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="fc1",
target_module=LinearWithGradAccum,
kwargs=dict(
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=LinearWithGradAccum,
kwargs=dict(
use_zbv=use_zbv,
),
),
]
)
policy[attn_cls] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="k_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="v_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="out_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
@ -221,15 +289,30 @@ class OPTPolicy(Policy):
held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
held_layers.append(module.embed_positions)
held_layers.append(module.project_in)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.final_layer_norm)
held_layers.append(module.project_out)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
held_layers.append(module.embed_positions)
held_layers.append(module.project_in)
stage_indices = stage_manager.get_stage_index(layers_per_stage)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.final_layer_norm)
held_layers.append(module.project_out)
else:
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
held_layers.append(module.embed_positions)
held_layers.append(module.project_in)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.final_layer_norm)
held_layers.append(module.project_out)
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
@ -323,8 +406,15 @@ class OPTForCausalLMPolicy(OPTPolicy):
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
stage_manager = self.pipeline_stage_manager
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.lm_head)
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -395,8 +485,15 @@ class OPTForQuestionAnsweringPolicy(OPTPolicy):
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
stage_manager = self.pipeline_stage_manager
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.qa_outputs)
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -9,6 +9,7 @@ from colossalai.shardformer.layer import (
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
LinearWithGradAccum,
PaddingEmbedding,
RMSNorm,
VocabParallelEmbedding1D,
@ -96,6 +97,8 @@ class Qwen2Policy(Policy):
attribute_replacement=decoder_attribute_replacement,
)
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -119,37 +122,134 @@ class Qwen2Policy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
],
)
elif use_zbv:
policy[Qwen2DecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
],
)
@ -278,7 +378,9 @@ class Qwen2Policy(Policy):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(ignore_chunk=True):
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.norm)
else:
@ -318,6 +420,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
def module_policy(self):
policy = super().module_policy()
setattr(self.shard_config, "causal_lm", True)
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
@ -327,7 +430,22 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
)
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
)
}
policy.update(new_item)
elif use_zbv:
# add a new item for casual lm
new_item = {
Qwen2ForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=LinearWithGradAccum,
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
)
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
@ -347,8 +465,14 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.lm_head)
else:
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -371,6 +495,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
def module_policy(self):
policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
new_item = {
@ -379,7 +504,28 @@ class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
SubModuleReplacementDescription(
suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]
)
}
policy.update(new_item)
elif use_zbv:
new_item = {
Qwen2ForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score",
target_module=LinearWithGradAccum,
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]
)
@ -399,8 +545,14 @@ class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.score)
else:
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -27,6 +27,7 @@ class SamPolicy(Policy):
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
assert (
@ -44,6 +45,7 @@ class SamPolicy(Policy):
kwargs={
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -51,6 +53,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -58,6 +61,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -65,6 +69,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
@ -80,6 +85,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -87,6 +93,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -94,6 +101,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -101,6 +109,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -108,6 +117,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -115,6 +125,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -122,6 +133,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -129,6 +141,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -136,6 +149,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -143,6 +157,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -150,6 +165,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -157,6 +173,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -164,6 +181,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -171,6 +189,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
@ -186,6 +205,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -193,6 +213,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -200,6 +221,7 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -207,6 +229,209 @@ class SamPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
# add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout`
policy[SamVisionAttention] = ModulePolicyDescription(
attribute_replacement={
"dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout)
},
method_replacement={"forward": forward_fn()},
sub_module_replacement=[],
)
elif use_zbv:
policy[SamVisionLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.qkv",
target_module=col_nn.FusedLinear,
kwargs={
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attn.proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.lin1",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.lin2",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
policy[SamTwoWayAttentionBlock] = ModulePolicyDescription(
attribute_replacement={
"self_attn.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads
// self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.q_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.k_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.v_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_token_to_image.out_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.lin1",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.lin2",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.q_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.k_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.v_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="cross_attn_image_to_token.out_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
policy[SamTwoWayTransformer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.q_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.k_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.v_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="final_attn_token_to_image.out_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],

View File

@ -13,6 +13,7 @@ from colossalai.shardformer.layer import (
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
LinearWithGradAccum,
PaddingEmbedding,
PaddingLMHead,
RMSNorm,
@ -77,6 +78,8 @@ class T5BasePolicy(Policy):
self.shard_config.enable_sequence_parallelism = False
warnings.warn("T5 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_heads % self.shard_config.tensor_parallel_size == 0
@ -119,6 +122,7 @@ class T5BasePolicy(Policy):
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -126,6 +130,7 @@ class T5BasePolicy(Policy):
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -133,6 +138,7 @@ class T5BasePolicy(Policy):
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -140,6 +146,7 @@ class T5BasePolicy(Policy):
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -168,6 +175,7 @@ class T5BasePolicy(Policy):
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -175,6 +183,7 @@ class T5BasePolicy(Policy):
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -183,6 +192,7 @@ class T5BasePolicy(Policy):
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
@ -198,6 +208,7 @@ class T5BasePolicy(Policy):
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -205,6 +216,142 @@ class T5BasePolicy(Policy):
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]
)
elif use_zbv:
policy[T5Stack] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]
)
policy[T5LayerSelfAttention] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]
)
policy[T5LayerCrossAttention] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
)
]
)
policy[T5Attention] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="k",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="v",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="o",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="relative_attention_bias",
target_module=Embedding1D,
kwargs=dict(
gather_output=False,
fp8_communication=self.shard_config.fp8_communication,
),
ignore_if_not_exist=True,
),
],
)
policy[T5LayerFF] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]
)
policy[T5DenseGatedActDense] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi_0 ",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="wi_1",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="wo",
target_module=LinearWithGradAccum,
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
]
)
policy[T5DenseActDense] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wi",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="wo",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -213,7 +360,6 @@ class T5BasePolicy(Policy):
),
]
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
@ -369,30 +515,61 @@ class T5BasePolicy(Policy):
num_decoder_layers = len(decoder.block) if decoder else 0
held_layers = []
layers_per_stage, decoder_starting_stage = self.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
if stage_manager.stage < decoder_starting_stage:
# current stage is in t5's encoder
if stage_manager.is_first_stage():
held_layers.append(model.shared)
held_layers.append(encoder.embed_tokens)
held_layers.append(encoder.dropout)
if stage_manager.stage == decoder_starting_stage - 1:
held_layers.append(encoder.final_layer_norm)
held_layers.append(encoder.dropout)
held_layers.extend(encoder.block[start_idx:end_idx])
if stage_manager.is_interleave:
layers_per_stage, decoder_starting_stage = self.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
stage_indices = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
if stage_manager.stage < decoder_starting_stage:
# current stage is in t5's encoder
if stage_manager.is_first_stage():
held_layers.append(model.shared)
held_layers.append(encoder.embed_tokens)
held_layers.append(encoder.dropout)
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(encoder.final_layer_norm)
held_layers.append(encoder.dropout)
for start_idx, end_idx in stage_indices:
held_layers.extend(encoder.block[start_idx:end_idx])
else:
# current stage is in t5's decoder
if stage_manager.stage == decoder_starting_stage:
held_layers.append(decoder.embed_tokens)
held_layers.append(decoder.dropout)
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(decoder.final_layer_norm)
held_layers.append(decoder.dropout)
for start_idx, end_idx in stage_indices:
held_layers.extend(decoder.block[start_idx:end_idx])
else:
# current stage is in t5's decoder
if stage_manager.stage == decoder_starting_stage:
held_layers.append(decoder.embed_tokens)
held_layers.append(decoder.dropout)
if stage_manager.is_last_stage():
held_layers.append(decoder.final_layer_norm)
held_layers.append(decoder.dropout)
held_layers.extend(decoder.block[start_idx:end_idx])
layers_per_stage, decoder_starting_stage = self.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
if stage_manager.stage < decoder_starting_stage:
# current stage is in t5's encoder
if stage_manager.is_first_stage():
held_layers.append(model.shared)
held_layers.append(encoder.embed_tokens)
held_layers.append(encoder.dropout)
if stage_manager.stage == decoder_starting_stage - 1:
held_layers.append(encoder.final_layer_norm)
held_layers.append(encoder.dropout)
held_layers.extend(encoder.block[start_idx:end_idx])
else:
# current stage is in t5's decoder
if stage_manager.stage == decoder_starting_stage:
held_layers.append(decoder.embed_tokens)
held_layers.append(decoder.dropout)
if stage_manager.is_last_stage():
held_layers.append(decoder.final_layer_norm)
held_layers.append(decoder.dropout)
held_layers.extend(decoder.block[start_idx:end_idx])
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
@ -545,8 +722,15 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
stage_manager = self.pipeline_stage_manager
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.lm_head)
else:
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -652,9 +836,16 @@ class T5ForTokenClassificationPolicy(T5EncoderPolicy):
"""
held_layers = super().get_held_layers()
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
else:
if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.dropout)
held_layers.append(self.model.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -43,6 +43,8 @@ class ViTPolicy(Policy):
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -72,6 +74,7 @@ class ViTPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -79,6 +82,7 @@ class ViTPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -86,6 +90,7 @@ class ViTPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -97,6 +102,7 @@ class ViTPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -109,6 +115,7 @@ class ViTPolicy(Policy):
kwargs={
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -116,6 +123,7 @@ class ViTPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -132,7 +140,92 @@ class ViTPolicy(Policy):
policy=policy,
target_key=ViTIntermediate,
)
elif use_zbv:
policy[ViTEmbeddings] = ModulePolicyDescription(
attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForReplicatedInput,
)
],
)
policy[ViTLayer] = ModulePolicyDescription(
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
],
)
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_vit_intermediate_forward(),
},
policy=policy,
target_key=ViTIntermediate,
)
# use flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
@ -173,11 +266,20 @@ class ViTPolicy(Policy):
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
if stage_manager.is_first_stage():
held_layers.append(module.embeddings)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embeddings)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.encoder.layer[start_idx:end_idx])
else:
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
if stage_manager.is_first_stage():
held_layers.append(module.embeddings)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx])
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict):
@ -213,9 +315,16 @@ class ViTModelPolicy(ViTPolicy):
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
held_layers.append(module.layernorm)
held_layers.append(module.pooler)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.layernorm)
held_layers.append(module.pooler)
else:
if stage_manager.is_last_stage():
held_layers.append(module.layernorm)
held_layers.append(module.pooler)
return held_layers
@ -226,6 +335,9 @@ class ViTForImageClassificationPolicy(ViTPolicy):
from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel
policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
new_item = {
ViTForImageClassification: ModulePolicyDescription(
@ -233,13 +345,33 @@ class ViTForImageClassificationPolicy(ViTPolicy):
SubModuleReplacementDescription(
suffix="classifier",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]
)
}
policy.update(new_item)
elif use_zbv:
new_item = {
ViTForImageClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="classifier",
target_module=col_nn.LinearWithGradAccum,
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]
)
}
policy.update(new_item)
if self.shard_config.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
self.set_pipeline_forward(
@ -256,9 +388,16 @@ class ViTForImageClassificationPolicy(ViTPolicy):
module = self.model.vit
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
held_layers.append(module.layernorm)
held_layers.append(self.model.classifier)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.layernorm)
held_layers.append(self.model.classifier)
else:
if stage_manager.is_last_stage():
held_layers.append(module.layernorm)
held_layers.append(self.model.classifier)
return held_layers
@ -285,8 +424,15 @@ class ViTForMaskedImageModelingPolicy(ViTPolicy):
module = self.model.vit
stage_manager = self.pipeline_stage_manager
if stage_manager.is_last_stage():
held_layers.append(module.layernorm)
held_layers.append(self.model.decoder)
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.layernorm)
held_layers.append(self.model.decoder)
else:
if stage_manager.is_last_stage():
held_layers.append(module.layernorm)
held_layers.append(self.model.decoder)
return held_layers

View File

@ -72,6 +72,8 @@ class WhisperPolicy(Policy):
"Whisper doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
)
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# TODO using the jit fused add_and_dropout affect the accuracy
if self.shard_config.enable_jit_fused:
self.shard_config.enable_jit_fused = False
@ -93,6 +95,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -100,6 +103,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -107,6 +111,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -114,6 +119,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -121,6 +127,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -128,6 +135,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
@ -148,6 +156,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -155,6 +164,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -162,6 +172,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -169,6 +180,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -176,6 +188,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -183,6 +196,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -190,6 +204,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -197,6 +212,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -204,6 +220,7 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -211,6 +228,145 @@ class WhisperPolicy(Policy):
target_module=col_nn.Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
elif use_zbv:
policy[WhisperEncoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
policy[WhisperDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.out_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="encoder_attn.q_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="encoder_attn.k_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="encoder_attn.v_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="encoder_attn.out_proj",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="fc1",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
@ -460,30 +616,66 @@ class WhisperPolicy(Policy):
num_decoder_layers = 0
held_layers = []
layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
start_idx, end_idx = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
if stage_manager.is_interleave:
layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
stage_indices = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
if stage_manager.stage < decoder_starting_stage:
# current stage is in whisper's encoder
if stage_manager.is_first_stage():
held_layers.append(encoder.embed_positions)
held_layers.append(encoder.conv1)
held_layers.append(encoder.conv2)
if stage_manager.stage == decoder_starting_stage - 1:
held_layers.append(encoder.layer_norm)
held_layers.extend(encoder.layers[start_idx:end_idx])
if stage_manager.stage < decoder_starting_stage:
# current stage is in whisper's encoder
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(encoder.embed_positions)
held_layers.append(encoder.conv1)
held_layers.append(encoder.conv2)
# interleaved: not use_zbv & stage_manager.stage == decoder_starting_stage - 1
# zbv: use_zbv & stage_manager.stage == first stage
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and decoder_starting_stage - 1
):
held_layers.append(encoder.layer_norm)
for start_idx, end_idx in stage_indices:
held_layers.extend(encoder.layers[start_idx:end_idx])
else:
# current stage is in whisper's decoder
# TODO:(Jianghai) We divide encoder and decoder layers into different parts here,
# the case encoder and decoder put in same stage should be add in the future.
if stage_manager.stage == decoder_starting_stage:
held_layers.append(decoder.embed_tokens)
held_layers.append(decoder.embed_positions)
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(decoder.layer_norm)
for start_idx, end_idx in stage_indices:
held_layers.extend(encoder.layers[start_idx:end_idx])
else:
# current stage is in whisper's decoder
# TODO:(Jianghai) We divide encoder and decoder layers into different parts here,
# the case encoder and decoder put in same stage should be add in the future.
if stage_manager.stage == decoder_starting_stage:
held_layers.append(decoder.embed_tokens)
held_layers.append(decoder.embed_positions)
if stage_manager.is_last_stage():
held_layers.append(decoder.layer_norm)
held_layers.extend(decoder.layers[start_idx:end_idx])
layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
start_idx, end_idx = self.get_whisper_stage_index(
layers_per_stage, stage_manager.stage, decoder_starting_stage
)
if stage_manager.stage < decoder_starting_stage:
# current stage is in whisper's encoder
if stage_manager.is_first_stage():
held_layers.append(encoder.embed_positions)
held_layers.append(encoder.conv1)
held_layers.append(encoder.conv2)
if stage_manager.stage == decoder_starting_stage - 1:
held_layers.append(encoder.layer_norm)
held_layers.extend(encoder.layers[start_idx:end_idx])
else:
# current stage is in whisper's decoder
# TODO:(Jianghai) We divide encoder and decoder layers into different parts here,
# the case encoder and decoder put in same stage should be add in the future.
if stage_manager.stage == decoder_starting_stage:
held_layers.append(decoder.embed_tokens)
held_layers.append(decoder.embed_positions)
if stage_manager.is_last_stage():
held_layers.append(decoder.layer_norm)
held_layers.extend(decoder.layers[start_idx:end_idx])
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
@ -575,8 +767,15 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.proj_out)
stage_manager = self.pipeline_stage_manager
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.proj_out)
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.proj_out)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
@ -629,9 +828,17 @@ class WhisperForAudioClassificationPolicy(WhisperPolicy):
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.projector)
held_layers.append(self.model.classifier)
stage_manager = self.pipeline_stage_manager
if stage_manager.is_interleave:
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(self.model.projector)
held_layers.append(self.model.classifier)
else:
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.projector)
held_layers.append(self.model.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:

View File

@ -0,0 +1,238 @@
# ZeroBubble Pipeline Parallelism
Author: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217)
**Related Paper**
- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241)
## Introduction
ZeroBubble (V Schedule):
Crucially, splitting B into two stages (also known as an activation gradient and a weight gradient) and a scheme like 1F1B1W can further reduce the bubble compared to the 1F1B scheme in earlier work.
## Hands-On Practice
We now demonstrate how to use ZeroBubble with booster API with 4 GPUs.
### step 1. Import libraries
```python
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
```
### step 2. Initialize Distributed Environment and Parallism Group
```python
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
```
### step 3. Initialize Module, Optimizer, and Pipeline Schedule
Build our model and Optimizer. We created a Llama with 8 Decoder-Layer. Then, inite the PipelineGraph and Pipeline schedule by get_v_schedule() function.
```python
# Global Param
NUM_BATCH = 8
NUM_TOK_PER_BATCH = 4
NUM_LAYERS = 8
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
# Init Llama from huggingface
configuration = LlamaConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=NUM_LAYERS,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
attn_implementation="flash_attention_2",
)
model = LlamaModel(configuration).cuda()
optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
```
### step 4. Initialize Module, Optimizer, and Pipeline Schedul
Then, we need to create the PipelineGraph and PipelineSchedule using the get_v_schedule() function. We need to initialise the PipelineGraph with the following parameters.
x_cost represents the runtime consumed by operation x of each model chunk.
x_mem represents the amount of memory consumed by the operation x of each model chunk.
These parameters are estimated and filled in before the pipeline starts. In fact, better results can be obtained based on the runtime and memory cost during the real computation of the model.
In the following example, we assume that the computation times for the model's forward, reverse B, and reverse W are 1, 1, 1, respectively, and the p2p communication time is 1.
```python
# Init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
n_stage=pp_size,
n_micro=num_microbatches,
f_cost=1,
b_cost=1,
w_cost=1,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
)
zbv_schedule = graph.get_v_schedule()
```
### step 5.Init Booster
Pass pp_style="zbv" when initialising the Plugin to use the ZeroBubble Pipeline.
```python
plugin = HybridParallelPlugin(
pp_size=4,
num_microbatches=4,
tp_size=1,
sp_size=1,
zero_stage=1,
initial_scale=1,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
dp_size = plugin.dp_size
booster = Booster(plugin=plugin)
```
### step 6.Train Your Model
```python
steps = 10
for step in range(steps):
input_embeddings = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(
input_embeddings, group=plugin.pp_group
)
data_iter = iter([{"inputs_embeds": input_embeddings}])
output = booster.execute_pipeline(
data_iter,
model,
lambda x, y: x.last_hidden_state.mean(),
optimizer,
return_loss=True,
return_outputs=True,
)
optimizer.step()
optimizer.zero_grad()
```
## Advanced Practice
In ColossalAI, you can get better training performance by using MetaCache and HybridParallel with ZeroBubble.
### 1.Use MetaCache with ZeroBubble
Pass "enable_metadata_cache=True" when initialising the Plugin to use the Meta Cache with ZeroBubble Pipeline.
```python
plugin = HybridParallelPlugin(
pp_size=2,
num_microbatches=4,
tp_size=2,
sp_size=2,
zero_stage=1,
initial_scale=1,
enable_metadata_cache=True,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
```
### 2.HybridParallel with ZeroBubble
Pass pp_size, tp_size, sp_size when initialising the Plugin to use the HybridParallel with ZeroBubble Pipeline.
```python
plugin = HybridParallelPlugin(
pp_size=2,
num_microbatches=2,
tp_size=2,
sp_size=2,
zero_stage=1,
initial_scale=1,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
```
Performance Benchmark
<table>
<tr>
<th nowrap="nowrap">HybridParallel Strategy</th>
<th nowrap="nowrap" align="center">Pipeline Parallel</th>
<th nowrap="nowrap" align="center">Sequence Parallel + Pipeline Parallel</th>
<th nowrap="nowrap" align="center">Data Parallel + Pipeline Parallel</th>
</tr>
<tr>
<td nowrap="nowrap" align="center" title="1F1B">With 1F1B</td>
<td nowrap="nowrap" align="center">15.27 samples/sec</td>
<td nowrap="nowrap" align="center">17.22 samples/sec</td>
<td nowrap="nowrap" align="center">14.06 samples/sec</td>
</tr>
<tr>
<td nowrap="nowrap" align="center" title="Zero Bubble">With Zero Bubble</td>
<td nowrap="nowrap" align="center">17.36 samples/sec</td>
<td nowrap="nowrap" align="center">18.38 samples/sec</td>
<td nowrap="nowrap" align="center">14.44 samples/sec</td>
</tr>
<tr>
<td colspan="39"></td>
</tr>
</table>
### 3.Fine-tuning Scheduler parameters
```python
```
## Model compatibility
<table>
<tr>
<th nowrap="nowrap">Shardformer/Model</th>
<th nowrap="nowrap" align="center">Bert</th>
<th nowrap="nowrap" align="center">Blip2</th>
<th nowrap="nowrap" align="center">Bloom</th>
<th nowrap="nowrap" align="center">Chatglm2</th>
<th nowrap="nowrap" align="center">Command</th>
<th nowrap="nowrap" align="center">Deepseek</th>
<th nowrap="nowrap" align="center">Falcon</th>
<th nowrap="nowrap" align="center">GPT2</th>
<th nowrap="nowrap" align="center">Gptj</th>
<th nowrap="nowrap" align="center">Llama</th>
<th nowrap="nowrap" align="center">Mistral</th>
<th nowrap="nowrap" align="center">Opt</th>
<th nowrap="nowrap" align="center">Qwen2</th>
<th nowrap="nowrap" align="center">Sam</th>
<th nowrap="nowrap" align="center">T5</th>
<th nowrap="nowrap" align="center">Vit</th>
<th nowrap="nowrap" align="center">Whisper</th>
</tr>
<tr>
<td nowrap="nowrap" align="center" title="ZeroBubble">ZeroBubble</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td colspan="39"></td>
</tr>
</table>
## API Reference
{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }}
<!-- doc-test-command: torchrun --standalone --nproc_per_node=4 zerobubble_pipeline_parallelism.py -->

View File

@ -0,0 +1,237 @@
# 零气泡流水线并行
作者: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217)
**相关论文**
- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241)
## 介绍
零气泡V Schedule
与早期工作中的1F1B方案相比零气泡流水线并行将B分成两个阶段也称为激活梯度和权重梯度形如1F1B1W这样的方案可以进一步减少气泡。
## 使用
我们将演示如何在 4 个 GPU 上使用带有 booster API 的 ZeroBubble
### step 1. 引用仓库
```python
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
```
### step 2. 初始化分布式环境
```python
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
```
### step 3. 初始化模型优化器
建立我们的模型和优化器 我们创建了一个带有8层Decoder-Layer的 Llama。然后使用get_v_schedule()函数创建PipelineGraph和Pipeline schedule。
```python
# Global Param
NUM_BATCH = 8
NUM_TOK_PER_BATCH = 4
NUM_LAYERS = 8
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
# Init Llama from huggingface
configuration = LlamaConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=NUM_LAYERS,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
attn_implementation="flash_attention_2",
)
model = LlamaModel(configuration).cuda()
optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
```
### step 4.初始化流水线Schedule
然后,我们需要使用 get_v_schedule() 函数创建 PipelineGraph 和 PipelineSchedule。我们需要用以下参数初始化 PipelineGraph。
x_cost 表示每个模型块的操作 x 所消耗的运行时间。
x_mem 表示每个模型块的操作 x 所消耗的内存量。
这些参数都是在流水线启动前估算并填入的。事实上,在模型的实际计算过程中,根据运行时间和内存成本可以获得更好的结果。
在下面的例子中,我们假设模型的正向、反向 B 和反向 W 的计算时间分别为 1、1、1p2p 通信时间为 1。
```python
# Init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
n_stage=pp_size,
n_micro=num_microbatches,
f_cost=1,
b_cost=1,
w_cost=1,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
)
zbv_schedule = graph.get_v_schedule()
```
### step 5.初始化Booster
在初始化Plugin时输入pp_style="zbv"以使用ZeroBubble流水线并行。
```python
plugin = HybridParallelPlugin(
pp_size=4,
num_microbatches=4,
tp_size=1,
sp_size=1,
zero_stage=1,
initial_scale=1,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
dp_size = plugin.dp_size
booster = Booster(plugin=plugin)
```
### step 6.训练模型
```python
steps = 10
for step in range(steps):
input_embeddings = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(
input_embeddings, group=plugin.pp_group
)
data_iter = iter([{"inputs_embeds": input_embeddings}])
output = booster.execute_pipeline(
data_iter,
model,
lambda x, y: x.last_hidden_state.mean(),
optimizer,
return_loss=True,
return_outputs=True,
)
optimizer.step()
optimizer.zero_grad()
```
## 进阶使用技巧
在 ColossalAI 中通过使用MetaCache和混合并行的ZeroBubble可以获得更好的训练性能。
### 1.在ZeroBubble中使用元数据缓存
在初始化Plugin时输入 "enable_metadata_cache=True"以便在ZeroBubble管道中使用元数据缓存。
```python
plugin = HybridParallelPlugin(
pp_size=2,
num_microbatches=4,
tp_size=2,
sp_size=2,
zero_stage=1,
initial_scale=1,
enable_metadata_cache=True,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
```
### 2.同时使用ZeroBubble和混合并行
在初始化插件时传递 pp_size, tp_size, sp_size, 以便使用零气泡混合并行管道HybridParallel with ZeroBubble Pipeline
```python
plugin = HybridParallelPlugin(
pp_size=2,
num_microbatches=2,
tp_size=2,
sp_size=2,
zero_stage=1,
initial_scale=1,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
```
性能指标
<table>
<tr>
<th nowrap="nowrap">HybridParallel Strategy</th>
<th nowrap="nowrap" align="center">Pipeline Parallel</th>
<th nowrap="nowrap" align="center">Sequence Parallel + Pipeline Parallel</th>
<th nowrap="nowrap" align="center">Data Parallel + Pipeline Parallel</th>
</tr>
<tr>
<td nowrap="nowrap" align="center" title="1F1B">With 1F1B</td>
<td nowrap="nowrap" align="center">15.27 samples/sec</td>
<td nowrap="nowrap" align="center">17.22 samples/sec</td>
<td nowrap="nowrap" align="center">14.06 samples/sec</td>
</tr>
<tr>
<td nowrap="nowrap" align="center" title="Zero Bubble">With Zero Bubble</td>
<td nowrap="nowrap" align="center">17.36 samples/sec</td>
<td nowrap="nowrap" align="center">18.38 samples/sec</td>
<td nowrap="nowrap" align="center">14.44 samples/sec</td>
</tr>
<tr>
<td colspan="39"></td>
</tr>
</table>
## 模型兼容性
<table>
<tr>
<th nowrap="nowrap">Shardformer/Model</th>
<th nowrap="nowrap" align="center">Bert</th>
<th nowrap="nowrap" align="center">Blip2</th>
<th nowrap="nowrap" align="center">Bloom</th>
<th nowrap="nowrap" align="center">Chatglm2</th>
<th nowrap="nowrap" align="center">Command</th>
<th nowrap="nowrap" align="center">Deepseek</th>
<th nowrap="nowrap" align="center">Falcon</th>
<th nowrap="nowrap" align="center">GPT2</th>
<th nowrap="nowrap" align="center">Gptj</th>
<th nowrap="nowrap" align="center">Llama</th>
<th nowrap="nowrap" align="center">Mistral</th>
<th nowrap="nowrap" align="center">Opt</th>
<th nowrap="nowrap" align="center">Qwen2</th>
<th nowrap="nowrap" align="center">Sam</th>
<th nowrap="nowrap" align="center">T5</th>
<th nowrap="nowrap" align="center">Vit</th>
<th nowrap="nowrap" align="center">Whisper</th>
</tr>
<tr>
<td nowrap="nowrap" align="center" title="ZeroBubble">ZeroBubble</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td colspan="39"></td>
</tr>
</table>
## API 参考
{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }}
<!-- doc-test-command: torchrun --standalone --nproc_per_node=4 zerobubble_pipeline_parallelism.py -->

View File

@ -8,7 +8,8 @@ from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from colossalai.pipeline.weight_grad_store import WeightGradStore
from colossalai.shardformer.layer import GPT2FusedLinearConv, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -118,11 +119,82 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
assert_close(target_grad, linear_row.weight.grad)
def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: str):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_base = GPT2FusedLinearConv.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode)
assert linear.weight.shape == torch.Size([48, 192])
assert linear_base.weight.shape == torch.Size([48, 192])
assert linear_base.bias.shape == torch.Size([192])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias
# ensure weights are reversibly loadable
linear_base.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_base.state_dict())
# check computation correctness
x = torch.rand(1, 4, 48).cuda()
out = linear(x)
gather_out = linear_base(x)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
# check the input gradients & weight gradients
assert_close(out.grad, gather_out.grad)
assert_close(linear.weight.grad, linear_base.weight.grad)
def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: str):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_base = GPT2FusedLinearConv.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode, use_zbv=True)
assert linear.weight.shape == torch.Size([48, 192])
assert linear_base.weight.shape == torch.Size([48, 192])
assert linear_base.bias.shape == torch.Size([192])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias
# ensure weights are reversibly loadable
linear_base.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_base.state_dict())
# check computation correctness
x = torch.rand(1, 4, 48).cuda()
out = linear(x)
gather_out = linear_base(x)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue
WeightGradStore.pop(chunk=0)
# check the input gradients & weight gradients
assert_close(out.grad, gather_out.grad)
assert_close(linear.weight.grad, linear_base.weight.grad)
@parameterize("lazy_init", [False, True])
@parameterize("seq_parallel_mode", ["split_gather", None])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel_mode)
check_linear_conv_1d_row(lazy_init, seq_parallel_mode)
check_linear_conv_1d_without_weight_grad_store(lazy_init, None)
check_linear_conv_1d_with_weight_grad_store(lazy_init, None)
def run_dist(rank, world_size, port):

View File

@ -7,7 +7,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row
from colossalai.shardformer.layer import FusedLinear, FusedLinear1D_Col, FusedLinear1D_Row
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -120,12 +120,45 @@ def check_linear_1d_col_row(lazy_init: bool):
assert_close(target_grad2, linear_row.weight.grad)
@parameterize("lazy_init", [False, True])
def check_linear_1d_base(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(8, 80).cuda()
with ctx:
linear_copy = nn.Linear(8, 80).cuda()
linear_base = FusedLinear.from_native_module(linear_copy)
assert linear.weight.shape == torch.Size([80, 8])
assert linear.bias.shape == torch.Size([80])
assert linear_base.weight.shape == torch.Size([80, 8])
assert linear_base.bias.shape == torch.Size([80])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias
# ensure weights are reversibly loadable
linear_base.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_base.state_dict())
# check computation correctness
x = torch.rand(4, 8).cuda()
out = linear(x)
base_out = linear_base(x)
assert_close(out, base_out)
# check backward correctness
out.sum().backward()
base_out.sum().backward()
assert_close(linear.weight.grad, linear_base.weight.grad)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
check_linear_1d_col()
check_linear_1d_row()
check_linear_1d_col_row()
check_linear_1d_base()
@rerun_if_address_is_in_use()