[feat] support zbv in mixtral benchmark; (#6083)

* [feat] support zbv in mixtral benchmark;

* [fix] MixtralForCausalLMPolicy get_held_layer support zbv;

* [feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv;

* [feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv

* [fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;

* [feat] Linear1D_COL/ROW support zbv WeightGradStore;

* [feat] support use_zbv in llama, mixtral modeling; only replace Linear1D_Col/Row policy;

* [fix] fix test case; moe error in second iter

* [feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;

* [fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd;

* [fix] debug zbv llama test;

* [fix] rm use_zbv flag in Shardconfig; rm debug info;

* [fix] add & fix  llama test

* [feat] support meta cache, meta_grad_send, meta_tensor_send; fix runtime too long in Recv Bwd; benchmark for llama + Hybrid(tp+pp);

* [fix\ fix fail case test_shard_llama

* [fix] fix test_shard_llama

* [fix] fix llama modeling policy;

* [fix] fix test_shard_llama ci;

* [fix] fix test zerobubble

* [fix] fix handle name; rm useless comments;

* [fix] fix send recv signature;

* [fix] fix comment in llama & benchmark

* [feat] support no tensor parallel Linear in shardformer; Add test for use weightGradStore and not use WeightGradStore

* [fix] fix linear (no tp) ops func name;
feature/zerobubble
duanjunwen 2024-10-31 18:17:29 +08:00 committed by GitHub
parent dac0e07b13
commit aed20fb2df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 938 additions and 151 deletions

View File

@ -1,16 +1,18 @@
from functools import partial from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch import torch
import torch.cuda import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_flatten, tree_map from torch.utils._pytree import tree_flatten, tree_map
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.pipeline.weight_grad_store import WeightGradStore
from ._utils import ( from ._utils import (
clone, clone,
@ -61,11 +63,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self.do_post_validation = False self.do_post_validation = False
# P2PMeta cache # P2PMeta cache
# self.enable_metadata_cache = enable_metadata_cache self.enable_metadata_cache = enable_metadata_cache
# self.send_tensor_metadata = True self.send_tensor_metadata = True
# self.send_grad_metadata = True self.send_grad_metadata = True
# self.tensor_metadata_recv = None self.tensor_metadata_recv = None
# self.grad_metadata_recv = None self.grad_metadata_recv = None
# P2P communication # P2P communication
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
@ -104,8 +106,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# dy buffer for local send bwd # dy buffer for local send bwd
self.local_send_backward_buffer = [] self.local_send_backward_buffer = []
# wait pp buffer
self.wait_handles = []
def assert_buffer_empty(self): def assert_buffer_empty(self):
# assert buuffer is empty at end # assert buffer is empty at end
assert len(self.input_tensors[0]) == 0 assert len(self.input_tensors[0]) == 0
assert len(self.input_tensors[1]) == 0 assert len(self.input_tensors[1]) == 0
assert len(self.output_tensors[0]) == 0 assert len(self.output_tensors[0]) == 0
@ -201,7 +206,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id = self.num_model_chunks - model_chunk_id - 1 model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id return model_chunk_id
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]: def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> List:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage. """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For ZBV. For ZBV.
@ -220,7 +225,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; cause u are chunk 0 in first rank, u have no prev rank; # do nothing; cause u are chunk 0 in first rank, u have no prev rank;
################# #################
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
return None, [] # return None, []
return []
################ ################
# chunk = 0 & not is_first_stage # chunk = 0 & not is_first_stage
@ -228,9 +234,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
################# #################
else: else:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank) input_tensor, wait_handles = self.comm.recv_forward(
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv
)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, wait_handles # return input_tensor, wait_handles
return wait_handles
else: else:
################ ################
@ -238,7 +249,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; cause u get y from local_send_forward_buffer in schedule f # do nothing; cause u get y from local_send_forward_buffer in schedule f
################ ################
if self.stage_manager.is_last_stage(ignore_chunk=True): if self.stage_manager.is_last_stage(ignore_chunk=True):
return None, [] # return None, []
return []
################ ################
# chunk = 1 & not is_last_stage # chunk = 1 & not is_last_stage
@ -246,11 +258,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
################ ################
else: else:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
input_tensor, wait_handles = self.comm.recv_forward(next_rank) input_tensor, wait_handles = self.comm.recv_forward(
next_rank, metadata_recv=self.tensor_metadata_recv
)
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, wait_handles # return input_tensor, wait_handles
return wait_handles
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]: def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For ZBV. For ZBV.
@ -270,7 +287,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; Already get dy from local_send_backward_buffer in schedule b # do nothing; Already get dy from local_send_backward_buffer in schedule b
################ ################
if self.stage_manager.is_last_stage(ignore_chunk=True): if self.stage_manager.is_last_stage(ignore_chunk=True):
return None, [] # return None, []
return []
################ ################
# chunk = 0 & not is_last_stage # chunk = 0 & not is_last_stage
@ -278,9 +296,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
################ ################
else: else:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank) output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank, metadata_recv=self.grad_metadata_recv
)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, wait_handles # return output_tensor_grad, wait_handles
return wait_handles
else: else:
# bwd chunk1 is left V; # bwd chunk1 is left V;
@ -289,7 +312,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; get loss from local # do nothing; get loss from local
################ ################
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
return None, [] # return None, []
return []
################ ################
# chunk = 1 & not first stage # chunk = 1 & not first stage
@ -297,9 +321,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
################ ################
else: else:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank) output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv
)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, wait_handles # return output_tensor_grad, wait_handles
return wait_handles
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
"""Sends the input tensor to the next stage in pipeline. """Sends the input tensor to the next stage in pipeline.
@ -329,7 +358,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank) send_handles = self.comm.send_forward(
output_object=output_tensor, next_rank=next_rank, send_metadata=self.send_tensor_metadata
)
self.send_tensor_metadata = not self.enable_metadata_cache
return send_handles return send_handles
else: else:
@ -347,7 +379,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(output_tensor, prev_rank) send_handles = self.comm.send_forward(
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata
)
self.send_tensor_metadata = not self.enable_metadata_cache
return send_handles return send_handles
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
@ -379,7 +414,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(input_tensor_grad, prev_rank) send_handles = self.comm.send_backward(
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
)
self.send_grad_metadata = not self.enable_metadata_cache
return send_handles return send_handles
# bwd chunk1 is left V; # bwd chunk1 is left V;
@ -398,7 +436,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(input_tensor_grad, next_rank) send_handles = self.comm.send_backward(
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata
)
self.send_grad_metadata = not self.enable_metadata_cache
return send_handles return send_handles
def forward_step( def forward_step(
@ -432,7 +473,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
internal_inputs = {} if input_obj is None else input_obj internal_inputs = {} if input_obj is None else input_obj
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, internal_inputs) output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
# last layer in model # last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
loss = criterion(output_obj, micro_batch) / self.num_microbatch loss = criterion(output_obj, micro_batch) / self.num_microbatch
@ -479,11 +519,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_grad_ = [] output_obj_grad_ = []
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx. # For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): # if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
return None # return None
# For loss backward; output_obj is loss; output_obj_grad should be None # For loss backward; output_obj is loss; output_obj_grad should be None
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
assert output_obj_grad is None assert output_obj_grad is None
input_obj_, _ = tree_flatten(input_obj) input_obj_, _ = tree_flatten(input_obj)
output_obj_.append(output_obj) # LOSS output_obj_.append(output_obj) # LOSS
@ -504,17 +544,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
ctx = optimizer.no_sync() ctx = optimizer.no_sync()
except AttributeError: except AttributeError:
ctx = model_chunk.no_sync() ctx = model_chunk.no_sync()
with ctx: with ctx:
optimizer.backward_by_grad( optimizer.backward_by_grad(
tensor=output_obj_, tensor=output_obj_,
grad=output_obj_grad_, grad=output_obj_grad_,
inputs=input_obj_, # inputs=input_obj_,
retain_graph=True, retain_graph=False,
) )
# Format output_obj_grad # Format output_obj_grad
input_obj_grad = {} input_obj_grad = dict()
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
pass pass
else: else:
@ -651,10 +689,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Do not release_tensor_data loss, release_tensor_data other output_obj; # Do not release_tensor_data loss, release_tensor_data other output_obj;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
self.output_tensors[model_chunk_id].append(output_obj) self.output_tensors[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(output_obj) # self.output_tensors_dw[model_chunk_id].append(output_obj)
else: else:
self.output_tensors[model_chunk_id].append(output_obj) self.output_tensors[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(output_obj) # self.output_tensors_dw[model_chunk_id].append(output_obj)
# add output to send_fwd_buffer # add output to send_fwd_buffer
if model_chunk_id == 0: # chunk 0 if model_chunk_id == 0: # chunk 0
@ -706,15 +744,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
input_obj = self.input_tensors[model_chunk_id].pop(0) input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0)
# save output_tensor_grad for dw # # save output_tensor_grad for dw
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# we save loss here # # we save loss here
self.output_tensors_grad_dw[model_chunk_id].append(output_obj) # self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
else: # else:
# we save output_tensor_grad here # # we save output_tensor_grad here
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) # self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# the_output_obj_grad = []
# if isinstance(output_obj, dict):
# for (k, v) in output_obj.items():
# the_output_obj_grad.append(v.requires_grad)
# else:
# the_output_obj_grad.append(output_obj.requires_grad)
# Step2: bwd step
input_object_grad = self.backward_b_step( input_object_grad = self.backward_b_step(
model_chunk=model_chunk, model_chunk=model_chunk,
model_chunk_id=model_chunk_id, model_chunk_id=model_chunk_id,
@ -739,6 +782,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# send to next # send to next
else: else:
self.send_backward_buffer[model_chunk_id].append(input_object_grad) self.send_backward_buffer[model_chunk_id].append(input_object_grad)
WeightGradStore.flush(chunk=model_chunk_id)
def schedule_w( def schedule_w(
self, self,
@ -758,16 +802,17 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
""" """
# get y & dy from buffer # get y & dy from buffer
output_obj = self.output_tensors_dw[model_chunk_id].pop(0) # output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) # output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
WeightGradStore.pop(chunk=model_chunk_id)
self.backward_w_step( # self.backward_w_step(
model_chunk=model_chunk, # model_chunk=model_chunk,
model_chunk_id=model_chunk_id, # model_chunk_id=model_chunk_id,
optimizer=optimizer, # optimizer=optimizer,
output_obj=output_obj, # output_obj=output_obj,
output_obj_grad=output_obj_grad, # output_obj_grad=output_obj_grad,
) # )
def run_forward_only( def run_forward_only(
self, self,
@ -844,7 +889,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication # communication
communication_func = self.communication_map[scheduled_node.type] communication_func = self.communication_map[scheduled_node.type]
communication_func(scheduled_node.chunk) wait_handle = communication_func(scheduled_node.chunk)
self.wait_handles.append(wait_handle)
elif scheduled_node.type == "F": elif scheduled_node.type == "F":
self.schedule_f( self.schedule_f(
scheduled_node=scheduled_node, scheduled_node=scheduled_node,
@ -868,6 +914,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk, model_chunk_id=scheduled_node.chunk,
optimizer=optimizer, optimizer=optimizer,
) )
for h in self.wait_handles:
for hh in h:
hh.wait()
# return loss & output # return loss & output
if outputs is not None: if outputs is not None:
@ -907,5 +956,4 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
) )
self.assert_buffer_empty() self.assert_buffer_empty()
return result return result

View File

@ -223,7 +223,6 @@ class PipelineStageManager:
# calculate the num_layers per stage # calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks layers_per_stage = [quotient] * num_stages * num_model_chunks
# deal with the rest layers # deal with the rest layers
if remainder > 0: if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2 start_position = (num_stages * num_model_chunks) // 2 - remainder // 2

View File

@ -0,0 +1,32 @@
import queue
class WeightGradStore:
cache = []
weight_grad_queue = [queue.Queue(), queue.Queue()]
@classmethod
def put(cls, total_input, grad_output, weight, func):
# func(total_input, grad_output, weight.main_grad)
cls.cache.append((total_input, grad_output, weight, func))
@classmethod
def flush(cls, chunk=0):
cls.weight_grad_queue[chunk].put(cls.cache)
cls.cache = []
@classmethod
def pop(cls, chunk=0):
# print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}")
if cls.weight_grad_queue[chunk].qsize() > 0:
stored_grads = cls.weight_grad_queue[chunk].get()
for total_input, grad_output, weight, func in stored_grads:
if weight.grad is not None:
func(total_input, grad_output, weight.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight.grad = grad_weight
else:
raise Exception("Pop empty queue.")

View File

@ -2,7 +2,7 @@ from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d, dist_cross_entropy from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
@ -11,6 +11,7 @@ from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2
__all__ = [ __all__ = [
"Embedding1D", "Embedding1D",
"VocabParallelEmbedding1D", "VocabParallelEmbedding1D",
"LinearWithGradAccum",
"Linear1D_Col", "Linear1D_Col",
"Linear1D_Row", "Linear1D_Row",
"GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Col",

View File

@ -1,7 +1,11 @@
import functools
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.pipeline.weight_grad_store import WeightGradStore
from .utils import is_share_sp_tp from .utils import is_share_sp_tp
try: try:
@ -125,12 +129,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False):
ctx.save_for_backward(input_, weight, bias) ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication ctx.fp8_communication = fp8_communication
ctx.use_zbv = use_zbv
if bias is not None: if bias is not None:
output = F.linear(input_, weight, bias) output = F.linear(input_, weight, bias)
else: else:
@ -143,6 +148,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
input, weight, bias = ctx.saved_tensors input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
fp8_communication = ctx.fp8_communication fp8_communication = ctx.fp8_communication
use_zbv = ctx.use_zbv
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias: if use_bias:
@ -164,9 +176,35 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None: if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad grad = weight.grad
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
if grad.dtype == torch.float32: if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None grad_weight = None
@ -175,6 +213,18 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
grad_weight = None grad_weight = None
else: else:
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else: else:
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
@ -182,6 +232,104 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
if ctx.async_grad_allreduce and not fp8_communication: if ctx.async_grad_allreduce and not fp8_communication:
handle.wait() handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None
class LinearWithGradAccum(torch.autograd.Function):
"""
Linear layer baseline (no tensor parallel version).
"""
@staticmethod
def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.async_grad_allreduce = async_grad_allreduce
ctx.use_zbv = use_zbv
if bias is not None:
output = F.linear(input_, weight, bias)
else:
output = F.linear(input_, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
use_zbv = ctx.use_zbv
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)
total_input = input.contiguous()
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
return grad_input, grad_weight, grad_bias, None, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None
@ -1043,12 +1191,18 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre
) )
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): def linear_with_async_comm(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
):
return LinearWithAsyncCommunication.apply( return LinearWithAsyncCommunication.apply(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv
) )
def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=False):
return LinearWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)
def linear_gather_forward_reducescatter_backward( def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
): ):

View File

@ -28,6 +28,7 @@ from ._operation import (
linear_gather_forward_reducescatter_backward, linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward, linear_reducescatter_forward_gather_backward,
linear_with_async_comm, linear_with_async_comm,
linear_with_grad_accum,
reduce_forward, reduce_forward,
reducescatter_forward_gather_backward, reducescatter_forward_gather_backward,
split_forward_gather_backward, split_forward_gather_backward,
@ -35,7 +36,148 @@ from ._operation import (
from .parallel_module import PaddingParallelModule, ParallelModule from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset from .utils import create_randomizer_with_offset
__all__ = ["Linear1D_Col", "Linear1D_Row"] __all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]
class LinearWithGradAccum(ParallelModule):
r"""Linear layer with no parallelism.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
use_zbv: bool = False,
**kwargs,
):
super().__init__(weight=weight, bias_=bias_, **kwargs)
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.skip_bias_add = skip_bias_add
self.device = device
self.use_zbv = use_zbv
if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=None)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
if weight is None:
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, **kwargs) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
linear_1d = LinearWithGradAccum(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert (
input_.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1]
)
# Set up backprop all-reduce.
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_grad_accum(
input_parallel,
self.weight,
bias,
False,
use_zbv=self.use_zbv,
)
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output
class Linear1D_Col(ParallelModule): class Linear1D_Col(ParallelModule):
@ -85,6 +227,7 @@ class Linear1D_Col(ParallelModule):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False, fp8_communication: bool = False,
use_zbv: bool = False,
**kwargs, **kwargs,
): ):
super().__init__(weight=weight, bias_=bias_, **kwargs) super().__init__(weight=weight, bias_=bias_, **kwargs)
@ -100,6 +243,7 @@ class Linear1D_Col(ParallelModule):
self.device = device self.device = device
self.process_group = process_group self.process_group = process_group
self.fp8_communication = fp8_communication self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
if skip_bias_add and not bias: if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None") raise ValueError("cannot skip bias addition if bias is None")
@ -201,13 +345,18 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather": if self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward( input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
) )
output_parallel = linear_with_async_comm( output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication input_parallel,
self.weight,
bias,
self.process_group,
False,
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
) )
elif self.seq_parallel_mode == "ring": elif self.seq_parallel_mode == "ring":
output_parallel = linear_gather_forward_reducescatter_backward( output_parallel = linear_gather_forward_reducescatter_backward(
@ -215,9 +364,14 @@ class Linear1D_Col(ParallelModule):
) )
else: else:
output_parallel = linear_with_async_comm( output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication input_parallel,
self.weight,
bias,
self.process_group,
True,
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
) )
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_forward_split_backward( output = gather_forward_split_backward(
@ -273,6 +427,7 @@ class Linear1D_Row(ParallelModule):
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1, stream_chunk_num: int = 1,
fp8_communication: bool = False, fp8_communication: bool = False,
use_zbv: bool = False,
): ):
super().__init__() super().__init__()
@ -288,6 +443,7 @@ class Linear1D_Row(ParallelModule):
self.seq_parallel_dim = seq_parallel_dim self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group) self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
if skip_bias_add and not bias: if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None") raise ValueError("cannot skip bias addition if bias is None")
@ -429,10 +585,14 @@ class Linear1D_Row(ParallelModule):
output = torch.cat(output_parallel_list, dim=-1) output = torch.cat(output_parallel_list, dim=-1)
else: else:
if self.seq_parallel_mode is None: if self.seq_parallel_mode is None:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output_parallel = linear_with_async_comm(
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather": elif self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output_parallel = linear_with_async_comm(
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
)
output = reducescatter_forward_gather_backward( output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
) )
@ -445,7 +605,9 @@ class Linear1D_Row(ParallelModule):
ring=True, ring=True,
) )
else: else:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output_parallel = linear_with_async_comm(
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
)
output = reduce_forward(output_parallel, self.process_group) output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add: if not self.skip_bias_add:

