Browse Source

Merge pull request #6114 from duanjunwen/dev/zero_bubble

[Zerobubble] Support LinearWithAsyncCommunication for sharderformer policy
feature/zerobubble
duanjunwen 4 days ago committed by GitHub
parent
commit
810cafb2f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 1
      colossalai/pipeline/p2p.py
  2. 150
      colossalai/pipeline/schedule/zero_bubble_pp.py
  3. 23
      colossalai/shardformer/layer/attn.py
  4. 1
      colossalai/shardformer/modeling/llama.py
  5. 1
      colossalai/shardformer/modeling/mixtral.py
  6. 98
      colossalai/shardformer/policies/bert.py
  7. 93
      colossalai/shardformer/policies/llama.py
  8. 71
      colossalai/shardformer/policies/mistral.py
  9. 76
      colossalai/shardformer/policies/mixtral.py
  10. 7
      examples/language/llama/benchmark.py
  11. 2
      examples/language/mixtral/benchmark.py
  12. 32
      tests/test_pipeline/test_schedule/test_zerobubble_pp.py

1
colossalai/pipeline/p2p.py

@ -432,7 +432,6 @@ def _communicate(
overlap_p2p=overlap_p2p, overlap_p2p=overlap_p2p,
send_first=send_first if send_first != None else True, send_first=send_first if send_first != None else True,
) )
if metadata_recv is not None: if metadata_recv is not None:
assert isinstance(metadata_recv, P2PMetadata) assert isinstance(metadata_recv, P2PMetadata)
tree_spec = metadata_recv.tree_spec tree_spec = metadata_recv.tree_spec

150
colossalai/pipeline/schedule/zero_bubble_pp.py

