[feat] add fwd_bwd_step, run_fwd_only;

pull/6034/head
duanjunwen 3 months ago
parent 48ba22dbfd
commit 6af81d8c0d

@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.cuda
@ -696,6 +696,54 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_grad=output_obj_grad,
)
def run_forward_only(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
assert self.forward_only
# prepare batch
self.load_batch(data_iter)
# prepare accum loss & output
accum_loss = None
# reset accum loss at fwd end;
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
it = 0
# while we still have schedules_node in self.schedules
while it < len(self.schedules):
scheduled_node = self.schedules[it]
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
if scheduled_node.type == "RECV_FORWARD":
self.recv_forward(scheduled_node.chunk)
elif scheduled_node.type == "SEND_FORWARD":
self.send_forward(scheduled_node.chunk)
if scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
criterion=criterion,
accum_loss=accum_loss,
outputs=outputs,
)
it += 1
# return loss & output
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
self,
model_chunk: Union[ModuleList, Module],
@ -704,7 +752,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
):
) -> Dict:
"""
Runs Zerobubble schedule, with communication between pipeline stages.
"""
@ -770,3 +818,37 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
if self.forward_only:
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
)
return result

@ -644,10 +644,10 @@ def run_fwd_bwd_vschedule_with_optim(
graph = PipelineGraph(
n_stage=world_size,
n_micro=num_microbatch,
f_cost=6,
b_cost=6,
w_cost=6,
c_cost=6,
f_cost=1,
b_cost=1,
w_cost=1,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
@ -714,7 +714,7 @@ def run_fwd_bwd_vschedule_with_optim(
)
torch.cuda.synchronize()
result = scheduler.run_forward_backward(
result = scheduler.forward_backward_step(
model_chunk=local_chunk,
data_iter=iter(data_iter),
criterion=criterion,
@ -793,6 +793,25 @@ def run_fwd_bwd_vschedule_with_optim(
assert val_base[:2] == val_pp
# 4) support Hybrid base 3)
def run_with_hybrid(
rank: int,
world_size: int,
port: int,
num_microbatch: int,
batch_size: int,
num_model_chunk: int,
):
pass
# 5) support MoE base 3)
# 6) support booster & Hybrid base 4)
# 6) support booster & MoE base 4)
@pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4])
@pytest.mark.parametrize("batch_size", [4])

Loading…
Cancel
Save