View File

@ -82,7 +82,7 @@ class LlamaPipelineForwards:
elif input_ids is not None: elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2] batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None: elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape[:2] batch_size, seq_length = inputs_embeds.shape[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:

View File

@ -60,6 +60,7 @@ class EPMixtralSparseMoeBlock(ParallelModule):
moe_dp_group: ProcessGroup, moe_dp_group: ProcessGroup,
ep_group: ProcessGroup, ep_group: ProcessGroup,
fp8_communication: bool = False, fp8_communication: bool = False,
use_zbv: bool = False,
): ):
assert tp_group is not None assert tp_group is not None
assert moe_dp_group is not None assert moe_dp_group is not None
@ -70,6 +71,7 @@ class EPMixtralSparseMoeBlock(ParallelModule):
self.ep_rank = dist.get_rank(ep_group) self.ep_rank = dist.get_rank(ep_group)
self.ep_group = ep_group self.ep_group = ep_group
self.fp8_communication = fp8_communication self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
if self.num_experts % self.ep_size != 0: if self.num_experts % self.ep_size != 0:
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.") raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
@ -89,13 +91,13 @@ class EPMixtralSparseMoeBlock(ParallelModule):
if self.tp_group.size() > 1: if self.tp_group.size() > 1:
for expert in held_experts: for expert in held_experts:
expert.w1 = Linear1D_Col.from_native_module( expert.w1 = Linear1D_Col.from_native_module(
expert.w1, self.tp_group, fp8_communication=self.fp8_communication expert.w1, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
) )
expert.w3 = Linear1D_Col.from_native_module( expert.w3 = Linear1D_Col.from_native_module(
expert.w3, self.tp_group, fp8_communication=self.fp8_communication expert.w3, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
) )
expert.w2 = Linear1D_Row.from_native_module( expert.w2 = Linear1D_Row.from_native_module(
expert.w2, self.tp_group, fp8_communication=self.fp8_communication expert.w2, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
) )
for p in self.experts.parameters(): for p in self.experts.parameters():
@ -399,6 +401,7 @@ class MixtralPipelineForwards:
if output_router_logits and past_router_logits is not None: if output_router_logits and past_router_logits is not None:
all_router_logits = past_router_logits + all_router_logits all_router_logits = past_router_logits + all_router_logits
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
if not return_dict: if not return_dict:
return tuple( return tuple(
@ -512,7 +515,6 @@ class MixtralPipelineForwards:
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
logits = logits.float() logits = logits.float()
loss = None loss = None
if labels is not None: if labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n

View File

@ -60,6 +60,8 @@ class LlamaPolicy(Policy):
else: else:
norm_cls = RMSNorm norm_cls = RMSNorm
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
sp_mode = self.shard_config.sequence_parallelism_mode or None sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
@ -126,37 +128,65 @@ class LlamaPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.q_proj", suffix="self_attn.q_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.k_proj", suffix="self_attn.k_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.v_proj", suffix="self_attn.v_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.o_proj", suffix="self_attn.o_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.gate_proj", suffix="mlp.gate_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.up_proj", suffix="mlp.up_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.down_proj", suffix="mlp.down_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
), ),
], ],
) )
@ -265,7 +295,6 @@ class LlamaPolicy(Policy):
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
): ):
held_layers.append(module.norm) held_layers.append(module.norm)
else: else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers)) layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
@ -385,6 +414,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
from transformers import LlamaForSequenceClassification from transformers import LlamaForSequenceClassification
policy = super().module_policy() policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification # add a new item for sequence classification
@ -397,6 +427,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
kwargs=dict( kwargs=dict(
gather_output=True, gather_output=True,
fp8_communication=self.shard_config.fp8_communication, fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
), ),
) )
] ]

View File

@ -52,6 +52,7 @@ class MixtralPolicy(Policy):
sp_group = self.shard_config.sequence_parallel_process_group or None sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"] sp_partial_derived = sp_mode in ["split_gather", "ring"]
tp_size = self.shard_config.tensor_parallel_size tp_size = self.shard_config.tensor_parallel_size
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# modified for both SP and TP # modified for both SP and TP
num_q_heads = self.model.config.num_attention_heads num_q_heads = self.model.config.num_attention_heads
@ -124,27 +125,43 @@ class MixtralPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.q_proj", suffix="self_attn.q_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication}, kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.k_proj", suffix="self_attn.k_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication}, kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.v_proj", suffix="self_attn.v_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"fp8_communication": self.shard_config.fp8_communication}, kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.o_proj", suffix="self_attn.o_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication}, kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="block_sparse_moe.gate", suffix="block_sparse_moe.gate",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication}, kwargs={
"gather_output": True,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
), ),
], ],
) )
@ -179,6 +196,7 @@ class MixtralPolicy(Policy):
"tp_group": self.shard_config.tensor_parallel_process_group, "tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group, "moe_dp_group": self.shard_config.moe_dp_group,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
) )
], ],
@ -313,6 +331,7 @@ class MixtralModelPolicy(MixtralPolicy):
class MixtralForCausalLMPolicy(MixtralPolicy): class MixtralForCausalLMPolicy(MixtralPolicy):
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# TODO: assign pg mesh from plugin to all modules # TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for causal lm # add a new item for causal lm
@ -322,9 +341,13 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
) )
] ],
) )
} }
policy.update(new_item) policy.update(new_item)
@ -343,7 +366,9 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers() held_layers = super().get_held_layers()
if stage_manager.is_last_stage(): if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head)
elif stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(self.model.lm_head) held_layers.append(self.model.lm_head)
return held_layers return held_layers
@ -369,6 +394,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
from transformers import MixtralForSequenceClassification from transformers import MixtralForSequenceClassification
policy = super().module_policy() policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification # add a new item for sequence classification
@ -378,7 +404,11 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="score", suffix="score",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication), kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
) )
] ]
) )

View File

@ -21,6 +21,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer import PipelineGradientCheckpointConfig
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -39,6 +40,7 @@ MODEL_CONFIGS = {
), ),
"5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8), "5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8),
"7b": LlamaConfig(max_position_embeddings=4096), "7b": LlamaConfig(max_position_embeddings=4096),
# "7b": LlamaConfig(num_hidden_layers=4, max_position_embeddings=4096),
"13b": LlamaConfig( "13b": LlamaConfig(
hidden_size=5120, hidden_size=5120,
intermediate_size=13824, intermediate_size=13824,
@ -91,7 +93,7 @@ def main():
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code") parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument( parser.add_argument(
@ -106,6 +108,7 @@ def main():
parser.add_argument("--no_cache", action="store_true") parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear") parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p")
parser.add_argument("--overlap_allgather", action="store_true") parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument( parser.add_argument(
"--sp_mode", "--sp_mode",
@ -126,9 +129,12 @@ def main():
{ {
"gradient_checkpoint_config": PipelineGradientCheckpointConfig( "gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_ckpt_layers_per_stage=[19, 19, 19, 13], num_ckpt_layers_per_stage=[19, 19, 19, 13],
# num_ckpt_layers_per_stage=[48, 48, 48, 48],
), ),
"num_layers_per_stage": [19, 20, 20, 21], "num_layers_per_stage": [19, 20, 20, 21],
"pp_style": "interleaved", # "num_layers_per_stage": [48, 48, 48, 48],
# "pp_style": "interleaved",
"pp_style": "1f1b",
} }
if args.custom_ckpt if args.custom_ckpt
else {} else {}
@ -137,6 +143,11 @@ def main():
# ============================== # ==============================
# Initialize Booster # Initialize Booster
# ============================== # ==============================
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
use_empty_init = True use_empty_init = True
if args.plugin == "gemini": if args.plugin == "gemini":
plugin = GeminiPlugin( plugin = GeminiPlugin(
@ -210,6 +221,24 @@ def main():
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
) )
elif args.plugin == "3d": elif args.plugin == "3d":
if args.pp_style == "zbv":
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
mem_w = -32 * config.hidden_size
mem_b = -mem_w - mem_f
scheduler_nodes = PipelineGraph(
n_stage=args.pp,
n_micro=args.batch_size // args.mbs,
f_cost=1000,
b_cost=1000,
w_cost=1000,
c_cost=1,
f_mem=mem_f * 1.5,
b_mem=mem_b * 1.5,
w_mem=mem_w * 1.5,
).get_v_schedule()
else:
scheduler_nodes = None
plugin = HybridParallelPlugin( plugin = HybridParallelPlugin(
tp_size=args.tp, tp_size=args.tp,
pp_size=args.pp, pp_size=args.pp,
@ -227,6 +256,7 @@ def main():
overlap_allgather=args.overlap_allgather, overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
scheduler_nodes=scheduler_nodes,
**hybrid_kwargs, **hybrid_kwargs,
) )
elif args.plugin == "3d_cpu": elif args.plugin == "3d_cpu":
@ -242,7 +272,7 @@ def main():
microbatch_size=args.mbs, microbatch_size=args.mbs,
initial_scale=2**8, initial_scale=2**8,
precision="bf16", precision="bf16",
overlap_p2p=args.overlap, overlap_p2p=args.overlap_p2p,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
) )
@ -260,6 +290,7 @@ def main():
config = MODEL_CONFIGS[args.config] config = MODEL_CONFIGS[args.config]
else: else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
dataset = RandomDataset( dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
@ -319,7 +350,7 @@ def main():
args.profile, args.profile,
args.ignore_steps, args.ignore_steps,
1, # avoid creating massive log files 1, # avoid creating massive log files
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", save_dir=f"./profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
nsys=args.nsys, nsys=args.nsys,
) as prof: ) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
@ -334,7 +365,11 @@ def main():
return_loss=True, return_loss=True,
) )
loss = outputs["loss"] loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1: if args.pp_style == "zbv":
if coordinator.is_master():
print(f"Step {step} loss: {loss}")
else:
if coordinator.is_last_process():
print(f"Step {step} loss: {loss}") print(f"Step {step} loss: {loss}")
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()

View File

@ -11,6 +11,7 @@ from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator, get_profile_context from performance_evaluator import PerformanceEvaluator, get_profile_context
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoConfig
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai import colossalai
@ -20,6 +21,7 @@ from colossalai.booster.plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer import PipelineGradientCheckpointConfig
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -85,7 +87,7 @@ def main():
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code") parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument( parser.add_argument(
@ -120,7 +122,7 @@ def main():
num_ckpt_layers_per_stage=[19, 19, 19, 13], num_ckpt_layers_per_stage=[19, 19, 19, 13],
), ),
"num_layers_per_stage": [19, 20, 20, 21], "num_layers_per_stage": [19, 20, 20, 21],
"pp_style": "interleaved", # "pp_style": "interleaved",
} }
if args.custom_ckpt if args.custom_ckpt
else {} else {}
@ -129,7 +131,29 @@ def main():
# ============================== # ==============================
# Initialize Booster # Initialize Booster
# ============================== # ==============================
if args.config in MODEL_CONFIGS:
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
if args.plugin == "3d": if args.plugin == "3d":
if args.pp_style == "zbv":
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
mem_w = -32 * config.hidden_size
mem_b = -mem_w - mem_f
scheduler_nodes = PipelineGraph(
n_stage=args.pp,
n_micro=args.batch_size // args.mbs,
f_cost=1000,
b_cost=1000,
w_cost=1000,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
).get_v_schedule()
else:
scheduler_nodes = None
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
ep_size=args.ep, ep_size=args.ep,
tp_size=args.tp, tp_size=args.tp,
@ -143,11 +167,13 @@ def main():
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
microbatch_size=args.mbs, microbatch_size=args.mbs,
num_microbatches=args.batch_size // args.mbs,
precision="bf16", precision="bf16",
enable_metadata_cache=not args.no_cache, enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather, overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
scheduler_nodes=scheduler_nodes,
**hybrid_kwargs, **hybrid_kwargs,
) )
else: else:
@ -183,8 +209,10 @@ def main():
with init_ctx: with init_ctx:
model = MixtralForCausalLM(config=config).to(torch.bfloat16) model = MixtralForCausalLM(config=config).to(torch.bfloat16)
# if args.grad_checkpoint:
# model.gradient_checkpointing_enable()
if args.grad_checkpoint: if args.grad_checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model_numel = get_model_numel(model) model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
@ -229,6 +257,10 @@ def main():
return_loss=True, return_loss=True,
) )
loss = outputs["loss"] loss = outputs["loss"]
if args.pp_style == "zbv":
if dist.get_rank() == 0:
print(f"Step {step} loss: {loss}")
else:
if dist.get_rank() == dist.get_world_size() - 1: if dist.get_rank() == dist.get_world_size() - 1:
print(f"Step {step} loss: {loss}") print(f"Step {step} loss: {loss}")
optimizer.step() optimizer.step()

View File

@ -21,11 +21,16 @@ def divide(x: float, y: float) -> float:
def all_reduce_mean(x: float, world_size: int) -> float: def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1: if world_size == 1:
return x return x
# BUG: RuntimeError: Invalid scalar type when use dist.all_reduce(tensor, group=gloo_group)
# # Use CPU tensor to avoid OOM/weird NCCl error
# gloo_group = dist.new_group(backend="gloo")
# tensor = torch.tensor([x], device="cpu")
# dist.all_reduce(tensor, group=gloo_group)
# tensor = tensor / world_size
# return tensor.item()
# Use CPU tensor to avoid OOM/weird NCCl error tensor = torch.tensor([x], device=torch.cuda.current_device(), dtype=torch.float)
gloo_group = dist.new_group(backend="gloo") dist.all_reduce(tensor)
tensor = torch.tensor([x], device="cpu")
dist.all_reduce(tensor, group=gloo_group)
tensor = tensor / world_size tensor = tensor / world_size
return tensor.item() return tensor.item()

View File

@ -8,12 +8,14 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close from torch.testing import assert_close
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
from transformers.models.mixtral.configuration_mixtral import MixtralConfig from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralModel from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai import colossalai
from colossalai.booster.booster import Booster from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin, MoeHybridParallelPlugin
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
@ -756,10 +758,11 @@ def run_with_hybridplugin(test_config):
@parameterize( @parameterize(
"config", "config",
[ [
(0, 1, 4, 1, 1), # (0, 1, 4, 1, 1),
(1, 2, 2, 1, 1), # (1, 2, 2, 1, 1),
(1, 2, 1, 2, 1), (1, 1, 2, 2, 1),
(1, 2, 1, 1, 2), # (1, 2, 1, 2, 1),
# (1, 2, 1, 1, 2),
], ],
) )
def run_with_booster_moehybridplugin(config: Tuple[int, ...]): def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
@ -790,6 +793,8 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
seed_all(10086) seed_all(10086)
torch_model = MixtralModel(config).to(dtype).cuda() torch_model = MixtralModel(config).to(dtype).cuda()
# TODO: Support MixtralForCausalLM
# torch_model = MixtralForCausalLM(config).to(dtype).cuda()
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
# init schedule # init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024 h, a, s = config.hidden_size, config.num_attention_heads, 1024
@ -892,7 +897,7 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
# =================================================================================== # ===================================================================================
# run normal model with all dp(different) inputs # run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
torch_output_sum = 0 torch_output_sum = 0
for input_data_ in all_inputs: for input_data_ in all_inputs:
@ -905,6 +910,7 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
p.grad /= dp_size p.grad /= dp_size
torch_optimizer.step() torch_optimizer.step()
torch_optimizer.zero_grad() torch_optimizer.zero_grad()
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} config {test_config} test passed") print(f"rank {dist.get_rank()} config {test_config} test passed")
clear_layout_converter() clear_layout_converter()
@ -912,11 +918,169 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@parameterize(
"config",
[
(1, 2, 2, 1), # Pass
# TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture;
# (0, 4, 1, 1),
# (1, 2, 1, 2),
# (1, 1, 2, 2),
],
)
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
stage, pp_size, tp_size, sp_size = config
num_microbatches = pp_size
dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
torch.cuda.set_device(dist.get_rank())
########
# init base model
########
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
config = LlamaConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=NUM_LAYERS,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
attn_implementation="flash_attention_2",
)
# init model with the same seed
seed_all(10086)
torch_model = LlamaModel(config).to(dtype).cuda()
# TODO: Support MixtralForCausalLM
# torch_model = MixtralForCausalLM(config).to(dtype).cuda()
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
# init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
n_stage=pp_size,
n_micro=num_microbatches,
f_cost=1,
b_cost=1,
w_cost=1,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
)
zbv_schedule = graph.get_v_schedule()
# init HybridParallelPlugin
plugin = HybridParallelPlugin(
pp_size=pp_size,
num_microbatches=pp_size,
tp_size=tp_size,
sp_size=sp_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
overlap_communication=False,
initial_scale=1,
precision=precision,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
dp_size = plugin.dp_size
booster = Booster(plugin=plugin)
########
# init pp model
########
parallel_model = deepcopy(torch_model)
parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
# create different input along dp axis
seed_all(1453 + rank)
torch_model.train()
parallel_model.train()
for _ in range(2):
# gen random input
input_embeddings = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(
input_embeddings, group=plugin.pp_group
) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
# run the model with hybrid parallel
if booster.plugin.stage_manager is not None:
# for test with pp
data_iter = iter([{"inputs_embeds": input_embeddings}])
sharded_output = booster.execute_pipeline(
data_iter,
parallel_model,
lambda x, y: x.last_hidden_state.mean(),
parallel_optimizer,
return_loss=True,
return_outputs=True,
)
# stage 0 chunk 0
parallel_output = None
if (
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
):
parallel_output = sharded_output["loss"]
else:
parallel_output = torch.tensor(12345.0, device="cuda")
# broadcast along pp axis
dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group)
else:
# for test without pp
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
# avg dp grads follows zero optimizer
for p in torch_model.parameters():
if p.grad is not None:
p.grad /= dp_size
torch_optimizer.step()
torch_optimizer.zero_grad()
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# run_fwd_bwd_vschedule_with_optim()
run_with_booster_moehybridplugin() run_with_booster_moehybridplugin()
run_with_booster_hybridplugin()
@pytest.mark.dist @pytest.mark.dist
@ -928,5 +1092,6 @@ def test_pp():
) )
# python -m pytest -s tests/test_pipeline/test_schedule/test_zerobubble_pp.py
if __name__ == "__main__": if __name__ == "__main__":
test_pp() test_pp()

View File

@ -8,7 +8,8 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.pipeline.weight_grad_store import WeightGradStore
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum
from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -117,6 +118,93 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool):
assert_close(x_for_unshard.grad, x_for_shard.grad) assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
)
assert linear_base.weight.shape == torch.Size([128, 32])
assert linear_base.bias.shape == torch.Size([128])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias
linear.load_state_dict(linear_base.state_dict())
linear_base.load_state_dict(linear.state_dict())
# check computation correctness
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
# run forward
out = linear(x_for_unshard)
gather_out = linear_base(x_for_shard)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
assert_close(linear.weight.grad, linear_base.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
)
assert linear_base.weight.shape == torch.Size([128, 32])
assert linear_base.bias.shape == torch.Size([128])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias
linear.load_state_dict(linear_base.state_dict())
linear_base.load_state_dict(linear.state_dict())
# check computation correctness
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
# run forward
out = linear(x_for_unshard)
gather_out = linear_base(x_for_shard)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
# Weight grad is None before we do WeightGradStore pop
assert linear_base.weight.grad is None
# after WeightGradStore pop (dw computation complete), we assert weight grad
WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue
WeightGradStore.pop(chunk=0)
assert_close(linear.weight.grad, linear_base.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool): def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
@ -182,6 +270,8 @@ def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap):
check_linear_1d_col(lazy_init, seq_parallel_mode, overlap) check_linear_1d_col(lazy_init, seq_parallel_mode, overlap)
check_linear_1d_row(lazy_init, seq_parallel_mode) check_linear_1d_row(lazy_init, seq_parallel_mode)
check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap) check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap)
check_linear_without_weight_grad_store(lazy_init, seq_parallel_mode)
check_linear_with_weight_grad_store(lazy_init, seq_parallel_mode)
def check_dist_linear(rank, world_size, port): def check_dist_linear(rank, world_size, port):

