mirror of https://github.com/hpcaitech/ColossalAI
[feat] support zbv in mixtral benchmark; (#6083)
* [feat] support zbv in mixtral benchmark; * [fix] MixtralForCausalLMPolicy get_held_layer support zbv; * [feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv; * [feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv * [fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling; * [feat] Linear1D_COL/ROW support zbv WeightGradStore; * [feat] support use_zbv in llama, mixtral modeling; only replace Linear1D_Col/Row policy; * [fix] fix test case; moe error in second iter * [feat]EPMixtralSparseMoeBlock (op in MOE) support zbv; * [fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd; * [fix] debug zbv llama test; * [fix] rm use_zbv flag in Shardconfig; rm debug info; * [fix] add & fix llama test * [feat] support meta cache, meta_grad_send, meta_tensor_send; fix runtime too long in Recv Bwd; benchmark for llama + Hybrid(tp+pp); * [fix\ fix fail case test_shard_llama * [fix] fix test_shard_llama * [fix] fix llama modeling policy; * [fix] fix test_shard_llama ci; * [fix] fix test zerobubble * [fix] fix handle name; rm useless comments; * [fix] fix send recv signature; * [fix] fix comment in llama & benchmark * [feat] support no tensor parallel Linear in shardformer; Add test for use weightGradStore and not use WeightGradStore * [fix] fix linear (no tp) ops func name;feature/zerobubble
parent
dac0e07b13
commit
aed20fb2df
|
@ -1,16 +1,18 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
import torch.distributed
|
||||||
from torch.nn import Module, ModuleList
|
from torch.nn import Module, ModuleList
|
||||||
from torch.utils._pytree import tree_flatten, tree_map
|
from torch.utils._pytree import tree_flatten, tree_map
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||||
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
||||||
|
|
||||||
from ._utils import (
|
from ._utils import (
|
||||||
clone,
|
clone,
|
||||||
|
@ -61,11 +63,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
self.do_post_validation = False
|
self.do_post_validation = False
|
||||||
|
|
||||||
# P2PMeta cache
|
# P2PMeta cache
|
||||||
# self.enable_metadata_cache = enable_metadata_cache
|
self.enable_metadata_cache = enable_metadata_cache
|
||||||
# self.send_tensor_metadata = True
|
self.send_tensor_metadata = True
|
||||||
# self.send_grad_metadata = True
|
self.send_grad_metadata = True
|
||||||
# self.tensor_metadata_recv = None
|
self.tensor_metadata_recv = None
|
||||||
# self.grad_metadata_recv = None
|
self.grad_metadata_recv = None
|
||||||
|
|
||||||
# P2P communication
|
# P2P communication
|
||||||
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
||||||
|
@ -104,8 +106,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# dy buffer for local send bwd
|
# dy buffer for local send bwd
|
||||||
self.local_send_backward_buffer = []
|
self.local_send_backward_buffer = []
|
||||||
|
|
||||||
|
# wait pp buffer
|
||||||
|
self.wait_handles = []
|
||||||
|
|
||||||
def assert_buffer_empty(self):
|
def assert_buffer_empty(self):
|
||||||
# assert buuffer is empty at end
|
# assert buffer is empty at end
|
||||||
assert len(self.input_tensors[0]) == 0
|
assert len(self.input_tensors[0]) == 0
|
||||||
assert len(self.input_tensors[1]) == 0
|
assert len(self.input_tensors[1]) == 0
|
||||||
assert len(self.output_tensors[0]) == 0
|
assert len(self.output_tensors[0]) == 0
|
||||||
|
@ -201,7 +206,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
||||||
return model_chunk_id
|
return model_chunk_id
|
||||||
|
|
||||||
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
|
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
||||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||||
For ZBV.
|
For ZBV.
|
||||||
|
|
||||||
|
@ -220,7 +225,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
|
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
|
||||||
#################
|
#################
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
return None, []
|
# return None, []
|
||||||
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
# chunk = 0 & not is_first_stage
|
# chunk = 0 & not is_first_stage
|
||||||
|
@ -228,9 +234,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
#################
|
#################
|
||||||
else:
|
else:
|
||||||
prev_rank = self.stage_manager.get_prev_rank()
|
prev_rank = self.stage_manager.get_prev_rank()
|
||||||
input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank)
|
input_tensor, wait_handles = self.comm.recv_forward(
|
||||||
|
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv
|
||||||
|
)
|
||||||
|
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||||
|
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||||
return input_tensor, wait_handles
|
# return input_tensor, wait_handles
|
||||||
|
return wait_handles
|
||||||
|
|
||||||
else:
|
else:
|
||||||
################
|
################
|
||||||
|
@ -238,7 +249,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# do nothing; cause u get y from local_send_forward_buffer in schedule f
|
# do nothing; cause u get y from local_send_forward_buffer in schedule f
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
return None, []
|
# return None, []
|
||||||
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
# chunk = 1 & not is_last_stage
|
# chunk = 1 & not is_last_stage
|
||||||
|
@ -246,11 +258,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
################
|
################
|
||||||
else:
|
else:
|
||||||
next_rank = self.stage_manager.get_next_rank()
|
next_rank = self.stage_manager.get_next_rank()
|
||||||
input_tensor, wait_handles = self.comm.recv_forward(next_rank)
|
input_tensor, wait_handles = self.comm.recv_forward(
|
||||||
|
next_rank, metadata_recv=self.tensor_metadata_recv
|
||||||
|
)
|
||||||
|
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||||
|
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||||
return input_tensor, wait_handles
|
# return input_tensor, wait_handles
|
||||||
|
return wait_handles
|
||||||
|
|
||||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
|
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||||
For ZBV.
|
For ZBV.
|
||||||
|
|
||||||
|
@ -270,7 +287,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# do nothing; Already get dy from local_send_backward_buffer in schedule b
|
# do nothing; Already get dy from local_send_backward_buffer in schedule b
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
return None, []
|
# return None, []
|
||||||
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
# chunk = 0 & not is_last_stage
|
# chunk = 0 & not is_last_stage
|
||||||
|
@ -278,9 +296,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
################
|
################
|
||||||
else:
|
else:
|
||||||
next_rank = self.stage_manager.get_next_rank()
|
next_rank = self.stage_manager.get_next_rank()
|
||||||
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank)
|
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
||||||
|
next_rank, metadata_recv=self.grad_metadata_recv
|
||||||
|
)
|
||||||
|
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||||
|
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||||
return output_tensor_grad, wait_handles
|
# return output_tensor_grad, wait_handles
|
||||||
|
return wait_handles
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# bwd chunk1 is left V;
|
# bwd chunk1 is left V;
|
||||||
|
@ -289,7 +312,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# do nothing; get loss from local
|
# do nothing; get loss from local
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
return None, []
|
# return None, []
|
||||||
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
# chunk = 1 & not first stage
|
# chunk = 1 & not first stage
|
||||||
|
@ -297,9 +321,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
################
|
################
|
||||||
else:
|
else:
|
||||||
prev_rank = self.stage_manager.get_prev_rank()
|
prev_rank = self.stage_manager.get_prev_rank()
|
||||||
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank)
|
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
||||||
|
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv
|
||||||
|
)
|
||||||
|
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||||
|
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||||
return output_tensor_grad, wait_handles
|
# return output_tensor_grad, wait_handles
|
||||||
|
return wait_handles
|
||||||
|
|
||||||
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||||
"""Sends the input tensor to the next stage in pipeline.
|
"""Sends the input tensor to the next stage in pipeline.
|
||||||
|
@ -329,7 +358,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
else:
|
else:
|
||||||
next_rank = self.stage_manager.get_next_rank()
|
next_rank = self.stage_manager.get_next_rank()
|
||||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||||
send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank)
|
send_handles = self.comm.send_forward(
|
||||||
|
output_object=output_tensor, next_rank=next_rank, send_metadata=self.send_tensor_metadata
|
||||||
|
)
|
||||||
|
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||||
return send_handles
|
return send_handles
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -347,7 +379,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
else:
|
else:
|
||||||
prev_rank = self.stage_manager.get_prev_rank()
|
prev_rank = self.stage_manager.get_prev_rank()
|
||||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||||
send_handles = self.comm.send_forward(output_tensor, prev_rank)
|
send_handles = self.comm.send_forward(
|
||||||
|
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata
|
||||||
|
)
|
||||||
|
self.send_tensor_metadata = not self.enable_metadata_cache
|
||||||
return send_handles
|
return send_handles
|
||||||
|
|
||||||
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
||||||
|
@ -379,7 +414,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
else:
|
else:
|
||||||
prev_rank = self.stage_manager.get_prev_rank()
|
prev_rank = self.stage_manager.get_prev_rank()
|
||||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||||
send_handles = self.comm.send_backward(input_tensor_grad, prev_rank)
|
send_handles = self.comm.send_backward(
|
||||||
|
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
|
||||||
|
)
|
||||||
|
self.send_grad_metadata = not self.enable_metadata_cache
|
||||||
return send_handles
|
return send_handles
|
||||||
|
|
||||||
# bwd chunk1 is left V;
|
# bwd chunk1 is left V;
|
||||||
|
@ -398,7 +436,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
else:
|
else:
|
||||||
next_rank = self.stage_manager.get_next_rank()
|
next_rank = self.stage_manager.get_next_rank()
|
||||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||||
send_handles = self.comm.send_backward(input_tensor_grad, next_rank)
|
send_handles = self.comm.send_backward(
|
||||||
|
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata
|
||||||
|
)
|
||||||
|
self.send_grad_metadata = not self.enable_metadata_cache
|
||||||
return send_handles
|
return send_handles
|
||||||
|
|
||||||
def forward_step(
|
def forward_step(
|
||||||
|
@ -432,7 +473,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
internal_inputs = {} if input_obj is None else input_obj
|
internal_inputs = {} if input_obj is None else input_obj
|
||||||
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||||
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
|
||||||
|
|
||||||
# last layer in model
|
# last layer in model
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||||
|
@ -479,11 +519,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
output_obj_grad_ = []
|
output_obj_grad_ = []
|
||||||
|
|
||||||
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
|
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
|
||||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
# if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
return None
|
# return None
|
||||||
|
|
||||||
# For loss backward; output_obj is loss; output_obj_grad should be None
|
# For loss backward; output_obj is loss; output_obj_grad should be None
|
||||||
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
assert output_obj_grad is None
|
assert output_obj_grad is None
|
||||||
input_obj_, _ = tree_flatten(input_obj)
|
input_obj_, _ = tree_flatten(input_obj)
|
||||||
output_obj_.append(output_obj) # LOSS
|
output_obj_.append(output_obj) # LOSS
|
||||||
|
@ -504,17 +544,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
ctx = optimizer.no_sync()
|
ctx = optimizer.no_sync()
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
ctx = model_chunk.no_sync()
|
ctx = model_chunk.no_sync()
|
||||||
|
|
||||||
with ctx:
|
with ctx:
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj_,
|
tensor=output_obj_,
|
||||||
grad=output_obj_grad_,
|
grad=output_obj_grad_,
|
||||||
inputs=input_obj_,
|
# inputs=input_obj_,
|
||||||
retain_graph=True,
|
retain_graph=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format output_obj_grad
|
# Format output_obj_grad
|
||||||
input_obj_grad = {}
|
input_obj_grad = dict()
|
||||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
@ -651,10 +689,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# Do not release_tensor_data loss, release_tensor_data other output_obj;
|
# Do not release_tensor_data loss, release_tensor_data other output_obj;
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
self.output_tensors[model_chunk_id].append(output_obj)
|
self.output_tensors[model_chunk_id].append(output_obj)
|
||||||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
# self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||||
else:
|
else:
|
||||||
self.output_tensors[model_chunk_id].append(output_obj)
|
self.output_tensors[model_chunk_id].append(output_obj)
|
||||||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
# self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||||
|
|
||||||
# add output to send_fwd_buffer
|
# add output to send_fwd_buffer
|
||||||
if model_chunk_id == 0: # chunk 0
|
if model_chunk_id == 0: # chunk 0
|
||||||
|
@ -706,15 +744,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||||
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
||||||
|
|
||||||
# save output_tensor_grad for dw
|
# # save output_tensor_grad for dw
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# we save loss here
|
# # we save loss here
|
||||||
self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
|
# self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
|
||||||
else:
|
# else:
|
||||||
# we save output_tensor_grad here
|
# # we save output_tensor_grad here
|
||||||
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||||
|
# the_output_obj_grad = []
|
||||||
|
# if isinstance(output_obj, dict):
|
||||||
|
# for (k, v) in output_obj.items():
|
||||||
|
# the_output_obj_grad.append(v.requires_grad)
|
||||||
|
# else:
|
||||||
|
# the_output_obj_grad.append(output_obj.requires_grad)
|
||||||
|
|
||||||
# Step2: bwd step
|
|
||||||
input_object_grad = self.backward_b_step(
|
input_object_grad = self.backward_b_step(
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=model_chunk_id,
|
model_chunk_id=model_chunk_id,
|
||||||
|
@ -739,6 +782,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# send to next
|
# send to next
|
||||||
else:
|
else:
|
||||||
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
|
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
|
||||||
|
WeightGradStore.flush(chunk=model_chunk_id)
|
||||||
|
|
||||||
def schedule_w(
|
def schedule_w(
|
||||||
self,
|
self,
|
||||||
|
@ -758,16 +802,17 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# get y & dy from buffer
|
# get y & dy from buffer
|
||||||
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
||||||
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
|
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
|
||||||
|
WeightGradStore.pop(chunk=model_chunk_id)
|
||||||
|
|
||||||
self.backward_w_step(
|
# self.backward_w_step(
|
||||||
model_chunk=model_chunk,
|
# model_chunk=model_chunk,
|
||||||
model_chunk_id=model_chunk_id,
|
# model_chunk_id=model_chunk_id,
|
||||||
optimizer=optimizer,
|
# optimizer=optimizer,
|
||||||
output_obj=output_obj,
|
# output_obj=output_obj,
|
||||||
output_obj_grad=output_obj_grad,
|
# output_obj_grad=output_obj_grad,
|
||||||
)
|
# )
|
||||||
|
|
||||||
def run_forward_only(
|
def run_forward_only(
|
||||||
self,
|
self,
|
||||||
|
@ -844,7 +889,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||||
# communication
|
# communication
|
||||||
communication_func = self.communication_map[scheduled_node.type]
|
communication_func = self.communication_map[scheduled_node.type]
|
||||||
communication_func(scheduled_node.chunk)
|
wait_handle = communication_func(scheduled_node.chunk)
|
||||||
|
self.wait_handles.append(wait_handle)
|
||||||
elif scheduled_node.type == "F":
|
elif scheduled_node.type == "F":
|
||||||
self.schedule_f(
|
self.schedule_f(
|
||||||
scheduled_node=scheduled_node,
|
scheduled_node=scheduled_node,
|
||||||
|
@ -868,6 +914,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
model_chunk_id=scheduled_node.chunk,
|
model_chunk_id=scheduled_node.chunk,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
|
for h in self.wait_handles:
|
||||||
|
for hh in h:
|
||||||
|
hh.wait()
|
||||||
|
|
||||||
# return loss & output
|
# return loss & output
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
|
@ -907,5 +956,4 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assert_buffer_empty()
|
self.assert_buffer_empty()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -223,7 +223,6 @@ class PipelineStageManager:
|
||||||
|
|
||||||
# calculate the num_layers per stage
|
# calculate the num_layers per stage
|
||||||
layers_per_stage = [quotient] * num_stages * num_model_chunks
|
layers_per_stage = [quotient] * num_stages * num_model_chunks
|
||||||
|
|
||||||
# deal with the rest layers
|
# deal with the rest layers
|
||||||
if remainder > 0:
|
if remainder > 0:
|
||||||
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
|
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
import queue
|
||||||
|
|
||||||
|
|
||||||
|
class WeightGradStore:
|
||||||
|
|
||||||
|
cache = []
|
||||||
|
weight_grad_queue = [queue.Queue(), queue.Queue()]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def put(cls, total_input, grad_output, weight, func):
|
||||||
|
# func(total_input, grad_output, weight.main_grad)
|
||||||
|
cls.cache.append((total_input, grad_output, weight, func))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def flush(cls, chunk=0):
|
||||||
|
cls.weight_grad_queue[chunk].put(cls.cache)
|
||||||
|
cls.cache = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pop(cls, chunk=0):
|
||||||
|
# print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}")
|
||||||
|
if cls.weight_grad_queue[chunk].qsize() > 0:
|
||||||
|
stored_grads = cls.weight_grad_queue[chunk].get()
|
||||||
|
for total_input, grad_output, weight, func in stored_grads:
|
||||||
|
if weight.grad is not None:
|
||||||
|
func(total_input, grad_output, weight.grad)
|
||||||
|
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
|
||||||
|
else:
|
||||||
|
grad_weight = func(total_input, grad_output)
|
||||||
|
weight.grad = grad_weight
|
||||||
|
else:
|
||||||
|
raise Exception("Pop empty queue.")
|
|
@ -2,7 +2,7 @@ from ._operation import all_to_all_comm
|
||||||
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
|
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
|
||||||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||||
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
|
||||||
from .loss import cross_entropy_1d, dist_cross_entropy
|
from .loss import cross_entropy_1d, dist_cross_entropy
|
||||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||||
from .parallel_module import ParallelModule
|
from .parallel_module import ParallelModule
|
||||||
|
@ -11,6 +11,7 @@ from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Embedding1D",
|
"Embedding1D",
|
||||||
"VocabParallelEmbedding1D",
|
"VocabParallelEmbedding1D",
|
||||||
|
"LinearWithGradAccum",
|
||||||
"Linear1D_Col",
|
"Linear1D_Col",
|
||||||
"Linear1D_Row",
|
"Linear1D_Row",
|
||||||
"GPT2FusedLinearConv1D_Col",
|
"GPT2FusedLinearConv1D_Col",
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
|
import functools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
||||||
|
|
||||||
from .utils import is_share_sp_tp
|
from .utils import is_share_sp_tp
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -125,12 +129,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False):
|
||||||
ctx.save_for_backward(input_, weight, bias)
|
ctx.save_for_backward(input_, weight, bias)
|
||||||
ctx.use_bias = bias is not None
|
ctx.use_bias = bias is not None
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
ctx.async_grad_allreduce = async_grad_allreduce
|
ctx.async_grad_allreduce = async_grad_allreduce
|
||||||
ctx.fp8_communication = fp8_communication
|
ctx.fp8_communication = fp8_communication
|
||||||
|
ctx.use_zbv = use_zbv
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output = F.linear(input_, weight, bias)
|
output = F.linear(input_, weight, bias)
|
||||||
else:
|
else:
|
||||||
|
@ -143,6 +148,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
input, weight, bias = ctx.saved_tensors
|
input, weight, bias = ctx.saved_tensors
|
||||||
use_bias = ctx.use_bias
|
use_bias = ctx.use_bias
|
||||||
fp8_communication = ctx.fp8_communication
|
fp8_communication = ctx.fp8_communication
|
||||||
|
use_zbv = ctx.use_zbv
|
||||||
|
|
||||||
|
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
|
||||||
|
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
|
||||||
|
|
||||||
|
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
|
||||||
|
return wgrad_gemm_func(_grad_output_.t(), _input_)
|
||||||
|
|
||||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
||||||
if use_bias:
|
if use_bias:
|
||||||
|
@ -164,9 +176,35 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||||
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||||
|
|
||||||
if _grad_accum_fusion_available and weight.grad is not None:
|
if _grad_accum_fusion_available and weight.grad is not None:
|
||||||
grad = weight.grad
|
grad = weight.grad
|
||||||
|
if use_zbv:
|
||||||
|
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
|
||||||
|
if grad.dtype == torch.float32:
|
||||||
|
WeightGradStore.put(
|
||||||
|
total_input,
|
||||||
|
grad_output,
|
||||||
|
weight,
|
||||||
|
functools.partial(
|
||||||
|
execute_w_pass_grad_accum,
|
||||||
|
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
grad_weight = None
|
||||||
|
elif grad.dtype in (torch.float16, torch.bfloat16):
|
||||||
|
WeightGradStore.put(
|
||||||
|
total_input,
|
||||||
|
grad_output,
|
||||||
|
weight,
|
||||||
|
functools.partial(
|
||||||
|
execute_w_pass_grad_accum,
|
||||||
|
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
grad_weight = None
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
|
||||||
|
else:
|
||||||
if grad.dtype == torch.float32:
|
if grad.dtype == torch.float32:
|
||||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||||
grad_weight = None
|
grad_weight = None
|
||||||
|
@ -175,6 +213,18 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
grad_weight = None
|
grad_weight = None
|
||||||
else:
|
else:
|
||||||
grad_weight = grad_output.t().matmul(total_input)
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
|
else:
|
||||||
|
if use_zbv:
|
||||||
|
WeightGradStore.put(
|
||||||
|
total_input,
|
||||||
|
grad_output,
|
||||||
|
weight,
|
||||||
|
functools.partial(
|
||||||
|
execute_w_pass,
|
||||||
|
wgrad_gemm_func=torch.matmul,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
grad_weight = None
|
||||||
else:
|
else:
|
||||||
grad_weight = grad_output.t().matmul(total_input)
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
|
|
||||||
|
@ -182,6 +232,104 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
|
|
||||||
if ctx.async_grad_allreduce and not fp8_communication:
|
if ctx.async_grad_allreduce and not fp8_communication:
|
||||||
handle.wait()
|
handle.wait()
|
||||||
|
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class LinearWithGradAccum(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Linear layer baseline (no tensor parallel version).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False):
|
||||||
|
ctx.save_for_backward(input_, weight, bias)
|
||||||
|
ctx.use_bias = bias is not None
|
||||||
|
ctx.async_grad_allreduce = async_grad_allreduce
|
||||||
|
ctx.use_zbv = use_zbv
|
||||||
|
if bias is not None:
|
||||||
|
output = F.linear(input_, weight, bias)
|
||||||
|
else:
|
||||||
|
output = F.linear(input_, weight)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input, weight, bias = ctx.saved_tensors
|
||||||
|
use_bias = ctx.use_bias
|
||||||
|
use_zbv = ctx.use_zbv
|
||||||
|
|
||||||
|
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
|
||||||
|
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
|
||||||
|
|
||||||
|
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
|
||||||
|
return wgrad_gemm_func(_grad_output_.t(), _input_)
|
||||||
|
|
||||||
|
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
||||||
|
if use_bias:
|
||||||
|
bias.view(bias.shape)
|
||||||
|
|
||||||
|
total_input = input.contiguous()
|
||||||
|
grad_input = grad_output.matmul(weight)
|
||||||
|
grad_output = grad_output.contiguous()
|
||||||
|
# Convert the tensor shapes to 2D for execution compatibility
|
||||||
|
if len(grad_output.shape) > 2:
|
||||||
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||||
|
total_input = total_input.view(-1, total_input.shape[-1])
|
||||||
|
|
||||||
|
if _grad_accum_fusion_available and weight.grad is not None:
|
||||||
|
grad = weight.grad
|
||||||
|
if use_zbv:
|
||||||
|
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
|
||||||
|
if grad.dtype == torch.float32:
|
||||||
|
WeightGradStore.put(
|
||||||
|
total_input,
|
||||||
|
grad_output,
|
||||||
|
weight,
|
||||||
|
functools.partial(
|
||||||
|
execute_w_pass_grad_accum,
|
||||||
|
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
grad_weight = None
|
||||||
|
elif grad.dtype in (torch.float16, torch.bfloat16):
|
||||||
|
WeightGradStore.put(
|
||||||
|
total_input,
|
||||||
|
grad_output,
|
||||||
|
weight,
|
||||||
|
functools.partial(
|
||||||
|
execute_w_pass_grad_accum,
|
||||||
|
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
grad_weight = None
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
|
||||||
|
else:
|
||||||
|
if grad.dtype == torch.float32:
|
||||||
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||||
|
grad_weight = None
|
||||||
|
elif grad.dtype == torch.float16:
|
||||||
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||||
|
grad_weight = None
|
||||||
|
else:
|
||||||
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
|
else:
|
||||||
|
if use_zbv:
|
||||||
|
WeightGradStore.put(
|
||||||
|
total_input,
|
||||||
|
grad_output,
|
||||||
|
weight,
|
||||||
|
functools.partial(
|
||||||
|
execute_w_pass,
|
||||||
|
wgrad_gemm_func=torch.matmul,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
grad_weight = None
|
||||||
|
else:
|
||||||
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
|
|
||||||
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||||
|
|
||||||
|
@ -1043,12 +1191,18 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
def linear_with_async_comm(
|
||||||
|
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
|
||||||
|
):
|
||||||
return LinearWithAsyncCommunication.apply(
|
return LinearWithAsyncCommunication.apply(
|
||||||
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
|
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=False):
|
||||||
|
return LinearWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)
|
||||||
|
|
||||||
|
|
||||||
def linear_gather_forward_reducescatter_backward(
|
def linear_gather_forward_reducescatter_backward(
|
||||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
|
||||||
):
|
):
|
||||||
|
|
|
@ -28,6 +28,7 @@ from ._operation import (
|
||||||
linear_gather_forward_reducescatter_backward,
|
linear_gather_forward_reducescatter_backward,
|
||||||
linear_reducescatter_forward_gather_backward,
|
linear_reducescatter_forward_gather_backward,
|
||||||
linear_with_async_comm,
|
linear_with_async_comm,
|
||||||
|
linear_with_grad_accum,
|
||||||
reduce_forward,
|
reduce_forward,
|
||||||
reducescatter_forward_gather_backward,
|
reducescatter_forward_gather_backward,
|
||||||
split_forward_gather_backward,
|
split_forward_gather_backward,
|
||||||
|
@ -35,7 +36,148 @@ from ._operation import (
|
||||||
from .parallel_module import PaddingParallelModule, ParallelModule
|
from .parallel_module import PaddingParallelModule, ParallelModule
|
||||||
from .utils import create_randomizer_with_offset
|
from .utils import create_randomizer_with_offset
|
||||||
|
|
||||||
__all__ = ["Linear1D_Col", "Linear1D_Row"]
|
__all__ = ["LinearWithGradAccum", "Linear1D_Col", "Linear1D_Row"]
|
||||||
|
|
||||||
|
|
||||||
|
class LinearWithGradAccum(ParallelModule):
|
||||||
|
r"""Linear layer with no parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features (int): size of each input sample.
|
||||||
|
out_features (int): size of each output sample.
|
||||||
|
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||||
|
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||||
|
device (`torch.device`): The device of parameters, defaults to None.
|
||||||
|
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||||
|
to all GPUs, otherwise, every GPU will have its output
|
||||||
|
which is :math:`Y_i = XA_i`, defaults to False
|
||||||
|
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||||
|
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
|
||||||
|
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||||
|
which is preserved for kernel fusion, defaults to False
|
||||||
|
weight_initializer (`typing.Callable`):
|
||||||
|
The initializer of weight, defaults to kaiming uniform initializer.
|
||||||
|
bias_initializer (`typing.Callable`):
|
||||||
|
The initializer of bias, defaults to xavier uniform initializer.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
dtype: torch.dtype = None,
|
||||||
|
device: torch.device = None,
|
||||||
|
skip_bias_add: bool = False,
|
||||||
|
weight: Optional[Parameter] = None,
|
||||||
|
bias_: Optional[Parameter] = None,
|
||||||
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
|
use_zbv: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(weight=weight, bias_=bias_, **kwargs)
|
||||||
|
|
||||||
|
# Keep input parameters
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.skip_bias_add = skip_bias_add
|
||||||
|
self.device = device
|
||||||
|
self.use_zbv = use_zbv
|
||||||
|
|
||||||
|
if skip_bias_add and not bias:
|
||||||
|
raise ValueError("cannot skip bias addition if bias is None")
|
||||||
|
|
||||||
|
# offset the seed with randomizer index and rank
|
||||||
|
seed = torch.random.initial_seed()
|
||||||
|
|
||||||
|
self.randomizer = create_randomizer_with_offset(seed, process_group=None)
|
||||||
|
|
||||||
|
# sanity check
|
||||||
|
if weight is not None:
|
||||||
|
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
|
||||||
|
else:
|
||||||
|
assert bias_ is None, "bias_ must be None if weight is None"
|
||||||
|
|
||||||
|
# Parameters.
|
||||||
|
if weight is None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||||
|
self.weight = weight
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
if bias_ is None:
|
||||||
|
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||||
|
self.bias = bias_
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
if weight is None:
|
||||||
|
# init weights
|
||||||
|
self.reset_parameters(weight_initializer, bias_initializer)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(module: nn.Linear, **kwargs) -> ParallelModule:
|
||||||
|
r"""
|
||||||
|
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||||
|
"""
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
# get the attributes
|
||||||
|
in_features = module.in_features
|
||||||
|
out_features = module.out_features
|
||||||
|
bias = module.bias is not None
|
||||||
|
device = module.weight.device
|
||||||
|
|
||||||
|
linear_1d = LinearWithGradAccum(
|
||||||
|
in_features=in_features,
|
||||||
|
out_features=out_features,
|
||||||
|
bias=bias,
|
||||||
|
device=device,
|
||||||
|
weight=module.weight,
|
||||||
|
bias_=module.bias,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return linear_1d
|
||||||
|
|
||||||
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
|
with self.randomizer.fork_rng(enable_cpu=True):
|
||||||
|
fan_in, fan_out = self.in_features, self.out_features
|
||||||
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||||
|
if self.bias is not None:
|
||||||
|
bias_initializer(self.bias, fan_in=fan_in)
|
||||||
|
|
||||||
|
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
|
assert (
|
||||||
|
input_.shape[-1] == self.weight.shape[-1]
|
||||||
|
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||||
|
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up backprop all-reduce.
|
||||||
|
input_parallel = input_
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
output_parallel = linear_with_grad_accum(
|
||||||
|
input_parallel,
|
||||||
|
self.weight,
|
||||||
|
bias,
|
||||||
|
False,
|
||||||
|
use_zbv=self.use_zbv,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = output_parallel
|
||||||
|
|
||||||
|
if self.skip_bias_add:
|
||||||
|
return output, self.bias
|
||||||
|
else:
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class Linear1D_Col(ParallelModule):
|
class Linear1D_Col(ParallelModule):
|
||||||
|
@ -85,6 +227,7 @@ class Linear1D_Col(ParallelModule):
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
|
use_zbv: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(weight=weight, bias_=bias_, **kwargs)
|
super().__init__(weight=weight, bias_=bias_, **kwargs)
|
||||||
|
@ -100,6 +243,7 @@ class Linear1D_Col(ParallelModule):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.fp8_communication = fp8_communication
|
self.fp8_communication = fp8_communication
|
||||||
|
self.use_zbv = use_zbv
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
raise ValueError("cannot skip bias addition if bias is None")
|
raise ValueError("cannot skip bias addition if bias is None")
|
||||||
|
@ -201,13 +345,18 @@ class Linear1D_Col(ParallelModule):
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
if self.seq_parallel_mode == "split_gather":
|
if self.seq_parallel_mode == "split_gather":
|
||||||
input_parallel = gather_forward_reducescatter_backward(
|
input_parallel = gather_forward_reducescatter_backward(
|
||||||
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||||
)
|
)
|
||||||
output_parallel = linear_with_async_comm(
|
output_parallel = linear_with_async_comm(
|
||||||
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
|
input_parallel,
|
||||||
|
self.weight,
|
||||||
|
bias,
|
||||||
|
self.process_group,
|
||||||
|
False,
|
||||||
|
fp8_communication=self.fp8_communication,
|
||||||
|
use_zbv=self.use_zbv,
|
||||||
)
|
)
|
||||||
elif self.seq_parallel_mode == "ring":
|
elif self.seq_parallel_mode == "ring":
|
||||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||||
|
@ -215,9 +364,14 @@ class Linear1D_Col(ParallelModule):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_parallel = linear_with_async_comm(
|
output_parallel = linear_with_async_comm(
|
||||||
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
input_parallel,
|
||||||
|
self.weight,
|
||||||
|
bias,
|
||||||
|
self.process_group,
|
||||||
|
True,
|
||||||
|
fp8_communication=self.fp8_communication,
|
||||||
|
use_zbv=self.use_zbv,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
output = gather_forward_split_backward(
|
output = gather_forward_split_backward(
|
||||||
|
@ -273,6 +427,7 @@ class Linear1D_Row(ParallelModule):
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||||
stream_chunk_num: int = 1,
|
stream_chunk_num: int = 1,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
|
use_zbv: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -288,6 +443,7 @@ class Linear1D_Row(ParallelModule):
|
||||||
self.seq_parallel_dim = seq_parallel_dim
|
self.seq_parallel_dim = seq_parallel_dim
|
||||||
self.num_partitions = dist.get_world_size(self.process_group)
|
self.num_partitions = dist.get_world_size(self.process_group)
|
||||||
self.fp8_communication = fp8_communication
|
self.fp8_communication = fp8_communication
|
||||||
|
self.use_zbv = use_zbv
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
raise ValueError("cannot skip bias addition if bias is None")
|
raise ValueError("cannot skip bias addition if bias is None")
|
||||||
|
@ -429,10 +585,14 @@ class Linear1D_Row(ParallelModule):
|
||||||
output = torch.cat(output_parallel_list, dim=-1)
|
output = torch.cat(output_parallel_list, dim=-1)
|
||||||
else:
|
else:
|
||||||
if self.seq_parallel_mode is None:
|
if self.seq_parallel_mode is None:
|
||||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
output_parallel = linear_with_async_comm(
|
||||||
|
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
|
||||||
|
)
|
||||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||||
elif self.seq_parallel_mode == "split_gather":
|
elif self.seq_parallel_mode == "split_gather":
|
||||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
output_parallel = linear_with_async_comm(
|
||||||
|
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
|
||||||
|
)
|
||||||
output = reducescatter_forward_gather_backward(
|
output = reducescatter_forward_gather_backward(
|
||||||
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||||
)
|
)
|
||||||
|
@ -445,7 +605,9 @@ class Linear1D_Row(ParallelModule):
|
||||||
ring=True,
|
ring=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
output_parallel = linear_with_async_comm(
|
||||||
|
input_, self.weight, None, self.process_group, False, use_zbv=self.use_zbv
|
||||||
|
)
|
||||||
output = reduce_forward(output_parallel, self.process_group)
|
output = reduce_forward(output_parallel, self.process_group)
|
||||||
|
|
||||||
if not self.skip_bias_add:
|
if not self.skip_bias_add:
|
||||||
|
|
|
@ -82,7 +82,7 @@ class LlamaPipelineForwards:
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
batch_size, seq_length = input_ids.shape[:2]
|
batch_size, seq_length = input_ids.shape[:2]
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape[:2]
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
|
|
|
@ -60,6 +60,7 @@ class EPMixtralSparseMoeBlock(ParallelModule):
|
||||||
moe_dp_group: ProcessGroup,
|
moe_dp_group: ProcessGroup,
|
||||||
ep_group: ProcessGroup,
|
ep_group: ProcessGroup,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
|
use_zbv: bool = False,
|
||||||
):
|
):
|
||||||
assert tp_group is not None
|
assert tp_group is not None
|
||||||
assert moe_dp_group is not None
|
assert moe_dp_group is not None
|
||||||
|
@ -70,6 +71,7 @@ class EPMixtralSparseMoeBlock(ParallelModule):
|
||||||
self.ep_rank = dist.get_rank(ep_group)
|
self.ep_rank = dist.get_rank(ep_group)
|
||||||
self.ep_group = ep_group
|
self.ep_group = ep_group
|
||||||
self.fp8_communication = fp8_communication
|
self.fp8_communication = fp8_communication
|
||||||
|
self.use_zbv = use_zbv
|
||||||
|
|
||||||
if self.num_experts % self.ep_size != 0:
|
if self.num_experts % self.ep_size != 0:
|
||||||
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
|
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
|
||||||
|
@ -89,13 +91,13 @@ class EPMixtralSparseMoeBlock(ParallelModule):
|
||||||
if self.tp_group.size() > 1:
|
if self.tp_group.size() > 1:
|
||||||
for expert in held_experts:
|
for expert in held_experts:
|
||||||
expert.w1 = Linear1D_Col.from_native_module(
|
expert.w1 = Linear1D_Col.from_native_module(
|
||||||
expert.w1, self.tp_group, fp8_communication=self.fp8_communication
|
expert.w1, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
|
||||||
)
|
)
|
||||||
expert.w3 = Linear1D_Col.from_native_module(
|
expert.w3 = Linear1D_Col.from_native_module(
|
||||||
expert.w3, self.tp_group, fp8_communication=self.fp8_communication
|
expert.w3, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
|
||||||
)
|
)
|
||||||
expert.w2 = Linear1D_Row.from_native_module(
|
expert.w2 = Linear1D_Row.from_native_module(
|
||||||
expert.w2, self.tp_group, fp8_communication=self.fp8_communication
|
expert.w2, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
|
||||||
)
|
)
|
||||||
|
|
||||||
for p in self.experts.parameters():
|
for p in self.experts.parameters():
|
||||||
|
@ -399,6 +401,7 @@ class MixtralPipelineForwards:
|
||||||
|
|
||||||
if output_router_logits and past_router_logits is not None:
|
if output_router_logits and past_router_logits is not None:
|
||||||
all_router_logits = past_router_logits + all_router_logits
|
all_router_logits = past_router_logits + all_router_logits
|
||||||
|
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.is_last_stage():
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
|
@ -512,7 +515,6 @@ class MixtralPipelineForwards:
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
logits = logits.float()
|
logits = logits.float()
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
|
|
|
@ -60,6 +60,8 @@ class LlamaPolicy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = RMSNorm
|
norm_cls = RMSNorm
|
||||||
|
|
||||||
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
sp_size = self.shard_config.sequence_parallel_size or None
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
|
@ -126,37 +128,65 @@ class LlamaPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.q_proj",
|
suffix="self_attn.q_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.k_proj",
|
suffix="self_attn.k_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.v_proj",
|
suffix="self_attn.v_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.gate_proj",
|
suffix="mlp.gate_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.up_proj",
|
suffix="mlp.up_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.down_proj",
|
suffix="mlp.down_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
seq_parallel_mode=sp_mode,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -265,7 +295,6 @@ class LlamaPolicy(Policy):
|
||||||
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
|
||||||
):
|
):
|
||||||
held_layers.append(module.norm)
|
held_layers.append(module.norm)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||||
if stage_manager.is_first_stage():
|
if stage_manager.is_first_stage():
|
||||||
|
@ -385,6 +414,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||||
from transformers import LlamaForSequenceClassification
|
from transformers import LlamaForSequenceClassification
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for sequence classification
|
# add a new item for sequence classification
|
||||||
|
@ -397,6 +427,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||||
kwargs=dict(
|
kwargs=dict(
|
||||||
gather_output=True,
|
gather_output=True,
|
||||||
fp8_communication=self.shard_config.fp8_communication,
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
|
@ -52,6 +52,7 @@ class MixtralPolicy(Policy):
|
||||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
tp_size = self.shard_config.tensor_parallel_size
|
tp_size = self.shard_config.tensor_parallel_size
|
||||||
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
|
|
||||||
# modified for both SP and TP
|
# modified for both SP and TP
|
||||||
num_q_heads = self.model.config.num_attention_heads
|
num_q_heads = self.model.config.num_attention_heads
|
||||||
|
@ -124,27 +125,43 @@ class MixtralPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.q_proj",
|
suffix="self_attn.q_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.k_proj",
|
suffix="self_attn.k_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.v_proj",
|
suffix="self_attn.v_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
kwargs={"fp8_communication": self.shard_config.fp8_communication},
|
kwargs={
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="block_sparse_moe.gate",
|
suffix="block_sparse_moe.gate",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},
|
kwargs={
|
||||||
|
"gather_output": True,
|
||||||
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -179,6 +196,7 @@ class MixtralPolicy(Policy):
|
||||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": use_zbv,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
@ -313,6 +331,7 @@ class MixtralModelPolicy(MixtralPolicy):
|
||||||
class MixtralForCausalLMPolicy(MixtralPolicy):
|
class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
# TODO: assign pg mesh from plugin to all modules
|
# TODO: assign pg mesh from plugin to all modules
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for causal lm
|
# add a new item for causal lm
|
||||||
|
@ -322,9 +341,13 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
gather_output=True,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
policy.update(new_item)
|
policy.update(new_item)
|
||||||
|
@ -343,7 +366,9 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||||
"""Get pipeline layers for current stage."""
|
"""Get pipeline layers for current stage."""
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
held_layers = super().get_held_layers()
|
held_layers = super().get_held_layers()
|
||||||
if stage_manager.is_last_stage():
|
if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
held_layers.append(self.model.lm_head)
|
||||||
|
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
held_layers.append(self.model.lm_head)
|
held_layers.append(self.model.lm_head)
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
|
@ -369,6 +394,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
||||||
from transformers import MixtralForSequenceClassification
|
from transformers import MixtralForSequenceClassification
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for sequence classification
|
# add a new item for sequence classification
|
||||||
|
@ -378,7 +404,11 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="score",
|
suffix="score",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
|
kwargs=dict(
|
||||||
|
gather_output=True,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,6 +21,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
|
||||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
@ -39,6 +40,7 @@ MODEL_CONFIGS = {
|
||||||
),
|
),
|
||||||
"5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8),
|
"5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8),
|
||||||
"7b": LlamaConfig(max_position_embeddings=4096),
|
"7b": LlamaConfig(max_position_embeddings=4096),
|
||||||
|
# "7b": LlamaConfig(num_hidden_layers=4, max_position_embeddings=4096),
|
||||||
"13b": LlamaConfig(
|
"13b": LlamaConfig(
|
||||||
hidden_size=5120,
|
hidden_size=5120,
|
||||||
intermediate_size=13824,
|
intermediate_size=13824,
|
||||||
|
@ -91,7 +93,7 @@ def main():
|
||||||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||||
|
|
||||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
|
||||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||||
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -106,6 +108,7 @@ def main():
|
||||||
parser.add_argument("--no_cache", action="store_true")
|
parser.add_argument("--no_cache", action="store_true")
|
||||||
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
||||||
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
|
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
|
||||||
|
parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p")
|
||||||
parser.add_argument("--overlap_allgather", action="store_true")
|
parser.add_argument("--overlap_allgather", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sp_mode",
|
"--sp_mode",
|
||||||
|
@ -126,9 +129,12 @@ def main():
|
||||||
{
|
{
|
||||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||||
|
# num_ckpt_layers_per_stage=[48, 48, 48, 48],
|
||||||
),
|
),
|
||||||
"num_layers_per_stage": [19, 20, 20, 21],
|
"num_layers_per_stage": [19, 20, 20, 21],
|
||||||
"pp_style": "interleaved",
|
# "num_layers_per_stage": [48, 48, 48, 48],
|
||||||
|
# "pp_style": "interleaved",
|
||||||
|
"pp_style": "1f1b",
|
||||||
}
|
}
|
||||||
if args.custom_ckpt
|
if args.custom_ckpt
|
||||||
else {}
|
else {}
|
||||||
|
@ -137,6 +143,11 @@ def main():
|
||||||
# ==============================
|
# ==============================
|
||||||
# Initialize Booster
|
# Initialize Booster
|
||||||
# ==============================
|
# ==============================
|
||||||
|
if args.config in MODEL_CONFIGS:
|
||||||
|
config = MODEL_CONFIGS[args.config]
|
||||||
|
else:
|
||||||
|
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||||
|
|
||||||
use_empty_init = True
|
use_empty_init = True
|
||||||
if args.plugin == "gemini":
|
if args.plugin == "gemini":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
|
@ -210,6 +221,24 @@ def main():
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
elif args.plugin == "3d":
|
elif args.plugin == "3d":
|
||||||
|
if args.pp_style == "zbv":
|
||||||
|
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
|
||||||
|
mem_w = -32 * config.hidden_size
|
||||||
|
mem_b = -mem_w - mem_f
|
||||||
|
scheduler_nodes = PipelineGraph(
|
||||||
|
n_stage=args.pp,
|
||||||
|
n_micro=args.batch_size // args.mbs,
|
||||||
|
f_cost=1000,
|
||||||
|
b_cost=1000,
|
||||||
|
w_cost=1000,
|
||||||
|
c_cost=1,
|
||||||
|
f_mem=mem_f * 1.5,
|
||||||
|
b_mem=mem_b * 1.5,
|
||||||
|
w_mem=mem_w * 1.5,
|
||||||
|
).get_v_schedule()
|
||||||
|
else:
|
||||||
|
scheduler_nodes = None
|
||||||
|
|
||||||
plugin = HybridParallelPlugin(
|
plugin = HybridParallelPlugin(
|
||||||
tp_size=args.tp,
|
tp_size=args.tp,
|
||||||
pp_size=args.pp,
|
pp_size=args.pp,
|
||||||
|
@ -227,6 +256,7 @@ def main():
|
||||||
overlap_allgather=args.overlap_allgather,
|
overlap_allgather=args.overlap_allgather,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
|
scheduler_nodes=scheduler_nodes,
|
||||||
**hybrid_kwargs,
|
**hybrid_kwargs,
|
||||||
)
|
)
|
||||||
elif args.plugin == "3d_cpu":
|
elif args.plugin == "3d_cpu":
|
||||||
|
@ -242,7 +272,7 @@ def main():
|
||||||
microbatch_size=args.mbs,
|
microbatch_size=args.mbs,
|
||||||
initial_scale=2**8,
|
initial_scale=2**8,
|
||||||
precision="bf16",
|
precision="bf16",
|
||||||
overlap_p2p=args.overlap,
|
overlap_p2p=args.overlap_p2p,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
|
@ -260,6 +290,7 @@ def main():
|
||||||
config = MODEL_CONFIGS[args.config]
|
config = MODEL_CONFIGS[args.config]
|
||||||
else:
|
else:
|
||||||
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||||
|
|
||||||
torch.cuda.manual_seed(42)
|
torch.cuda.manual_seed(42)
|
||||||
dataset = RandomDataset(
|
dataset = RandomDataset(
|
||||||
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
|
||||||
|
@ -319,7 +350,7 @@ def main():
|
||||||
args.profile,
|
args.profile,
|
||||||
args.ignore_steps,
|
args.ignore_steps,
|
||||||
1, # avoid creating massive log files
|
1, # avoid creating massive log files
|
||||||
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
save_dir=f"./profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
|
||||||
nsys=args.nsys,
|
nsys=args.nsys,
|
||||||
) as prof:
|
) as prof:
|
||||||
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
|
||||||
|
@ -334,7 +365,11 @@ def main():
|
||||||
return_loss=True,
|
return_loss=True,
|
||||||
)
|
)
|
||||||
loss = outputs["loss"]
|
loss = outputs["loss"]
|
||||||
if dist.get_rank() == dist.get_world_size() - 1:
|
if args.pp_style == "zbv":
|
||||||
|
if coordinator.is_master():
|
||||||
|
print(f"Step {step} loss: {loss}")
|
||||||
|
else:
|
||||||
|
if coordinator.is_last_process():
|
||||||
print(f"Step {step} loss: {loss}")
|
print(f"Step {step} loss: {loss}")
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
|
@ -11,6 +11,7 @@ from data_utils import RandomDataset
|
||||||
from model_utils import format_numel_str, get_model_numel
|
from model_utils import format_numel_str, get_model_numel
|
||||||
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
from performance_evaluator import PerformanceEvaluator, get_profile_context
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoConfig
|
||||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
@ -20,6 +21,7 @@ from colossalai.booster.plugin import MoeHybridParallelPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
|
||||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
@ -85,7 +87,7 @@ def main():
|
||||||
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
|
parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
|
||||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||||
|
|
||||||
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
|
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"])
|
||||||
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
|
||||||
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
parser.add_argument("--profile", action="store_true", help="Profile the code")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -120,7 +122,7 @@ def main():
|
||||||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||||
),
|
),
|
||||||
"num_layers_per_stage": [19, 20, 20, 21],
|
"num_layers_per_stage": [19, 20, 20, 21],
|
||||||
"pp_style": "interleaved",
|
# "pp_style": "interleaved",
|
||||||
}
|
}
|
||||||
if args.custom_ckpt
|
if args.custom_ckpt
|
||||||
else {}
|
else {}
|
||||||
|
@ -129,7 +131,29 @@ def main():
|
||||||
# ==============================
|
# ==============================
|
||||||
# Initialize Booster
|
# Initialize Booster
|
||||||
# ==============================
|
# ==============================
|
||||||
|
if args.config in MODEL_CONFIGS:
|
||||||
|
config = MODEL_CONFIGS[args.config]
|
||||||
|
else:
|
||||||
|
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
|
||||||
|
|
||||||
if args.plugin == "3d":
|
if args.plugin == "3d":
|
||||||
|
if args.pp_style == "zbv":
|
||||||
|
mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length
|
||||||
|
mem_w = -32 * config.hidden_size
|
||||||
|
mem_b = -mem_w - mem_f
|
||||||
|
scheduler_nodes = PipelineGraph(
|
||||||
|
n_stage=args.pp,
|
||||||
|
n_micro=args.batch_size // args.mbs,
|
||||||
|
f_cost=1000,
|
||||||
|
b_cost=1000,
|
||||||
|
w_cost=1000,
|
||||||
|
c_cost=1,
|
||||||
|
f_mem=mem_f,
|
||||||
|
b_mem=mem_b,
|
||||||
|
w_mem=mem_w,
|
||||||
|
).get_v_schedule()
|
||||||
|
else:
|
||||||
|
scheduler_nodes = None
|
||||||
plugin = MoeHybridParallelPlugin(
|
plugin = MoeHybridParallelPlugin(
|
||||||
ep_size=args.ep,
|
ep_size=args.ep,
|
||||||
tp_size=args.tp,
|
tp_size=args.tp,
|
||||||
|
@ -143,11 +167,13 @@ def main():
|
||||||
enable_fused_normalization=torch.cuda.is_available(),
|
enable_fused_normalization=torch.cuda.is_available(),
|
||||||
enable_flash_attention=args.xformers,
|
enable_flash_attention=args.xformers,
|
||||||
microbatch_size=args.mbs,
|
microbatch_size=args.mbs,
|
||||||
|
num_microbatches=args.batch_size // args.mbs,
|
||||||
precision="bf16",
|
precision="bf16",
|
||||||
enable_metadata_cache=not args.no_cache,
|
enable_metadata_cache=not args.no_cache,
|
||||||
overlap_allgather=args.overlap_allgather,
|
overlap_allgather=args.overlap_allgather,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
|
scheduler_nodes=scheduler_nodes,
|
||||||
**hybrid_kwargs,
|
**hybrid_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -183,8 +209,10 @@ def main():
|
||||||
with init_ctx:
|
with init_ctx:
|
||||||
model = MixtralForCausalLM(config=config).to(torch.bfloat16)
|
model = MixtralForCausalLM(config=config).to(torch.bfloat16)
|
||||||
|
|
||||||
|
# if args.grad_checkpoint:
|
||||||
|
# model.gradient_checkpointing_enable()
|
||||||
if args.grad_checkpoint:
|
if args.grad_checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
|
|
||||||
model_numel = get_model_numel(model)
|
model_numel = get_model_numel(model)
|
||||||
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
||||||
|
@ -229,6 +257,10 @@ def main():
|
||||||
return_loss=True,
|
return_loss=True,
|
||||||
)
|
)
|
||||||
loss = outputs["loss"]
|
loss = outputs["loss"]
|
||||||
|
if args.pp_style == "zbv":
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
print(f"Step {step} loss: {loss}")
|
||||||
|
else:
|
||||||
if dist.get_rank() == dist.get_world_size() - 1:
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
print(f"Step {step} loss: {loss}")
|
print(f"Step {step} loss: {loss}")
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
|
@ -21,11 +21,16 @@ def divide(x: float, y: float) -> float:
|
||||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
return x
|
return x
|
||||||
|
# BUG: RuntimeError: Invalid scalar type when use dist.all_reduce(tensor, group=gloo_group)
|
||||||
|
# # Use CPU tensor to avoid OOM/weird NCCl error
|
||||||
|
# gloo_group = dist.new_group(backend="gloo")
|
||||||
|
# tensor = torch.tensor([x], device="cpu")
|
||||||
|
# dist.all_reduce(tensor, group=gloo_group)
|
||||||
|
# tensor = tensor / world_size
|
||||||
|
# return tensor.item()
|
||||||
|
|
||||||
# Use CPU tensor to avoid OOM/weird NCCl error
|
tensor = torch.tensor([x], device=torch.cuda.current_device(), dtype=torch.float)
|
||||||
gloo_group = dist.new_group(backend="gloo")
|
dist.all_reduce(tensor)
|
||||||
tensor = torch.tensor([x], device="cpu")
|
|
||||||
dist.all_reduce(tensor, group=gloo_group)
|
|
||||||
tensor = tensor / world_size
|
tensor = tensor / world_size
|
||||||
return tensor.item()
|
return tensor.item()
|
||||||
|
|
||||||
|
|
|
@ -8,12 +8,14 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaModel
|
||||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster.booster import Booster
|
from colossalai.booster.booster import Booster
|
||||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin, MoeHybridParallelPlugin
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
@ -756,10 +758,11 @@ def run_with_hybridplugin(test_config):
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
(0, 1, 4, 1, 1),
|
# (0, 1, 4, 1, 1),
|
||||||
(1, 2, 2, 1, 1),
|
# (1, 2, 2, 1, 1),
|
||||||
(1, 2, 1, 2, 1),
|
(1, 1, 2, 2, 1),
|
||||||
(1, 2, 1, 1, 2),
|
# (1, 2, 1, 2, 1),
|
||||||
|
# (1, 2, 1, 1, 2),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
|
@ -790,6 +793,8 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
seed_all(10086)
|
seed_all(10086)
|
||||||
|
|
||||||
torch_model = MixtralModel(config).to(dtype).cuda()
|
torch_model = MixtralModel(config).to(dtype).cuda()
|
||||||
|
# TODO: Support MixtralForCausalLM
|
||||||
|
# torch_model = MixtralForCausalLM(config).to(dtype).cuda()
|
||||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||||
# init schedule
|
# init schedule
|
||||||
h, a, s = config.hidden_size, config.num_attention_heads, 1024
|
h, a, s = config.hidden_size, config.num_attention_heads, 1024
|
||||||
|
@ -892,7 +897,7 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
|
|
||||||
# ===================================================================================
|
# ===================================================================================
|
||||||
# run normal model with all dp(different) inputs
|
# run normal model with all dp(different) inputs
|
||||||
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
|
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
||||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
||||||
torch_output_sum = 0
|
torch_output_sum = 0
|
||||||
for input_data_ in all_inputs:
|
for input_data_ in all_inputs:
|
||||||
|
@ -905,6 +910,7 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
p.grad /= dp_size
|
p.grad /= dp_size
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
torch_optimizer.zero_grad()
|
torch_optimizer.zero_grad()
|
||||||
|
|
||||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
print(f"rank {dist.get_rank()} config {test_config} test passed")
|
print(f"rank {dist.get_rank()} config {test_config} test passed")
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
|
@ -912,11 +918,169 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"config",
|
||||||
|
[
|
||||||
|
(1, 2, 2, 1), # Pass
|
||||||
|
# TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture;
|
||||||
|
# (0, 4, 1, 1),
|
||||||
|
# (1, 2, 1, 2),
|
||||||
|
# (1, 1, 2, 2),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
|
stage, pp_size, tp_size, sp_size = config
|
||||||
|
num_microbatches = pp_size
|
||||||
|
dist.get_world_size()
|
||||||
|
rank = dist.get_rank()
|
||||||
|
dtype, precision = torch.float16, "fp16"
|
||||||
|
torch.cuda.set_device(dist.get_rank())
|
||||||
|
|
||||||
|
########
|
||||||
|
# init base model
|
||||||
|
########
|
||||||
|
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
|
||||||
|
config = LlamaConfig(
|
||||||
|
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
|
||||||
|
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
|
||||||
|
num_hidden_layers=NUM_LAYERS,
|
||||||
|
num_attention_heads=NUM_HEADS,
|
||||||
|
num_key_value_heads=NUM_HEADS,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# init model with the same seed
|
||||||
|
seed_all(10086)
|
||||||
|
|
||||||
|
torch_model = LlamaModel(config).to(dtype).cuda()
|
||||||
|
# TODO: Support MixtralForCausalLM
|
||||||
|
# torch_model = MixtralForCausalLM(config).to(dtype).cuda()
|
||||||
|
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||||
|
# init schedule
|
||||||
|
h, a, s = config.hidden_size, config.num_attention_heads, 1024
|
||||||
|
mem_f = 34 * h + 5 * a * s
|
||||||
|
mem_w = -32 * h
|
||||||
|
mem_b = -mem_w - mem_f
|
||||||
|
graph = PipelineGraph(
|
||||||
|
n_stage=pp_size,
|
||||||
|
n_micro=num_microbatches,
|
||||||
|
f_cost=1,
|
||||||
|
b_cost=1,
|
||||||
|
w_cost=1,
|
||||||
|
c_cost=1,
|
||||||
|
f_mem=mem_f,
|
||||||
|
b_mem=mem_b,
|
||||||
|
w_mem=mem_w,
|
||||||
|
)
|
||||||
|
|
||||||
|
zbv_schedule = graph.get_v_schedule()
|
||||||
|
|
||||||
|
# init HybridParallelPlugin
|
||||||
|
plugin = HybridParallelPlugin(
|
||||||
|
pp_size=pp_size,
|
||||||
|
num_microbatches=pp_size,
|
||||||
|
tp_size=tp_size,
|
||||||
|
sp_size=sp_size,
|
||||||
|
zero_stage=stage,
|
||||||
|
enable_sequence_parallelism=sp_size > 1,
|
||||||
|
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||||
|
overlap_communication=False,
|
||||||
|
initial_scale=1,
|
||||||
|
precision=precision,
|
||||||
|
find_unused_parameters=True,
|
||||||
|
pp_style="zbv",
|
||||||
|
scheduler_nodes=zbv_schedule,
|
||||||
|
num_model_chunks=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
dp_size = plugin.dp_size
|
||||||
|
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
########
|
||||||
|
# init pp model
|
||||||
|
########
|
||||||
|
|
||||||
|
parallel_model = deepcopy(torch_model)
|
||||||
|
parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
|
||||||
|
parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
|
||||||
|
# create different input along dp axis
|
||||||
|
seed_all(1453 + rank)
|
||||||
|
|
||||||
|
torch_model.train()
|
||||||
|
parallel_model.train()
|
||||||
|
for _ in range(2):
|
||||||
|
# gen random input
|
||||||
|
input_embeddings = torch.rand(
|
||||||
|
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||||
|
).cuda()
|
||||||
|
dist.all_reduce(
|
||||||
|
input_embeddings, group=plugin.pp_group
|
||||||
|
) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
|
||||||
|
|
||||||
|
dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
|
||||||
|
dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
|
||||||
|
|
||||||
|
# run the model with hybrid parallel
|
||||||
|
if booster.plugin.stage_manager is not None:
|
||||||
|
# for test with pp
|
||||||
|
data_iter = iter([{"inputs_embeds": input_embeddings}])
|
||||||
|
sharded_output = booster.execute_pipeline(
|
||||||
|
data_iter,
|
||||||
|
parallel_model,
|
||||||
|
lambda x, y: x.last_hidden_state.mean(),
|
||||||
|
parallel_optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
return_outputs=True,
|
||||||
|
)
|
||||||
|
# stage 0 chunk 0
|
||||||
|
parallel_output = None
|
||||||
|
if (
|
||||||
|
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
|
||||||
|
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
|
||||||
|
):
|
||||||
|
parallel_output = sharded_output["loss"]
|
||||||
|
else:
|
||||||
|
parallel_output = torch.tensor(12345.0, device="cuda")
|
||||||
|
# broadcast along pp axis
|
||||||
|
dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# for test without pp
|
||||||
|
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
|
||||||
|
parallel_optimizer.backward(parallel_output)
|
||||||
|
parallel_optimizer.step()
|
||||||
|
parallel_optimizer.zero_grad()
|
||||||
|
dist.all_reduce(parallel_output, group=plugin.dp_group)
|
||||||
|
|
||||||
|
# ===================================================================================
|
||||||
|
# run normal model with all dp(different) inputs
|
||||||
|
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
||||||
|
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
||||||
|
torch_output_sum = 0
|
||||||
|
for input_data_ in all_inputs:
|
||||||
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||||
|
torch_output.backward()
|
||||||
|
torch_output_sum += torch_output.detach()
|
||||||
|
# avg dp grads follows zero optimizer
|
||||||
|
for p in torch_model.parameters():
|
||||||
|
if p.grad is not None:
|
||||||
|
p.grad /= dp_size
|
||||||
|
torch_optimizer.step()
|
||||||
|
torch_optimizer.zero_grad()
|
||||||
|
|
||||||
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
|
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
# run_fwd_bwd_vschedule_with_optim()
|
|
||||||
run_with_booster_moehybridplugin()
|
run_with_booster_moehybridplugin()
|
||||||
|
run_with_booster_hybridplugin()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -928,5 +1092,6 @@ def test_pp():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# python -m pytest -s tests/test_pipeline/test_schedule/test_zerobubble_pp.py
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_pp()
|
test_pp()
|
||||||
|
|
|
@ -8,7 +8,8 @@ from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
||||||
|
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum
|
||||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
@ -117,6 +118,93 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool):
|
||||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool):
|
||||||
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
|
linear = nn.Linear(32, 128).cuda()
|
||||||
|
with ctx:
|
||||||
|
linear_copy = nn.Linear(32, 128).cuda()
|
||||||
|
linear_base = LinearWithGradAccum.from_native_module(
|
||||||
|
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
|
||||||
|
)
|
||||||
|
assert linear_base.weight.shape == torch.Size([128, 32])
|
||||||
|
assert linear_base.bias.shape == torch.Size([128])
|
||||||
|
assert linear_copy.weight is linear_base.weight
|
||||||
|
assert linear_copy.bias is linear_base.bias
|
||||||
|
|
||||||
|
linear.load_state_dict(linear_base.state_dict())
|
||||||
|
linear_base.load_state_dict(linear.state_dict())
|
||||||
|
|
||||||
|
# check computation correctness
|
||||||
|
# [batch_size, seq_len, hidden_size]
|
||||||
|
x = torch.rand(2, 4, 32).cuda()
|
||||||
|
x_for_unshard = x.expand_as(x.clone())
|
||||||
|
x_for_unshard.requires_grad_(True)
|
||||||
|
x_for_shard = x.expand_as(x.clone())
|
||||||
|
x_for_shard.requires_grad_(True)
|
||||||
|
|
||||||
|
# run forward
|
||||||
|
out = linear(x_for_unshard)
|
||||||
|
gather_out = linear_base(x_for_shard)
|
||||||
|
assert_close(out, gather_out)
|
||||||
|
|
||||||
|
# check backward correctness
|
||||||
|
out.sum().backward()
|
||||||
|
gather_out.sum().backward()
|
||||||
|
assert_close(linear.weight.grad, linear_base.weight.grad)
|
||||||
|
# check the input gradients
|
||||||
|
assert x_for_shard.grad is not None
|
||||||
|
assert x_for_unshard.grad is not None
|
||||||
|
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool):
|
||||||
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
|
linear = nn.Linear(32, 128).cuda()
|
||||||
|
with ctx:
|
||||||
|
linear_copy = nn.Linear(32, 128).cuda()
|
||||||
|
linear_base = LinearWithGradAccum.from_native_module(
|
||||||
|
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
|
||||||
|
)
|
||||||
|
assert linear_base.weight.shape == torch.Size([128, 32])
|
||||||
|
assert linear_base.bias.shape == torch.Size([128])
|
||||||
|
assert linear_copy.weight is linear_base.weight
|
||||||
|
assert linear_copy.bias is linear_base.bias
|
||||||
|
|
||||||
|
linear.load_state_dict(linear_base.state_dict())
|
||||||
|
linear_base.load_state_dict(linear.state_dict())
|
||||||
|
|
||||||
|
# check computation correctness
|
||||||
|
# [batch_size, seq_len, hidden_size]
|
||||||
|
x = torch.rand(2, 4, 32).cuda()
|
||||||
|
x_for_unshard = x.expand_as(x.clone())
|
||||||
|
x_for_unshard.requires_grad_(True)
|
||||||
|
x_for_shard = x.expand_as(x.clone())
|
||||||
|
x_for_shard.requires_grad_(True)
|
||||||
|
|
||||||
|
# run forward
|
||||||
|
out = linear(x_for_unshard)
|
||||||
|
gather_out = linear_base(x_for_shard)
|
||||||
|
assert_close(out, gather_out)
|
||||||
|
|
||||||
|
# check backward correctness
|
||||||
|
out.sum().backward()
|
||||||
|
gather_out.sum().backward()
|
||||||
|
|
||||||
|
# Weight grad is None before we do WeightGradStore pop
|
||||||
|
assert linear_base.weight.grad is None
|
||||||
|
# after WeightGradStore pop (dw computation complete), we assert weight grad
|
||||||
|
WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue
|
||||||
|
WeightGradStore.pop(chunk=0)
|
||||||
|
assert_close(linear.weight.grad, linear_base.weight.grad)
|
||||||
|
|
||||||
|
# check the input gradients
|
||||||
|
assert x_for_shard.grad is not None
|
||||||
|
assert x_for_unshard.grad is not None
|
||||||
|
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||||
|
|
||||||
|
|
||||||
def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
|
@ -182,6 +270,8 @@ def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap):
|
||||||
check_linear_1d_col(lazy_init, seq_parallel_mode, overlap)
|
check_linear_1d_col(lazy_init, seq_parallel_mode, overlap)
|
||||||
check_linear_1d_row(lazy_init, seq_parallel_mode)
|
check_linear_1d_row(lazy_init, seq_parallel_mode)
|
||||||
check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap)
|
check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap)
|
||||||
|
check_linear_without_weight_grad_store(lazy_init, seq_parallel_mode)
|
||||||
|
check_linear_with_weight_grad_store(lazy_init, seq_parallel_mode)
|
||||||
|
|
||||||
|
|
||||||
def check_dist_linear(rank, world_size, port):
|
def check_dist_linear(rank, world_size, port):
|
||||||
|
|
|
@ -277,32 +277,33 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
{
|
# # TODO: assert layer error
|
||||||
"tp_size": 2,
|
# {
|
||||||
"pp_size": 2,
|
# "tp_size": 2,
|
||||||
"pp_style": "zbv",
|
# "pp_size": 2,
|
||||||
"num_model_chunks": 2,
|
# "pp_style": "zbv",
|
||||||
"num_microbatches": 4,
|
# "num_model_chunks": 2,
|
||||||
"enable_all_optimization": False,
|
# "num_microbatches": 4,
|
||||||
"precision": "fp16",
|
# "enable_all_optimization": False,
|
||||||
"zero_stage": 0,
|
# "precision": "fp16",
|
||||||
"initial_scale": 1,
|
# "zero_stage": 0,
|
||||||
"enable_gradient_checkpointing": True,
|
# "initial_scale": 1,
|
||||||
"parallel_output": False,
|
# "enable_gradient_checkpointing": True,
|
||||||
},
|
# "parallel_output": False,
|
||||||
{
|
# },
|
||||||
"tp_size": 2,
|
# {
|
||||||
"pp_size": 2,
|
# "tp_size": 2,
|
||||||
"pp_style": "zbv",
|
# "pp_size": 2,
|
||||||
"num_model_chunks": 2,
|
# "pp_style": "zbv",
|
||||||
"num_microbatches": 4,
|
# "num_model_chunks": 2,
|
||||||
"enable_all_optimization": False,
|
# "num_microbatches": 4,
|
||||||
"precision": "fp16",
|
# "enable_all_optimization": False,
|
||||||
"zero_stage": 1,
|
# "precision": "fp16",
|
||||||
"initial_scale": 1,
|
# "zero_stage": 1,
|
||||||
"enable_gradient_checkpointing": True,
|
# "initial_scale": 1,
|
||||||
"parallel_output": False,
|
# "enable_gradient_checkpointing": True,
|
||||||
},
|
# "parallel_output": False,
|
||||||
|
# },
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_llama_test(test_config):
|
def run_llama_test(test_config):
|
||||||
|
|
Loading…
Reference in New Issue