mirror of https://github.com/hpcaitech/ColossalAI
[feat] add test run_fwd_bwd automatic scheduling;
parent
fd5526b76e
commit
1d75045c37
|
@ -12,8 +12,8 @@ class ScheduledNode:
|
|||
chunk: int
|
||||
stage: int
|
||||
minibatch: int
|
||||
start_time: int
|
||||
completion_time: int
|
||||
# start_time: int
|
||||
# completion_time: int
|
||||
rollback: bool = False
|
||||
|
||||
|
||||
|
|
|
@ -176,6 +176,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
|
||||
#################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.recv_forward_buffer[model_chunk_id].append(None)
|
||||
return None, []
|
||||
|
||||
################
|
||||
|
@ -188,6 +189,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# metadata_recv=self.tensor_metadata_recv
|
||||
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
return input_tensor, wait_handles
|
||||
|
||||
else:
|
||||
|
@ -200,7 +202,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
|
||||
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
return input_tensor, []
|
||||
|
||||
################
|
||||
|
@ -214,7 +216,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# metadata_recv=self.tensor_metadata_recv
|
||||
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
||||
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
||||
|
||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||
return input_tensor, wait_handles
|
||||
|
||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
|
||||
|
@ -240,6 +242,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
output_tensor_grad = self.local_send_backward_buffer.pop(0)
|
||||
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
return output_tensor_grad, []
|
||||
|
||||
################
|
||||
|
@ -252,6 +255,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# metadata_recv=self.grad_metadata_recv
|
||||
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
else:
|
||||
|
@ -261,6 +265,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# do nothing; get loss from local
|
||||
################
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.recv_backward_buffer[model_chunk_id].append(None)
|
||||
return None, []
|
||||
|
||||
################
|
||||
|
@ -268,16 +273,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# self.comm.recv_backward recv bwd from prev stage;
|
||||
################
|
||||
else:
|
||||
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank)
|
||||
|
||||
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} output_tensor_grad {output_tensor_grad};\n buffer {self.recv_backward_buffer}")
|
||||
# metadata_recv=self.grad_metadata_recv
|
||||
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
||||
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||
return output_tensor_grad, wait_handles
|
||||
|
||||
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List:
|
||||
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
For ZBV.
|
||||
|
||||
|
@ -291,6 +296,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
"""
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
||||
if model_chunk_id == 0:
|
||||
################
|
||||
# chunk = 0 && is_last_stage
|
||||
|
@ -330,7 +336,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# self.send_tensor_metadata = not self.enable_metadata_cache
|
||||
return send_handles
|
||||
|
||||
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List:
|
||||
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
For ZBV.
|
||||
|
||||
|
@ -359,6 +365,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# Send dx to PREV stage;
|
||||
################
|
||||
else:
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
prev_rank = self.stage_manager.get_prev_rank()
|
||||
send_handles = self.comm.send_backward(input_tensor_grad, prev_rank)
|
||||
# send_metadata=self.send_grad_metadata
|
||||
|
@ -371,6 +378,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# hold dy to local_send_bwd_buffer;
|
||||
################
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
self.local_send_backward_buffer.append(input_tensor_grad)
|
||||
return []
|
||||
|
||||
|
@ -379,6 +387,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# Send dx to NEXT stage;
|
||||
################
|
||||
else:
|
||||
print(
|
||||
f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} send_backward_buffer {self.send_backward_buffer}"
|
||||
)
|
||||
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
||||
next_rank = self.stage_manager.get_next_rank()
|
||||
# print(f"send bwd input_tensor_grad {input_tensor_grad}")
|
||||
send_handles = self.comm.send_backward(input_tensor_grad, next_rank)
|
||||
|
@ -413,6 +425,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# Only attention_mask from micro_batch is used
|
||||
|
||||
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
|
||||
# fwd calculate
|
||||
output_obj = model_chunk[model_chunk_id](input_obj)
|
||||
# last layer in model
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
|
@ -463,6 +476,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# commom bwd step
|
||||
# print(f"bwd output_obj {output_obj} output_obj_grad {output_obj_grad} input_obj {input_obj}")
|
||||
# BUG:output_obj_grad is None
|
||||
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; tensor {output_obj};\n grad_tensors {output_obj_grad};\n inputs {input_obj}\n")
|
||||
torch.autograd.backward(
|
||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||
)
|
||||
|
@ -505,14 +519,21 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
outputs: Optional[List[Any]] = None,
|
||||
):
|
||||
# Step1: recv fwd
|
||||
# if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# # first layer
|
||||
# input_obj = input_obj
|
||||
# else:
|
||||
# # other layer
|
||||
# input_obj, wait_handles = self.recv_forward(model_chunk_id)
|
||||
# # print(f"recv input_obj {input_obj}")
|
||||
# _wait_p2p(wait_handles)
|
||||
|
||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# first layer
|
||||
input_obj = input_obj
|
||||
self.recv_forward_buffer[model_chunk_id].pop(0) # pop none
|
||||
else:
|
||||
# other layer
|
||||
input_obj, wait_handles = self.recv_forward(model_chunk_id)
|
||||
# print(f"recv input_obj {input_obj}")
|
||||
_wait_p2p(wait_handles)
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
# Step2: fwd step
|
||||
output_obj = self.forward_step(
|
||||
model_chunk=model_chunk,
|
||||
|
@ -522,6 +543,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
# print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}")
|
||||
|
||||
# add input and output object for backward b
|
||||
|
@ -532,7 +554,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||
|
||||
# Step3: send fwd
|
||||
send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
|
||||
# add output to send_fwd_buffer
|
||||
self.send_forward_buffer[model_chunk_id].append(output_obj)
|
||||
# send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
|
||||
|
||||
def schedule_b(
|
||||
self,
|
||||
|
@ -545,17 +569,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# output_obj_grad: Optional[dict],
|
||||
):
|
||||
# Step1: recv bwd
|
||||
# not first stage and chunk 1
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
output_tensor_grad, recv_bwd_handles = None, []
|
||||
# print(f"recv output_tensor_grad {output_tensor_grad}")
|
||||
else:
|
||||
output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
|
||||
# print(f"recv output_tensor_grad {output_tensor_grad}")
|
||||
# # not first stage and chunk 1
|
||||
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# output_tensor_grad, recv_bwd_handles = None, []
|
||||
# # print(f"recv output_tensor_grad {output_tensor_grad}")
|
||||
# else:
|
||||
# output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
|
||||
# # print(f"recv output_tensor_grad {output_tensor_grad}")
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n")
|
||||
|
||||
# get input and output object from buffer;
|
||||
input_obj = self.input_tensors[model_chunk_id].pop()
|
||||
output_obj = self.output_tensors[model_chunk_id].pop()
|
||||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
||||
|
||||
# save output_tensor_grad for dw
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
|
@ -565,9 +592,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# we save output_tensor_grad here
|
||||
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||
|
||||
_wait_p2p(recv_bwd_handles)
|
||||
# _wait_p2p(recv_bwd_handles)
|
||||
# print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}")
|
||||
# Step2: bwd step
|
||||
|
||||
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}")
|
||||
|
||||
input_object_grad = self.backward_b_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
|
@ -576,23 +606,23 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
output_obj=output_obj,
|
||||
output_obj_grad=output_tensor_grad,
|
||||
)
|
||||
print(f"input_object_grad {input_object_grad}")
|
||||
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}")
|
||||
|
||||
# Step3: send bwd
|
||||
send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad)
|
||||
# send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad)
|
||||
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
|
||||
|
||||
def schedule_w(
|
||||
self,
|
||||
scheduled_node,
|
||||
non_w_pending,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
):
|
||||
|
||||
# get y & dy from buffer
|
||||
output_obj = self.output_tensors_dw[model_chunk_id].pop()
|
||||
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop()
|
||||
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
||||
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
|
||||
|
||||
self.backward_w_step(
|
||||
model_chunk=model_chunk,
|
||||
|
@ -605,6 +635,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
def run_forward_backward(
|
||||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
input_obj: Optional[dict],
|
||||
data_iter: Iterable,
|
||||
criterion: Callable[..., Any],
|
||||
optimizer: Optional[OptimizerWrapper] = None,
|
||||
|
@ -615,19 +646,37 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# while we still have schedules_node in self.schedules
|
||||
while it < len(self.schedules):
|
||||
scheduled_node = self.schedules[it]
|
||||
print(f"it {it}; scheduled_node {scheduled_node};")
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
if scheduled_node.type == "RECV_FORWARD":
|
||||
self.recv_forward()
|
||||
self.recv_forward(scheduled_node.chunk)
|
||||
elif scheduled_node.type == "RECV_BACKWARD":
|
||||
self.recv_backward()
|
||||
self.recv_backward(scheduled_node.chunk)
|
||||
elif scheduled_node.type == "SEND_FORWARD":
|
||||
self.send_forward()
|
||||
self.send_forward(scheduled_node.chunk)
|
||||
elif scheduled_node.type == "SEND_BACKWARD":
|
||||
self.send_backward()
|
||||
elif scheduled_node.type == "F":
|
||||
self.schedule_f()
|
||||
self.send_backward(scheduled_node.chunk)
|
||||
if scheduled_node.type == "F":
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
input_obj=input_obj,
|
||||
criterion=criterion,
|
||||
accum_loss=return_loss,
|
||||
outputs=return_outputs,
|
||||
)
|
||||
elif scheduled_node.type == "B":
|
||||
self.schedule_b()
|
||||
self.schedule_b(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
)
|
||||
elif scheduled_node.type == "W":
|
||||
self.schedule_w()
|
||||
self.schedule_w(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
)
|
||||
it += 1
|
||||
|
|
|
@ -1176,17 +1176,8 @@ def model_chunk_dx_dw_comm_interleaved(
|
|||
print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||
|
||||
|
||||
def run_fwd_bwd(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dx_dw_dist():
|
||||
|
||||
spawn(
|
||||
model_chunk_dx_dw_comm_interleaved,
|
||||
nprocs=4,
|
|
@ -8,6 +8,7 @@ from torch.testing import assert_close
|
|||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
@ -34,6 +35,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
|||
return num_params, num_params_trainable
|
||||
|
||||
|
||||
# Test baseline; An 8 layer MLP do Zerobubble Pipeline on 4 node pp group;
|
||||
def test_zerobubble_pipeline_base(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
|
@ -427,18 +429,187 @@ def test_zerobubble_pipeline_base(
|
|||
assert_close(chunk_3[1].weight.grad, model_base.layers[4].weight.grad)
|
||||
|
||||
|
||||
# Test run_forward_backward with baseline;
|
||||
def test_run_fwd_bwd_base(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
):
|
||||
# init dist
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
rank = dist.get_rank()
|
||||
pp_size = world_size
|
||||
pg_mesh = ProcessGroupMesh(pp_size)
|
||||
|
||||
# stage_manager
|
||||
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size)
|
||||
|
||||
# schedule list
|
||||
zbv_schedule = [
|
||||
# stage 0
|
||||
[
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0),
|
||||
ScheduledNode(type="F", chunk=0, stage=0, minibatch=0),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=0),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=0),
|
||||
ScheduledNode(type="F", chunk=1, stage=0, minibatch=0),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=0),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=0),
|
||||
ScheduledNode(type="B", chunk=1, stage=0, minibatch=0),
|
||||
ScheduledNode(type="W", chunk=1, stage=0, minibatch=0),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=0),
|
||||
ScheduledNode(type="B", chunk=0, stage=0, minibatch=0),
|
||||
ScheduledNode(type="W", chunk=0, stage=0, minibatch=0),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
|
||||
],
|
||||
# stage 1
|
||||
[
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0),
|
||||
ScheduledNode(type="F", chunk=0, stage=1, minibatch=0),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=0),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=0),
|
||||
ScheduledNode(type="F", chunk=1, stage=1, minibatch=0),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=0),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=0),
|
||||
ScheduledNode(type="B", chunk=1, stage=1, minibatch=0),
|
||||
ScheduledNode(type="W", chunk=1, stage=1, minibatch=0),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=0),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0),
|
||||
ScheduledNode(type="B", chunk=0, stage=1, minibatch=0),
|
||||
ScheduledNode(type="W", chunk=0, stage=1, minibatch=0),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0),
|
||||
],
|
||||
# stage 2
|
||||
[
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0),
|
||||
ScheduledNode(type="F", chunk=0, stage=2, minibatch=0),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=0),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=0),
|
||||
ScheduledNode(type="F", chunk=1, stage=2, minibatch=0),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=0),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=0),
|
||||
ScheduledNode(type="B", chunk=1, stage=2, minibatch=0),
|
||||
ScheduledNode(type="W", chunk=1, stage=2, minibatch=0),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=0),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0),
|
||||
ScheduledNode(type="B", chunk=0, stage=2, minibatch=0),
|
||||
ScheduledNode(type="W", chunk=0, stage=2, minibatch=0),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0), # Send nothing
|
||||
],
|
||||
# stage 3
|
||||
[
|
||||
# chunk 0 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0),
|
||||
ScheduledNode(type="F", chunk=0, stage=3, minibatch=0),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=0),
|
||||
# chunk 1 fwd
|
||||
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=0),
|
||||
ScheduledNode(type="F", chunk=1, stage=3, minibatch=0),
|
||||
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=0),
|
||||
# chunk 1 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=0),
|
||||
ScheduledNode(type="B", chunk=1, stage=3, minibatch=0),
|
||||
ScheduledNode(type="W", chunk=1, stage=3, minibatch=0),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=0),
|
||||
# chunk 0 bwd
|
||||
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=0),
|
||||
ScheduledNode(type="B", chunk=0, stage=3, minibatch=0),
|
||||
ScheduledNode(type="W", chunk=0, stage=3, minibatch=0),
|
||||
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0),
|
||||
],
|
||||
]
|
||||
|
||||
scheduler = ZeroBubbleVPipeScheduler(
|
||||
schedule=zbv_schedule[rank],
|
||||
stage_manager=stage_manager,
|
||||
num_model_chunks=pp_size,
|
||||
num_microbatch=1,
|
||||
overlap_p2p=False,
|
||||
)
|
||||
|
||||
# loss func
|
||||
def criterion(x, *args, **kwargs):
|
||||
return (x * x).mean()
|
||||
|
||||
# init model and input
|
||||
num_layers = 8
|
||||
in_dim = out_dim = 8
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
|
||||
|
||||
input0.clone()
|
||||
deepcopy(model)
|
||||
|
||||
if rank == 0:
|
||||
# layer 0 & 7 to chunk 0 on rank0
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 0 or idx == 7:
|
||||
local_chunk.append(sub_model)
|
||||
elif rank == 1:
|
||||
# layer 1 & 6 to chunk 1 on rank1
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 1 or idx == 6:
|
||||
local_chunk.append(sub_model)
|
||||
elif rank == 2:
|
||||
# layer 2 & 5 to chunk 2 on rank2
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 2 or idx == 5:
|
||||
local_chunk.append(sub_model)
|
||||
else:
|
||||
# layer 3 & 4 to chunk 3 on rank3
|
||||
local_chunk = torch.nn.Sequential().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 3 or idx == 4:
|
||||
local_chunk.append(sub_model)
|
||||
print(
|
||||
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
scheduler.run_forward_backward(
|
||||
model_chunk=local_chunk,
|
||||
input_obj=input0,
|
||||
data_iter=None,
|
||||
criterion=criterion,
|
||||
optimizer=None,
|
||||
return_loss=None,
|
||||
return_outputs=None,
|
||||
)
|
||||
|
||||
|
||||
# @pytest.mark.dist
|
||||
# @pytest.mark.parametrize("num_microbatch", [4])
|
||||
# @pytest.mark.parametrize("batch_size", [4])
|
||||
# @pytest.mark.parametrize("num_model_chunk", [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pp():
|
||||
# spawn(
|
||||
# test_zerobubble_pipeline_base,
|
||||
# nprocs=4,
|
||||
# )
|
||||
|
||||
spawn(
|
||||
test_zerobubble_pipeline_base,
|
||||
test_run_fwd_bwd_base,
|
||||
nprocs=4,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
test_pp()
|
||||
|
|
Loading…
Reference in New Issue