[fix] fix zerobubble pp for shardformer type input;

pull/6065/head
duanjunwen 2024-09-18 07:14:34 +00:00
parent 9bc3b6e220
commit 3dbad102cf
3 changed files with 224 additions and 66 deletions

View File

@ -131,6 +131,16 @@ def retain_grad(x: Any) -> None:
x.retain_grad()
def require_grad(x: Any) -> None:
"""Call require_grad on a tensor.
Args:
x (Any): Object to be called.
"""
if isinstance(x, torch.Tensor) and x.requires_grad:
x.requires_grad_()
def detach(x: Any) -> Any:
"""Call detach() on a tensor.
@ -145,6 +155,34 @@ def detach(x: Any) -> Any:
return x
def clone(x: Any) -> Any:
"""Call clone() on a tensor.
Args:
x (Any): Object to be called.
Returns:
Any: The cloned object.
"""
if isinstance(x, torch.Tensor):
return x.clone()
return x
def deallocate(x: Any) -> Any:
"""Call deallocate() on a tensor.
Args:
x (Any): Object to be called.
Returns:
Any: The deallocate .data object.
"""
if isinstance(x, torch.Tensor):
return x.data.untyped_storage().resize_(0)
return x
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
"""Merge micro batches into a batch.

View File

@ -12,7 +12,7 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device
from ._utils import clone, detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
@ -39,6 +39,20 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
out.data.untyped_storage().resize_(0)
def require_grad(tensor):
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
"""
if tensor is None:
return
assert isinstance(tensor, torch.Tensor), "expected Tensor, found %s." % type(tensor).__name__
assert tensor._base is None, "counter-productive to free a view of another tensor."
tensor.requires_grad_()
class ZeroBubbleVPipeScheduler(PipelineSchedule):
def __init__(
self,
@ -409,6 +423,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
micro_batch: Optional[dict],
input_obj: Optional[dict],
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
@ -427,18 +442,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
# Load input ids, attention mask and labels
# micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# for the first stage, input_obj is None
# for the first stage, input_obj is None; So,we use micro_batch as input_obj
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
# Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
# fwd calculate
output_obj = model_chunk[model_chunk_id](input_obj)
if isinstance(model_chunk, ModuleList):
# fwd for ModuleList model
if input_obj is None:
output_obj = model_chunk[model_chunk_id](**micro_batch)
else:
output_obj = model_chunk[model_chunk_id](**input_obj)
else:
# fwd for shardformer
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
internal_inputs = {} if input_obj is None else input_obj
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
loss = criterion(output_obj) / self.num_microbatch
loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
@ -472,19 +496,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# calculate bwd b step ; only dx = w*dy;
# Retain the grad on the input_obj.
if input_obj is None:
return None
else:
tree_map(retain_grad, input_obj)
input_obj_ = input_obj["hidden_states"]
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss; so output_obj_grad should be None
assert output_obj_grad is None
output_obj_ = output_obj
else:
output_obj_ = output_obj["hidden_states"]
optimizer.backward_by_grad(
tensor=output_obj,
tensor=output_obj_,
grad=output_obj_grad,
inputs=input_obj,
inputs=input_obj_,
retain_graph=True,
)
return input_obj.grad
return input_obj_.grad
def backward_w_step(
self,
@ -511,8 +541,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss
output_obj_grad = None
output_obj_ = output_obj
else:
output_obj_ = output_obj["hidden_states"]
optimizer.backward_by_grad(
tensor=output_obj,
tensor=output_obj_,
grad=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
@ -543,9 +576,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# Step1: recv fwd
if model_chunk_id == 0:
# is first stage; get input from func param
# is first stage; get input from microbatch
if self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = micro_batch
input_obj = None
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
else:
@ -557,45 +590,68 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
# Here, let input_obj.requires_grad_()
tree_map(torch.Tensor.requires_grad_, input_obj)
if input_obj is not None:
tree_map(require_grad, input_obj)
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
# tree_map(torch.Tensor.requires_grad_, micro_batch)
# Step2: fwd step
output_obj = self.forward_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
micro_batch=micro_batch,
input_obj=input_obj,
criterion=criterion,
accum_loss=accum_loss,
outputs=outputs,
)
# Step3: deallocate output for bwd b & w; (do not detach output)
deallocate_output_obj = tree_map(clone, output_obj)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not deallocate bwd LOSS
pass
else:
# deallocate output
tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), deallocate_output_obj)
# add input and output object for backward b
if input_obj is not None:
self.input_tensors[model_chunk_id].append(input_obj)
else:
self.input_tensors[model_chunk_id].append(micro_batch)
# for bwd b&w, we only need the graph(grad_fn) of output_obj
# Do not deallocate loss, deallocate other output_obj;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)
else:
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)
# Step4: detach output for send fwd;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not detach bwd LOSS
pass
else:
detached_output_obj = output_obj.clone().detach()
# detach output
output_obj = tree_map(detach, output_obj)
# Step3: send fwd
# add output to send_fwd_buffer
if model_chunk_id == 0:
if model_chunk_id == 0: # chunk 0
# is last stage; send to local_send_forward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_forward_buffer.append(detached_output_obj)
self.local_send_forward_buffer.append(output_obj)
else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
else:
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
self.send_forward_buffer[model_chunk_id].append(output_obj)
else: # chunk 1
# is first stage; end of fwd; do nothing
if self.stage_manager.is_first_stage(ignore_chunk=True):
pass
else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True)
self.output_tensors[model_chunk_id].append(output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(output_obj)
self.send_forward_buffer[model_chunk_id].append(output_obj)
def schedule_b(
self,
@ -603,9 +659,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
# input_obj: Optional[dict],
# output_obj: Union[dict, torch.Tensor],
# output_obj_grad: Optional[dict],
):
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
@ -616,7 +669,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
Returns:
Nothing.
"""
# Step1: recv bwd
if model_chunk_id == 0:
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
@ -645,7 +697,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# we save output_tensor_grad here
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# _wait_p2p(recv_bwd_handles)
# Step2: bwd step
input_object_grad = self.backward_b_step(
model_chunk=model_chunk,
@ -777,8 +828,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# communication
communication_func = self.communication_map[scheduled_node.type]
communication_func(scheduled_node.chunk)
if scheduled_node.type == "F":
elif scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,

View File

@ -1,4 +1,6 @@
from copy import deepcopy
from functools import partial
from types import MethodType
from typing import Tuple
import pytest
@ -16,7 +18,8 @@ from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
# from tests.kit.model_zoo import model_zoo
class MlpModel(nn.Module):
@ -24,10 +27,32 @@ class MlpModel(nn.Module):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
def forward(self, x):
def forward(
self,
hidden_states,
):
for layer in self.layers:
x = layer(x)
return x
hidden_states = layer(hidden_states)
return hidden_states
def pp_linear_fwd(
forward,
data: torch.Tensor = None,
hidden_states: torch.Tensor = None,
stage_mgr: PipelineStageManager = None,
model_chunk_id: int = None,
):
with stage_mgr.switch_model_chunk_id(model_chunk_id):
# fwd end
if stage_mgr.is_first_stage() and model_chunk_id == 1:
return forward(hidden_states)
# fwd start
elif stage_mgr.is_first_stage() and model_chunk_id == 0:
return {"hidden_states": forward(hidden_states)}
# fwd middle
else:
return {"hidden_states": forward(hidden_states)}
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
@ -510,15 +535,15 @@ def run_fwd_bwd_iter_input(test_config):
"precision": "bf16",
"num_model_chunk": 2,
},
{
"batch_size": 8,
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 8,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 2,
},
# {
# "batch_size": 8,
# "tp_size": 1,
# "pp_size": 4,
# "num_microbatches": 8,
# "zero_stage": 1,
# "precision": "bf16",
# "num_model_chunk": 2,
# },
],
)
def run_fwd_bwd_vschedule_with_optim(test_config):
@ -562,6 +587,10 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
# init loss func
def criterion(x, *args, **kwargs):
x = x["hidden_states"]
return (x * x).mean()
def criterion_base(x, *args, **kwargs):
return (x * x).mean()
# init model and input
@ -572,9 +601,10 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
before_init_memory = torch.cuda.memory_allocated() / 1024**3
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
input_base = [t.clone() for t in data_iter]
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
data_iter = {"hidden_states": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
# input_base = [t.clone() for t in data_iter]
input_base = {k: v.clone() for k, v in data_iter.items()}
model_base = deepcopy(model)
if rank == 0:
@ -582,24 +612,44 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 0 or idx == 7:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
elif rank == 1:
# layer 1 & 6 to chunk 1 on rank1
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 1 or idx == 6:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
elif rank == 2:
# layer 2 & 5 to chunk 2 on rank2
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 2 or idx == 5:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
# init optimizer
@ -612,7 +662,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
torch.cuda.synchronize()
result = scheduler.forward_backward_step(
model_chunk=local_chunk,
data_iter=iter(data_iter),
data_iter=iter([data_iter]),
criterion=criterion,
optimizer=optimizer_pp,
return_loss=True,
@ -643,8 +693,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
# Fwd bwd for base
##########################
# fwd & bwd
output_base = model_base(input_base[0])
loss_base = criterion(output_base)
output_base = model_base(input_base["hidden_states"])
loss_base = criterion_base(output_base)
loss_base.backward()
optimizer_base.step()
@ -654,7 +704,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
# only chunk 1 stage 0 hold loss and output
if rank == 0:
assert_close(result["loss"], loss_base)
assert_close(result["outputs"], output_base)
assert_close(result["outputs"]["hidden_states"], output_base)
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ")
##########################
@ -727,6 +777,7 @@ def run_with_hybridplugin(test_config):
{
"pp_style": "zbv",
"tp_size": 1,
"ep_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
@ -737,7 +788,7 @@ def run_with_hybridplugin(test_config):
)
def run_with_moehybridplugin(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
# test_config["use_lazy_init"] = False
test_config["initial_scale"] = 2**16
model_list = [
"transformers_bert",
@ -749,6 +800,7 @@ def run_with_moehybridplugin(test_config):
# base param
model = model_fn()
data = data_gen_fn()
print(f"data {data}")
criterion = loss_fn
optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5)
@ -787,7 +839,7 @@ def run_with_moehybridplugin(test_config):
# plugin = MoeHybridParallelPlugin(
# **test_config
# )
# model_pp, optimizer_pp, criterion, data_pp = plugin.configure(
# model_pp, optimizer_pp, criterion, data_pp, _ = plugin.configure(
# model = model_pp,
# optimizer = optimizer_pp,
# criterion = criterion,
@ -806,16 +858,34 @@ def run_with_moehybridplugin(test_config):
# TODO:6) support booster & Hybrid base 4)
# TODO:7) support booster & MoEHybrid base 4)
@parameterize(
"test_config",
[
{
"pp_style": "zbv",
"tp_size": 1,
"ep_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunks": 2,
},
],
)
def run_with_booster_moehybridplugin(test_config):
pass
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# run_fwd_bwd_iter_input()
# run_fwd_bwd_vschedule_with_optim()
run_fwd_bwd_vschedule_with_optim()
# run_with_moehybridplugin()
run_with_moehybridplugin()
# run_with_booster_moehybridplugin()
@pytest.mark.dist