[feat] support zbv in mixtral benchmark; (#6083)

* [feat] support zbv in mixtral benchmark;

* [fix] MixtralForCausalLMPolicy get_held_layer support zbv;

* [feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv;

* [feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv

* [fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;

* [feat] Linear1D_COL/ROW support zbv WeightGradStore;

* [feat] support use_zbv in llama, mixtral modeling; only replace Linear1D_Col/Row policy;

* [fix] fix test case; moe error in second iter

* [feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;

* [fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd;

* [fix] debug zbv llama test;

* [fix] rm use_zbv flag in Shardconfig; rm debug info;

* [fix] add & fix  llama test

* [feat] support meta cache, meta_grad_send, meta_tensor_send; fix runtime too long in Recv Bwd; benchmark for llama + Hybrid(tp+pp);

* [fix\ fix fail case test_shard_llama

* [fix] fix test_shard_llama

* [fix] fix llama modeling policy;

* [fix] fix test_shard_llama ci;

* [fix] fix test zerobubble

* [fix] fix handle name; rm useless comments;

* [fix] fix send recv signature;

* [fix] fix comment in llama & benchmark

* [feat] support no tensor parallel Linear in shardformer; Add test for use weightGradStore and not use WeightGradStore

* [fix] fix linear (no tp) ops func name;
feature/zerobubble
duanjunwen 2024-10-31 18:17:29 +08:00 committed by GitHub
parent dac0e07b13
commit aed20fb2df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 938 additions and 151 deletions

View File

@ -1,16 +1,18 @@
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch
import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_flatten, tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.pipeline.weight_grad_store import WeightGradStore
from ._utils import (
clone,
@ -61,11 +63,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self.do_post_validation = False
# P2PMeta cache
# self.enable_metadata_cache = enable_metadata_cache
# self.send_tensor_metadata = True
# self.send_grad_metadata = True
# self.tensor_metadata_recv = None
# self.grad_metadata_recv = None
self.enable_metadata_cache = enable_metadata_cache
self.send_tensor_metadata = True
self.send_grad_metadata = True
self.tensor_metadata_recv = None
self.grad_metadata_recv = None
# P2P communication
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
@ -104,8 +106,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# dy buffer for local send bwd
self.local_send_backward_buffer = []
# wait pp buffer
self.wait_handles = []
def assert_buffer_empty(self):
# assert buuffer is empty at end
# assert buffer is empty at end
assert len(self.input_tensors[0]) == 0
assert len(self.input_tensors[1]) == 0
assert len(self.output_tensors[0]) == 0
@ -201,7 +206,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> List:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For ZBV.
@ -220,7 +225,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
#################
if self.stage_manager.is_first_stage(ignore_chunk=True):
return None, []
# return None, []
return []
################
# chunk = 0 & not is_first_stage
@ -228,9 +234,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
#################
else:
prev_rank = self.stage_manager.get_prev_rank()
input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank)
input_tensor, wait_handles = self.comm.recv_forward(
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv
)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, wait_handles
# return input_tensor, wait_handles
return wait_handles
else:
################
@ -238,7 +249,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; cause u get y from local_send_forward_buffer in schedule f
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
return None, []
# return None, []
return []
################
# chunk = 1 & not is_last_stage
@ -246,11 +258,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
################
else:
next_rank = self.stage_manager.get_next_rank()
input_tensor, wait_handles = self.comm.recv_forward(next_rank)
input_tensor, wait_handles = self.comm.recv_forward(
next_rank, metadata_recv=self.tensor_metadata_recv
)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, wait_handles
# return input_tensor, wait_handles
return wait_handles
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For ZBV.
@ -270,7 +287,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; Already get dy from local_send_backward_buffer in schedule b
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
return None, []
# return None, []
return []
################
# chunk = 0 & not is_last_stage
@ -278,9 +296,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
################
else:
next_rank = self.stage_manager.get_next_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank)
output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank, metadata_recv=self.grad_metadata_recv
)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, wait_handles
# return output_tensor_grad, wait_handles
return wait_handles
else:
# bwd chunk1 is left V;
@ -289,7 +312,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; get loss from local
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
return None, []
# return None, []
return []
################
# chunk = 1 & not first stage
@ -297,9 +321,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
################
else:
prev_rank = self.stage_manager.get_prev_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank)
output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv
)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, wait_handles
# return output_tensor_grad, wait_handles
return wait_handles
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
"""Sends the input tensor to the next stage in pipeline.
@ -329,7 +358,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
next_rank = self.stage_manager.get_next_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank)
send_handles = self.comm.send_forward(
output_object=output_tensor, next_rank=next_rank, send_metadata=self.send_tensor_metadata
)
self.send_tensor_metadata = not self.enable_metadata_cache
return send_handles
else:
@ -347,7 +379,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
prev_rank = self.stage_manager.get_prev_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(output_tensor, prev_rank)
send_handles = self.comm.send_forward(
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata
)
self.send_tensor_metadata = not self.enable_metadata_cache
return send_handles
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
@ -379,7 +414,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
prev_rank = self.stage_manager.get_prev_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(input_tensor_grad, prev_rank)
send_handles = self.comm.send_backward(
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
)
self.send_grad_metadata = not self.enable_metadata_cache
return send_handles
# bwd chunk1 is left V;
@ -398,7 +436,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
next_rank = self.stage_manager.get_next_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(input_tensor_grad, next_rank)
send_handles = self.comm.send_backward(
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata
)
self.send_grad_metadata = not self.enable_metadata_cache
return send_handles
def forward_step(
@ -432,7 +473,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
internal_inputs = {} if input_obj is None else input_obj
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
loss = criterion(output_obj, micro_batch) / self.num_microbatch
@ -479,11 +519,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_grad_ = []
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
return None
# if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
# return None
# For loss backward; output_obj is loss; output_obj_grad should be None
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
assert output_obj_grad is None
input_obj_, _ = tree_flatten(input_obj)
output_obj_.append(output_obj) # LOSS
@ -504,17 +544,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
ctx = optimizer.no_sync()
except AttributeError:
ctx = model_chunk.no_sync()
with ctx:
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
inputs=input_obj_,
retain_graph=True,
# inputs=input_obj_,
retain_graph=False,
)
# Format output_obj_grad
input_obj_grad = {}
input_obj_grad = dict()
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
pass
else:
@ -651,10 +689,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Do not release_tensor_data loss, release_tensor_data other output_obj;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
self.output_tensors[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(output_obj)
# self.output_tensors_dw[model_chunk_id].append(output_obj)
else:
self.output_tensors[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(output_obj)
# self.output_tensors_dw[model_chunk_id].append(output_obj)
# add output to send_fwd_buffer
if model_chunk_id == 0: # chunk 0
@ -706,15 +744,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0)
# save output_tensor_grad for dw
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# we save loss here
self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
else:
# we save output_tensor_grad here
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# # save output_tensor_grad for dw
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# # we save loss here
# self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
# else:
# # we save output_tensor_grad here
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# the_output_obj_grad = []
# if isinstance(output_obj, dict):
# for (k, v) in output_obj.items():
# the_output_obj_grad.append(v.requires_grad)
# else:
# the_output_obj_grad.append(output_obj.requires_grad)
# Step2: bwd step
input_object_grad = self.backward_b_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
@ -739,6 +782,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# send to next
else:
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
WeightGradStore.flush(chunk=model_chunk_id)
def schedule_w(
self,
@ -758,16 +802,17 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
"""
# get y & dy from buffer
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
WeightGradStore.pop(chunk=model_chunk_id)
self.backward_w_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
optimizer=optimizer,
output_obj=output_obj,
output_obj_grad=output_obj_grad,
)
# self.backward_w_step(
# model_chunk=model_chunk,
# model_chunk_id=model_chunk_id,
# optimizer=optimizer,
# output_obj=output_obj,
# output_obj_grad=output_obj_grad,
# )
def run_forward_only(
self,
@ -844,7 +889,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
communication_func = self.communication_map[scheduled_node.type]
communication_func(scheduled_node.chunk)
wait_handle = communication_func(scheduled_node.chunk)
self.wait_handles.append(wait_handle)
elif scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
@ -868,6 +914,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
for h in self.wait_handles:
for hh in h:
hh.wait()
# return loss & output
if outputs is not None:
@ -907,5 +956,4 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
)
self.assert_buffer_empty()
return result

View File

@ -223,7 +223,6 @@ class PipelineStageManager:
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks
# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2

View File

@ -0,0 +1,32 @@
import queue
class WeightGradStore:
cache = []
weight_grad_queue = [queue.Queue(), queue.Queue()]
@classmethod
def put(cls, total_input, grad_output, weight, func):
# func(total_input, grad_output, weight.main_grad)
cls.cache.append((total_input, grad_output, weight, func))
@classmethod
def flush(cls, chunk=0):
cls.weight_grad_queue[chunk].put(cls.cache)
cls.cache = []
@classmethod
def pop(cls, chunk=0):
# print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}")
if cls.weight_grad_queue[chunk].qsize() > 0:
stored_grads = cls.weight_grad_queue[chunk].get()
for total_input, grad_output, weight, func in stored_grads:
if weight.grad is not None:
func(total_input, grad_output, weight.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight.grad = grad_weight
else:
raise Exception("Pop empty queue.")

View File

@ -2,7 +2,7 @@ from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
@ -11,6 +11,7 @@ from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2
__all__ = [
"Embedding1D",
"VocabParallelEmbedding1D",
"LinearWithGradAccum",
"Linear1D_Col",
"Linear1D_Row",
"GPT2FusedLinearConv1D_Col",

View File

@ -1,7 +1,11 @@
import functools
import torch
import torch.distributed as dist
import torch.nn.functional as F
from colossalai.pipeline.weight_grad_store import WeightGradStore
from .utils import is_share_sp_tp
try:
@ -125,12 +129,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
ctx.use_zbv = use_zbv
if bias is not None:
output = F.linear(input_, weight, bias)
else:
@ -143,6 +148,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
fp8_communication = ctx.fp8_communication
use_zbv = ctx.use_zbv
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
@ -164,24 +176,160 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce and not fp8_communication:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None
class LinearWithGradAccum(torch.autograd.Function):
"""
Linear layer baseline (no tensor parallel version).
"""
@staticmethod
def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.async_grad_allreduce = async_grad_allreduce
ctx.use_zbv = use_zbv
if bias is not None:
output = F.linear(input_, weight, bias)
else:
output = F.linear(input_, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
use_zbv = ctx.use_zbv
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)
total_input = input.contiguous()
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
return grad_input, grad_weight, grad_bias, None, None, None, None
@ -1043,12 +1191,18 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre
)
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
def linear_with_async_comm(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
):
return LinearWithAsyncCommunication.apply(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv
)
def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=False):
return LinearWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)
def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
):

View File

@ -28,6 +28,7 @@ from ._operation import (
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
linear_with_grad_accum,
reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward,
@ -35,7 +36,148 @@ from ._operation import (
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset
__all__ = ["Linear1D_Col", "Linear1D_Row"]
__all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]
class LinearWithGradAccum(ParallelModule):
r"""Linear layer with no parallelism.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
use_zbv: bool = False,
**kwargs,
):
super().__init__(weight=weight, bias_=bias_, **kwargs)
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.skip_bias_add = skip_bias_add
self.device = device
self.use_zbv = use_zbv
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=None)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
if weight is None:
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, **kwargs) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
linear_1d = LinearWithGradAccum(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)
# Set up backprop all-reduce.
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_grad_accum(
input_parallel,
self.weight,
bias,
False,
use_zbv=self.use_zbv,
)
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output
class Linear1D_Col(ParallelModule):
@ -85,6 +227,7 @@ class Linear1D_Col(ParallelModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
use_zbv: bool = False,
**kwargs,
):
super().__init__(weight=weight, bias_=bias_, **kwargs)
@ -100,6 +243,7 @@ class Linear1D_Col(ParallelModule):
self.device = device
self.process_group = process_group
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -201,13 +345,18 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
input_parallel,
self.weight,
bias,
self.process_group,
False,
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
)
elif self.seq_parallel_mode == "ring":
output_parallel = linear_gather_forward_reducescatter_backward(
@ -215,9 +364,14 @@ class Linear1D_Col(ParallelModule):
)
else:
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
input_parallel,
self.weight,
bias,
self.process_group,
True,
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(
@ -273,6 +427,7 @@ class Linear1D_Row(ParallelModule):
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1,
fp8_communication: bool = False,
use_zbv: bool = False,
):
super().__init__()
@ -288,6 +443,7 @@ class Linear1D_Row(ParallelModule):
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
@ -429,10 +585,14 @@ class Linear1D_Row(ParallelModule):
output = torch.cat(output_parallel_list, dim=-1)
else:
if self.seq_parallel_mode is None:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output_parallel = linear_with_async_comm(
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output_parallel = linear_with_async_comm(
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
@ -445,7 +605,9 @@ class Linear1D_Row(ParallelModule):
ring=True,
)
else:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output_parallel = linear_with_async_comm(
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
)
output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add:

View File

@ -82,7 +82,7 @@ class LlamaPipelineForwards:
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape[:2]
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:

View File

@ -60,6 +60,7 @@ class EPMixtralSparseMoeBlock(ParallelModule):
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
fp8_communication: bool = False,
use_zbv: bool = False,
):
assert tp_group is not None
assert moe_dp_group is not None
@ -70,6 +71,7 @@ class EPMixtralSparseMoeBlock(ParallelModule):
self.ep_rank = dist.get_rank(ep_group)
self.ep_group = ep_group
self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
if self.num_experts % self.ep_size != 0:
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
@ -89,13 +91,13 @@ class EPMixtralSparseMoeBlock(ParallelModule):
if self.tp_group.size() > 1:
for expert in held_experts:
expert.w1 = Linear1D_Col.from_native_module(
expert.w1, self.tp_group, fp8_communication=self.fp8_communication
expert.w1, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
)
expert.w3 = Linear1D_Col.from_native_module(
expert.w3, self.tp_group, fp8_communication=self.fp8_communication
expert.w3, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
)
expert.w2 = Linear1D_Row.from_native_module(
expert.w2, self.tp_group, fp8_communication=self.fp8_communication
expert.w2, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
)
for p in self.experts.parameters():
@ -399,6 +401,7 @@ class MixtralPipelineForwards:
if output_router_logits and past_router_logits is not None:
all_router_logits = past_router_logits + all_router_logits
if stage_manager.is_last_stage():
if not return_dict:
return tuple(
@ -512,7 +515,6 @@ class MixtralPipelineForwards:
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n

View File

@ -60,6 +60,8 @@ class LlamaPolicy(Policy):
else:
norm_cls = RMSNorm
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
@ -126,37 +128,65 @@ class LlamaPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
],
)
@ -265,7 +295,6 @@ class LlamaPolicy(Policy):
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
held_layers.append(module.norm)
else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
@ -385,6 +414,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
from transformers import LlamaForSequenceClassification
policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
@ -397,6 +427,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]

View File

@ -52,6 +52,7 @@ class MixtralPolicy(Policy):
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
tp_size = self.shard_config.tensor_parallel_size
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# modified for both SP and TP
num_q_heads = self.model.config.num_attention_heads
@ -124,27 +125,43 @@ class MixtralPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="block_sparse_moe.gate",
target_module=Linear1D_Col,
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},
kwargs={
"gather_output": True,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
@ -179,6 +196,7 @@ class MixtralPolicy(Policy):
"tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
)
],
@ -313,6 +331,7 @@ class MixtralModelPolicy(MixtralPolicy):
class MixtralForCausalLMPolicy(MixtralPolicy):
def module_policy(self):
policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
# add a new item for causal lm
@ -322,9 +341,13 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]
],
)
}
policy.update(new_item)
@ -343,7 +366,9 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
return held_layers
@ -369,6 +394,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
from transformers import MixtralForSequenceClassification
policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
@ -378,7 +404,11 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
SubModuleReplacementDescription(
suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]
)

View File

@ -21,6 +21,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
from colossalai.shardformer import PipelineGradientCheckpointConfig
warnings.filterwarnings("ignore")
@ -39,6 +40,7 @@ MODEL_CONFIGS = {
),
"5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8),
"7b": LlamaConfig(max_position_embeddings=4096),
# "7b": LlamaConfig(num_hidden_layers=4, max_position_embeddings=4096),
"13b": LlamaConfig(
hidden_size=5120,
intermediate_size=13824,
@ -91,7 +93,7 @@ def main():
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument(
@ -106,6 +108,7 @@ def main():
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p")
parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument(
"--sp_mode",
@ -126,9 +129,12 @@ def main():
{
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_ckpt_layers_per_stage=[19, 19, 19, 13],
# num_ckpt_layers_per_stage=[48, 48, 48, 48],
),
"num_layers_per_stage": [19, 20, 20, 21],
"pp_style": "interleaved",
# "num_layers_per_stage": [48, 48, 48, 48],
# "pp_style": "interleaved",
"pp_style": "1f1b",
}
if args.custom_ckpt
else {}
@ -137,6 +143,11 @@ def main():
# ==============================
# Initialize Booster
# ==============================
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
use_empty_init = True
if args.plugin == "gemini":
plugin = GeminiPlugin(
@ -210,6 +221,24 @@ def main():
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "3d":
if args.pp_style == "zbv":
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
mem_w = -32 * config.hidden_size
mem_b = -mem_w - mem_f
scheduler_nodes = PipelineGraph(
n_stage=args.pp,
n_micro=args.batch_size // args.mbs,
f_cost=1000,
b_cost=1000,
w_cost=1000,
c_cost=1,
f_mem=mem_f * 1.5,
b_mem=mem_b * 1.5,
w_mem=mem_w * 1.5,
).get_v_schedule()
else:
scheduler_nodes = None
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
@ -227,6 +256,7 @@ def main():
overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
scheduler_nodes=scheduler_nodes,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
@ -242,7 +272,7 @@ def main():
microbatch_size=args.mbs,
initial_scale=2**8,
precision="bf16",
overlap_p2p=args.overlap,
overlap_p2p=args.overlap_p2p,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
)
@ -260,6 +290,7 @@ def main():
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42)
dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
@ -319,7 +350,7 @@ def main():
args.profile,
args.ignore_steps,
1, # avoid creating massive log files
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
save_dir=f"./profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
nsys=args.nsys,
) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
@ -334,8 +365,12 @@ def main():
return_loss=True,
)
loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
if args.pp_style == "zbv":
if coordinator.is_master():
print(f"Step {step} loss: {loss}")
else:
if coordinator.is_last_process():
print(f"Step {step} loss: {loss}")
optimizer.step()
optimizer.zero_grad()

View File

@ -11,6 +11,7 @@ from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator, get_profile_context
from tqdm import tqdm
from transformers import AutoConfig
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai
@ -20,6 +21,7 @@ from colossalai.booster.plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
from colossalai.shardformer import PipelineGradientCheckpointConfig
warnings.filterwarnings("ignore")
@ -85,7 +87,7 @@ def main():
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument(
@ -120,7 +122,7 @@ def main():
num_ckpt_layers_per_stage=[19, 19, 19, 13],
),
"num_layers_per_stage": [19, 20, 20, 21],
"pp_style": "interleaved",
# "pp_style": "interleaved",
}
if args.custom_ckpt
else {}
@ -129,7 +131,29 @@ def main():
# ==============================
# Initialize Booster
# ==============================
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
if args.plugin == "3d":
if args.pp_style == "zbv":
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
mem_w = -32 * config.hidden_size
mem_b = -mem_w - mem_f
scheduler_nodes = PipelineGraph(
n_stage=args.pp,
n_micro=args.batch_size // args.mbs,
f_cost=1000,
b_cost=1000,
w_cost=1000,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
).get_v_schedule()
else:
scheduler_nodes = None
plugin = MoeHybridParallelPlugin(
ep_size=args.ep,
tp_size=args.tp,
@ -143,11 +167,13 @@ def main():
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
num_microbatches=args.batch_size // args.mbs,
precision="bf16",
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
scheduler_nodes=scheduler_nodes,
**hybrid_kwargs,
)
else:
@ -183,8 +209,10 @@ def main():
with init_ctx:
model = MixtralForCausalLM(config=config).to(torch.bfloat16)
# if args.grad_checkpoint:
# model.gradient_checkpointing_enable()
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
@ -229,8 +257,12 @@ def main():
return_loss=True,
)
loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
if args.pp_style == "zbv":
if dist.get_rank() == 0:
print(f"Step {step} loss: {loss}")
else:
if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}")
optimizer.step()
optimizer.zero_grad()

View File

@ -21,11 +21,16 @@ def divide(x: float, y: float) -> float:
def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1:
return x
# BUG: RuntimeError: Invalid scalar type when use dist.all_reduce(tensor, group=gloo_group)
# # Use CPU tensor to avoid OOM/weird NCCl error
# gloo_group = dist.new_group(backend="gloo")
# tensor = torch.tensor([x], device="cpu")
# dist.all_reduce(tensor, group=gloo_group)
# tensor = tensor / world_size
# return tensor.item()
# Use CPU tensor to avoid OOM/weird NCCl error
gloo_group = dist.new_group(backend="gloo")
tensor = torch.tensor([x], device="cpu")
dist.all_reduce(tensor, group=gloo_group)
tensor = torch.tensor([x], device=torch.cuda.current_device(), dtype=torch.float)
dist.all_reduce(tensor)
tensor = tensor / world_size
return tensor.item()

View File

@ -8,12 +8,14 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin, MoeHybridParallelPlugin
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import OptimizerWrapper
from colossalai.logging import disable_existing_loggers
@ -756,10 +758,11 @@ def run_with_hybridplugin(test_config):
@parameterize(
"config",
[
(0, 1, 4, 1, 1),
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
# (0, 1, 4, 1, 1),
# (1, 2, 2, 1, 1),
(1, 1, 2, 2, 1),
# (1, 2, 1, 2, 1),
# (1, 2, 1, 1, 2),
],
)
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
@ -790,6 +793,8 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
seed_all(10086)
torch_model = MixtralModel(config).to(dtype).cuda()
# TODO: Support MixtralForCausalLM
# torch_model = MixtralForCausalLM(config).to(dtype).cuda()
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
# init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024
@ -892,7 +897,7 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
@ -905,18 +910,177 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
p.grad /= dp_size
torch_optimizer.step()
torch_optimizer.zero_grad()
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} config {test_config} test passed")
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize(
"config",
[
(1, 2, 2, 1), # Pass
# TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture;
# (0, 4, 1, 1),
# (1, 2, 1, 2),
# (1, 1, 2, 2),
],
)
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
stage, pp_size, tp_size, sp_size = config
num_microbatches = pp_size
dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
torch.cuda.set_device(dist.get_rank())
########
# init base model
########
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
config = LlamaConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=NUM_LAYERS,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
attn_implementation="flash_attention_2",
)
# init model with the same seed
seed_all(10086)
torch_model = LlamaModel(config).to(dtype).cuda()
# TODO: Support MixtralForCausalLM
# torch_model = MixtralForCausalLM(config).to(dtype).cuda()
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
# init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
n_stage=pp_size,
n_micro=num_microbatches,
f_cost=1,
b_cost=1,
w_cost=1,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
)
zbv_schedule = graph.get_v_schedule()
# init HybridParallelPlugin
plugin = HybridParallelPlugin(
pp_size=pp_size,
num_microbatches=pp_size,
tp_size=tp_size,
sp_size=sp_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
overlap_communication=False,
initial_scale=1,
precision=precision,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
dp_size = plugin.dp_size
booster = Booster(plugin=plugin)
########
# init pp model
########
parallel_model = deepcopy(torch_model)
parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
# create different input along dp axis
seed_all(1453 + rank)
torch_model.train()
parallel_model.train()
for _ in range(2):
# gen random input
input_embeddings = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(
input_embeddings, group=plugin.pp_group
) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
# run the model with hybrid parallel
if booster.plugin.stage_manager is not None:
# for test with pp
data_iter = iter([{"inputs_embeds": input_embeddings}])
sharded_output = booster.execute_pipeline(
data_iter,
parallel_model,
lambda x, y: x.last_hidden_state.mean(),
parallel_optimizer,
return_loss=True,
return_outputs=True,
)
# stage 0 chunk 0
parallel_output = None
if (
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
):
parallel_output = sharded_output["loss"]
else:
parallel_output = torch.tensor(12345.0, device="cuda")
# broadcast along pp axis
dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group)
else:
# for test without pp
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
# avg dp grads follows zero optimizer
for p in torch_model.parameters():
if p.grad is not None:
p.grad /= dp_size
torch_optimizer.step()
torch_optimizer.zero_grad()
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# run_fwd_bwd_vschedule_with_optim()
run_with_booster_moehybridplugin()
run_with_booster_hybridplugin()
@pytest.mark.dist
@ -928,5 +1092,6 @@ def test_pp():
)
# python -m pytest -s tests/test_pipeline/test_schedule/test_zerobubble_pp.py
if __name__ == "__main__":
test_pp()

View File

@ -8,7 +8,8 @@ from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.pipeline.weight_grad_store import WeightGradStore
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -117,6 +118,93 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool):
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
)
assert linear_base.weight.shape == torch.Size([128, 32])
assert linear_base.bias.shape == torch.Size([128])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias
linear.load_state_dict(linear_base.state_dict())
linear_base.load_state_dict(linear.state_dict())
# check computation correctness
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
# run forward
out = linear(x_for_unshard)
gather_out = linear_base(x_for_shard)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
assert_close(linear.weight.grad, linear_base.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
)
assert linear_base.weight.shape == torch.Size([128, 32])
assert linear_base.bias.shape == torch.Size([128])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias
linear.load_state_dict(linear_base.state_dict())
linear_base.load_state_dict(linear.state_dict())
# check computation correctness
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
# run forward
out = linear(x_for_unshard)
gather_out = linear_base(x_for_shard)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
# Weight grad is None before we do WeightGradStore pop
assert linear_base.weight.grad is None
# after WeightGradStore pop (dw computation complete), we assert weight grad
WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue
WeightGradStore.pop(chunk=0)
assert_close(linear.weight.grad, linear_base.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
@ -182,6 +270,8 @@ def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap):
check_linear_1d_col(lazy_init, seq_parallel_mode, overlap)
check_linear_1d_row(lazy_init, seq_parallel_mode)
check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap)
check_linear_without_weight_grad_store(lazy_init, seq_parallel_mode)
check_linear_with_weight_grad_store(lazy_init, seq_parallel_mode)
def check_dist_linear(rank, world_size, port):

View File

@ -277,32 +277,33 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "zbv",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 0,
"initial_scale": 1,
"enable_gradient_checkpointing": True,
"parallel_output": False,
},
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "zbv",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
"enable_gradient_checkpointing": True,
"parallel_output": False,
},
# # TODO: assert layer error
# {
# "tp_size": 2,
# "pp_size": 2,
# "pp_style": "zbv",
# "num_model_chunks": 2,
# "num_microbatches": 4,
# "enable_all_optimization": False,
# "precision": "fp16",
# "zero_stage": 0,
# "initial_scale": 1,
# "enable_gradient_checkpointing": True,
# "parallel_output": False,
# },
# {
# "tp_size": 2,
# "pp_size": 2,
# "pp_style": "zbv",
# "num_model_chunks": 2,
# "num_microbatches": 4,
# "enable_all_optimization": False,
# "precision": "fp16",
# "zero_stage": 1,
# "initial_scale": 1,
# "enable_gradient_checkpointing": True,
# "parallel_output": False,
# },
],
)
def run_llama_test(test_config):