Merge pull request #6065 from duanjunwen/dev/zero_bubble

[Feat] Support zero bubble with shardformer input
pull/6075/head
duanjunwen 2024-09-24 19:17:37 +08:00 committed by GitHub
commit 8501202a35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 328 additions and 88 deletions

View File

@ -29,6 +29,7 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import cast_to_distributed
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
@ -207,6 +208,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
scheduler_nodes: List = None,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
@ -282,8 +284,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
assert (
pp_style == "interleaved" or pp_style == "zbv"
) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
@ -293,7 +297,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
)
@ -315,6 +319,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
)
elif pp_style == "zbv":
self.schedule = ZeroBubbleVPipeScheduler(
schedule=scheduler_nodes,
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
overlap_p2p=overlap_p2p,
)
else:
raise NotImplementedError()

View File

@ -131,6 +131,16 @@ def retain_grad(x: Any) -> None:
x.retain_grad()
def require_grad(x: Any) -> None:
"""Call require_grad on a tensor.
Args:
x (Any): Object to be called.
"""
if isinstance(x, torch.Tensor) and not x.requires_grad:
x.requires_grad_()
def detach(x: Any) -> Any:
"""Call detach() on a tensor.
@ -145,6 +155,34 @@ def detach(x: Any) -> Any:
return x
def clone(x: Any) -> Any:
"""Call clone() on a tensor.
Args:
x (Any): Object to be called.
Returns:
Any: The cloned object.
"""
if isinstance(x, torch.Tensor):
return x.clone()
return x
def release_tensor_data(x: Any) -> Any:
"""Call untyped_storage().resize_(0) on a tensor. Use to release tensor.data and keep grad_fn.
Args:
x (Any): Object to be called.
Returns:
Any: The deallocate .data object.
"""
if isinstance(x, torch.Tensor):
return x.data.untyped_storage().resize_(0)
return x
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
"""Merge micro batches into a batch.

View File

@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.cuda
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
from torch.utils._pytree import tree_flatten, tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
@ -12,7 +12,18 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device
from ._utils import (
clone,
detach,
get_batch_size,
get_micro_batch,
merge_batch,
model_forward,
release_tensor_data,
require_grad,
retain_grad,
to_device,
)
from .base import PipelineSchedule
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
@ -24,21 +35,6 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
req.wait()
def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
"""
if (out is None) or (not deallocate_pipeline_outputs):
return
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
assert out._base is None, "counter-productive to free a view of another tensor."
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
out.data.untyped_storage().resize_(0)
class ZeroBubbleVPipeScheduler(PipelineSchedule):
def __init__(
self,
@ -409,6 +405,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
micro_batch: Optional[dict],
input_obj: Optional[dict],
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
@ -427,18 +424,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
# Load input ids, attention mask and labels
# micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# for the first stage, input_obj is None
# for the first stage, input_obj is None; So,we use micro_batch as input_obj
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
# Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
# fwd calculate
output_obj = model_chunk[model_chunk_id](input_obj)
# fwd calculate
internal_inputs = {} if input_obj is None else input_obj
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
loss = criterion(output_obj) / self.num_microbatch
loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
@ -452,6 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
micro_batch: Optional[dict],
input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
@ -462,7 +460,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
optimizer (OptimizerWrapper): Optimizer to update the model
input_obj (Optional[dict]): x.
input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj)
output_obj (Union[dict, torch.Tensor]): y.
output_obj_grad (dict): dy.
@ -471,20 +469,52 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
"""
# calculate bwd b step ; only dx = w*dy;
# Retain the grad on the input_obj.
tree_map(retain_grad, input_obj)
# Retain the grad on the input_obj. No need retain_grad microbatch
if input_obj is not None:
tree_map(retain_grad, input_obj)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss; so output_obj_grad should be None
# x, y, dy list for backward_by_grad; Type: list[tensor];
input_obj_ = []
output_obj_ = []
output_obj_grad_ = []
# For chunk 0 stage 0, use micro_batch as input_obj_
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj_, _ = tree_flatten(micro_batch)
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
# For loss backward; output_obj is loss; output_obj_grad should be None
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
assert output_obj_grad is None
input_obj_, _ = tree_flatten(input_obj)
output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(output_obj_grad) # None
# For other chunk stage, use input_obj as input_obj_;
else:
input_obj_, _ = tree_flatten(input_obj)
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=input_obj,
tensor=output_obj_,
grad=output_obj_grad_,
inputs=input_obj_,
retain_graph=True,
)
return input_obj.grad
# Format output_obj_grad
input_obj_grad = {}
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
for k, v in micro_batch.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
else:
for k, v in input_obj.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
return input_obj_grad
def backward_w_step(
self,
@ -508,12 +538,21 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
"""
# calculate bwd w step ; only dw = x*dy;
# y, dy list for w backward_by_grad; Type: list[tensor];
output_obj_ = []
output_obj_grad_ = []
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss
output_obj_grad = None
# loss backward; output_obj is loss;
output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(None) # None
else:
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
tensor=output_obj_,
grad=output_obj_grad_,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
@ -543,9 +582,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# Step1: recv fwd
if model_chunk_id == 0:
# is first stage; get input from func param
# is first stage; get input from microbatch
if self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = micro_batch
input_obj = None
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
else:
@ -557,55 +596,75 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
# Here, let input_obj.requires_grad_()
tree_map(torch.Tensor.requires_grad_, input_obj)
# if input_obj is not None:
if not isinstance(input_obj, torch.Tensor):
tree_map(require_grad, input_obj)
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
# tree_map(torch.Tensor.requires_grad_, micro_batch)
# Step2: fwd step
output_obj = self.forward_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
micro_batch=micro_batch,
input_obj=input_obj,
criterion=criterion,
accum_loss=accum_loss,
outputs=outputs,
)
# Step3:
# 3-1:detach output; detach output for send fwd;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not detach bwd LOSS
pass
else:
detached_output_obj = output_obj.clone().detach()
# detach output
detached_output_obj = tree_map(detach, output_obj)
# 3-2 clone detached_output_obj
detached_output_obj = tree_map(clone, detached_output_obj)
# 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not release_tensor_data bwd LOSS
pass
else:
# release_tensor_data output
tree_map(release_tensor_data, output_obj)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
# for bwd b&w, we only need the graph(grad_fn) of output_obj
# Do not release_tensor_data loss, release_tensor_data other output_obj;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
self.output_tensors[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(output_obj)
else:
self.output_tensors[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(output_obj)
# Step3: send fwd
# add output to send_fwd_buffer
if model_chunk_id == 0:
if model_chunk_id == 0: # chunk 0
# is last stage; send to local_send_forward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_forward_buffer.append(detached_output_obj)
else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
else:
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
else: # chunk 1
# is first stage; end of fwd; do nothing
if self.stage_manager.is_first_stage(ignore_chunk=True):
pass
else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True)
self.output_tensors[model_chunk_id].append(output_obj)
# add output object for backward w
self.output_tensors_dw[model_chunk_id].append(output_obj)
def schedule_b(
self,
scheduled_node,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
# input_obj: Optional[dict],
# output_obj: Union[dict, torch.Tensor],
# output_obj_grad: Optional[dict],
):
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
@ -616,25 +675,24 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
Returns:
Nothing.
"""
# Step1: recv bwd
if model_chunk_id == 0:
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
output_tensor_grad = self.local_send_backward_buffer.pop(0)
# chunk 0 not last stage; recv output_grad from recv_backward_buffer
# chunk0 not last stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
else:
# chunk1, is first stage; recv LOSS from local send bwd buffer
if self.stage_manager.is_first_stage(ignore_chunk=True):
output_tensor_grad = None
# chunk1, not first stage; recv output_grad from recv_backward_buffer
# chunk1, not first stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
# get input and output object from buffer;
input_obj = self.input_tensors[model_chunk_id].pop(0)
micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0)
# save output_tensor_grad for dw
@ -645,12 +703,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# we save output_tensor_grad here
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# _wait_p2p(recv_bwd_handles)
# Step2: bwd step
input_object_grad = self.backward_b_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
optimizer=optimizer,
micro_batch=micro_batch,
input_obj=input_obj,
output_obj=output_obj,
output_obj_grad=output_tensor_grad,
@ -777,8 +835,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# communication
communication_func = self.communication_map[scheduled_node.type]
communication_func(scheduled_node.chunk)
if scheduled_node.type == "F":
elif scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,

View File

@ -1,4 +1,6 @@
from copy import deepcopy
from functools import partial
from types import MethodType
from typing import Tuple
import pytest
@ -14,6 +16,7 @@ from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@ -23,10 +26,32 @@ class MlpModel(nn.Module):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
def forward(self, x):
def forward(
self,
hidden_states,
):
for layer in self.layers:
x = layer(x)
return x
hidden_states = layer(hidden_states)
return hidden_states
def pp_linear_fwd(
forward,
data: torch.Tensor = None,
hidden_states: torch.Tensor = None,
stage_mgr: PipelineStageManager = None,
model_chunk_id: int = None,
):
with stage_mgr.switch_model_chunk_id(model_chunk_id):
# fwd end
if stage_mgr.is_first_stage() and model_chunk_id == 1:
return forward(hidden_states)
# fwd start
elif stage_mgr.is_first_stage() and model_chunk_id == 0:
return {"hidden_states": forward(data)}
# fwd middle
else:
return {"hidden_states": forward(hidden_states)}
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
@ -561,19 +586,24 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
# init loss func
def criterion(x, *args, **kwargs):
x = x["hidden_states"]
return (x * x).mean()
def criterion_base(x, *args, **kwargs):
return (x * x).mean()
# init model and input
batch_size = test_config["batch_size"]
num_layers = 8
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
in_dim = out_dim = 4096
in_dim = out_dim = 1024
before_init_memory = torch.cuda.memory_allocated() / 1024**3
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
input_base = [t.clone() for t in data_iter]
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
# input_base = [t.clone() for t in data_iter]
input_base = {k: v.clone() for k, v in data_iter.items()}
model_base = deepcopy(model)
if rank == 0:
@ -581,24 +611,44 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 0 or idx == 7:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
elif rank == 1:
# layer 1 & 6 to chunk 1 on rank1
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 1 or idx == 6:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
elif rank == 2:
# layer 2 & 5 to chunk 2 on rank2
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 2 or idx == 5:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
# init optimizer
@ -611,7 +661,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
torch.cuda.synchronize()
result = scheduler.forward_backward_step(
model_chunk=local_chunk,
data_iter=iter(data_iter),
data_iter=iter([data_iter]),
criterion=criterion,
optimizer=optimizer_pp,
return_loss=True,
@ -624,26 +674,28 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
# assert memory
if rank != 0:
# w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
# output hid_dim * hid_dim * 4(fp32) / 1024**3
# optim state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 / 1024**3)}")
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 / 1024**3)
# w.grad: hid_dim * hid_dim * microbatch * 4(fp32) * 2 (2 layer in each stage) / 1024**3
# output: hid_dim * hid_dim * microbatch * 4(fp32) / 1024**3
# optim: state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
print(
f" num_microbatch {num_microbatch} rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 * batch_size / 1024**3)}"
)
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 * batch_size / 1024**3)
else:
# rank0 will also hold output;
print(
f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
f" num_microbatch {num_microbatch} rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 * batch_size / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
)
assert round((after_pp_step_memory - after_init_memory), 5) <= round(
(in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
(in_dim * in_dim * 4 * 5 * batch_size / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
)
##########################
# Fwd bwd for base
##########################
# fwd & bwd
output_base = model_base(input_base[0])
loss_base = criterion(output_base)
output_base = model_base(input_base["data"])
loss_base = criterion_base(output_base)
loss_base.backward()
optimizer_base.step()
@ -653,7 +705,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
# only chunk 1 stage 0 hold loss and output
if rank == 0:
assert_close(result["loss"], loss_base)
assert_close(result["outputs"], output_base)
assert_close(result["outputs"]["hidden_states"], output_base)
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ")
##########################
@ -724,28 +776,108 @@ def run_with_hybridplugin(test_config):
"test_config",
[
{
"batch_size": 8,
"pp_style": "zbv",
"tp_size": 1,
"ep_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 2,
"num_model_chunks": 2,
},
],
)
def run_with_moehybridplugin(test_config):
model_zoo.get_sub_registry("transformers_bert")
test_config["use_lazy_init"] = False
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
# test_config["use_lazy_init"] = False
test_config["initial_scale"] = 2**16
model_list = [
"transformers_bert",
]
clear_layout_converter()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name in model_list:
# base param
model = model_fn()
data = data_gen_fn()
print(f"data {data}")
criterion = loss_fn
optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5)
output = model(**data)
loss = criterion(output)
loss.backward()
optimizer.step()
print(f"output {output}")
# # pp param
# model_pp = deepcopy(model)
# data_pp = deepcopy(data)
# optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5))
# # init pipeline graph
# h, a, s = model.config.hidden_size, model.config.num_attention_heads, 1024
# mem_f = 34 * h + 5 * a * s
# mem_w = -32 * h
# mem_b = -mem_w - mem_f
# graph = PipelineGraph(
# n_stage=test_config["pp_size"],
# n_micro=test_config["num_microbatches"],
# f_cost=1,
# b_cost=1,
# w_cost=1,
# c_cost=1,
# f_mem=mem_f,
# b_mem=mem_b,
# w_mem=mem_w,
# # max_mem=mem_f * (p * 2 + m_offset),
# )
# zbv_schedule = graph.get_v_schedule()
# test_config["scheduler_nodes"] = zbv_schedule
# plugin = MoeHybridParallelPlugin(
# **test_config
# )
# model_pp, optimizer_pp, criterion, data_pp, _ = plugin.configure(
# model = model_pp,
# optimizer = optimizer_pp,
# criterion = criterion,
# dataloader = data_pp,
# )
# output_pp = plugin.execute_pipeline(
# data_iter=iter(data),
# model=model,
# criterion=criterion,
# optimizer=optimizer,
# return_loss = True,
# return_outputs = True,
# )
# TODO:6) support booster & Hybrid base 4)
# TODO:7) support booster & MoEHybrid base 4)
@parameterize(
"test_config",
[
{
"pp_style": "zbv",
"tp_size": 1,
"ep_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunks": 2,
},
],
)
def run_with_booster_moehybridplugin(test_config):
pass
def run_dist(rank, world_size, port):
@ -754,6 +886,7 @@ def run_dist(rank, world_size, port):
# run_fwd_bwd_iter_input()
run_fwd_bwd_vschedule_with_optim()
# run_with_moehybridplugin()
# run_with_booster_moehybridplugin()
@pytest.mark.dist