View File

@ -277,32 +277,33 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{ # # TODO: assert layer error
"tp_size": 2, # {
"pp_size": 2, # "tp_size": 2,
"pp_style": "zbv", # "pp_size": 2,
"num_model_chunks": 2, # "pp_style": "zbv",
"num_microbatches": 4, # "num_model_chunks": 2,
"enable_all_optimization": False, # "num_microbatches": 4,
"precision": "fp16", # "enable_all_optimization": False,
"zero_stage": 0, # "precision": "fp16",
"initial_scale": 1, # "zero_stage": 0,
"enable_gradient_checkpointing": True, # "initial_scale": 1,
"parallel_output": False, # "enable_gradient_checkpointing": True,
}, # "parallel_output": False,
{ # },
"tp_size": 2, # {
"pp_size": 2, # "tp_size": 2,
"pp_style": "zbv", # "pp_size": 2,
"num_model_chunks": 2, # "pp_style": "zbv",
"num_microbatches": 4, # "num_model_chunks": 2,
"enable_all_optimization": False, # "num_microbatches": 4,
"precision": "fp16", # "enable_all_optimization": False,
"zero_stage": 1, # "precision": "fp16",
"initial_scale": 1, # "zero_stage": 1,
"enable_gradient_checkpointing": True, # "initial_scale": 1,
"parallel_output": False, # "enable_gradient_checkpointing": True,
}, # "parallel_output": False,
# },
], ],
) )
def run_llama_test(test_config): def run_llama_test(test_config):