mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #6065 from duanjunwen/dev/zero_bubble
[Feat] Support zero bubble with shardformer inputpull/6075/head
commit
8501202a35
|
@ -29,6 +29,7 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.nn.optimizer import cast_to_distributed
|
||||
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
|
||||
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
|
||||
|
@ -207,6 +208,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
custom_policy: Policy = None,
|
||||
pp_style: str = "1f1b",
|
||||
num_model_chunks: int = 1,
|
||||
scheduler_nodes: List = None,
|
||||
num_layers_per_stage: Optional[List[int]] = None,
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
|
@ -282,8 +284,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.custom_policy = custom_policy
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
|
||||
assert (
|
||||
pp_style == "interleaved" or pp_style == "zbv"
|
||||
) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
|
@ -293,7 +297,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self.stage_manager = PipelineStageManager(
|
||||
self.pg_mesh,
|
||||
pipeline_axis=self.pp_axis,
|
||||
enable_interleave=pp_style == "interleaved",
|
||||
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_layers_per_stage=num_layers_per_stage,
|
||||
)
|
||||
|
@ -315,6 +319,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
)
|
||||
elif pp_style == "zbv":
|
||||
self.schedule = ZeroBubbleVPipeScheduler(
|
||||
schedule=scheduler_nodes,
|
||||
stage_manager=self.stage_manager,
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_microbatch=num_microbatches,
|
||||
overlap_p2p=overlap_p2p,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
@ -131,6 +131,16 @@ def retain_grad(x: Any) -> None:
|
|||
x.retain_grad()
|
||||
|
||||
|
||||
def require_grad(x: Any) -> None:
|
||||
"""Call require_grad on a tensor.
|
||||
|
||||
Args:
|
||||
x (Any): Object to be called.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor) and not x.requires_grad:
|
||||
x.requires_grad_()
|
||||
|
||||
|
||||
def detach(x: Any) -> Any:
|
||||
"""Call detach() on a tensor.
|
||||
|
||||
|
@ -145,6 +155,34 @@ def detach(x: Any) -> Any:
|
|||
return x
|
||||
|
||||
|
||||
def clone(x: Any) -> Any:
|
||||
"""Call clone() on a tensor.
|
||||
|
||||
Args:
|
||||
x (Any): Object to be called.
|
||||
|
||||
Returns:
|
||||
Any: The cloned object.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.clone()
|
||||
return x
|
||||
|
||||
|
||||
def release_tensor_data(x: Any) -> Any:
|
||||
"""Call untyped_storage().resize_(0) on a tensor. Use to release tensor.data and keep grad_fn.
|
||||
|
||||
Args:
|
||||
x (Any): Object to be called.
|
||||
|
||||
Returns:
|
||||
Any: The deallocate .data object.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.data.untyped_storage().resize_(0)
|
||||
return x
|
||||
|
||||
|
||||
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
|
||||
"""Merge micro batches into a batch.
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.cuda
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
@ -12,7 +12,18 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication
|
|||
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device
|
||||
from ._utils import (
|
||||
clone,
|
||||
detach,
|
||||
get_batch_size,
|
||||
get_micro_batch,
|
||||
merge_batch,
|
||||
model_forward,
|
||||
release_tensor_data,
|
||||
require_grad,
|
||||
retain_grad,
|
||||
to_device,
|
||||
)
|
||||
from .base import PipelineSchedule
|
||||
|
||||
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
||||
|
@ -24,21 +35,6 @@ def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
|
|||
req.wait()
|
||||
|
||||
|
||||
def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
|
||||
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
|
||||
|
||||
This method should be called right after the output tensor has been
|
||||
sent to the next pipeline stage. At this point, the output tensor is
|
||||
only useful for its '.grad_fn' field, and not its '.data'.
|
||||
"""
|
||||
if (out is None) or (not deallocate_pipeline_outputs):
|
||||
return
|
||||
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
|
||||
assert out._base is None, "counter-productive to free a view of another tensor."
|
||||
# out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
|
||||
out.data.untyped_storage().resize_(0)
|
||||
|
||||
|
||||
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -409,6 +405,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
micro_batch: Optional[dict],
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
|
@ -427,18 +424,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||
"""
|
||||
# Load input ids, attention mask and labels
|
||||
# micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||
|
||||
# for the first stage, input_obj is None
|
||||
# for the first stage, input_obj is None; So,we use micro_batch as input_obj
|
||||
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
|
||||
# Only attention_mask from micro_batch is used
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
# fwd calculate
|
||||
output_obj = model_chunk[model_chunk_id](input_obj)
|
||||
# fwd calculate
|
||||
internal_inputs = {} if input_obj is None else input_obj
|
||||
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
|
||||
|
||||
# last layer in model
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
loss = criterion(output_obj) / self.num_microbatch
|
||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||
if accum_loss is not None:
|
||||
accum_loss.add_(loss.detach())
|
||||
if outputs is not None:
|
||||
|
@ -452,6 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
optimizer: OptimizerWrapper,
|
||||
micro_batch: Optional[dict],
|
||||
input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
|
@ -462,7 +460,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
optimizer (OptimizerWrapper): Optimizer to update the model
|
||||
input_obj (Optional[dict]): x.
|
||||
input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj)
|
||||
output_obj (Union[dict, torch.Tensor]): y.
|
||||
output_obj_grad (dict): dy.
|
||||
|
||||
|
@ -471,20 +469,52 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
"""
|
||||
# calculate bwd b step ; only dx = w*dy;
|
||||
|
||||
# Retain the grad on the input_obj.
|
||||
tree_map(retain_grad, input_obj)
|
||||
# Retain the grad on the input_obj. No need retain_grad microbatch
|
||||
if input_obj is not None:
|
||||
tree_map(retain_grad, input_obj)
|
||||
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# loss backward; output_obj is loss; so output_obj_grad should be None
|
||||
# x, y, dy list for backward_by_grad; Type: list[tensor];
|
||||
input_obj_ = []
|
||||
output_obj_ = []
|
||||
output_obj_grad_ = []
|
||||
|
||||
# For chunk 0 stage 0, use micro_batch as input_obj_
|
||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
input_obj_, _ = tree_flatten(micro_batch)
|
||||
output_obj_, _ = tree_flatten(output_obj) # y
|
||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||
|
||||
# For loss backward; output_obj is loss; output_obj_grad should be None
|
||||
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
assert output_obj_grad is None
|
||||
input_obj_, _ = tree_flatten(input_obj)
|
||||
output_obj_.append(output_obj) # LOSS
|
||||
output_obj_grad_.append(output_obj_grad) # None
|
||||
|
||||
# For other chunk stage, use input_obj as input_obj_;
|
||||
else:
|
||||
input_obj_, _ = tree_flatten(input_obj)
|
||||
output_obj_, _ = tree_flatten(output_obj) # y
|
||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj,
|
||||
grad=output_obj_grad,
|
||||
inputs=input_obj,
|
||||
tensor=output_obj_,
|
||||
grad=output_obj_grad_,
|
||||
inputs=input_obj_,
|
||||
retain_graph=True,
|
||||
)
|
||||
return input_obj.grad
|
||||
|
||||
# Format output_obj_grad
|
||||
input_obj_grad = {}
|
||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
for k, v in micro_batch.items():
|
||||
if isinstance(v, torch.Tensor) and v.grad is not None:
|
||||
input_obj_grad[k] = v.grad
|
||||
else:
|
||||
for k, v in input_obj.items():
|
||||
if isinstance(v, torch.Tensor) and v.grad is not None:
|
||||
input_obj_grad[k] = v.grad
|
||||
return input_obj_grad
|
||||
|
||||
def backward_w_step(
|
||||
self,
|
||||
|
@ -508,12 +538,21 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
"""
|
||||
# calculate bwd w step ; only dw = x*dy;
|
||||
|
||||
# y, dy list for w backward_by_grad; Type: list[tensor];
|
||||
output_obj_ = []
|
||||
output_obj_grad_ = []
|
||||
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# loss backward; output_obj is loss
|
||||
output_obj_grad = None
|
||||
# loss backward; output_obj is loss;
|
||||
output_obj_.append(output_obj) # LOSS
|
||||
output_obj_grad_.append(None) # None
|
||||
else:
|
||||
output_obj_, _ = tree_flatten(output_obj) # y
|
||||
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj,
|
||||
grad=output_obj_grad,
|
||||
tensor=output_obj_,
|
||||
grad=output_obj_grad_,
|
||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
retain_graph=False,
|
||||
)
|
||||
|
@ -543,9 +582,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||
# Step1: recv fwd
|
||||
if model_chunk_id == 0:
|
||||
# is first stage; get input from func param
|
||||
# is first stage; get input from microbatch
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
input_obj = micro_batch
|
||||
input_obj = None
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
else:
|
||||
|
@ -557,55 +596,75 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
# Here, let input_obj.requires_grad_()
|
||||
tree_map(torch.Tensor.requires_grad_, input_obj)
|
||||
# if input_obj is not None:
|
||||
if not isinstance(input_obj, torch.Tensor):
|
||||
tree_map(require_grad, input_obj)
|
||||
|
||||
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
|
||||
# tree_map(torch.Tensor.requires_grad_, micro_batch)
|
||||
|
||||
# Step2: fwd step
|
||||
output_obj = self.forward_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
micro_batch=micro_batch,
|
||||
input_obj=input_obj,
|
||||
criterion=criterion,
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
# Step3:
|
||||
# 3-1:detach output; detach output for send fwd;
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# We should not detach bwd LOSS
|
||||
pass
|
||||
else:
|
||||
detached_output_obj = output_obj.clone().detach()
|
||||
# detach output
|
||||
detached_output_obj = tree_map(detach, output_obj)
|
||||
# 3-2 clone detached_output_obj
|
||||
detached_output_obj = tree_map(clone, detached_output_obj)
|
||||
|
||||
# 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output)
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# We should not release_tensor_data bwd LOSS
|
||||
pass
|
||||
else:
|
||||
# release_tensor_data output
|
||||
tree_map(release_tensor_data, output_obj)
|
||||
|
||||
# add input and output object for backward b
|
||||
self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
|
||||
|
||||
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||
# Do not release_tensor_data loss, release_tensor_data other output_obj;
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.output_tensors[model_chunk_id].append(output_obj)
|
||||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||
else:
|
||||
self.output_tensors[model_chunk_id].append(output_obj)
|
||||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||
|
||||
# Step3: send fwd
|
||||
# add output to send_fwd_buffer
|
||||
if model_chunk_id == 0:
|
||||
if model_chunk_id == 0: # chunk 0
|
||||
# is last stage; send to local_send_forward_buffer
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
self.local_send_forward_buffer.append(detached_output_obj)
|
||||
else:
|
||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||
else:
|
||||
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
|
||||
else: # chunk 1
|
||||
# is first stage; end of fwd; do nothing
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
pass
|
||||
else:
|
||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||
|
||||
# add input and output object for backward b
|
||||
self.input_tensors[model_chunk_id].append(input_obj)
|
||||
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||
deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True)
|
||||
self.output_tensors[model_chunk_id].append(output_obj)
|
||||
# add output object for backward w
|
||||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||
|
||||
def schedule_b(
|
||||
self,
|
||||
scheduled_node,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
optimizer: OptimizerWrapper,
|
||||
# input_obj: Optional[dict],
|
||||
# output_obj: Union[dict, torch.Tensor],
|
||||
# output_obj_grad: Optional[dict],
|
||||
):
|
||||
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
|
||||
|
||||
|
@ -616,25 +675,24 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
Returns:
|
||||
Nothing.
|
||||
"""
|
||||
|
||||
# Step1: recv bwd
|
||||
if model_chunk_id == 0:
|
||||
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
output_tensor_grad = self.local_send_backward_buffer.pop(0)
|
||||
# chunk 0 not last stage; recv output_grad from recv_backward_buffer
|
||||
# chunk0 not last stage; recv output_grad from recv_backward_buffer
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
else:
|
||||
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
output_tensor_grad = None
|
||||
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
||||
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
# get input and output object from buffer;
|
||||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||
micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
||||
|
||||
# save output_tensor_grad for dw
|
||||
|
@ -645,12 +703,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# we save output_tensor_grad here
|
||||
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||
|
||||
# _wait_p2p(recv_bwd_handles)
|
||||
# Step2: bwd step
|
||||
input_object_grad = self.backward_b_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
optimizer=optimizer,
|
||||
micro_batch=micro_batch,
|
||||
input_obj=input_obj,
|
||||
output_obj=output_obj,
|
||||
output_obj_grad=output_tensor_grad,
|
||||
|
@ -777,8 +835,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
communication_func(scheduled_node.chunk)
|
||||
|
||||
if scheduled_node.type == "F":
|
||||
elif scheduled_node.type == "F":
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
@ -14,6 +16,7 @@ from colossalai.logging import disable_existing_loggers
|
|||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
@ -23,10 +26,32 @@ class MlpModel(nn.Module):
|
|||
super().__init__()
|
||||
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
|
||||
|
||||
def forward(self, x):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
hidden_states = layer(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def pp_linear_fwd(
|
||||
forward,
|
||||
data: torch.Tensor = None,
|
||||
hidden_states: torch.Tensor = None,
|
||||
stage_mgr: PipelineStageManager = None,
|
||||
model_chunk_id: int = None,
|
||||
):
|
||||
with stage_mgr.switch_model_chunk_id(model_chunk_id):
|
||||
# fwd end
|
||||
if stage_mgr.is_first_stage() and model_chunk_id == 1:
|
||||
return forward(hidden_states)
|
||||
# fwd start
|
||||
elif stage_mgr.is_first_stage() and model_chunk_id == 0:
|
||||
return {"hidden_states": forward(data)}
|
||||
# fwd middle
|
||||
else:
|
||||
return {"hidden_states": forward(hidden_states)}
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
|
@ -561,19 +586,24 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
|
||||
# init loss func
|
||||
def criterion(x, *args, **kwargs):
|
||||
x = x["hidden_states"]
|
||||
return (x * x).mean()
|
||||
|
||||
def criterion_base(x, *args, **kwargs):
|
||||
return (x * x).mean()
|
||||
|
||||
# init model and input
|
||||
batch_size = test_config["batch_size"]
|
||||
num_layers = 8
|
||||
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
||||
in_dim = out_dim = 4096
|
||||
in_dim = out_dim = 1024
|
||||
before_init_memory = torch.cuda.memory_allocated() / 1024**3
|
||||
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||
|
||||
input_base = [t.clone() for t in data_iter]
|
||||
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||
data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
|
||||
# input_base = [t.clone() for t in data_iter]
|
||||
input_base = {k: v.clone() for k, v in data_iter.items()}
|
||||
model_base = deepcopy(model)
|
||||
|
||||
if rank == 0:
|
||||
|
@ -581,24 +611,44 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 0 or idx == 7:
|
||||
sub_model._forward = sub_model.forward
|
||||
sub_model.forward = MethodType(
|
||||
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||
sub_model._forward,
|
||||
)
|
||||
local_chunk.append(sub_model)
|
||||
elif rank == 1:
|
||||
# layer 1 & 6 to chunk 1 on rank1
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 1 or idx == 6:
|
||||
sub_model._forward = sub_model.forward
|
||||
sub_model.forward = MethodType(
|
||||
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||
sub_model._forward,
|
||||
)
|
||||
local_chunk.append(sub_model)
|
||||
elif rank == 2:
|
||||
# layer 2 & 5 to chunk 2 on rank2
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 2 or idx == 5:
|
||||
sub_model._forward = sub_model.forward
|
||||
sub_model.forward = MethodType(
|
||||
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||
sub_model._forward,
|
||||
)
|
||||
local_chunk.append(sub_model)
|
||||
else:
|
||||
# layer 3 & 4 to chunk 3 on rank3
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 3 or idx == 4:
|
||||
sub_model._forward = sub_model.forward
|
||||
sub_model.forward = MethodType(
|
||||
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||
sub_model._forward,
|
||||
)
|
||||
local_chunk.append(sub_model)
|
||||
|
||||
# init optimizer
|
||||
|
@ -611,7 +661,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
torch.cuda.synchronize()
|
||||
result = scheduler.forward_backward_step(
|
||||
model_chunk=local_chunk,
|
||||
data_iter=iter(data_iter),
|
||||
data_iter=iter([data_iter]),
|
||||
criterion=criterion,
|
||||
optimizer=optimizer_pp,
|
||||
return_loss=True,
|
||||
|
@ -624,26 +674,28 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
|
||||
# assert memory
|
||||
if rank != 0:
|
||||
# w.grad hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
|
||||
# output hid_dim * hid_dim * 4(fp32) / 1024**3
|
||||
# optim state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
|
||||
print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 / 1024**3)}")
|
||||
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 / 1024**3)
|
||||
# w.grad: hid_dim * hid_dim * microbatch * 4(fp32) * 2 (2 layer in each stage) / 1024**3
|
||||
# output: hid_dim * hid_dim * microbatch * 4(fp32) / 1024**3
|
||||
# optim: state hid_dim * hid_dim * 4(fp32) * 2 (2 layer in each stage) / 1024**3
|
||||
print(
|
||||
f" num_microbatch {num_microbatch} rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 5 * batch_size / 1024**3)}"
|
||||
)
|
||||
assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 5 * batch_size / 1024**3)
|
||||
else:
|
||||
# rank0 will also hold output;
|
||||
print(
|
||||
f"rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
|
||||
f" num_microbatch {num_microbatch} rank {rank}: {round((after_pp_step_memory - after_init_memory), 5)} <= {round((in_dim * in_dim * 4 * 5 * batch_size / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5)}"
|
||||
)
|
||||
assert round((after_pp_step_memory - after_init_memory), 5) <= round(
|
||||
(in_dim * in_dim * 4 * 5 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
|
||||
(in_dim * in_dim * 4 * 5 * batch_size / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5
|
||||
)
|
||||
|
||||
##########################
|
||||
# Fwd bwd for base
|
||||
##########################
|
||||
# fwd & bwd
|
||||
output_base = model_base(input_base[0])
|
||||
loss_base = criterion(output_base)
|
||||
output_base = model_base(input_base["data"])
|
||||
loss_base = criterion_base(output_base)
|
||||
loss_base.backward()
|
||||
optimizer_base.step()
|
||||
|
||||
|
@ -653,7 +705,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||
# only chunk 1 stage 0 hold loss and output
|
||||
if rank == 0:
|
||||
assert_close(result["loss"], loss_base)
|
||||
assert_close(result["outputs"], output_base)
|
||||
assert_close(result["outputs"]["hidden_states"], output_base)
|
||||
|
||||
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ")
|
||||
##########################
|
||||
|
@ -724,28 +776,108 @@ def run_with_hybridplugin(test_config):
|
|||
"test_config",
|
||||
[
|
||||
{
|
||||
"batch_size": 8,
|
||||
"pp_style": "zbv",
|
||||
"tp_size": 1,
|
||||
"ep_size": 1,
|
||||
"pp_size": 4,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 1,
|
||||
"precision": "bf16",
|
||||
"num_model_chunk": 2,
|
||||
"num_model_chunks": 2,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_with_moehybridplugin(test_config):
|
||||
model_zoo.get_sub_registry("transformers_bert")
|
||||
test_config["use_lazy_init"] = False
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||
# test_config["use_lazy_init"] = False
|
||||
test_config["initial_scale"] = 2**16
|
||||
model_list = [
|
||||
"transformers_bert",
|
||||
]
|
||||
clear_layout_converter()
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
if name in model_list:
|
||||
# base param
|
||||
model = model_fn()
|
||||
data = data_gen_fn()
|
||||
print(f"data {data}")
|
||||
criterion = loss_fn
|
||||
optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5)
|
||||
|
||||
output = model(**data)
|
||||
loss = criterion(output)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
print(f"output {output}")
|
||||
|
||||
# # pp param
|
||||
# model_pp = deepcopy(model)
|
||||
# data_pp = deepcopy(data)
|
||||
# optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5))
|
||||
|
||||
# # init pipeline graph
|
||||
# h, a, s = model.config.hidden_size, model.config.num_attention_heads, 1024
|
||||
# mem_f = 34 * h + 5 * a * s
|
||||
# mem_w = -32 * h
|
||||
# mem_b = -mem_w - mem_f
|
||||
# graph = PipelineGraph(
|
||||
# n_stage=test_config["pp_size"],
|
||||
# n_micro=test_config["num_microbatches"],
|
||||
# f_cost=1,
|
||||
# b_cost=1,
|
||||
# w_cost=1,
|
||||
# c_cost=1,
|
||||
# f_mem=mem_f,
|
||||
# b_mem=mem_b,
|
||||
# w_mem=mem_w,
|
||||
# # max_mem=mem_f * (p * 2 + m_offset),
|
||||
# )
|
||||
|
||||
# zbv_schedule = graph.get_v_schedule()
|
||||
|
||||
# test_config["scheduler_nodes"] = zbv_schedule
|
||||
# plugin = MoeHybridParallelPlugin(
|
||||
# **test_config
|
||||
# )
|
||||
# model_pp, optimizer_pp, criterion, data_pp, _ = plugin.configure(
|
||||
# model = model_pp,
|
||||
# optimizer = optimizer_pp,
|
||||
# criterion = criterion,
|
||||
# dataloader = data_pp,
|
||||
# )
|
||||
|
||||
# output_pp = plugin.execute_pipeline(
|
||||
# data_iter=iter(data),
|
||||
# model=model,
|
||||
# criterion=criterion,
|
||||
# optimizer=optimizer,
|
||||
# return_loss = True,
|
||||
# return_outputs = True,
|
||||
# )
|
||||
|
||||
|
||||
# TODO:6) support booster & Hybrid base 4)
|
||||
|
||||
|
||||
# TODO:7) support booster & MoEHybrid base 4)
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"pp_style": "zbv",
|
||||
"tp_size": 1,
|
||||
"ep_size": 1,
|
||||
"pp_size": 4,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 1,
|
||||
"precision": "bf16",
|
||||
"num_model_chunks": 2,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_with_booster_moehybridplugin(test_config):
|
||||
pass
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
@ -754,6 +886,7 @@ def run_dist(rank, world_size, port):
|
|||
# run_fwd_bwd_iter_input()
|
||||
run_fwd_bwd_vschedule_with_optim()
|
||||
# run_with_moehybridplugin()
|
||||
# run_with_booster_moehybridplugin()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue