mirror of https://github.com/hpcaitech/ColossalAI
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv * [feat] support chatglm2, command, deepseek for zbv * [feat] support zbv in shardformer policy: falcon,gptj,mistral,opt,qwen2,t5, vit, whisper * [feat] support GPT2FusedLinearConv1D * [feat] support GPT2FusedLinear (without tp) * [fix] debug FusedConvLinear * [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv Col and Row. * [Shardformer] support FusedLinear1D base for zbv * [shardformer] support zbv in FusedLinear1D base, Col, Row * [shardformer] support zbv in blip2 and sam policy * [shardformer] fix bug incorrect number of gradients; add fusedLinear base testcase; * [fix] fix incorrect number of gradients ; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [Shardformer] add en doc for zbv; * [fix] fix typo in Model compatibility table * [fix] fix API Reference typo * [Shardformer] add zh-Han doc for zbv * [fix] fix Linear name; update en & zh doc * [fix] fix shardformer doc import err * [fix] fix shardconfig import in doc * [fix] fix shardformer doc * [fix] fix shardconfig doc * [fix] fix config * [fix] remove shardconfig * [fix] fix doc * [feat] add zbv doc string * [fix] rm doc * [fix] fix doc * [fix] empty zbv doc * [fix] ifx torch version * [fix] fix torch version * [fix] fix torch versions * [fix] fix torch versions * [fix] fix pyramid versions * [fix] fix pyramid, zope version * [fix] try fix workflow * [fix] try import ShardConfig in yml * [fix] fix workflow * [fix] fix workflow * [fix] fix workflow * [fix] fix workflow * [fix] fix ci * [fix] fix zbv doc * [fix] fix param for qkv linear, gpt2fused linear; fix requirments; * [fix] fix policy use fused_linear * [fix] fix weight grad none, err caused by weight ptr change * [fix] fix comm in WeightGradStore * [fix] fix WeightGradStore pop param * [fix] remove useless param in doc; fix gpt2 qkv test; * [shardformer] simplify execute_w_pass_grad_accum; * [fix] rm useless comments * [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass * [shardformer] Run meaningful doc test * [shadformer] fix doc test cmd; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6178/head
parent
af06d162cf
commit
a9bedc7a43
|
@ -58,6 +58,7 @@ jobs:
|
|||
# there is no main branch, so it's safe to checkout the main branch from the merged branch
|
||||
# docer will rebase the remote main branch to the merged branch, so we have to config user
|
||||
- name: Make the merged branch main
|
||||
|
||||
run: |
|
||||
cd ColossalAI
|
||||
git checkout -b main
|
||||
|
|
|
@ -38,6 +38,19 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
|
|||
|
||||
|
||||
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
r"""
|
||||
ZeroBubbleVPipeScheduler
|
||||
|
||||
Args:
|
||||
stage_manager (PipelineStageManager): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.
|
||||
schedule (List[ScheduledNode]): Schedule for ZeroBubbleVPipe.
|
||||
num_model_chunks (int) : The number of model chunk in a device.
|
||||
num_microbatch (Optional[int]): The number of microbatch.
|
||||
microbatch_size (Optional[int]): The size per microbatch.
|
||||
enable_metadata_cache (bool): whether to enable metadata cache to acclerate communication.
|
||||
overlap_p2p (bool): whether to use overlap_p2p.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stage_manager: PipelineStageManager,
|
||||
|
|
|
@ -8,7 +8,6 @@ class WeightGradStore:
|
|||
|
||||
@classmethod
|
||||
def put(cls, total_input, grad_output, weight, func):
|
||||
# func(total_input, grad_output, weight.main_grad)
|
||||
cls.cache.append((total_input, grad_output, weight, func))
|
||||
|
||||
@classmethod
|
||||
|
@ -18,15 +17,26 @@ class WeightGradStore:
|
|||
|
||||
@classmethod
|
||||
def pop(cls, chunk=0):
|
||||
# print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}")
|
||||
if cls.weight_grad_queue[chunk].qsize() > 0:
|
||||
stored_grads = cls.weight_grad_queue[chunk].get()
|
||||
for total_input, grad_output, weight, func in stored_grads:
|
||||
if weight.grad is not None:
|
||||
func(total_input, grad_output, weight.grad)
|
||||
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
|
||||
if isinstance(weight, tuple):
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
|
||||
# View will lead to weight ptr change
|
||||
# weight_cal & weight_origin in tuple, weight_cal use to cal dw, weight_origin use to update
|
||||
_, weight_origin = weight
|
||||
if weight_origin.grad is not None:
|
||||
func(total_input, grad_output, weight_origin.grad)
|
||||
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
|
||||
else:
|
||||
grad_weight = func(total_input, grad_output)
|
||||
weight_origin.grad = grad_weight
|
||||
else:
|
||||
grad_weight = func(total_input, grad_output)
|
||||
weight.grad = grad_weight
|
||||
if weight.grad is not None:
|
||||
func(total_input, grad_output, weight.grad)
|
||||
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
|
||||
else:
|
||||
grad_weight = func(total_input, grad_output)
|
||||
weight.grad = grad_weight
|
||||
else:
|
||||
raise Exception("Pop empty queue.")
|
||||
|
|
|
@ -6,7 +6,14 @@ from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHe
|
|||
from .loss import cross_entropy_1d, dist_cross_entropy
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
from .qkv_fused_linear import (
|
||||
FusedLinear,
|
||||
FusedLinear1D_Col,
|
||||
FusedLinear1D_Row,
|
||||
GPT2FusedLinearConv,
|
||||
GPT2FusedLinearConv1D_Col,
|
||||
GPT2FusedLinearConv1D_Row,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Embedding1D",
|
||||
|
@ -14,8 +21,9 @@ __all__ = [
|
|||
"LinearWithGradAccum",
|
||||
"Linear1D_Col",
|
||||
"Linear1D_Row",
|
||||
"GPT2FusedLinearConv1D_Col",
|
||||
"GPT2FusedLinearConv",
|
||||
"GPT2FusedLinearConv1D_Row",
|
||||
"GPT2FusedLinearConv1D_Col",
|
||||
"DropoutForParallelInput",
|
||||
"DropoutForReplicatedInput",
|
||||
"cross_entropy_1d",
|
||||
|
@ -26,6 +34,7 @@ __all__ = [
|
|||
"FusedLayerNorm",
|
||||
"FusedRMSNorm",
|
||||
"FusedLinear1D_Col",
|
||||
"FusedLinear",
|
||||
"ParallelModule",
|
||||
"PaddingEmbedding",
|
||||
"PaddingLMHead",
|
||||
|
|
|
@ -6,7 +6,13 @@ import torch.nn.functional as F
|
|||
|
||||
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
||||
|
||||
from .utils import is_share_sp_tp
|
||||
from .utils import (
|
||||
execute_conv1d_w_pass,
|
||||
execute_conv1d_w_pass_grad_accum,
|
||||
execute_w_pass,
|
||||
execute_w_pass_grad_accum,
|
||||
is_share_sp_tp,
|
||||
)
|
||||
|
||||
try:
|
||||
import fused_mix_prec_layer_norm_cuda
|
||||
|
@ -73,12 +79,13 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
ctx.fp8_communication = fp8_communication
|
||||
ctx.use_zbv = use_zbv
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
|
||||
|
@ -92,8 +99,10 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||
input, weight, bias = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
fp8_communication = ctx.fp8_communication
|
||||
use_zbv = ctx.use_zbv
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
|
||||
weight_origin = weight
|
||||
weight = weight.view(weight.shape)
|
||||
if bias is not None:
|
||||
bias = bias.view(bias.shape)
|
||||
|
@ -114,7 +123,42 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||
|
||||
grad_weight = total_input.t().matmul(grad_output)
|
||||
# split dx & dw
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if use_zbv:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
(weight, weight_origin),
|
||||
functools.partial(
|
||||
execute_conv1d_w_pass_grad_accum,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = total_input.t().matmul(grad_output)
|
||||
else:
|
||||
if use_zbv:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
(weight, weight_origin),
|
||||
functools.partial(
|
||||
execute_conv1d_w_pass,
|
||||
wgrad_gemm_func=torch.matmul,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = total_input.t().matmul(grad_output)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce and not fp8_communication:
|
||||
|
@ -123,6 +167,87 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
|||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
class MatmulWithGradAccum(torch.autograd.Function):
|
||||
"""
|
||||
Linear layer execution with grad accum in backprop. (no tp version)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
ctx.use_zbv = use_zbv
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight, bias = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
use_zbv = ctx.use_zbv
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
|
||||
weight_origin = weight
|
||||
weight = weight.view(weight.shape)
|
||||
if bias is not None:
|
||||
bias = bias.view(bias.shape)
|
||||
|
||||
total_input = input
|
||||
grad_input = grad_output.matmul(weight.T)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
# split dx & dw
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
|
||||
if use_zbv:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
(weight, weight_origin),
|
||||
functools.partial(
|
||||
execute_conv1d_w_pass_grad_accum,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = total_input.t().matmul(grad_output)
|
||||
else:
|
||||
if use_zbv:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
(weight, weight_origin),
|
||||
functools.partial(
|
||||
execute_conv1d_w_pass,
|
||||
wgrad_gemm_func=torch.matmul,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = total_input.t().matmul(grad_output)
|
||||
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
"""
|
||||
Linear layer execution with asynchronous communication in backprop.
|
||||
|
@ -150,12 +275,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
fp8_communication = ctx.fp8_communication
|
||||
use_zbv = ctx.use_zbv
|
||||
|
||||
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
|
||||
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
|
||||
|
||||
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
|
||||
return wgrad_gemm_func(_grad_output_.t(), _input_)
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
||||
if use_bias:
|
||||
bias.view(bias.shape)
|
||||
|
@ -179,31 +298,15 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if use_zbv:
|
||||
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
|
||||
if grad.dtype == torch.float32:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
elif grad.dtype in (torch.float16, torch.bfloat16):
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
|
@ -259,12 +362,6 @@ class LinearWithGradAccum(torch.autograd.Function):
|
|||
use_bias = ctx.use_bias
|
||||
use_zbv = ctx.use_zbv
|
||||
|
||||
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
|
||||
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
|
||||
|
||||
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
|
||||
return wgrad_gemm_func(_grad_output_.t(), _input_)
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
||||
if use_bias:
|
||||
bias.view(bias.shape)
|
||||
|
@ -280,31 +377,15 @@ class LinearWithGradAccum(torch.autograd.Function):
|
|||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if use_zbv:
|
||||
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
|
||||
if grad.dtype == torch.float32:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
elif grad.dtype in (torch.float16, torch.bfloat16):
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
|
@ -454,12 +535,13 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, use_zbv=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
||||
ctx.dim = dim
|
||||
ctx.use_zbv = use_zbv
|
||||
|
||||
if ring is True:
|
||||
input_to_gather = {"input": input_}
|
||||
|
@ -491,6 +573,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
use_bias = ctx.use_bias
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
use_zbv = ctx.use_zbv
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
||||
if use_bias:
|
||||
|
@ -518,23 +601,46 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
if use_zbv:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
else:
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
else:
|
||||
if use_zbv:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass,
|
||||
wgrad_gemm_func=torch.matmul,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
handle.wait()
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None, None
|
||||
return output, grad_weight, grad_bias, None, None, None, None, None
|
||||
|
||||
|
||||
def _ring_as_reducescatter(
|
||||
|
@ -606,11 +712,12 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, dim, ring):
|
||||
def forward(ctx, input_, weight, bias, process_group, dim, ring, use_zbv=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.dim = dim
|
||||
ctx.use_zbv = use_zbv
|
||||
|
||||
if ring is True:
|
||||
input_to_reducescatter = {"input": input_}
|
||||
|
@ -651,7 +758,7 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|||
use_bias = ctx.use_bias
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
|
||||
use_zbv = ctx.use_zbv
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
||||
if use_bias:
|
||||
bias = bias.view(bias.shape)
|
||||
|
@ -666,10 +773,47 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
|
|||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.reshape(-1, total_input.shape[-1])
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if use_zbv:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
else:
|
||||
if use_zbv:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass,
|
||||
wgrad_gemm_func=torch.matmul,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
|
||||
# grad_weight = grad_output.t().matmul(total_input)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||
|
@ -723,13 +867,16 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
|
||||
def forward(
|
||||
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv=False
|
||||
):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
||||
ctx.dim = dim
|
||||
ctx.fp8_communication = fp8_communication
|
||||
ctx.use_zbv = use_zbv
|
||||
|
||||
if ring is True:
|
||||
input_to_gather = {"input": input_}
|
||||
|
@ -759,8 +906,10 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
use_bias = ctx.use_bias
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
use_zbv = ctx.use_zbv
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
||||
weight_origin = weight
|
||||
weight = weight.view(weight.shape)
|
||||
if use_bias:
|
||||
bias = bias.view(bias.shape)
|
||||
|
@ -785,13 +934,49 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
|||
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
# all-reduce scheduled first and have GPU resources allocated
|
||||
|
||||
grad_weight = total_input.t().matmul(grad_output)
|
||||
# split dx & dw
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if use_zbv:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
(weight, weight_origin),
|
||||
functools.partial(
|
||||
execute_conv1d_w_pass_grad_accum,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = total_input.t().matmul(grad_output)
|
||||
else:
|
||||
if use_zbv:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
(weight, weight_origin),
|
||||
functools.partial(
|
||||
execute_conv1d_w_pass,
|
||||
wgrad_gemm_func=torch.matmul,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = total_input.t().matmul(grad_output)
|
||||
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
handle.wait()
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None, None, None
|
||||
return output, grad_weight, grad_bias, None, None, None, None, None, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
|
@ -1108,12 +1293,18 @@ def _all_to_all_single(
|
|||
).contiguous()
|
||||
|
||||
|
||||
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
def matmul_with_async_comm(
|
||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
|
||||
):
|
||||
return MatmulWithAsyncCommunication.apply(
|
||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
|
||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv
|
||||
)
|
||||
|
||||
|
||||
def matmul_with_grad_comm(input_, weight, bias, async_grad_allreduce, use_zbv=False):
|
||||
return MatmulWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)
|
||||
|
||||
|
||||
def linear_with_async_comm(
|
||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
|
||||
):
|
||||
|
@ -1127,10 +1318,10 @@ def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=F
|
|||
|
||||
|
||||
def linear_gather_forward_reducescatter_backward(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, use_zbv=False
|
||||
):
|
||||
return _LinearWithGatherForwardReduceScatterBackward.apply(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, use_zbv
|
||||
)
|
||||
|
||||
|
||||
|
@ -1142,15 +1333,25 @@ def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_commun
|
|||
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim, fp8_communication)
|
||||
|
||||
|
||||
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
|
||||
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring)
|
||||
def linear_reducescatter_forward_gather_backward(
|
||||
input_, weight, bias=None, process_group=None, dim=1, ring=False, use_zbv=False
|
||||
):
|
||||
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring, use_zbv)
|
||||
|
||||
|
||||
def matmul_gather_forward_reducescatter_backward(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False
|
||||
input_,
|
||||
weight,
|
||||
bias,
|
||||
process_group,
|
||||
async_grad_reduce_scatter,
|
||||
dim,
|
||||
ring=False,
|
||||
fp8_communication=False,
|
||||
use_zbv=False,
|
||||
):
|
||||
return _MatmulWithGatherForwardReduceScatterBackward.apply(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication, use_zbv
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -350,6 +350,7 @@ class Linear1D_Col(ParallelModule):
|
|||
True,
|
||||
self.seq_parallel_dim,
|
||||
ring=self.seq_parallel_mode == "ring",
|
||||
use_zbv=self.use_zbv,
|
||||
)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(
|
||||
|
@ -580,6 +581,7 @@ class Linear1D_Row(ParallelModule):
|
|||
process_group=self.process_group,
|
||||
dim=self.seq_parallel_dim,
|
||||
ring=self.seq_parallel_mode == "ring",
|
||||
use_zbv=self.use_zbv,
|
||||
)
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
|
|
|
@ -7,7 +7,6 @@ from typing import Callable, List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
@ -28,8 +27,10 @@ from ._operation import (
|
|||
linear_gather_forward_reducescatter_backward,
|
||||
linear_reducescatter_forward_gather_backward,
|
||||
linear_with_async_comm,
|
||||
linear_with_grad_accum,
|
||||
matmul_gather_forward_reducescatter_backward,
|
||||
matmul_with_async_comm,
|
||||
matmul_with_grad_comm,
|
||||
reduce_forward,
|
||||
reducescatter_forward_gather_backward,
|
||||
split_forward_gather_backward,
|
||||
|
@ -37,7 +38,14 @@ from ._operation import (
|
|||
from .parallel_module import ParallelModule
|
||||
from .utils import create_randomizer_with_offset, is_share_sp_tp
|
||||
|
||||
__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"]
|
||||
__all__ = [
|
||||
"FusedLinear1D_Col",
|
||||
"FusedLinear1D_Row",
|
||||
"FusedLinear",
|
||||
"GPT2FusedLinearConv1D_Col",
|
||||
"GPT2FusedLinearConv1D_Row",
|
||||
"GPT2FusedLinearConv",
|
||||
]
|
||||
|
||||
# ====================================
|
||||
# For GPT Only
|
||||
|
@ -228,6 +236,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
use_zbv: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -241,6 +250,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
self.split_sizes = split_sizes
|
||||
self.process_group = process_group
|
||||
self.fp8_communication = fp8_communication
|
||||
self.use_zbv = use_zbv
|
||||
|
||||
assert (
|
||||
sum(split_sizes) == out_features
|
||||
|
@ -375,6 +385,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
1,
|
||||
ring=self.seq_parallel_mode == "ring",
|
||||
fp8_communication=self.fp8_communication,
|
||||
use_zbv=self.use_zbv,
|
||||
)
|
||||
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
|
||||
# Set up backprop all-reduce.
|
||||
|
@ -386,6 +397,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
self.process_group,
|
||||
True,
|
||||
fp8_communication=self.fp8_communication,
|
||||
use_zbv=self.use_zbv,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
|
||||
|
@ -441,6 +453,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1,
|
||||
fp8_communication: bool = False,
|
||||
use_zbv: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -455,6 +468,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
self.fp8_communication = fp8_communication
|
||||
self.use_zbv = use_zbv
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
@ -620,6 +634,152 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
return output, self.bias
|
||||
|
||||
|
||||
class GPT2FusedLinearConv(ParallelModule):
|
||||
r"""Linear layer without parallelism.
|
||||
This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (:class:`typing.Callable`, optional):
|
||||
The initializer of weight, defaults to kaiming uniform initializer.
|
||||
bias_initializer (:class:`typing.Callable`, optional):
|
||||
The initializer of bias, defaults to xavier uniform initializer.
|
||||
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
seq_parallel_mode: str = None,
|
||||
seq_parallel_dim: int = 1,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
use_zbv: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.use_zbv = use_zbv
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, None)
|
||||
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
|
||||
else:
|
||||
assert bias_ is None, "bias_ must be None if weight is None"
|
||||
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
# Initialize weight.
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
|
||||
if bias:
|
||||
if bias_ is None:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||
self.bias = bias_
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
if weight is None:
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> ParallelModule:
|
||||
r"""
|
||||
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
|
||||
|
||||
Args:
|
||||
module (`nn.Linear`): The module to be converted.
|
||||
split_sizes (List[int]): The sizes of the split tensor. In GPT2, Q,K,V are fused in one weight.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.weight.shape[0]
|
||||
out_features = module.weight.shape[1]
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
|
||||
linear_1d = GPT2FusedLinearConv(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = input_
|
||||
output_parallel = matmul_with_grad_comm(
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
False,
|
||||
self.use_zbv,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
|
||||
|
||||
output = output_parallel
|
||||
|
||||
if self.skip_bias_add:
|
||||
return output, self.bias
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
# ====================================
|
||||
# For Fused torch.nn.Linear
|
||||
# ====================================
|
||||
|
@ -671,6 +831,7 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
use_zbv: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
|
@ -684,6 +845,7 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
self.split_sizes = split_sizes
|
||||
self.process_group = process_group
|
||||
self.fp8_communication = fp8_communication
|
||||
self.use_zbv = use_zbv
|
||||
|
||||
assert (
|
||||
sum(split_sizes) == out_features
|
||||
|
@ -811,10 +973,17 @@ class FusedLinear1D_Col(ParallelModule):
|
|||
True,
|
||||
self.seq_parallel_dim,
|
||||
ring=self.seq_parallel_mode == "ring",
|
||||
use_zbv=self.use_zbv,
|
||||
)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
True,
|
||||
fp8_communication=self.fp8_communication,
|
||||
use_zbv=self.use_zbv,
|
||||
)
|
||||
|
||||
if self.gather_output:
|
||||
|
@ -870,6 +1039,7 @@ class FusedLinear1D_Row(ParallelModule):
|
|||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
use_zbv: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
|
@ -883,6 +1053,7 @@ class FusedLinear1D_Row(ParallelModule):
|
|||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
self.fp8_communication = fp8_communication
|
||||
self.use_zbv = use_zbv
|
||||
|
||||
assert (
|
||||
sum(split_sizes) == in_features
|
||||
|
@ -1009,9 +1180,18 @@ class FusedLinear1D_Row(ParallelModule):
|
|||
process_group=self.process_group,
|
||||
dim=self.seq_parallel_dim,
|
||||
ring=self.seq_parallel_mode == "ring",
|
||||
use_zbv=self.use_zbv,
|
||||
)
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
# output_parallel = F.linear(input_, self.weight) # Replace to LinearWithGradAccum
|
||||
output_parallel = linear_with_grad_accum(
|
||||
input_,
|
||||
self.weight,
|
||||
None,
|
||||
False,
|
||||
use_zbv=self.use_zbv,
|
||||
)
|
||||
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
|
@ -1020,3 +1200,156 @@ class FusedLinear1D_Row(ParallelModule):
|
|||
return output
|
||||
else:
|
||||
return output, self.bias
|
||||
|
||||
|
||||
class FusedLinear(ParallelModule):
|
||||
r"""Fused Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
||||
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `torch.nn.Linear` layer (Fused QKV) in normal torch layer of huggingface, like SAM.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample.
|
||||
split_sizes (List[int]): The sizes of the split tensor.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
device (`torch.device`): The device of parameters, defaults to None.
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (`typing.Callable`):
|
||||
The initializer of weight, defaults to kaiming uniform initializer.
|
||||
bias_initializer (`typing.Callable`):
|
||||
The initializer of bias, defaults to xavier uniform initializer.
|
||||
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
seq_parallel_mode: str = None,
|
||||
seq_parallel_dim: int = 1,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
use_zbv: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.use_zbv = use_zbv
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=None)
|
||||
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
|
||||
else:
|
||||
assert bias_ is None, "bias_ must be None if weight is None"
|
||||
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
# Initialize weight.
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
|
||||
if bias:
|
||||
if bias_ is None:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||
self.bias = bias_
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
if weight is None:
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Module,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> ParallelModule:
|
||||
r"""
|
||||
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
|
||||
|
||||
Args:
|
||||
module (`nn.Linear`): The module to be converted.
|
||||
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
|
||||
split_sizes (List[int]): The sizes of the split tensor. In common, Q,K,V are fused in one weight.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
|
||||
linear_1d = FusedLinear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert (
|
||||
input_.shape[-1] == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = input_
|
||||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
output_parallel = linear_with_grad_accum(input_parallel, self.weight, bias, True, use_zbv=self.use_zbv)
|
||||
|
||||
output = output_parallel
|
||||
|
||||
if self.skip_bias_add:
|
||||
return output, self.bias
|
||||
else:
|
||||
return output
|
||||
|
|
|
@ -9,6 +9,43 @@ from torch.distributed import ProcessGroup, get_world_size
|
|||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
try:
|
||||
import fused_weight_gradient_mlp_cuda
|
||||
|
||||
_grad_accum_fusion_available = True
|
||||
except ImportError:
|
||||
_grad_accum_fusion_available = False
|
||||
|
||||
|
||||
# execute_w_pass_grad_accum & execute_conv1d_w_pass for GPT2FusedLinearConv1D
|
||||
def execute_conv1d_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_):
|
||||
if _input_.dtype == torch.float32:
|
||||
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32
|
||||
elif _input_.dtype in (torch.float16, torch.bfloat16):
|
||||
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16
|
||||
else:
|
||||
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
|
||||
wgrad_gemm_accum_func(_grad_output_, _input_, _weight_main_grad_)
|
||||
|
||||
|
||||
def execute_conv1d_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
|
||||
return wgrad_gemm_func(_input_.t(), _grad_output_)
|
||||
|
||||
|
||||
# execute_w_pass_grad_accum & execute_w_pass for Linear (except GPT2FusedLinearConv1D)
|
||||
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_):
|
||||
if _input_.dtype == torch.float32:
|
||||
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32
|
||||
elif _input_.dtype in (torch.float16, torch.bfloat16):
|
||||
wgrad_gemm_accum_func = fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16
|
||||
else:
|
||||
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
|
||||
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
|
||||
|
||||
|
||||
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
|
||||
return wgrad_gemm_func(_grad_output_.t(), _input_)
|
||||
|
||||
|
||||
class SeqParallelUtils:
|
||||
@staticmethod
|
||||
|
|
|
@ -51,6 +51,8 @@ class BlipPolicy(Policy):
|
|||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
|
@ -73,6 +75,7 @@ class BlipPolicy(Policy):
|
|||
kwargs={
|
||||
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -80,6 +83,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -88,6 +92,7 @@ class BlipPolicy(Policy):
|
|||
kwargs={
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -95,6 +100,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
|
@ -126,6 +132,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -133,6 +140,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -140,6 +148,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -151,6 +160,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -162,6 +172,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -169,6 +180,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -176,6 +188,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -187,6 +200,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -198,6 +212,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -205,6 +220,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -227,6 +243,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -234,6 +251,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -241,6 +259,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -248,6 +267,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -255,6 +275,7 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -262,6 +283,226 @@ class BlipPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
|
||||
if self.enable_bias_gelu_fused:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_blip2_mlp_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=Blip2MLP,
|
||||
)
|
||||
elif use_zbv:
|
||||
policy[Blip2EncoderLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.qkv",
|
||||
target_module=col_nn.FusedLinear,
|
||||
kwargs={
|
||||
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.projection",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc1",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc2",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
policy[Blip2QFormerModel] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
policy[Blip2QFormerLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.query",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.key",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.value",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.query",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.key",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.value",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.attention.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate_query.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
policy[OPTDecoderLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
|
|
|
@ -59,6 +59,8 @@ class BloomPolicy(Policy):
|
|||
|
||||
sp_partial_derived = sp_mode == "split_gather"
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.n_head % self.shard_config.tensor_parallel_size == 0
|
||||
|
@ -78,6 +80,7 @@ class BloomPolicy(Policy):
|
|||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -86,6 +89,7 @@ class BloomPolicy(Policy):
|
|||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -98,6 +102,7 @@ class BloomPolicy(Policy):
|
|||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -106,6 +111,7 @@ class BloomPolicy(Policy):
|
|||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
|
@ -120,6 +126,52 @@ class BloomPolicy(Policy):
|
|||
},
|
||||
)
|
||||
|
||||
if use_zbv:
|
||||
policy[BloomBlock] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.attention_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_h_to_4h",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_4h_to_h",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
|
@ -247,14 +299,27 @@ class BloomPolicy(Policy):
|
|||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.word_embeddings)
|
||||
held_layers.append(module.word_embeddings_layernorm)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
if stage_manager.is_interleave:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.word_embeddings)
|
||||
held_layers.append(module.word_embeddings_layernorm)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(module.ln_f)
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.word_embeddings)
|
||||
held_layers.append(module.word_embeddings_layernorm)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
@ -328,8 +393,14 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -351,6 +422,7 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
|
|||
from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
# handle tensor parallelism
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
|
@ -363,6 +435,18 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
|
|||
policy=policy,
|
||||
target_key=BloomForSequenceClassification,
|
||||
)
|
||||
elif use_zbv:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="score",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
gather_output=True, fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=BloomForSequenceClassification,
|
||||
)
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(
|
||||
model_cls=BloomForSequenceClassification,
|
||||
|
@ -375,8 +459,14 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.score)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.score)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -389,6 +479,7 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
|
|||
from transformers.models.bloom.modeling_bloom import BloomForTokenClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
# handle tensor parallelism
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
|
@ -407,6 +498,24 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
|
|||
policy=policy,
|
||||
target_key=BloomForTokenClassification,
|
||||
)
|
||||
elif use_zbv:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="classifier",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
gather_output=True, fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BloomForTokenClassification,
|
||||
)
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(
|
||||
model_cls=BloomForTokenClassification,
|
||||
|
@ -420,9 +529,16 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -448,8 +564,14 @@ class BloomForQuestionAnsweringPolicy(BloomPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -83,6 +83,8 @@ class ChatGLMPolicy(Policy):
|
|||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
|
@ -145,6 +147,35 @@ class ChatGLMPolicy(Policy):
|
|||
),
|
||||
],
|
||||
)
|
||||
elif use_zbv:
|
||||
policy["GLMBlock"] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"seq_parallel_dim": 0,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"seq_parallel_dim": 0,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.core_attention.attention_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
|
@ -261,17 +292,30 @@ class ChatGLMPolicy(Policy):
|
|||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embedding)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.encoder.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
if module.encoder.post_layer_norm:
|
||||
held_layers.append(module.encoder.final_layernorm)
|
||||
if stage_manager.is_interleave:
|
||||
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.embed_tokens)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
if module.encoder.post_layer_norm:
|
||||
held_layers.append(module.encoder.final_layernorm)
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embedding)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.encoder.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
if module.encoder.post_layer_norm:
|
||||
held_layers.append(module.encoder.final_layernorm)
|
||||
|
||||
# rotary_pos_emb is needed for all stages
|
||||
held_layers.append(module.rotary_pos_emb)
|
||||
# rotary_pos_emb is needed for all stages
|
||||
held_layers.append(module.rotary_pos_emb)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
@ -335,8 +379,15 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
|
|||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.transformer.output_layer)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.transformer.output_layer)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.transformer.output_layer)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -10,6 +10,7 @@ from colossalai.shardformer.layer import (
|
|||
LayerNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
LinearWithGradAccum,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
VocabParallelEmbedding1D,
|
||||
|
@ -107,6 +108,8 @@ class CommandPolicy(Policy):
|
|||
target_key=CohereModel,
|
||||
)
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
num_q_heads % tp_size == 0
|
||||
|
@ -128,41 +131,137 @@ class CommandPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
elif use_zbv:
|
||||
policy[CohereDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
|
@ -258,7 +357,9 @@ class CommandPolicy(Policy):
|
|||
held_layers.append(module.embed_tokens)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(module.norm)
|
||||
|
||||
else:
|
||||
|
@ -351,8 +452,14 @@ class CommandForCausalLMPolicy(CommandPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -6,7 +6,7 @@ from torch import Tensor
|
|||
from torch.nn import Module
|
||||
from transformers.utils import is_flash_attn_greater_or_equal_2_10
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, LinearWithGradAccum
|
||||
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer.linear import Linear1D_Row
|
||||
from colossalai.shardformer.modeling.deepseek import (
|
||||
|
@ -107,6 +107,8 @@ class DeepseekPolicy(Policy):
|
|||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# tensor parallelism for non-moe params
|
||||
assert (
|
||||
|
@ -133,22 +135,58 @@ class DeepseekPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate",
|
||||
target_module=DeepseekMoEGate_Col,
|
||||
kwargs={
|
||||
"gather_output": True,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"config": self.model.config,
|
||||
},
|
||||
ignore_if_not_exist=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
elif use_zbv:
|
||||
policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={"fp8_communication": self.shard_config.fp8_communication, "use_zbv": use_zbv},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate",
|
||||
|
@ -162,7 +200,6 @@ class DeepseekPolicy(Policy):
|
|||
),
|
||||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
|
@ -291,13 +328,26 @@ class DeepseekPolicy(Policy):
|
|||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.embed_tokens)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(module.norm)
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
@ -330,6 +380,7 @@ class DeepseekModelPolicy(DeepseekPolicy):
|
|||
class DeepseekForCausalLMPolicy(DeepseekPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
# TODO: assign pg mesh from plugin to all modules
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
|
@ -339,7 +390,29 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
elif use_zbv:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
"DeepseekForCausalLM": ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -360,8 +433,14 @@ class DeepseekForCausalLMPolicy(DeepseekPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -51,6 +51,8 @@ class FalconPolicy(Policy):
|
|||
if self.tie_weight:
|
||||
embedding_cls = col_nn.PaddingEmbedding
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
|
@ -73,10 +75,16 @@ class FalconPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs=dict(
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.attention_dropout",
|
||||
|
@ -85,8 +93,17 @@ class FalconPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_h_to_4h",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_4h_to_h",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs=dict(
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(suffix="mlp.dense_4h_to_h", target_module=col_nn.Linear1D_Row),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -98,6 +115,44 @@ class FalconPolicy(Policy):
|
|||
"build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group)
|
||||
},
|
||||
)
|
||||
elif use_zbv:
|
||||
policy[FalconDecoderLayer] = ModulePolicyDescription(
|
||||
method_replacement={"forward": get_tp_falcon_decoder_layer_forward()},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.attention_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_h_to_4h",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dense_4h_to_h",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
|
@ -191,13 +246,26 @@ class FalconPolicy(Policy):
|
|||
module = self.model.transformer
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = []
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.word_embeddings)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.word_embeddings)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(module.ln_f)
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.word_embeddings)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
@ -281,8 +349,14 @@ class FalconForCausalLMPolicy(FalconPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -308,11 +382,23 @@ class FalconForSequenceClassificationPolicy(FalconPolicy):
|
|||
|
||||
policy = super().module_policy()
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
# handle tensor parallelism
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True, use_zbv=use_zbv)
|
||||
),
|
||||
policy=policy,
|
||||
target_key=FalconForSequenceClassification,
|
||||
)
|
||||
elif use_zbv:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="score",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs=dict(gather_output=True, use_zbv=use_zbv),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=FalconForSequenceClassification,
|
||||
|
@ -330,8 +416,14 @@ class FalconForSequenceClassificationPolicy(FalconPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.score)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.score)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -348,12 +440,32 @@ class FalconForTokenClassificationPolicy(FalconPolicy):
|
|||
|
||||
policy = super().module_policy()
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
# handle tensor parallelism
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="classifier",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, use_zbv=use_zbv),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=FalconForTokenClassification,
|
||||
)
|
||||
elif use_zbv:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="classifier",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs=dict(gather_output=True, use_zbv=use_zbv),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
|
@ -375,9 +487,16 @@ class FalconForTokenClassificationPolicy(FalconPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -394,11 +513,25 @@ class FalconForQuestionAnsweringPolicy(FalconPolicy):
|
|||
|
||||
policy = super().module_policy()
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
# handle tensor parallelism
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="qa_outputs", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="qa_outputs",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, use_zbv=use_zbv),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=FalconForQuestionAnswering,
|
||||
)
|
||||
elif use_zbv:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="qa_outputs",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, use_zbv=use_zbv),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=FalconForQuestionAnswering,
|
||||
|
@ -415,8 +548,14 @@ class FalconForQuestionAnsweringPolicy(FalconPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -67,6 +67,8 @@ class GPT2Policy(Policy):
|
|||
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
use_flash_attention = self.shard_config.enable_flash_attention
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
|
@ -94,12 +96,17 @@ class GPT2Policy(Policy):
|
|||
"split_sizes": [self.model.config.hidden_size] * 3,
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_fc",
|
||||
|
@ -109,12 +116,17 @@ class GPT2Policy(Policy):
|
|||
"seq_parallel_mode": sp_mode,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
kwargs={"seq_parallel_mode": sp_mode, "fp8_communication": self.shard_config.fp8_communication},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.attn_dropout",
|
||||
|
@ -138,6 +150,78 @@ class GPT2Policy(Policy):
|
|||
policy=policy,
|
||||
target_key=GPT2MLP,
|
||||
)
|
||||
elif use_zbv:
|
||||
policy[GPT2Model] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="drop",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
policy[GPT2Block] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_attn",
|
||||
target_module=col_nn.GPT2FusedLinearConv,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_fc",
|
||||
target_module=col_nn.GPT2FusedLinearConv,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.attn_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.resid_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
],
|
||||
)
|
||||
if self.enable_bias_gelu_fused:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_gpt2_mlp_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=GPT2MLP,
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
|
||||
self.append_or_create_submodule_replacement(
|
||||
|
@ -352,8 +436,17 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
# if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True):
|
||||
# held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -420,13 +513,24 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
|||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
multiple_choice_head = self.model.multiple_choice_head
|
||||
held_layers.append(self.model.lm_head)
|
||||
held_layers.append(multiple_choice_head.summary)
|
||||
held_layers.append(multiple_choice_head.activation)
|
||||
held_layers.append(multiple_choice_head.first_dropout)
|
||||
held_layers.append(multiple_choice_head.last_dropout)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
held_layers.append(multiple_choice_head.summary)
|
||||
held_layers.append(multiple_choice_head.activation)
|
||||
held_layers.append(multiple_choice_head.first_dropout)
|
||||
held_layers.append(multiple_choice_head.last_dropout)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
multiple_choice_head = self.model.multiple_choice_head
|
||||
held_layers.append(self.model.lm_head)
|
||||
held_layers.append(multiple_choice_head.summary)
|
||||
held_layers.append(multiple_choice_head.activation)
|
||||
held_layers.append(multiple_choice_head.first_dropout)
|
||||
held_layers.append(multiple_choice_head.last_dropout)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
@ -464,8 +568,17 @@ class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
|
|||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
# if self.pipeline_stage_manager.is_last_stage():
|
||||
# held_layers.append(self.model.qa_outputs)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -503,9 +616,20 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
|
|||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
# if self.pipeline_stage_manager.is_last_stage():
|
||||
# held_layers.append(self.model.dropout)
|
||||
# held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -530,8 +654,18 @@ class GPT2ForSequenceClassificationPolicy(GPT2Policy):
|
|||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.score)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.score)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.score)
|
||||
|
||||
# if self.pipeline_stage_manager.is_last_stage():
|
||||
# held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -51,6 +51,8 @@ class GPTJPolicy(Policy):
|
|||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
|
@ -76,6 +78,7 @@ class GPTJPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -83,6 +86,7 @@ class GPTJPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -90,6 +94,7 @@ class GPTJPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -97,6 +102,7 @@ class GPTJPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -104,6 +110,7 @@ class GPTJPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -111,6 +118,72 @@ class GPTJPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.attn_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.resid_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
],
|
||||
)
|
||||
elif use_zbv:
|
||||
policy[GPTJBlock] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.k_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.q_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.v_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.out_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc_in",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.fc_out",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -127,7 +200,6 @@ class GPTJPolicy(Policy):
|
|||
),
|
||||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
|
@ -200,13 +272,25 @@ class GPTJPolicy(Policy):
|
|||
|
||||
held_layers = []
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.h))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.wte)
|
||||
held_layers.append(module.drop)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
if stage_manager.is_interleave:
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.wte)
|
||||
held_layers.append(module.drop)
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(module.ln_f)
|
||||
else:
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.wte)
|
||||
held_layers.append(module.drop)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.h[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.ln_f)
|
||||
return held_layers
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
|
@ -309,8 +393,15 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
|
|||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -349,8 +440,15 @@ class GPTJForSequenceClassificationPolicy(GPTJPolicy):
|
|||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.score)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.score)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -378,8 +476,15 @@ class GPTJForQuestionAnsweringPolicy(GPTJPolicy):
|
|||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -324,9 +324,10 @@ class MistralPolicy(Policy):
|
|||
held_layers.append(module.embed_tokens)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(module.norm)
|
||||
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
|
@ -419,8 +420,14 @@ class MistralForCausalLMPolicy(MistralPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -475,8 +482,14 @@ class MistralForSequenceClassificationPolicy(MistralPolicy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.score)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.score)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -10,6 +10,7 @@ from colossalai.shardformer.layer import (
|
|||
LayerNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
LinearWithGradAccum,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
VocabParallelEmbedding1D,
|
||||
|
@ -76,6 +77,8 @@ class OPTPolicy(Policy):
|
|||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
|
@ -85,10 +88,16 @@ class OPTPolicy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
@ -104,6 +113,7 @@ class OPTPolicy(Policy):
|
|||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -111,6 +121,7 @@ class OPTPolicy(Policy):
|
|||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -118,6 +129,7 @@ class OPTPolicy(Policy):
|
|||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -125,11 +137,67 @@ class OPTPolicy(Policy):
|
|||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
elif use_zbv:
|
||||
policy[OPTDecoderLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
policy[attn_cls] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="out_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
|
@ -221,15 +289,30 @@ class OPTPolicy(Policy):
|
|||
|
||||
held_layers = []
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
held_layers.append(module.embed_positions)
|
||||
held_layers.append(module.project_in)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.final_layer_norm)
|
||||
held_layers.append(module.project_out)
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.embed_tokens)
|
||||
held_layers.append(module.embed_positions)
|
||||
held_layers.append(module.project_in)
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(module.final_layer_norm)
|
||||
held_layers.append(module.project_out)
|
||||
else:
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
held_layers.append(module.embed_positions)
|
||||
held_layers.append(module.project_in)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.final_layer_norm)
|
||||
held_layers.append(module.project_out)
|
||||
return held_layers
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
|
@ -323,8 +406,15 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
|||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -395,8 +485,15 @@ class OPTForQuestionAnsweringPolicy(OPTPolicy):
|
|||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.qa_outputs)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -9,6 +9,7 @@ from colossalai.shardformer.layer import (
|
|||
FusedRMSNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
LinearWithGradAccum,
|
||||
PaddingEmbedding,
|
||||
RMSNorm,
|
||||
VocabParallelEmbedding1D,
|
||||
|
@ -96,6 +97,8 @@ class Qwen2Policy(Policy):
|
|||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
|
@ -119,37 +122,134 @@ class Qwen2Policy(Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
elif use_zbv:
|
||||
policy[Qwen2DecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -278,7 +378,9 @@ class Qwen2Policy(Policy):
|
|||
held_layers.append(module.embed_tokens)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(module.norm)
|
||||
|
||||
else:
|
||||
|
@ -318,6 +420,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
|||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
setattr(self.shard_config, "causal_lm", True)
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
|
@ -327,7 +430,22 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
|
||||
)
|
||||
],
|
||||
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
elif use_zbv:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
Qwen2ForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
|
||||
)
|
||||
],
|
||||
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
||||
|
@ -347,8 +465,14 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -371,6 +495,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
|||
class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
|
@ -379,7 +504,28 @@ class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="score",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
elif use_zbv:
|
||||
new_item = {
|
||||
Qwen2ForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@ -399,8 +545,14 @@ class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
|
|||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.score)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.score)
|
||||
else:
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.score)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -27,6 +27,7 @@ class SamPolicy(Policy):
|
|||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
|
@ -44,6 +45,7 @@ class SamPolicy(Policy):
|
|||
kwargs={
|
||||
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -51,6 +53,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -58,6 +61,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -65,6 +69,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
|
@ -80,6 +85,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -87,6 +93,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -94,6 +101,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -101,6 +109,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -108,6 +117,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -115,6 +125,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -122,6 +133,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -129,6 +141,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -136,6 +149,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -143,6 +157,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -150,6 +165,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -157,6 +173,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -164,6 +181,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -171,6 +189,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
|
@ -186,6 +205,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -193,6 +213,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -200,6 +221,7 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -207,6 +229,209 @@ class SamPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout`
|
||||
policy[SamVisionAttention] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout)
|
||||
},
|
||||
method_replacement={"forward": forward_fn()},
|
||||
sub_module_replacement=[],
|
||||
)
|
||||
elif use_zbv:
|
||||
policy[SamVisionLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.qkv",
|
||||
target_module=col_nn.FusedLinear,
|
||||
kwargs={
|
||||
"split_sizes": [self.model.config.vision_config.hidden_size] * 3,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin1",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin2",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
policy[SamTwoWayAttentionBlock] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"self_attn.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.q_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.k_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.v_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_token_to_image.out_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin1",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.lin2",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.q_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.k_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.v_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="cross_attn_image_to_token.out_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
policy[SamTwoWayTransformer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.q_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.k_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.v_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_attn_token_to_image.out_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
|
|
|
@ -13,6 +13,7 @@ from colossalai.shardformer.layer import (
|
|||
FusedRMSNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
LinearWithGradAccum,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
RMSNorm,
|
||||
|
@ -77,6 +78,8 @@ class T5BasePolicy(Policy):
|
|||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("T5 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_heads % self.shard_config.tensor_parallel_size == 0
|
||||
|
@ -119,6 +122,7 @@ class T5BasePolicy(Policy):
|
|||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -126,6 +130,7 @@ class T5BasePolicy(Policy):
|
|||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -133,6 +138,7 @@ class T5BasePolicy(Policy):
|
|||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -140,6 +146,7 @@ class T5BasePolicy(Policy):
|
|||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -168,6 +175,7 @@ class T5BasePolicy(Policy):
|
|||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -175,6 +183,7 @@ class T5BasePolicy(Policy):
|
|||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -183,6 +192,7 @@ class T5BasePolicy(Policy):
|
|||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -198,6 +208,7 @@ class T5BasePolicy(Policy):
|
|||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -205,6 +216,142 @@ class T5BasePolicy(Policy):
|
|||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
elif use_zbv:
|
||||
policy[T5Stack] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
policy[T5LayerSelfAttention] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
policy[T5LayerCrossAttention] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]
|
||||
)
|
||||
policy[T5Attention] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="o",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="relative_attention_bias",
|
||||
target_module=Embedding1D,
|
||||
kwargs=dict(
|
||||
gather_output=False,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
),
|
||||
ignore_if_not_exist=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
policy[T5LayerFF] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
policy[T5DenseGatedActDense] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_0 ",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_1",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
policy[T5DenseActDense] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -213,7 +360,6 @@ class T5BasePolicy(Policy):
|
|||
),
|
||||
]
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
|
@ -369,30 +515,61 @@ class T5BasePolicy(Policy):
|
|||
num_decoder_layers = len(decoder.block) if decoder else 0
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage, decoder_starting_stage = self.distribute_t5_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
|
||||
)
|
||||
start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
|
||||
|
||||
if stage_manager.stage < decoder_starting_stage:
|
||||
# current stage is in t5's encoder
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(model.shared)
|
||||
held_layers.append(encoder.embed_tokens)
|
||||
held_layers.append(encoder.dropout)
|
||||
if stage_manager.stage == decoder_starting_stage - 1:
|
||||
held_layers.append(encoder.final_layer_norm)
|
||||
held_layers.append(encoder.dropout)
|
||||
held_layers.extend(encoder.block[start_idx:end_idx])
|
||||
if stage_manager.is_interleave:
|
||||
layers_per_stage, decoder_starting_stage = self.distribute_t5_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
|
||||
)
|
||||
stage_indices = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
|
||||
if stage_manager.stage < decoder_starting_stage:
|
||||
# current stage is in t5's encoder
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(model.shared)
|
||||
held_layers.append(encoder.embed_tokens)
|
||||
held_layers.append(encoder.dropout)
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(encoder.final_layer_norm)
|
||||
held_layers.append(encoder.dropout)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(encoder.block[start_idx:end_idx])
|
||||
else:
|
||||
# current stage is in t5's decoder
|
||||
if stage_manager.stage == decoder_starting_stage:
|
||||
held_layers.append(decoder.embed_tokens)
|
||||
held_layers.append(decoder.dropout)
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(decoder.final_layer_norm)
|
||||
held_layers.append(decoder.dropout)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(decoder.block[start_idx:end_idx])
|
||||
else:
|
||||
# current stage is in t5's decoder
|
||||
if stage_manager.stage == decoder_starting_stage:
|
||||
held_layers.append(decoder.embed_tokens)
|
||||
held_layers.append(decoder.dropout)
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(decoder.final_layer_norm)
|
||||
held_layers.append(decoder.dropout)
|
||||
held_layers.extend(decoder.block[start_idx:end_idx])
|
||||
layers_per_stage, decoder_starting_stage = self.distribute_t5_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
|
||||
)
|
||||
start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
|
||||
|
||||
if stage_manager.stage < decoder_starting_stage:
|
||||
# current stage is in t5's encoder
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(model.shared)
|
||||
held_layers.append(encoder.embed_tokens)
|
||||
held_layers.append(encoder.dropout)
|
||||
if stage_manager.stage == decoder_starting_stage - 1:
|
||||
held_layers.append(encoder.final_layer_norm)
|
||||
held_layers.append(encoder.dropout)
|
||||
held_layers.extend(encoder.block[start_idx:end_idx])
|
||||
else:
|
||||
# current stage is in t5's decoder
|
||||
if stage_manager.stage == decoder_starting_stage:
|
||||
held_layers.append(decoder.embed_tokens)
|
||||
held_layers.append(decoder.dropout)
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(decoder.final_layer_norm)
|
||||
held_layers.append(decoder.dropout)
|
||||
held_layers.extend(decoder.block[start_idx:end_idx])
|
||||
return held_layers
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
|
@ -545,8 +722,15 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
|||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.lm_head)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -652,9 +836,16 @@ class T5ForTokenClassificationPolicy(T5EncoderPolicy):
|
|||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
else:
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -43,6 +43,8 @@ class ViTPolicy(Policy):
|
|||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
|
@ -72,6 +74,7 @@ class ViTPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -79,6 +82,7 @@ class ViTPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -86,6 +90,7 @@ class ViTPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -97,6 +102,7 @@ class ViTPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -109,6 +115,7 @@ class ViTPolicy(Policy):
|
|||
kwargs={
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -116,6 +123,7 @@ class ViTPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -132,7 +140,92 @@ class ViTPolicy(Policy):
|
|||
policy=policy,
|
||||
target_key=ViTIntermediate,
|
||||
)
|
||||
elif use_zbv:
|
||||
policy[ViTEmbeddings] = ModulePolicyDescription(
|
||||
attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForReplicatedInput,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
policy[ViTLayer] = ModulePolicyDescription(
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.query",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.key",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.value",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
),
|
||||
],
|
||||
)
|
||||
if self.enable_bias_gelu_fused:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_vit_intermediate_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=ViTIntermediate,
|
||||
)
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
|
@ -173,11 +266,20 @@ class ViTPolicy(Policy):
|
|||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embeddings)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
||||
if stage_manager.is_interleave:
|
||||
assert stage_manager.num_model_chunks is not None
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(module.embeddings)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
||||
else:
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embeddings)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.encoder.layer[start_idx:end_idx])
|
||||
return held_layers
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict):
|
||||
|
@ -213,9 +315,16 @@ class ViTModelPolicy(ViTPolicy):
|
|||
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.layernorm)
|
||||
held_layers.append(module.pooler)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(module.layernorm)
|
||||
held_layers.append(module.pooler)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.layernorm)
|
||||
held_layers.append(module.pooler)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
@ -226,6 +335,9 @@ class ViTForImageClassificationPolicy(ViTPolicy):
|
|||
from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
new_item = {
|
||||
ViTForImageClassification: ModulePolicyDescription(
|
||||
|
@ -233,13 +345,33 @@ class ViTForImageClassificationPolicy(ViTPolicy):
|
|||
SubModuleReplacementDescription(
|
||||
suffix="classifier",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
elif use_zbv:
|
||||
new_item = {
|
||||
ViTForImageClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="classifier",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.shard_config.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
|
@ -256,9 +388,16 @@ class ViTForImageClassificationPolicy(ViTPolicy):
|
|||
|
||||
module = self.model.vit
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.layernorm)
|
||||
held_layers.append(self.model.classifier)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(module.layernorm)
|
||||
held_layers.append(self.model.classifier)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.layernorm)
|
||||
held_layers.append(self.model.classifier)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
@ -285,8 +424,15 @@ class ViTForMaskedImageModelingPolicy(ViTPolicy):
|
|||
|
||||
module = self.model.vit
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.layernorm)
|
||||
held_layers.append(self.model.decoder)
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(module.layernorm)
|
||||
held_layers.append(self.model.decoder)
|
||||
else:
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.layernorm)
|
||||
held_layers.append(self.model.decoder)
|
||||
|
||||
return held_layers
|
||||
|
|
|
@ -72,6 +72,8 @@ class WhisperPolicy(Policy):
|
|||
"Whisper doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
)
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
# TODO using the jit fused add_and_dropout affect the accuracy
|
||||
if self.shard_config.enable_jit_fused:
|
||||
self.shard_config.enable_jit_fused = False
|
||||
|
@ -93,6 +95,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -100,6 +103,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -107,6 +111,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -114,6 +119,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -121,6 +127,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -128,6 +135,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
|
@ -148,6 +156,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -155,6 +164,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -162,6 +172,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -169,6 +180,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -176,6 +188,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -183,6 +196,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -190,6 +204,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -197,6 +212,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -204,6 +220,7 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -211,6 +228,145 @@ class WhisperPolicy(Policy):
|
|||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
elif use_zbv:
|
||||
policy[WhisperEncoderLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
policy[WhisperDecoderLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.q_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.k_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.v_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.out_proj",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.LinearWithGradAccum,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
|
@ -460,30 +616,66 @@ class WhisperPolicy(Policy):
|
|||
num_decoder_layers = 0
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
|
||||
)
|
||||
start_idx, end_idx = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
|
||||
if stage_manager.is_interleave:
|
||||
layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
|
||||
)
|
||||
stage_indices = self.get_whisper_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
|
||||
|
||||
if stage_manager.stage < decoder_starting_stage:
|
||||
# current stage is in whisper's encoder
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(encoder.embed_positions)
|
||||
held_layers.append(encoder.conv1)
|
||||
held_layers.append(encoder.conv2)
|
||||
if stage_manager.stage == decoder_starting_stage - 1:
|
||||
held_layers.append(encoder.layer_norm)
|
||||
held_layers.extend(encoder.layers[start_idx:end_idx])
|
||||
if stage_manager.stage < decoder_starting_stage:
|
||||
# current stage is in whisper's encoder
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
held_layers.append(encoder.embed_positions)
|
||||
held_layers.append(encoder.conv1)
|
||||
held_layers.append(encoder.conv2)
|
||||
# interleaved: not use_zbv & stage_manager.stage == decoder_starting_stage - 1
|
||||
# zbv: use_zbv & stage_manager.stage == first stage
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and decoder_starting_stage - 1
|
||||
):
|
||||
held_layers.append(encoder.layer_norm)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(encoder.layers[start_idx:end_idx])
|
||||
else:
|
||||
# current stage is in whisper's decoder
|
||||
# TODO:(Jianghai) We divide encoder and decoder layers into different parts here,
|
||||
# the case encoder and decoder put in same stage should be add in the future.
|
||||
if stage_manager.stage == decoder_starting_stage:
|
||||
held_layers.append(decoder.embed_tokens)
|
||||
held_layers.append(decoder.embed_positions)
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(decoder.layer_norm)
|
||||
for start_idx, end_idx in stage_indices:
|
||||
held_layers.extend(encoder.layers[start_idx:end_idx])
|
||||
else:
|
||||
# current stage is in whisper's decoder
|
||||
# TODO:(Jianghai) We divide encoder and decoder layers into different parts here,
|
||||
# the case encoder and decoder put in same stage should be add in the future.
|
||||
if stage_manager.stage == decoder_starting_stage:
|
||||
held_layers.append(decoder.embed_tokens)
|
||||
held_layers.append(decoder.embed_positions)
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(decoder.layer_norm)
|
||||
held_layers.extend(decoder.layers[start_idx:end_idx])
|
||||
layers_per_stage, decoder_starting_stage = self.distribute_whisper_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
|
||||
)
|
||||
start_idx, end_idx = self.get_whisper_stage_index(
|
||||
layers_per_stage, stage_manager.stage, decoder_starting_stage
|
||||
)
|
||||
|
||||
if stage_manager.stage < decoder_starting_stage:
|
||||
# current stage is in whisper's encoder
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(encoder.embed_positions)
|
||||
held_layers.append(encoder.conv1)
|
||||
held_layers.append(encoder.conv2)
|
||||
if stage_manager.stage == decoder_starting_stage - 1:
|
||||
held_layers.append(encoder.layer_norm)
|
||||
held_layers.extend(encoder.layers[start_idx:end_idx])
|
||||
else:
|
||||
# current stage is in whisper's decoder
|
||||
# TODO:(Jianghai) We divide encoder and decoder layers into different parts here,
|
||||
# the case encoder and decoder put in same stage should be add in the future.
|
||||
if stage_manager.stage == decoder_starting_stage:
|
||||
held_layers.append(decoder.embed_tokens)
|
||||
held_layers.append(decoder.embed_positions)
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(decoder.layer_norm)
|
||||
held_layers.extend(decoder.layers[start_idx:end_idx])
|
||||
return held_layers
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
|
@ -575,8 +767,15 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
|||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.proj_out)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.proj_out)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.proj_out)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
@ -629,9 +828,17 @@ class WhisperForAudioClassificationPolicy(WhisperPolicy):
|
|||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.projector)
|
||||
held_layers.append(self.model.classifier)
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_interleave:
|
||||
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
|
||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||
):
|
||||
held_layers.append(self.model.projector)
|
||||
held_layers.append(self.model.classifier)
|
||||
else:
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.projector)
|
||||
held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
|
|
|
@ -0,0 +1,238 @@
|
|||
# ZeroBubble Pipeline Parallelism
|
||||
Author: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217)
|
||||
|
||||
**Related Paper**
|
||||
- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241)
|
||||
|
||||
## Introduction
|
||||
ZeroBubble (V Schedule):
|
||||
Crucially, splitting B into two stages (also known as an activation gradient and a weight gradient) and a scheme like 1F1B1W can further reduce the bubble compared to the 1F1B scheme in earlier work.
|
||||
|
||||
## Hands-On Practice
|
||||
We now demonstrate how to use ZeroBubble with booster API with 4 GPUs.
|
||||
|
||||
### step 1. Import libraries
|
||||
```python
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin
|
||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||
```
|
||||
|
||||
### step 2. Initialize Distributed Environment and Parallism Group
|
||||
```python
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
```
|
||||
|
||||
### step 3. Initialize Module, Optimizer, and Pipeline Schedule
|
||||
Build our model and Optimizer. We created a Llama with 8 Decoder-Layer. Then, inite the PipelineGraph and Pipeline schedule by get_v_schedule() function.
|
||||
```python
|
||||
# Global Param
|
||||
NUM_BATCH = 8
|
||||
NUM_TOK_PER_BATCH = 4
|
||||
NUM_LAYERS = 8
|
||||
HIDDEN_SIZE_PER_HEAD = 4
|
||||
NUM_HEADS = 4
|
||||
# Init Llama from huggingface
|
||||
configuration = LlamaConfig(
|
||||
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
|
||||
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
|
||||
num_hidden_layers=NUM_LAYERS,
|
||||
num_attention_heads=NUM_HEADS,
|
||||
num_key_value_heads=NUM_HEADS,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
model = LlamaModel(configuration).cuda()
|
||||
optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
|
||||
```
|
||||
### step 4. Initialize Module, Optimizer, and Pipeline Schedul
|
||||
Then, we need to create the PipelineGraph and PipelineSchedule using the get_v_schedule() function. We need to initialise the PipelineGraph with the following parameters.
|
||||
x_cost represents the runtime consumed by operation x of each model chunk.
|
||||
x_mem represents the amount of memory consumed by the operation x of each model chunk.
|
||||
These parameters are estimated and filled in before the pipeline starts. In fact, better results can be obtained based on the runtime and memory cost during the real computation of the model.
|
||||
In the following example, we assume that the computation times for the model's forward, reverse B, and reverse W are 1, 1, 1, respectively, and the p2p communication time is 1.
|
||||
```python
|
||||
# Init schedule
|
||||
h, a, s = config.hidden_size, config.num_attention_heads, 1024
|
||||
mem_f = 34 * h + 5 * a * s
|
||||
mem_w = -32 * h
|
||||
mem_b = -mem_w - mem_f
|
||||
graph = PipelineGraph(
|
||||
n_stage=pp_size,
|
||||
n_micro=num_microbatches,
|
||||
f_cost=1,
|
||||
b_cost=1,
|
||||
w_cost=1,
|
||||
c_cost=1,
|
||||
f_mem=mem_f,
|
||||
b_mem=mem_b,
|
||||
w_mem=mem_w,
|
||||
)
|
||||
zbv_schedule = graph.get_v_schedule()
|
||||
```
|
||||
|
||||
### step 5.Init Booster
|
||||
Pass pp_style="zbv" when initialising the Plugin to use the ZeroBubble Pipeline.
|
||||
```python
|
||||
plugin = HybridParallelPlugin(
|
||||
pp_size=4,
|
||||
num_microbatches=4,
|
||||
tp_size=1,
|
||||
sp_size=1,
|
||||
zero_stage=1,
|
||||
initial_scale=1,
|
||||
find_unused_parameters=True,
|
||||
pp_style="zbv",
|
||||
scheduler_nodes=zbv_schedule,
|
||||
num_model_chunks=2,
|
||||
)
|
||||
|
||||
dp_size = plugin.dp_size
|
||||
booster = Booster(plugin=plugin)
|
||||
```
|
||||
|
||||
### step 6.Train Your Model
|
||||
```python
|
||||
steps = 10
|
||||
for step in range(steps):
|
||||
input_embeddings = torch.rand(
|
||||
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||
).cuda()
|
||||
dist.all_reduce(
|
||||
input_embeddings, group=plugin.pp_group
|
||||
)
|
||||
data_iter = iter([{"inputs_embeds": input_embeddings}])
|
||||
output = booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
lambda x, y: x.last_hidden_state.mean(),
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True,
|
||||
)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Advanced Practice
|
||||
In ColossalAI, you can get better training performance by using MetaCache and HybridParallel with ZeroBubble.
|
||||
### 1.Use MetaCache with ZeroBubble
|
||||
Pass "enable_metadata_cache=True" when initialising the Plugin to use the Meta Cache with ZeroBubble Pipeline.
|
||||
```python
|
||||
plugin = HybridParallelPlugin(
|
||||
pp_size=2,
|
||||
num_microbatches=4,
|
||||
tp_size=2,
|
||||
sp_size=2,
|
||||
zero_stage=1,
|
||||
initial_scale=1,
|
||||
enable_metadata_cache=True,
|
||||
find_unused_parameters=True,
|
||||
pp_style="zbv",
|
||||
scheduler_nodes=zbv_schedule,
|
||||
num_model_chunks=2,
|
||||
)
|
||||
```
|
||||
|
||||
### 2.HybridParallel with ZeroBubble
|
||||
Pass pp_size, tp_size, sp_size when initialising the Plugin to use the HybridParallel with ZeroBubble Pipeline.
|
||||
```python
|
||||
plugin = HybridParallelPlugin(
|
||||
pp_size=2,
|
||||
num_microbatches=2,
|
||||
tp_size=2,
|
||||
sp_size=2,
|
||||
zero_stage=1,
|
||||
initial_scale=1,
|
||||
find_unused_parameters=True,
|
||||
pp_style="zbv",
|
||||
scheduler_nodes=zbv_schedule,
|
||||
num_model_chunks=2,
|
||||
)
|
||||
```
|
||||
Performance Benchmark
|
||||
<table>
|
||||
<tr>
|
||||
<th nowrap="nowrap">HybridParallel Strategy</th>
|
||||
<th nowrap="nowrap" align="center">Pipeline Parallel</th>
|
||||
<th nowrap="nowrap" align="center">Sequence Parallel + Pipeline Parallel</th>
|
||||
<th nowrap="nowrap" align="center">Data Parallel + Pipeline Parallel</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="center" title="1F1B">With 1F1B</td>
|
||||
<td nowrap="nowrap" align="center">15.27 samples/sec</td>
|
||||
<td nowrap="nowrap" align="center">17.22 samples/sec</td>
|
||||
<td nowrap="nowrap" align="center">14.06 samples/sec</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="center" title="Zero Bubble">With Zero Bubble</td>
|
||||
<td nowrap="nowrap" align="center">17.36 samples/sec</td>
|
||||
<td nowrap="nowrap" align="center">18.38 samples/sec</td>
|
||||
<td nowrap="nowrap" align="center">14.44 samples/sec</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="39"></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### 3.Fine-tuning Scheduler parameters
|
||||
|
||||
```python
|
||||
```
|
||||
## Model compatibility
|
||||
<table>
|
||||
<tr>
|
||||
<th nowrap="nowrap">Shardformer/Model</th>
|
||||
<th nowrap="nowrap" align="center">Bert</th>
|
||||
<th nowrap="nowrap" align="center">Blip2</th>
|
||||
<th nowrap="nowrap" align="center">Bloom</th>
|
||||
<th nowrap="nowrap" align="center">Chatglm2</th>
|
||||
<th nowrap="nowrap" align="center">Command</th>
|
||||
<th nowrap="nowrap" align="center">Deepseek</th>
|
||||
<th nowrap="nowrap" align="center">Falcon</th>
|
||||
<th nowrap="nowrap" align="center">GPT2</th>
|
||||
<th nowrap="nowrap" align="center">Gptj</th>
|
||||
<th nowrap="nowrap" align="center">Llama</th>
|
||||
<th nowrap="nowrap" align="center">Mistral</th>
|
||||
<th nowrap="nowrap" align="center">Opt</th>
|
||||
<th nowrap="nowrap" align="center">Qwen2</th>
|
||||
<th nowrap="nowrap" align="center">Sam</th>
|
||||
<th nowrap="nowrap" align="center">T5</th>
|
||||
<th nowrap="nowrap" align="center">Vit</th>
|
||||
<th nowrap="nowrap" align="center">Whisper</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="center" title="ZeroBubble">ZeroBubble</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="39"></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## API Reference
|
||||
{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }}
|
||||
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=4 zerobubble_pipeline_parallelism.py -->
|
|
@ -0,0 +1,237 @@
|
|||
# 零气泡流水线并行
|
||||
作者: [Junwen Duan](https://github.com/duanjunwen), [Hongxin Liu](https://github.com/ver217)
|
||||
|
||||
**相关论文**
|
||||
- [Zero Bubble Pipeline Parallelism](https://arxiv.org/abs/2401.10241)
|
||||
|
||||
## 介绍
|
||||
零气泡(V Schedule):
|
||||
与早期工作中的1F1B方案相比,零气泡流水线并行将B分成两个阶段(也称为激活梯度和权重梯度),形如1F1B1W这样的方案可以进一步减少气泡。
|
||||
|
||||
## 使用
|
||||
我们将演示如何在 4 个 GPU 上使用带有 booster API 的 ZeroBubble
|
||||
|
||||
### step 1. 引用仓库
|
||||
```python
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin
|
||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||
```
|
||||
|
||||
### step 2. 初始化分布式环境
|
||||
```python
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
```
|
||||
|
||||
### step 3. 初始化模型优化器
|
||||
建立我们的模型和优化器 我们创建了一个带有8层Decoder-Layer的 Llama。然后,使用get_v_schedule()函数创建PipelineGraph和Pipeline schedule。
|
||||
|
||||
```python
|
||||
# Global Param
|
||||
NUM_BATCH = 8
|
||||
NUM_TOK_PER_BATCH = 4
|
||||
NUM_LAYERS = 8
|
||||
HIDDEN_SIZE_PER_HEAD = 4
|
||||
NUM_HEADS = 4
|
||||
# Init Llama from huggingface
|
||||
configuration = LlamaConfig(
|
||||
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
|
||||
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
|
||||
num_hidden_layers=NUM_LAYERS,
|
||||
num_attention_heads=NUM_HEADS,
|
||||
num_key_value_heads=NUM_HEADS,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
model = LlamaModel(configuration).cuda()
|
||||
optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
|
||||
```
|
||||
### step 4.初始化流水线Schedule
|
||||
然后,我们需要使用 get_v_schedule() 函数创建 PipelineGraph 和 PipelineSchedule。我们需要用以下参数初始化 PipelineGraph。
|
||||
x_cost 表示每个模型块的操作 x 所消耗的运行时间。
|
||||
x_mem 表示每个模型块的操作 x 所消耗的内存量。
|
||||
这些参数都是在流水线启动前估算并填入的。事实上,在模型的实际计算过程中,根据运行时间和内存成本可以获得更好的结果。
|
||||
在下面的例子中,我们假设模型的正向、反向 B 和反向 W 的计算时间分别为 1、1、1,p2p 通信时间为 1。
|
||||
```python
|
||||
# Init schedule
|
||||
h, a, s = config.hidden_size, config.num_attention_heads, 1024
|
||||
mem_f = 34 * h + 5 * a * s
|
||||
mem_w = -32 * h
|
||||
mem_b = -mem_w - mem_f
|
||||
graph = PipelineGraph(
|
||||
n_stage=pp_size,
|
||||
n_micro=num_microbatches,
|
||||
f_cost=1,
|
||||
b_cost=1,
|
||||
w_cost=1,
|
||||
c_cost=1,
|
||||
f_mem=mem_f,
|
||||
b_mem=mem_b,
|
||||
w_mem=mem_w,
|
||||
)
|
||||
zbv_schedule = graph.get_v_schedule()
|
||||
```
|
||||
|
||||
### step 5.初始化Booster
|
||||
在初始化Plugin时输入pp_style="zbv",以使用ZeroBubble流水线并行。
|
||||
```python
|
||||
plugin = HybridParallelPlugin(
|
||||
pp_size=4,
|
||||
num_microbatches=4,
|
||||
tp_size=1,
|
||||
sp_size=1,
|
||||
zero_stage=1,
|
||||
initial_scale=1,
|
||||
find_unused_parameters=True,
|
||||
pp_style="zbv",
|
||||
scheduler_nodes=zbv_schedule,
|
||||
num_model_chunks=2,
|
||||
)
|
||||
|
||||
dp_size = plugin.dp_size
|
||||
booster = Booster(plugin=plugin)
|
||||
```
|
||||
|
||||
### step 6.训练模型
|
||||
```python
|
||||
steps = 10
|
||||
for step in range(steps):
|
||||
input_embeddings = torch.rand(
|
||||
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||
).cuda()
|
||||
dist.all_reduce(
|
||||
input_embeddings, group=plugin.pp_group
|
||||
)
|
||||
data_iter = iter([{"inputs_embeds": input_embeddings}])
|
||||
output = booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
lambda x, y: x.last_hidden_state.mean(),
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True,
|
||||
)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## 进阶使用技巧
|
||||
在 ColossalAI 中,通过使用MetaCache和混合并行的ZeroBubble,可以获得更好的训练性能。
|
||||
|
||||
### 1.在ZeroBubble中使用元数据缓存
|
||||
在初始化Plugin时输入 "enable_metadata_cache=True",以便在ZeroBubble管道中使用元数据缓存。
|
||||
```python
|
||||
plugin = HybridParallelPlugin(
|
||||
pp_size=2,
|
||||
num_microbatches=4,
|
||||
tp_size=2,
|
||||
sp_size=2,
|
||||
zero_stage=1,
|
||||
initial_scale=1,
|
||||
enable_metadata_cache=True,
|
||||
find_unused_parameters=True,
|
||||
pp_style="zbv",
|
||||
scheduler_nodes=zbv_schedule,
|
||||
num_model_chunks=2,
|
||||
)
|
||||
```
|
||||
|
||||
### 2.同时使用ZeroBubble和混合并行
|
||||
在初始化插件时传递 pp_size, tp_size, sp_size, 以便使用零气泡混合并行管道(HybridParallel with ZeroBubble Pipeline)。
|
||||
```python
|
||||
plugin = HybridParallelPlugin(
|
||||
pp_size=2,
|
||||
num_microbatches=2,
|
||||
tp_size=2,
|
||||
sp_size=2,
|
||||
zero_stage=1,
|
||||
initial_scale=1,
|
||||
find_unused_parameters=True,
|
||||
pp_style="zbv",
|
||||
scheduler_nodes=zbv_schedule,
|
||||
num_model_chunks=2,
|
||||
)
|
||||
```
|
||||
性能指标
|
||||
<table>
|
||||
<tr>
|
||||
<th nowrap="nowrap">HybridParallel Strategy</th>
|
||||
<th nowrap="nowrap" align="center">Pipeline Parallel</th>
|
||||
<th nowrap="nowrap" align="center">Sequence Parallel + Pipeline Parallel</th>
|
||||
<th nowrap="nowrap" align="center">Data Parallel + Pipeline Parallel</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="center" title="1F1B">With 1F1B</td>
|
||||
<td nowrap="nowrap" align="center">15.27 samples/sec</td>
|
||||
<td nowrap="nowrap" align="center">17.22 samples/sec</td>
|
||||
<td nowrap="nowrap" align="center">14.06 samples/sec</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="center" title="Zero Bubble">With Zero Bubble</td>
|
||||
<td nowrap="nowrap" align="center">17.36 samples/sec</td>
|
||||
<td nowrap="nowrap" align="center">18.38 samples/sec</td>
|
||||
<td nowrap="nowrap" align="center">14.44 samples/sec</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="39"></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## 模型兼容性
|
||||
<table>
|
||||
<tr>
|
||||
<th nowrap="nowrap">Shardformer/Model</th>
|
||||
<th nowrap="nowrap" align="center">Bert</th>
|
||||
<th nowrap="nowrap" align="center">Blip2</th>
|
||||
<th nowrap="nowrap" align="center">Bloom</th>
|
||||
<th nowrap="nowrap" align="center">Chatglm2</th>
|
||||
<th nowrap="nowrap" align="center">Command</th>
|
||||
<th nowrap="nowrap" align="center">Deepseek</th>
|
||||
<th nowrap="nowrap" align="center">Falcon</th>
|
||||
<th nowrap="nowrap" align="center">GPT2</th>
|
||||
<th nowrap="nowrap" align="center">Gptj</th>
|
||||
<th nowrap="nowrap" align="center">Llama</th>
|
||||
<th nowrap="nowrap" align="center">Mistral</th>
|
||||
<th nowrap="nowrap" align="center">Opt</th>
|
||||
<th nowrap="nowrap" align="center">Qwen2</th>
|
||||
<th nowrap="nowrap" align="center">Sam</th>
|
||||
<th nowrap="nowrap" align="center">T5</th>
|
||||
<th nowrap="nowrap" align="center">Vit</th>
|
||||
<th nowrap="nowrap" align="center">Whisper</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td nowrap="nowrap" align="center" title="ZeroBubble">ZeroBubble</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
<td nowrap="nowrap" align="center">✔️</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="39"></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## API 参考
|
||||
{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }}
|
||||
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=4 zerobubble_pipeline_parallelism.py -->
|
|
@ -8,7 +8,8 @@ from torch.testing import assert_close
|
|||
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
||||
from colossalai.shardformer.layer import GPT2FusedLinearConv, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
@ -118,11 +119,82 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
|
|||
assert_close(target_grad, linear_row.weight.grad)
|
||||
|
||||
|
||||
def check_linear_conv_1d_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: str):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
with ctx:
|
||||
linear_copy = Conv1D(192, 48).cuda()
|
||||
linear_base = GPT2FusedLinearConv.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear_base.weight.shape == torch.Size([48, 192])
|
||||
assert linear_base.bias.shape == torch.Size([192])
|
||||
assert linear_copy.weight is linear_base.weight
|
||||
assert linear_copy.bias is linear_base.bias
|
||||
|
||||
# ensure weights are reversibly loadable
|
||||
linear_base.load_state_dict(linear.state_dict())
|
||||
linear.load_state_dict(linear_base.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(1, 4, 48).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_base(x)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
|
||||
# check the input gradients & weight gradients
|
||||
assert_close(out.grad, gather_out.grad)
|
||||
assert_close(linear.weight.grad, linear_base.weight.grad)
|
||||
|
||||
|
||||
def check_linear_conv_1d_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: str):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
with ctx:
|
||||
linear_copy = Conv1D(192, 48).cuda()
|
||||
linear_base = GPT2FusedLinearConv.from_native_module(linear_copy, seq_parallel_mode=seq_parallel_mode, use_zbv=True)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear_base.weight.shape == torch.Size([48, 192])
|
||||
assert linear_base.bias.shape == torch.Size([192])
|
||||
assert linear_copy.weight is linear_base.weight
|
||||
assert linear_copy.bias is linear_base.bias
|
||||
|
||||
# ensure weights are reversibly loadable
|
||||
linear_base.load_state_dict(linear.state_dict())
|
||||
linear.load_state_dict(linear_base.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(1, 4, 48).cuda()
|
||||
out = linear(x)
|
||||
gather_out = linear_base(x)
|
||||
assert_close(out, gather_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
gather_out.sum().backward()
|
||||
|
||||
WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue
|
||||
WeightGradStore.pop(chunk=0)
|
||||
|
||||
# check the input gradients & weight gradients
|
||||
assert_close(out.grad, gather_out.grad)
|
||||
assert_close(linear.weight.grad, linear_base.weight.grad)
|
||||
|
||||
|
||||
@parameterize("lazy_init", [False, True])
|
||||
@parameterize("seq_parallel_mode", ["split_gather", None])
|
||||
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
|
||||
check_linear_conv_1d_col(lazy_init, seq_parallel_mode)
|
||||
check_linear_conv_1d_row(lazy_init, seq_parallel_mode)
|
||||
check_linear_conv_1d_without_weight_grad_store(lazy_init, None)
|
||||
check_linear_conv_1d_with_weight_grad_store(lazy_init, None)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch.testing import assert_close
|
|||
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row
|
||||
from colossalai.shardformer.layer import FusedLinear, FusedLinear1D_Col, FusedLinear1D_Row
|
||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
@ -120,12 +120,45 @@ def check_linear_1d_col_row(lazy_init: bool):
|
|||
assert_close(target_grad2, linear_row.weight.grad)
|
||||
|
||||
|
||||
@parameterize("lazy_init", [False, True])
|
||||
def check_linear_1d_base(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
linear = nn.Linear(8, 80).cuda()
|
||||
with ctx:
|
||||
linear_copy = nn.Linear(8, 80).cuda()
|
||||
linear_base = FusedLinear.from_native_module(linear_copy)
|
||||
|
||||
assert linear.weight.shape == torch.Size([80, 8])
|
||||
assert linear.bias.shape == torch.Size([80])
|
||||
assert linear_base.weight.shape == torch.Size([80, 8])
|
||||
assert linear_base.bias.shape == torch.Size([80])
|
||||
assert linear_copy.weight is linear_base.weight
|
||||
assert linear_copy.bias is linear_base.bias
|
||||
|
||||
# ensure weights are reversibly loadable
|
||||
linear_base.load_state_dict(linear.state_dict())
|
||||
linear.load_state_dict(linear_base.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 8).cuda()
|
||||
out = linear(x)
|
||||
base_out = linear_base(x)
|
||||
assert_close(out, base_out)
|
||||
|
||||
# check backward correctness
|
||||
out.sum().backward()
|
||||
base_out.sum().backward()
|
||||
|
||||
assert_close(linear.weight.grad, linear_base.weight.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
check_linear_1d_col()
|
||||
check_linear_1d_row()
|
||||
check_linear_1d_col_row()
|
||||
check_linear_1d_base()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
Loading…
Reference in New Issue