mirror of https://github.com/hpcaitech/ColossalAI
[feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass;
parent
9e0bd1af00
commit
8b37323f16
|
@ -495,7 +495,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
scheduled_node,
|
||||
model_chunk: torch.nn.ModuleList,
|
||||
model_chunk_id: int,
|
||||
input_obj: Optional[dict],
|
||||
criterion: Callable,
|
||||
accum_loss: Optional[torch.Tensor] = None,
|
||||
outputs: Optional[List[Any]] = None,
|
||||
|
@ -506,7 +505,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
scheduled_node:
|
||||
model_chunk (ModuleList or Module): Model Chunk to be run;
|
||||
model_chunk_id (int): The current model chunk idx;
|
||||
input_obj (Optional[dict]): x;
|
||||
criterion (Callable): loss function;
|
||||
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
|
||||
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
|
||||
|
@ -518,7 +516,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
if model_chunk_id == 0:
|
||||
# is first stage; get input from func param
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
input_obj = input_obj
|
||||
input_obj = self.load_micro_batch(model_chunk_id=model_chunk_id)
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
else:
|
||||
|
@ -671,7 +669,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
def run_forward_backward(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
input_obj: Optional[dict],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
|
@ -683,7 +680,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
"""
|
||||
# # prepare batch
|
||||
self.load_batch(data_iter)
|
||||
# print(f"self.batch {self.batch}; self.batch_size {self.batch_size}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}")
|
||||
print(
|
||||
f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}"
|
||||
)
|
||||
|
||||
it = 0
|
||||
# while we still have schedules_node in self.schedules
|
||||
|
@ -707,7 +706,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
input_obj=input_obj,
|
||||
criterion=criterion,
|
||||
accum_loss=return_loss,
|
||||
outputs=return_outputs,
|
||||
|
|
|
@ -36,8 +36,8 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
|||
return num_params, num_params_trainable
|
||||
|
||||
|
||||
# Test run_forward_backward with baseline;
|
||||
def test_run_fwd_bwd_base(
|
||||
# Test iter input & multiple microbatch
|
||||
def test_run_fwd_bwd_iter_input(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
|
@ -47,7 +47,7 @@ def test_run_fwd_bwd_base(
|
|||
rank = dist.get_rank()
|
||||
pp_size = world_size
|
||||
pg_mesh = ProcessGroupMesh(pp_size)
|
||||
|
||||
num_microbatch = 4
|
||||
# stage_manager
|
||||
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size)
|
||||
|
||||
|
@ -55,6 +55,7 @@ def test_run_fwd_bwd_base(
|
|||
zbv_schedule = [
|
||||
# stage 0
|
||||
[
|
||||
# microbatch 0
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0),
|
||||
ScheduledNode(type="F", chunk=0, stage=0, minibatch=0),
|
||||
|
@ -73,9 +74,67 @@ def test_run_fwd_bwd_base(
|
|||
ScheduledNode(type="B", chunk=0, stage=0, minibatch=0),
|
||||
ScheduledNode(type="W", chunk=0, stage=0, minibatch=0),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
|
||||
# microbatch 1
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=1),
|
||||
ScheduledNode(type="F", chunk=0, stage=0, minibatch=1),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=1),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=1),
|
||||
ScheduledNode(type="F", chunk=1, stage=0, minibatch=1),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=1),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=1),
|
||||
ScheduledNode(type="B", chunk=1, stage=0, minibatch=1),
|
||||
ScheduledNode(type="W", chunk=1, stage=0, minibatch=1),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=1),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=1),
|
||||
ScheduledNode(type="B", chunk=0, stage=0, minibatch=1),
|
||||
ScheduledNode(type="W", chunk=0, stage=0, minibatch=1),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1),
|
||||
# microbatch 2
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=2),
|
||||
ScheduledNode(type="F", chunk=0, stage=0, minibatch=2),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=2),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=2),
|
||||
ScheduledNode(type="F", chunk=1, stage=0, minibatch=2),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=2),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=2),
|
||||
ScheduledNode(type="B", chunk=1, stage=0, minibatch=2),
|
||||
ScheduledNode(type="W", chunk=1, stage=0, minibatch=2),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=2),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=2),
|
||||
ScheduledNode(type="B", chunk=0, stage=0, minibatch=2),
|
||||
ScheduledNode(type="W", chunk=0, stage=0, minibatch=2),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2),
|
||||
# microbatch 3
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=3),
|
||||
ScheduledNode(type="F", chunk=0, stage=0, minibatch=3),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=3),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=3),
|
||||
ScheduledNode(type="F", chunk=1, stage=0, minibatch=3),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=3),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=3),
|
||||
ScheduledNode(type="B", chunk=1, stage=0, minibatch=3),
|
||||
ScheduledNode(type="W", chunk=1, stage=0, minibatch=3),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=3),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=3),
|
||||
ScheduledNode(type="B", chunk=0, stage=0, minibatch=3),
|
||||
ScheduledNode(type="W", chunk=0, stage=0, minibatch=3),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3),
|
||||
],
|
||||
# stage 1
|
||||
[
|
||||
# microbatch 0
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0),
|
||||
ScheduledNode(type="F", chunk=0, stage=1, minibatch=0),
|
||||
|
@ -94,9 +153,67 @@ def test_run_fwd_bwd_base(
|
|||
ScheduledNode(type="B", chunk=0, stage=1, minibatch=0),
|
||||
ScheduledNode(type="W", chunk=0, stage=1, minibatch=0),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
|
||||
# microbatch 1
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=1),
|
||||
ScheduledNode(type="F", chunk=0, stage=1, minibatch=1),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=1),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=1),
|
||||
ScheduledNode(type="F", chunk=1, stage=1, minibatch=1),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=1),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=1),
|
||||
ScheduledNode(type="B", chunk=1, stage=1, minibatch=1),
|
||||
ScheduledNode(type="W", chunk=1, stage=1, minibatch=1),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=1),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=1),
|
||||
ScheduledNode(type="B", chunk=0, stage=1, minibatch=1),
|
||||
ScheduledNode(type="W", chunk=0, stage=1, minibatch=1),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1),
|
||||
# microbatch 2
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=2),
|
||||
ScheduledNode(type="F", chunk=0, stage=1, minibatch=2),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=2),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=2),
|
||||
ScheduledNode(type="F", chunk=1, stage=1, minibatch=2),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=2),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=2),
|
||||
ScheduledNode(type="B", chunk=1, stage=1, minibatch=2),
|
||||
ScheduledNode(type="W", chunk=1, stage=1, minibatch=2),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=2),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=2),
|
||||
ScheduledNode(type="B", chunk=0, stage=1, minibatch=2),
|
||||
ScheduledNode(type="W", chunk=0, stage=1, minibatch=2),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2),
|
||||
# microbatch 3
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=3),
|
||||
ScheduledNode(type="F", chunk=0, stage=1, minibatch=3),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=3),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=3),
|
||||
ScheduledNode(type="F", chunk=1, stage=1, minibatch=3),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=3),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=3),
|
||||
ScheduledNode(type="B", chunk=1, stage=1, minibatch=3),
|
||||
ScheduledNode(type="W", chunk=1, stage=1, minibatch=3),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=3),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=3),
|
||||
ScheduledNode(type="B", chunk=0, stage=1, minibatch=3),
|
||||
ScheduledNode(type="W", chunk=0, stage=1, minibatch=3),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3),
|
||||
],
|
||||
# stage 2
|
||||
[
|
||||
# microbatch 0
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0),
|
||||
ScheduledNode(type="F", chunk=0, stage=2, minibatch=0),
|
||||
|
@ -114,10 +231,68 @@ def test_run_fwd_bwd_base(
|
|||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0),
|
||||
ScheduledNode(type="B", chunk=0, stage=2, minibatch=0),
|
||||
ScheduledNode(type="W", chunk=0, stage=2, minibatch=0),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0), # Send nothing
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0),
|
||||
# microbatch 1
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=1),
|
||||
ScheduledNode(type="F", chunk=0, stage=2, minibatch=1),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=1),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=1),
|
||||
ScheduledNode(type="F", chunk=1, stage=2, minibatch=1),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=1),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=1),
|
||||
ScheduledNode(type="B", chunk=1, stage=2, minibatch=1),
|
||||
ScheduledNode(type="W", chunk=1, stage=2, minibatch=1),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=1),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=1),
|
||||
ScheduledNode(type="B", chunk=0, stage=2, minibatch=1),
|
||||
ScheduledNode(type="W", chunk=0, stage=2, minibatch=1),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=1),
|
||||
# microbatch 2
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=2),
|
||||
ScheduledNode(type="F", chunk=0, stage=2, minibatch=2),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=2),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=2),
|
||||
ScheduledNode(type="F", chunk=1, stage=2, minibatch=2),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=2),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=2),
|
||||
ScheduledNode(type="B", chunk=1, stage=2, minibatch=2),
|
||||
ScheduledNode(type="W", chunk=1, stage=2, minibatch=2),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=2),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=2),
|
||||
ScheduledNode(type="B", chunk=0, stage=2, minibatch=2),
|
||||
ScheduledNode(type="W", chunk=0, stage=2, minibatch=2),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=2),
|
||||
# microbatch 3
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=3),
|
||||
ScheduledNode(type="F", chunk=0, stage=2, minibatch=3),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=3),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=3),
|
||||
ScheduledNode(type="F", chunk=1, stage=2, minibatch=3),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=3),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=3),
|
||||
ScheduledNode(type="B", chunk=1, stage=2, minibatch=3),
|
||||
ScheduledNode(type="W", chunk=1, stage=2, minibatch=3),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=3),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=3),
|
||||
ScheduledNode(type="B", chunk=0, stage=2, minibatch=3),
|
||||
ScheduledNode(type="W", chunk=0, stage=2, minibatch=3),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=3),
|
||||
],
|
||||
# stage 3
|
||||
[
|
||||
# microbatch 0
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0),
|
||||
ScheduledNode(type="F", chunk=0, stage=3, minibatch=0),
|
||||
|
@ -136,6 +311,63 @@ def test_run_fwd_bwd_base(
|
|||
ScheduledNode(type="B", chunk=0, stage=3, minibatch=0),
|
||||
ScheduledNode(type="W", chunk=0, stage=3, minibatch=0),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0),
|
||||
# microbatch 1
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=1),
|
||||
ScheduledNode(type="F", chunk=0, stage=3, minibatch=1),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=1),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=1),
|
||||
ScheduledNode(type="F", chunk=1, stage=3, minibatch=1),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=1),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=1),
|
||||
ScheduledNode(type="B", chunk=1, stage=3, minibatch=1),
|
||||
ScheduledNode(type="W", chunk=1, stage=3, minibatch=1),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=1),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=1),
|
||||
ScheduledNode(type="B", chunk=0, stage=3, minibatch=1),
|
||||
ScheduledNode(type="W", chunk=0, stage=3, minibatch=1),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=1),
|
||||
# microbatch 2
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=2),
|
||||
ScheduledNode(type="F", chunk=0, stage=3, minibatch=2),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=2),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=2),
|
||||
ScheduledNode(type="F", chunk=1, stage=3, minibatch=2),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=2),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=2),
|
||||
ScheduledNode(type="B", chunk=1, stage=3, minibatch=2),
|
||||
ScheduledNode(type="W", chunk=1, stage=3, minibatch=2),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=2),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=2),
|
||||
ScheduledNode(type="B", chunk=0, stage=3, minibatch=2),
|
||||
ScheduledNode(type="W", chunk=0, stage=3, minibatch=2),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=2),
|
||||
# microbatch 3
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=3),
|
||||
ScheduledNode(type="F", chunk=0, stage=3, minibatch=3),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=3),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=3),
|
||||
ScheduledNode(type="F", chunk=1, stage=3, minibatch=3),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=3),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=3),
|
||||
ScheduledNode(type="B", chunk=1, stage=3, minibatch=3),
|
||||
ScheduledNode(type="W", chunk=1, stage=3, minibatch=3),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=3),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=3),
|
||||
ScheduledNode(type="B", chunk=0, stage=3, minibatch=3),
|
||||
ScheduledNode(type="W", chunk=0, stage=3, minibatch=3),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=3),
|
||||
],
|
||||
]
|
||||
|
||||
|
@ -143,7 +375,7 @@ def test_run_fwd_bwd_base(
|
|||
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
|
||||
stage_manager=stage_manager,
|
||||
num_model_chunks=pp_size,
|
||||
num_microbatch=1,
|
||||
num_microbatch=num_microbatch,
|
||||
overlap_p2p=False,
|
||||
)
|
||||
|
||||
|
@ -152,14 +384,15 @@ def test_run_fwd_bwd_base(
|
|||
return (x * x).mean()
|
||||
|
||||
# init model and input
|
||||
batch_size = 4
|
||||
num_layers = 8
|
||||
in_dim = out_dim = 8
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
|
||||
# data_iter = [input0]
|
||||
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||
|
||||
input_base = input0.clone()
|
||||
[t.clone() for t in data_iter]
|
||||
model_base = deepcopy(model)
|
||||
|
||||
if rank == 0:
|
||||
|
@ -193,9 +426,7 @@ def test_run_fwd_bwd_base(
|
|||
torch.cuda.synchronize()
|
||||
scheduler.run_forward_backward(
|
||||
model_chunk=local_chunk,
|
||||
input_obj=input0,
|
||||
# data_iter=iter(data_iter),
|
||||
data_iter=None,
|
||||
data_iter=iter(data_iter),
|
||||
criterion=criterion,
|
||||
optimizer=None,
|
||||
return_loss=None,
|
||||
|
@ -206,8 +437,7 @@ def test_run_fwd_bwd_base(
|
|||
# Fwd bwd for base
|
||||
##########################
|
||||
# fwd & bwd
|
||||
output_base = model_base(input_base)
|
||||
# loss_base = output_base.mean()
|
||||
output_base = model_base(data_iter[0])
|
||||
loss_base = criterion(output_base)
|
||||
loss_base.backward()
|
||||
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
@ -245,15 +475,6 @@ def test_run_fwd_bwd_base(
|
|||
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
||||
|
||||
|
||||
# Test iter input & multiple microbatch
|
||||
def test_run_fwd_bwd_iter_input(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
# @pytest.mark.parametrize("num_microbatch", [4])
|
||||
# @pytest.mark.parametrize("batch_size", [4])
|
||||
|
@ -261,7 +482,7 @@ def test_run_fwd_bwd_iter_input(
|
|||
@rerun_if_address_is_in_use()
|
||||
def test_pp():
|
||||
spawn(
|
||||
test_run_fwd_bwd_base,
|
||||
test_run_fwd_bwd_iter_input,
|
||||
nprocs=4,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue