mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix zerobubble pp for shardformer type input;
parent
9bc3b6e220
commit
3dbad102cf
|
@ -131,6 +131,16 @@ def retain_grad(x: Any) -> None:
|
||||||
x.retain_grad()
|
x.retain_grad()
|
||||||
|
|
||||||
|
|
||||||
|
def require_grad(x: Any) -> None:
|
||||||
|
"""Call require_grad on a tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Any): Object to be called.
|
||||||
|
"""
|
||||||
|
if isinstance(x, torch.Tensor) and x.requires_grad:
|
||||||
|
x.requires_grad_()
|
||||||
|
|
||||||
|
|
||||||
def detach(x: Any) -> Any:
|
def detach(x: Any) -> Any:
|
||||||
"""Call detach() on a tensor.
|
"""Call detach() on a tensor.
|
||||||
|
|
||||||
|
@ -145,6 +155,34 @@ def detach(x: Any) -> Any:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def clone(x: Any) -> Any:
|
||||||
|
"""Call clone() on a tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Any): Object to be called.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The cloned object.
|
||||||
|
"""
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
return x.clone()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def deallocate(x: Any) -> Any:
|
||||||
|
"""Call deallocate() on a tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Any): Object to be called.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The deallocate .data object.
|
||||||
|
"""
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
return x.data.untyped_storage().resize_(0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
|
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
|
||||||
"""Merge micro batches into a batch.
|
"""Merge micro batches into a batch.
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||||
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device
|
from ._utils import clone, detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||||
from .base import PipelineSchedule
|
from .base import PipelineSchedule
|
||||||
|
|
||||||
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
|
||||||
|
@ -39,6 +39,20 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
|
||||||
out.data.untyped_storage().resize_(0)
|
out.data.untyped_storage().resize_(0)
|
||||||
|
|
||||||
|
|
||||||
|
def require_grad(tensor):
|
||||||
|
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
|
||||||
|
|
||||||
|
This method should be called right after the output tensor has been
|
||||||
|
sent to the next pipeline stage. At this point, the output tensor is
|
||||||
|
only useful for its '.grad_fn' field, and not its '.data'.
|
||||||
|
"""
|
||||||
|
if tensor is None:
|
||||||
|
return
|
||||||
|
assert isinstance(tensor, torch.Tensor), "expected Tensor, found %s." % type(tensor).__name__
|
||||||
|
assert tensor._base is None, "counter-productive to free a view of another tensor."
|
||||||
|
tensor.requires_grad_()
|
||||||
|
|
||||||
|
|
||||||
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -409,6 +423,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
self,
|
self,
|
||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
|
micro_batch: Optional[dict],
|
||||||
input_obj: Optional[dict],
|
input_obj: Optional[dict],
|
||||||
criterion: Callable,
|
criterion: Callable,
|
||||||
accum_loss: Optional[torch.Tensor] = None,
|
accum_loss: Optional[torch.Tensor] = None,
|
||||||
|
@ -427,18 +442,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
|
||||||
"""
|
"""
|
||||||
# Load input ids, attention mask and labels
|
# Load input ids, attention mask and labels
|
||||||
# micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
# for the first stage, input_obj is None; So,we use micro_batch as input_obj
|
||||||
|
|
||||||
# for the first stage, input_obj is None
|
|
||||||
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
|
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
|
||||||
# Only attention_mask from micro_batch is used
|
# Only attention_mask from micro_batch is used
|
||||||
|
|
||||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||||
# fwd calculate
|
# fwd calculate
|
||||||
output_obj = model_chunk[model_chunk_id](input_obj)
|
if isinstance(model_chunk, ModuleList):
|
||||||
|
# fwd for ModuleList model
|
||||||
|
if input_obj is None:
|
||||||
|
output_obj = model_chunk[model_chunk_id](**micro_batch)
|
||||||
|
else:
|
||||||
|
output_obj = model_chunk[model_chunk_id](**input_obj)
|
||||||
|
else:
|
||||||
|
# fwd for shardformer
|
||||||
|
# NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers
|
||||||
|
internal_inputs = {} if input_obj is None else input_obj
|
||||||
|
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
|
||||||
|
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
|
||||||
|
|
||||||
# last layer in model
|
# last layer in model
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
loss = criterion(output_obj) / self.num_microbatch
|
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||||
if accum_loss is not None:
|
if accum_loss is not None:
|
||||||
accum_loss.add_(loss.detach())
|
accum_loss.add_(loss.detach())
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
|
@ -472,19 +496,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# calculate bwd b step ; only dx = w*dy;
|
# calculate bwd b step ; only dx = w*dy;
|
||||||
|
|
||||||
# Retain the grad on the input_obj.
|
# Retain the grad on the input_obj.
|
||||||
|
if input_obj is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
tree_map(retain_grad, input_obj)
|
tree_map(retain_grad, input_obj)
|
||||||
|
input_obj_ = input_obj["hidden_states"]
|
||||||
|
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# loss backward; output_obj is loss; so output_obj_grad should be None
|
# loss backward; output_obj is loss; so output_obj_grad should be None
|
||||||
assert output_obj_grad is None
|
assert output_obj_grad is None
|
||||||
|
output_obj_ = output_obj
|
||||||
|
else:
|
||||||
|
output_obj_ = output_obj["hidden_states"]
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj,
|
tensor=output_obj_,
|
||||||
grad=output_obj_grad,
|
grad=output_obj_grad,
|
||||||
inputs=input_obj,
|
inputs=input_obj_,
|
||||||
retain_graph=True,
|
retain_graph=True,
|
||||||
)
|
)
|
||||||
return input_obj.grad
|
return input_obj_.grad
|
||||||
|
|
||||||
def backward_w_step(
|
def backward_w_step(
|
||||||
self,
|
self,
|
||||||
|
@ -511,8 +541,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# loss backward; output_obj is loss
|
# loss backward; output_obj is loss
|
||||||
output_obj_grad = None
|
output_obj_grad = None
|
||||||
|
output_obj_ = output_obj
|
||||||
|
else:
|
||||||
|
output_obj_ = output_obj["hidden_states"]
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj,
|
tensor=output_obj_,
|
||||||
grad=output_obj_grad,
|
grad=output_obj_grad,
|
||||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||||
retain_graph=False,
|
retain_graph=False,
|
||||||
|
@ -543,9 +576,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||||
# Step1: recv fwd
|
# Step1: recv fwd
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# is first stage; get input from func param
|
# is first stage; get input from microbatch
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
input_obj = micro_batch
|
input_obj = None
|
||||||
else:
|
else:
|
||||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||||
else:
|
else:
|
||||||
|
@ -557,45 +590,68 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||||
|
|
||||||
# Here, let input_obj.requires_grad_()
|
# Here, let input_obj.requires_grad_()
|
||||||
tree_map(torch.Tensor.requires_grad_, input_obj)
|
if input_obj is not None:
|
||||||
|
tree_map(require_grad, input_obj)
|
||||||
|
|
||||||
|
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
|
||||||
|
# tree_map(torch.Tensor.requires_grad_, micro_batch)
|
||||||
|
|
||||||
# Step2: fwd step
|
# Step2: fwd step
|
||||||
output_obj = self.forward_step(
|
output_obj = self.forward_step(
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=model_chunk_id,
|
model_chunk_id=model_chunk_id,
|
||||||
|
micro_batch=micro_batch,
|
||||||
input_obj=input_obj,
|
input_obj=input_obj,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
accum_loss=accum_loss,
|
accum_loss=accum_loss,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Step3: deallocate output for bwd b & w; (do not detach output)
|
||||||
|
deallocate_output_obj = tree_map(clone, output_obj)
|
||||||
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
# We should not deallocate bwd LOSS
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# deallocate output
|
||||||
|
tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), deallocate_output_obj)
|
||||||
|
|
||||||
|
# add input and output object for backward b
|
||||||
|
if input_obj is not None:
|
||||||
|
self.input_tensors[model_chunk_id].append(input_obj)
|
||||||
|
else:
|
||||||
|
self.input_tensors[model_chunk_id].append(micro_batch)
|
||||||
|
|
||||||
|
# for bwd b&w, we only need the graph(grad_fn) of output_obj
|
||||||
|
# Do not deallocate loss, deallocate other output_obj;
|
||||||
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
|
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
|
||||||
|
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)
|
||||||
|
else:
|
||||||
|
self.output_tensors[model_chunk_id].append(deallocate_output_obj)
|
||||||
|
self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj)
|
||||||
|
|
||||||
|
# Step4: detach output for send fwd;
|
||||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# We should not detach bwd LOSS
|
# We should not detach bwd LOSS
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
detached_output_obj = output_obj.clone().detach()
|
# detach output
|
||||||
|
output_obj = tree_map(detach, output_obj)
|
||||||
|
|
||||||
# Step3: send fwd
|
|
||||||
# add output to send_fwd_buffer
|
# add output to send_fwd_buffer
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0: # chunk 0
|
||||||
# is last stage; send to local_send_forward_buffer
|
# is last stage; send to local_send_forward_buffer
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
self.local_send_forward_buffer.append(detached_output_obj)
|
self.local_send_forward_buffer.append(output_obj)
|
||||||
else:
|
else:
|
||||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
||||||
else:
|
else: # chunk 1
|
||||||
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
|
# is first stage; end of fwd; do nothing
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
||||||
|
|
||||||
# add input and output object for backward b
|
|
||||||
self.input_tensors[model_chunk_id].append(input_obj)
|
|
||||||
# detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj
|
|
||||||
deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True)
|
|
||||||
self.output_tensors[model_chunk_id].append(output_obj)
|
|
||||||
# add output object for backward w
|
|
||||||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
|
||||||
|
|
||||||
def schedule_b(
|
def schedule_b(
|
||||||
self,
|
self,
|
||||||
|
@ -603,9 +659,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
model_chunk_id: int,
|
model_chunk_id: int,
|
||||||
optimizer: OptimizerWrapper,
|
optimizer: OptimizerWrapper,
|
||||||
# input_obj: Optional[dict],
|
|
||||||
# output_obj: Union[dict, torch.Tensor],
|
|
||||||
# output_obj_grad: Optional[dict],
|
|
||||||
):
|
):
|
||||||
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
|
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
|
||||||
|
|
||||||
|
@ -616,7 +669,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
Returns:
|
Returns:
|
||||||
Nothing.
|
Nothing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Step1: recv bwd
|
# Step1: recv bwd
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
|
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
|
||||||
|
@ -645,7 +697,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# we save output_tensor_grad here
|
# we save output_tensor_grad here
|
||||||
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||||
|
|
||||||
# _wait_p2p(recv_bwd_handles)
|
|
||||||
# Step2: bwd step
|
# Step2: bwd step
|
||||||
input_object_grad = self.backward_b_step(
|
input_object_grad = self.backward_b_step(
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
|
@ -777,8 +828,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# communication
|
# communication
|
||||||
communication_func = self.communication_map[scheduled_node.type]
|
communication_func = self.communication_map[scheduled_node.type]
|
||||||
communication_func(scheduled_node.chunk)
|
communication_func(scheduled_node.chunk)
|
||||||
|
elif scheduled_node.type == "F":
|
||||||
if scheduled_node.type == "F":
|
|
||||||
self.schedule_f(
|
self.schedule_f(
|
||||||
scheduled_node=scheduled_node,
|
scheduled_node=scheduled_node,
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from functools import partial
|
||||||
|
from types import MethodType
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -16,7 +18,8 @@ from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from tests.kit.model_zoo import model_zoo
|
|
||||||
|
# from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
class MlpModel(nn.Module):
|
class MlpModel(nn.Module):
|
||||||
|
@ -24,10 +27,32 @@ class MlpModel(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
|
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
):
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x)
|
hidden_states = layer(hidden_states)
|
||||||
return x
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def pp_linear_fwd(
|
||||||
|
forward,
|
||||||
|
data: torch.Tensor = None,
|
||||||
|
hidden_states: torch.Tensor = None,
|
||||||
|
stage_mgr: PipelineStageManager = None,
|
||||||
|
model_chunk_id: int = None,
|
||||||
|
):
|
||||||
|
with stage_mgr.switch_model_chunk_id(model_chunk_id):
|
||||||
|
# fwd end
|
||||||
|
if stage_mgr.is_first_stage() and model_chunk_id == 1:
|
||||||
|
return forward(hidden_states)
|
||||||
|
# fwd start
|
||||||
|
elif stage_mgr.is_first_stage() and model_chunk_id == 0:
|
||||||
|
return {"hidden_states": forward(hidden_states)}
|
||||||
|
# fwd middle
|
||||||
|
else:
|
||||||
|
return {"hidden_states": forward(hidden_states)}
|
||||||
|
|
||||||
|
|
||||||
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
|
@ -510,15 +535,15 @@ def run_fwd_bwd_iter_input(test_config):
|
||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"num_model_chunk": 2,
|
"num_model_chunk": 2,
|
||||||
},
|
},
|
||||||
{
|
# {
|
||||||
"batch_size": 8,
|
# "batch_size": 8,
|
||||||
"tp_size": 1,
|
# "tp_size": 1,
|
||||||
"pp_size": 4,
|
# "pp_size": 4,
|
||||||
"num_microbatches": 8,
|
# "num_microbatches": 8,
|
||||||
"zero_stage": 1,
|
# "zero_stage": 1,
|
||||||
"precision": "bf16",
|
# "precision": "bf16",
|
||||||
"num_model_chunk": 2,
|
# "num_model_chunk": 2,
|
||||||
},
|
# },
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_fwd_bwd_vschedule_with_optim(test_config):
|
def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
|
@ -562,6 +587,10 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
|
|
||||||
# init loss func
|
# init loss func
|
||||||
def criterion(x, *args, **kwargs):
|
def criterion(x, *args, **kwargs):
|
||||||
|
x = x["hidden_states"]
|
||||||
|
return (x * x).mean()
|
||||||
|
|
||||||
|
def criterion_base(x, *args, **kwargs):
|
||||||
return (x * x).mean()
|
return (x * x).mean()
|
||||||
|
|
||||||
# init model and input
|
# init model and input
|
||||||
|
@ -572,9 +601,10 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
before_init_memory = torch.cuda.memory_allocated() / 1024**3
|
before_init_memory = torch.cuda.memory_allocated() / 1024**3
|
||||||
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
|
||||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||||
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||||
|
data_iter = {"hidden_states": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
|
||||||
input_base = [t.clone() for t in data_iter]
|
# input_base = [t.clone() for t in data_iter]
|
||||||
|
input_base = {k: v.clone() for k, v in data_iter.items()}
|
||||||
model_base = deepcopy(model)
|
model_base = deepcopy(model)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
@ -582,24 +612,44 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
local_chunk = torch.nn.ModuleList().to(rank)
|
||||||
for idx, sub_model in enumerate(model.layers):
|
for idx, sub_model in enumerate(model.layers):
|
||||||
if idx == 0 or idx == 7:
|
if idx == 0 or idx == 7:
|
||||||
|
sub_model._forward = sub_model.forward
|
||||||
|
sub_model.forward = MethodType(
|
||||||
|
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||||
|
sub_model._forward,
|
||||||
|
)
|
||||||
local_chunk.append(sub_model)
|
local_chunk.append(sub_model)
|
||||||
elif rank == 1:
|
elif rank == 1:
|
||||||
# layer 1 & 6 to chunk 1 on rank1
|
# layer 1 & 6 to chunk 1 on rank1
|
||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
local_chunk = torch.nn.ModuleList().to(rank)
|
||||||
for idx, sub_model in enumerate(model.layers):
|
for idx, sub_model in enumerate(model.layers):
|
||||||
if idx == 1 or idx == 6:
|
if idx == 1 or idx == 6:
|
||||||
|
sub_model._forward = sub_model.forward
|
||||||
|
sub_model.forward = MethodType(
|
||||||
|
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||||
|
sub_model._forward,
|
||||||
|
)
|
||||||
local_chunk.append(sub_model)
|
local_chunk.append(sub_model)
|
||||||
elif rank == 2:
|
elif rank == 2:
|
||||||
# layer 2 & 5 to chunk 2 on rank2
|
# layer 2 & 5 to chunk 2 on rank2
|
||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
local_chunk = torch.nn.ModuleList().to(rank)
|
||||||
for idx, sub_model in enumerate(model.layers):
|
for idx, sub_model in enumerate(model.layers):
|
||||||
if idx == 2 or idx == 5:
|
if idx == 2 or idx == 5:
|
||||||
|
sub_model._forward = sub_model.forward
|
||||||
|
sub_model.forward = MethodType(
|
||||||
|
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||||
|
sub_model._forward,
|
||||||
|
)
|
||||||
local_chunk.append(sub_model)
|
local_chunk.append(sub_model)
|
||||||
else:
|
else:
|
||||||
# layer 3 & 4 to chunk 3 on rank3
|
# layer 3 & 4 to chunk 3 on rank3
|
||||||
local_chunk = torch.nn.ModuleList().to(rank)
|
local_chunk = torch.nn.ModuleList().to(rank)
|
||||||
for idx, sub_model in enumerate(model.layers):
|
for idx, sub_model in enumerate(model.layers):
|
||||||
if idx == 3 or idx == 4:
|
if idx == 3 or idx == 4:
|
||||||
|
sub_model._forward = sub_model.forward
|
||||||
|
sub_model.forward = MethodType(
|
||||||
|
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
|
||||||
|
sub_model._forward,
|
||||||
|
)
|
||||||
local_chunk.append(sub_model)
|
local_chunk.append(sub_model)
|
||||||
|
|
||||||
# init optimizer
|
# init optimizer
|
||||||
|
@ -612,7 +662,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
result = scheduler.forward_backward_step(
|
result = scheduler.forward_backward_step(
|
||||||
model_chunk=local_chunk,
|
model_chunk=local_chunk,
|
||||||
data_iter=iter(data_iter),
|
data_iter=iter([data_iter]),
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
optimizer=optimizer_pp,
|
optimizer=optimizer_pp,
|
||||||
return_loss=True,
|
return_loss=True,
|
||||||
|
@ -643,8 +693,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
# Fwd bwd for base
|
# Fwd bwd for base
|
||||||
##########################
|
##########################
|
||||||
# fwd & bwd
|
# fwd & bwd
|
||||||
output_base = model_base(input_base[0])
|
output_base = model_base(input_base["hidden_states"])
|
||||||
loss_base = criterion(output_base)
|
loss_base = criterion_base(output_base)
|
||||||
loss_base.backward()
|
loss_base.backward()
|
||||||
optimizer_base.step()
|
optimizer_base.step()
|
||||||
|
|
||||||
|
@ -654,7 +704,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
# only chunk 1 stage 0 hold loss and output
|
# only chunk 1 stage 0 hold loss and output
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
assert_close(result["loss"], loss_base)
|
assert_close(result["loss"], loss_base)
|
||||||
assert_close(result["outputs"], output_base)
|
assert_close(result["outputs"]["hidden_states"], output_base)
|
||||||
|
|
||||||
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ")
|
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ")
|
||||||
##########################
|
##########################
|
||||||
|
@ -727,6 +777,7 @@ def run_with_hybridplugin(test_config):
|
||||||
{
|
{
|
||||||
"pp_style": "zbv",
|
"pp_style": "zbv",
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
|
"ep_size": 1,
|
||||||
"pp_size": 4,
|
"pp_size": 4,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
|
@ -737,7 +788,7 @@ def run_with_hybridplugin(test_config):
|
||||||
)
|
)
|
||||||
def run_with_moehybridplugin(test_config):
|
def run_with_moehybridplugin(test_config):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||||
test_config["use_lazy_init"] = False
|
# test_config["use_lazy_init"] = False
|
||||||
test_config["initial_scale"] = 2**16
|
test_config["initial_scale"] = 2**16
|
||||||
model_list = [
|
model_list = [
|
||||||
"transformers_bert",
|
"transformers_bert",
|
||||||
|
@ -749,6 +800,7 @@ def run_with_moehybridplugin(test_config):
|
||||||
# base param
|
# base param
|
||||||
model = model_fn()
|
model = model_fn()
|
||||||
data = data_gen_fn()
|
data = data_gen_fn()
|
||||||
|
print(f"data {data}")
|
||||||
criterion = loss_fn
|
criterion = loss_fn
|
||||||
optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5)
|
optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5)
|
||||||
|
|
||||||
|
@ -787,7 +839,7 @@ def run_with_moehybridplugin(test_config):
|
||||||
# plugin = MoeHybridParallelPlugin(
|
# plugin = MoeHybridParallelPlugin(
|
||||||
# **test_config
|
# **test_config
|
||||||
# )
|
# )
|
||||||
# model_pp, optimizer_pp, criterion, data_pp = plugin.configure(
|
# model_pp, optimizer_pp, criterion, data_pp, _ = plugin.configure(
|
||||||
# model = model_pp,
|
# model = model_pp,
|
||||||
# optimizer = optimizer_pp,
|
# optimizer = optimizer_pp,
|
||||||
# criterion = criterion,
|
# criterion = criterion,
|
||||||
|
@ -806,16 +858,34 @@ def run_with_moehybridplugin(test_config):
|
||||||
|
|
||||||
# TODO:6) support booster & Hybrid base 4)
|
# TODO:6) support booster & Hybrid base 4)
|
||||||
|
|
||||||
|
|
||||||
# TODO:7) support booster & MoEHybrid base 4)
|
# TODO:7) support booster & MoEHybrid base 4)
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"pp_style": "zbv",
|
||||||
|
"tp_size": 1,
|
||||||
|
"ep_size": 1,
|
||||||
|
"pp_size": 4,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"precision": "bf16",
|
||||||
|
"num_model_chunks": 2,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_with_booster_moehybridplugin(test_config):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
# run_fwd_bwd_iter_input()
|
# run_fwd_bwd_iter_input()
|
||||||
# run_fwd_bwd_vschedule_with_optim()
|
run_fwd_bwd_vschedule_with_optim()
|
||||||
# run_with_moehybridplugin()
|
# run_with_moehybridplugin()
|
||||||
run_with_moehybridplugin()
|
# run_with_booster_moehybridplugin()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
|
Loading…
Reference in New Issue