@ -64,10 +64,28 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# P2PMeta cache # P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache self.enable_metadata_cache = enable_metadata_cache
self.send_tensor_metadata = True
self.send_grad_metadata = True # check send_tensor_metadata, send_grad_metadata
self.tensor_metadata_recv = None # pp4 as sample, we should follow this meta strategy
self.grad_metadata_recv = None # send_tensor_meta(fwd) send_grad_meta(bwd)
# chunk0 | chunk1 chunk0 | chunk 1
# stage 0 T | F F | T
# stage 1 T | T T | T
# stage 2 T | T T | T
# stage 3 F | T F | T
if stage_manager.is_first_stage(ignore_chunk=True):
self.send_tensor_metadata = [True, False]
self.send_grad_metadata = [False, True]
elif stage_manager.is_last_stage(ignore_chunk=True):
self.send_tensor_metadata = [False, True]
self.send_grad_metadata = [True, False]
else:
self.send_tensor_metadata = [True, True]
self.send_grad_metadata = [True, True]
# meta cache buffer
self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta]
self.grad_metadata_recv = [None, None]
# P2P communication # P2P communication
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
@ -96,10 +114,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self.output_tensors_grad_dw = [[], []] self.output_tensors_grad_dw = [[], []]
# buffer for communication # buffer for communication
self.send_forward_buffer = [[], []] self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
self.recv_forward_buffer = [[], []] self.recv_forward_buffer = [
self.send_backward_buffer = [[], []] [],
self.recv_backward_buffer = [[], []] [],
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
self.recv_backward_buffer = [
[],
[],
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
# y buffer for local send fwd # y buffer for local send fwd
self.local_send_forward_buffer = [] self.local_send_forward_buffer = []
@ -225,7 +249,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; cause u are chunk 0 in first rank, u have no prev rank; # do nothing; cause u are chunk 0 in first rank, u have no prev rank;
################# #################
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
# return None, []
return [] return []
################ ################
@ -235,12 +258,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
input_tensor, wait_handles = self.comm.recv_forward( input_tensor, wait_handles = self.comm.recv_forward(
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
) )
if self.enable_metadata_cache and self.tensor_metadata_recv is None: if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor) self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor) self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
# return input_tensor, wait_handles
return wait_handles return wait_handles
else: else:
@ -259,12 +281,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
input_tensor, wait_handles = self.comm.recv_forward( input_tensor, wait_handles = self.comm.recv_forward(
next_rank, metadata_recv=self.tensor_metadata_recv next_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
) )
if self.enable_metadata_cache and self.tensor_metadata_recv is None: if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor) self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor) self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
# return input_tensor, wait_handles
return wait_handles return wait_handles
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List: def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
@ -287,7 +308,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; Already get dy from local_send_backward_buffer in schedule b # do nothing; Already get dy from local_send_backward_buffer in schedule b
################ ################
if self.stage_manager.is_last_stage(ignore_chunk=True): if self.stage_manager.is_last_stage(ignore_chunk=True):
# return None, []
return [] return []
################ ################
@ -297,12 +317,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward( output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank, metadata_recv=self.grad_metadata_recv next_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
) )
if self.enable_metadata_cache and self.grad_metadata_recv is None: if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad) self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
# return output_tensor_grad, wait_handles
return wait_handles return wait_handles
else: else:
@ -312,7 +331,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; get loss from local # do nothing; get loss from local
################ ################
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
# return None, []
return [] return []
################ ################
@ -322,12 +340,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward( output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv next_rank=prev_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
) )
if self.enable_metadata_cache and self.grad_metadata_recv is None: if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad) self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
# return output_tensor_grad, wait_handles
return wait_handles return wait_handles
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
@ -349,6 +366,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; hold y on local_send_forward_buffer # do nothing; hold y on local_send_forward_buffer
################ ################
if self.stage_manager.is_last_stage(ignore_chunk=True): if self.stage_manager.is_last_stage(ignore_chunk=True):
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
return [] return []
################ ################
@ -359,9 +377,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward( send_handles = self.comm.send_forward(
output_object=output_tensor, next_rank=next_rank, send_metadata=self.send_tensor_metadata output_object=output_tensor,
next_rank=next_rank,
send_metadata=self.send_tensor_metadata[model_chunk_id],
) )
self.send_tensor_metadata = not self.enable_metadata_cache self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles return send_handles
else: else:
@ -370,6 +390,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part # do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part
################ ################
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
return [] return []
################ ################
@ -380,9 +401,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward( send_handles = self.comm.send_forward(
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata output_tensor, prev_rank, send_metadata=self.send_tensor_metadata[model_chunk_id]
) )
self.send_tensor_metadata = not self.enable_metadata_cache self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles return send_handles
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
@ -405,6 +426,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; cause u are the first chunk in first stage; bwd end # do nothing; cause u are the first chunk in first stage; bwd end
################ ################
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
return [] return []
################ ################
@ -415,9 +437,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
prev_rank = self.stage_manager.get_prev_rank() prev_rank = self.stage_manager.get_prev_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward( send_handles = self.comm.send_backward(
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
) )
self.send_grad_metadata = not self.enable_metadata_cache self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles return send_handles
# bwd chunk1 is left V; # bwd chunk1 is left V;
@ -427,6 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b; # do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
################ ################
if self.stage_manager.is_last_stage(ignore_chunk=True): if self.stage_manager.is_last_stage(ignore_chunk=True):
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
return [] return []
################ ################
@ -437,9 +460,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
next_rank = self.stage_manager.get_next_rank() next_rank = self.stage_manager.get_next_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward( send_handles = self.comm.send_backward(
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
) )
self.send_grad_metadata = not self.enable_metadata_cache self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles return send_handles
def forward_step( def forward_step(
@ -519,8 +542,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_grad_ = [] output_obj_grad_ = []
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx. # For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
# if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
# return None
# For loss backward; output_obj is loss; output_obj_grad should be None # For loss backward; output_obj is loss; output_obj_grad should be None
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
@ -633,9 +654,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if model_chunk_id == 0: if model_chunk_id == 0:
# is first stage; get input from microbatch # is first stage; get input from microbatch
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = None input_obj = None # (tensor, wait_handle)
else: else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
for h in input_obj[1]:
h.wait()
input_obj = input_obj[0]
else: else:
# is last stage; recv from local # is last stage; recv from local
if self.stage_manager.is_last_stage(ignore_chunk=True): if self.stage_manager.is_last_stage(ignore_chunk=True):
@ -643,7 +667,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# not last stage; recv from next # not last stage; recv from next
else: else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
for h in input_obj[1]:
h.wait()
input_obj = input_obj[0]
# Here, let input_obj.requires_grad_() # Here, let input_obj.requires_grad_()
# if input_obj is not None: # if input_obj is not None:
if not isinstance(input_obj, torch.Tensor): if not isinstance(input_obj, torch.Tensor):
@ -689,10 +715,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Do not release_tensor_data loss, release_tensor_data other output_obj; # Do not release_tensor_data loss, release_tensor_data other output_obj;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
self.output_tensors[model_chunk_id].append(output_obj) self.output_tensors[model_chunk_id].append(output_obj)
# self.output_tensors_dw[model_chunk_id].append(output_obj)
else: else:
self.output_tensors[model_chunk_id].append(output_obj) self.output_tensors[model_chunk_id].append(output_obj)
# self.output_tensors_dw[model_chunk_id].append(output_obj)
# add output to send_fwd_buffer # add output to send_fwd_buffer
if model_chunk_id == 0: # chunk 0 if model_chunk_id == 0: # chunk 0
@ -732,6 +756,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# chunk0 not last stage; recv output_grad from recv_backward_buffer # chunk0 not last stage; recv output_grad from recv_backward_buffer
else: else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
for h in output_tensor_grad[1]:
h.wait()
output_tensor_grad = output_tensor_grad[0]
else: else:
# chunk1, is first stage; recv LOSS from local send bwd buffer # chunk1, is first stage; recv LOSS from local send bwd buffer
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
@ -739,25 +766,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# chunk1, not first stage; recv output_grad from recv_backward_buffer # chunk1, not first stage; recv output_grad from recv_backward_buffer
else: else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
for h in output_tensor_grad[1]:
h.wait()
output_tensor_grad = output_tensor_grad[0]
# get input and output object from buffer; # get input and output object from buffer;
input_obj = self.input_tensors[model_chunk_id].pop(0) input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0)
# # save output_tensor_grad for dw
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# # we save loss here
# self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
# else:
# # we save output_tensor_grad here
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# the_output_obj_grad = []
# if isinstance(output_obj, dict):
# for (k, v) in output_obj.items():
# the_output_obj_grad.append(v.requires_grad)
# else:
# the_output_obj_grad.append(output_obj.requires_grad)
input_object_grad = self.backward_b_step( input_object_grad = self.backward_b_step(
model_chunk=model_chunk, model_chunk=model_chunk,
model_chunk_id=model_chunk_id, model_chunk_id=model_chunk_id,
@ -800,20 +816,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
Returns: Returns:
Nothing. Nothing.
""" """
# get y & dy from buffer
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
WeightGradStore.pop(chunk=model_chunk_id) WeightGradStore.pop(chunk=model_chunk_id)
# self.backward_w_step(
# model_chunk=model_chunk,
# model_chunk_id=model_chunk_id,
# optimizer=optimizer,
# output_obj=output_obj,
# output_obj_grad=output_obj_grad,
# )
def run_forward_only( def run_forward_only(
self, self,
model_chunk: Union[ModuleList, Module], model_chunk: Union[ModuleList, Module],
@ -890,6 +894,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# communication # communication
communication_func = self.communication_map[scheduled_node.type] communication_func = self.communication_map[scheduled_node.type]
wait_handle = communication_func(scheduled_node.chunk) wait_handle = communication_func(scheduled_node.chunk)
# We wait recv handle in fwd step and bwd step. Here only need to wait for send handle
if scheduled_node.type in {"SEND_FORWARD", "SEND_BACKWARD"}:
self.wait_handles.append(wait_handle) self.wait_handles.append(wait_handle)
elif scheduled_node.type == "F": elif scheduled_node.type == "F":
self.schedule_f( self.schedule_f(
@ -914,10 +920,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk, model_chunk_id=scheduled_node.chunk,
optimizer=optimizer, optimizer=optimizer,
) )
# wait here to ensure all communication is done
for h in self.wait_handles: for h in self.wait_handles:
for hh in h: for hh in h:
hh.wait() hh.wait()
# return loss & output # return loss & output
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) outputs = merge_batch(outputs)

23
colossalai/shardformer/layer/attn.py

@ -6,6 +6,7 @@ import torch.distributed
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from packaging import version
from colossalai.kernel.kernel_loader import ( from colossalai.kernel.kernel_loader import (
FlashAttentionDaoLoader, FlashAttentionDaoLoader,
@ -642,9 +643,7 @@ class RingAttention(torch.autograd.Function):
max_seqlen_q = max_seqlen_kv = max_seqlen max_seqlen_q = max_seqlen_kv = max_seqlen
cu_seqlens_half = cu_seqlens // 2 cu_seqlens_half = cu_seqlens // 2
max_seqlen_half = max_seqlen // 2 max_seqlen_half = max_seqlen // 2
misc_kwargs = { misc_kwargs = {
"window_size": (-1, -1),
"alibi_slopes": None, "alibi_slopes": None,
"softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale, "softmax_scale": q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale,
"dropout_p": dropout_p, "dropout_p": dropout_p,
@ -652,6 +651,13 @@ class RingAttention(torch.autograd.Function):
"softcap": 0.0, "softcap": 0.0,
"return_softmax": False, "return_softmax": False,
} }
import flash_attn
if version.parse(flash_attn.__version__) > version.parse("2.6.3"):
misc_kwargs["window_size_left"] = -1
misc_kwargs["window_size_right"] = -1
else:
misc_kwargs["window_size"] = (-1, -1)
if ( if (
RingAttention.HALF_INDICES is not None RingAttention.HALF_INDICES is not None
@ -707,6 +713,19 @@ class RingAttention(torch.autograd.Function):
# Helper to pass args to FA # Helper to pass args to FA
def _forward(q, k, v, causal): def _forward(q, k, v, causal):
if version.parse(flash_attn.__version__) > version.parse("2.6.3"):
(out, softmax_lse, S_dmask, rng_state) = _flash_attn_forward(
q,
k,
v,
cu_seqlens_q if q.shape[0] == t else cu_seqlens_half,
cu_seqlens_kv if k.shape[0] == t else cu_seqlens_half,
max_seqlen_q if q.shape[0] == t else max_seqlen_half,
max_seqlen_kv if k.shape[0] == t else max_seqlen_half,
causal=causal,
**misc_kwargs,
)
else:
( (
_, _,
_, _,

1
colossalai/shardformer/modeling/llama.py

@ -191,7 +191,6 @@ class LlamaPipelineForwards:
num_model_chunks=stage_manager.num_model_chunks, num_model_chunks=stage_manager.num_model_chunks,
) )
assert num_ckpt_layers <= end_idx - start_idx assert num_ckpt_layers <= end_idx - start_idx
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)

1
colossalai/shardformer/modeling/mixtral.py

@ -381,7 +381,6 @@ class MixtralPipelineForwards:
output_router_logits, output_router_logits,
use_cache, use_cache,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:

98
colossalai/shardformer/policies/bert.py

@ -75,6 +75,8 @@ class BertPolicy(Policy):
sp_partial_derived = sp_mode == "split_gather" sp_partial_derived = sp_mode == "split_gather"
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
@ -97,6 +99,7 @@ class BertPolicy(Policy):
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -105,6 +108,7 @@ class BertPolicy(Policy):
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -113,6 +117,7 @@ class BertPolicy(Policy):
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -125,6 +130,7 @@ class BertPolicy(Policy):
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -138,6 +144,7 @@ class BertPolicy(Policy):
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"skip_bias_add": self.enable_bias_gelu_fused, "skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -146,6 +153,97 @@ class BertPolicy(Policy):
kwargs={ kwargs={
"seq_parallel_mode": sp_mode, "seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
],
)
policy[BertEmbeddings] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
]
)
if self.enable_bias_gelu_fused:
self.append_or_create_method_replacement(
description={
"forward": get_jit_fused_bert_intermediate_forward(),
},
policy=policy,
target_key=BertIntermediate,
)
elif use_zbv:
policy[BertLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"skip_bias_add": self.enable_bias_gelu_fused,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.LinearWithGradAccum,
kwargs={
"seq_parallel_mode": sp_mode,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(

93
colossalai/shardformer/policies/llama.py

@ -9,6 +9,7 @@ from colossalai.shardformer.layer import (
FusedRMSNorm, FusedRMSNorm,
Linear1D_Col, Linear1D_Col,
Linear1D_Row, Linear1D_Row,
LinearWithGradAccum,
PaddingEmbedding, PaddingEmbedding,
PaddingLMHead, PaddingLMHead,
RMSNorm, RMSNorm,
@ -104,7 +105,7 @@ class LlamaPolicy(Policy):
policy=policy, policy=policy,
target_key=LlamaModel, target_key=LlamaModel,
) )
# enable tp, replace layer to tp Linear1D_Col,Linear1D_Row,
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
num_q_heads % tp_size == 0 num_q_heads % tp_size == 0
@ -191,6 +192,76 @@ class LlamaPolicy(Policy):
], ],
) )
# not enable tp, replace layer to LinearWithGradAccum
elif use_zbv:
policy[LlamaDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=LinearWithGradAccum,
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
),
],
)
if embedding_cls is not None: if embedding_cls is not None:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
@ -416,6 +487,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
policy = super().module_policy() policy = super().module_policy()
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
# enable tp, replace layer to tp Linear1D_Col,Linear1D_Row,
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification # add a new item for sequence classification
new_item = { new_item = {
@ -434,6 +506,25 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
) )
} }
policy.update(new_item) policy.update(new_item)
# enable tp, replace layer to LinearWithGradAccum
elif use_zbv:
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score",
target_module=LinearWithGradAccum,
kwargs=dict(
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]
)
}
policy.update(new_item)
# to be confirmed # to be confirmed
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
# set None as default # set None as default

71
colossalai/shardformer/policies/mistral.py

@ -10,6 +10,7 @@ from colossalai.shardformer.layer import (
FusedRMSNorm, FusedRMSNorm,
Linear1D_Col, Linear1D_Col,
Linear1D_Row, Linear1D_Row,
LinearWithGradAccum,
PaddingEmbedding, PaddingEmbedding,
PaddingLMHead, PaddingLMHead,
VocabParallelEmbedding1D, VocabParallelEmbedding1D,
@ -62,6 +63,8 @@ class MistralPolicy(Policy):
if self.tie_weight: if self.tie_weight:
embedding_cls = PaddingEmbedding embedding_cls = PaddingEmbedding
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn( warnings.warn(
@ -90,6 +93,7 @@ class MistralPolicy(Policy):
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={ kwargs={
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -97,6 +101,7 @@ class MistralPolicy(Policy):
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={ kwargs={
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -104,6 +109,7 @@ class MistralPolicy(Policy):
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={ kwargs={
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -111,6 +117,7 @@ class MistralPolicy(Policy):
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs={ kwargs={
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -118,6 +125,7 @@ class MistralPolicy(Policy):
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={ kwargs={
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -125,6 +133,7 @@ class MistralPolicy(Policy):
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs={ kwargs={
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -132,6 +141,68 @@ class MistralPolicy(Policy):
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs={ kwargs={
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
elif use_zbv:
policy[MistralDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
}, },
), ),
], ],

76
colossalai/shardformer/policies/mixtral.py

@ -7,9 +7,18 @@ from torch import Tensor
from torch.nn import Module from torch.nn import Module
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col from colossalai.shardformer.layer import (
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D FusedRMSNorm,
from colossalai.shardformer.layer.linear import Linear1D_Row Linear1D_Col,
Linear1D_Row,
LinearWithGradAccum,
PaddingEmbedding,
VocabParallelEmbedding1D,
)
# from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
# from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
# from colossalai.shardformer.layer.linear import Linear1D_Row
from colossalai.shardformer.modeling.mixtral import ( from colossalai.shardformer.modeling.mixtral import (
EPMixtralSparseMoeBlock, EPMixtralSparseMoeBlock,
MixtralPipelineForwards, MixtralPipelineForwards,
@ -166,6 +175,51 @@ class MixtralPolicy(Policy):
], ],
) )
elif use_zbv:
policy[MixtralDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
suffix="block_sparse_moe.gate",
target_module=LinearWithGradAccum,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
),
],
)
if embedding_cls is not None: if embedding_cls is not None:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
@ -351,6 +405,22 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
) )
} }
policy.update(new_item) policy.update(new_item)
elif use_zbv:
new_item = {
MixtralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=LinearWithGradAccum,
kwargs=dict(
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
],
)
}
policy.update(new_item)
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
# set None as default # set None as default

7
examples/language/llama/benchmark.py

@ -163,8 +163,6 @@ def main():
enable_async_reduce=not args.disable_async_reduce, enable_async_reduce=not args.disable_async_reduce,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
) )
elif args.plugin == "gemini_auto": elif args.plugin == "gemini_auto":
plugin = GeminiPlugin( plugin = GeminiPlugin(
@ -179,8 +177,6 @@ def main():
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
use_fp8=args.use_fp8, use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
) )
elif args.plugin == "fsdp": elif args.plugin == "fsdp":
if use_empty_init: if use_empty_init:
@ -192,7 +188,6 @@ def main():
), ),
param_init_fn=empty_init(), param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
) )
else: else:
plugin = TorchFSDPPlugin( plugin = TorchFSDPPlugin(
@ -214,7 +209,6 @@ def main():
cpu_offload=CPUOffload(offload_params=True), cpu_offload=CPUOffload(offload_params=True),
param_init_fn=empty_init(), param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
) )
else: else:
plugin = TorchFSDPPlugin( plugin = TorchFSDPPlugin(
@ -225,7 +219,6 @@ def main():
), ),
cpu_offload=CPUOffload(offload_params=True), cpu_offload=CPUOffload(offload_params=True),
fp8_communication=args.use_fp8_comm, fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
) )
elif args.plugin == "3d": elif args.plugin == "3d":
if args.pp_style == "zbv": if args.pp_style == "zbv":

2
examples/language/mixtral/benchmark.py

@ -122,7 +122,7 @@ def main():
num_ckpt_layers_per_stage=[19, 19, 19, 13], num_ckpt_layers_per_stage=[19, 19, 19, 13],
), ),
"num_layers_per_stage": [19, 20, 20, 21], "num_layers_per_stage": [19, 20, 20, 21],
# "pp_style": "interleaved", "pp_style": "interleaved",
} }
if args.custom_ckpt if args.custom_ckpt
else {} else {}

32
tests/test_pipeline/test_schedule/test_zerobubble_pp.py

@ -749,24 +749,17 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups) assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups)
# TODO:3) support booster & Hybrid base 2)
def run_with_hybridplugin(test_config):
pass
# TODO:4) support booster & MoEHybrid base 2)
@parameterize( @parameterize(
"config", "config",
[ [
# (0, 1, 4, 1, 1), (1, 2, 1, 1, 2),
# (1, 2, 2, 1, 1),
(1, 1, 2, 2, 1), (1, 1, 2, 2, 1),
# (1, 2, 1, 2, 1), (1, 2, 1, 2, 1),
# (1, 2, 1, 1, 2), (1, 2, 2, 1, 1),
(1, 1, 4, 1, 1),
], ],
) )
def run_with_booster_moehybridplugin(config: Tuple[int, ...]): def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
test_config = config
stage, ep_size, pp_size, tp_size, sp_size = config stage, ep_size, pp_size, tp_size, sp_size = config
num_microbatches = pp_size num_microbatches = pp_size
dist.get_world_size() dist.get_world_size()
@ -876,7 +869,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
return_outputs=True, return_outputs=True,
) )
# stage 0 chunk 0 # stage 0 chunk 0
parallel_output = None
if ( if (
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True) booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
and rank == dist.get_process_group_ranks(plugin.pp_group)[0] and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
@ -910,9 +902,7 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
p.grad /= dp_size p.grad /= dp_size
torch_optimizer.step() torch_optimizer.step()
torch_optimizer.zero_grad() torch_optimizer.zero_grad()
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} config {test_config} test passed")
clear_layout_converter() clear_layout_converter()
Randomizer.reset_index() Randomizer.reset_index()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -921,11 +911,11 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
@parameterize( @parameterize(
"config", "config",
[ [
(1, 2, 2, 1), # Pass # Pass
# TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture; (1, 2, 2, 1),
# (0, 4, 1, 1), (1, 2, 1, 2),
# (1, 2, 1, 2), (1, 1, 2, 2),
# (1, 1, 2, 2), (1, 4, 1, 1),
], ],
) )
def run_with_booster_hybridplugin(config: Tuple[int, ...]): def run_with_booster_hybridplugin(config: Tuple[int, ...]):
@ -1034,7 +1024,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
return_outputs=True, return_outputs=True,
) )
# stage 0 chunk 0 # stage 0 chunk 0
parallel_output = None
if ( if (
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True) booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
and rank == dist.get_process_group_ranks(plugin.pp_group)[0] and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
@ -1068,9 +1057,8 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
p.grad /= dp_size p.grad /= dp_size
torch_optimizer.step() torch_optimizer.step()
torch_optimizer.zero_grad() torch_optimizer.zero_grad()
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
clear_layout_converter() clear_layout_converter()
Randomizer.reset_index() Randomizer.reset_index()
torch.cuda.empty_cache() torch.cuda.empty_cache()

Loading…
Cancel
Save