[feat] add test run_fwd_bwd automatic scheduling;

pull/6034/head
duanjunwen 2024-08-26 11:21:56 +00:00
parent fd5526b76e
commit 1d75045c37
4 changed files with 259 additions and 48 deletions

View File

@ -12,8 +12,8 @@ class ScheduledNode:
chunk: int
stage: int
minibatch: int
start_time: int
completion_time: int
# start_time: int
# completion_time: int
rollback: bool = False

View File

@ -176,6 +176,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
#################
if self.stage_manager.is_first_stage(ignore_chunk=True):
self.recv_forward_buffer[model_chunk_id].append(None)
return None, []
################
@ -188,6 +189,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# metadata_recv=self.tensor_metadata_recv
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, wait_handles
else:
@ -200,7 +202,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, []
################
@ -214,7 +216,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# metadata_recv=self.tensor_metadata_recv
# if self.enable_metadata_cache and self.tensor_metadata_recv is None:
# self.tensor_metadata_recv = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
return input_tensor, wait_handles
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
@ -240,6 +242,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_tensor_grad = self.local_send_backward_buffer.pop(0)
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, []
################
@ -252,6 +255,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# metadata_recv=self.grad_metadata_recv
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, wait_handles
else:
@ -261,6 +265,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; get loss from local
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
self.recv_backward_buffer[model_chunk_id].append(None)
return None, []
################
@ -268,16 +273,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# self.comm.recv_backward recv bwd from prev stage;
################
else:
prev_rank = self.stage_manager.get_prev_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank)
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} output_tensor_grad {output_tensor_grad};\n buffer {self.recv_backward_buffer}")
# metadata_recv=self.grad_metadata_recv
# if self.enable_metadata_cache and self.grad_metadata_recv is None:
# self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
return output_tensor_grad, wait_handles
def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = None) -> List:
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
"""Sends the input tensor to the next stage in pipeline.
For ZBV.
@ -291,6 +296,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
if model_chunk_id == 0:
################
# chunk = 0 && is_last_stage
@ -330,7 +336,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# self.send_tensor_metadata = not self.enable_metadata_cache
return send_handles
def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: int = None) -> List:
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
"""Sends the gradient tensor to the previous stage in pipeline.
For ZBV.
@ -359,6 +365,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Send dx to PREV stage;
################
else:
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
prev_rank = self.stage_manager.get_prev_rank()
send_handles = self.comm.send_backward(input_tensor_grad, prev_rank)
# send_metadata=self.send_grad_metadata
@ -371,6 +378,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# hold dy to local_send_bwd_buffer;
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
self.local_send_backward_buffer.append(input_tensor_grad)
return []
@ -379,6 +387,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Send dx to NEXT stage;
################
else:
print(
f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} send_backward_buffer {self.send_backward_buffer}"
)
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
next_rank = self.stage_manager.get_next_rank()
# print(f"send bwd input_tensor_grad {input_tensor_grad}")
send_handles = self.comm.send_backward(input_tensor_grad, next_rank)
@ -413,6 +425,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
# fwd calculate
output_obj = model_chunk[model_chunk_id](input_obj)
# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
@ -463,6 +476,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# commom bwd step
# print(f"bwd output_obj {output_obj} output_obj_grad {output_obj_grad} input_obj {input_obj}")
# BUG:output_obj_grad is None
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; tensor {output_obj};\n grad_tensors {output_obj_grad};\n inputs {input_obj}\n")
torch.autograd.backward(
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
)
@ -505,14 +519,21 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs: Optional[List[Any]] = None,
):
# Step1: recv fwd
# if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
# # first layer
# input_obj = input_obj
# else:
# # other layer
# input_obj, wait_handles = self.recv_forward(model_chunk_id)
# # print(f"recv input_obj {input_obj}")
# _wait_p2p(wait_handles)
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
# first layer
input_obj = input_obj
self.recv_forward_buffer[model_chunk_id].pop(0) # pop none
else:
# other layer
input_obj, wait_handles = self.recv_forward(model_chunk_id)
# print(f"recv input_obj {input_obj}")
_wait_p2p(wait_handles)
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
# Step2: fwd step
output_obj = self.forward_step(
model_chunk=model_chunk,
@ -522,6 +543,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss=accum_loss,
outputs=outputs,
)
# print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}")
# add input and output object for backward b
@ -532,7 +554,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
self.output_tensors_dw[model_chunk_id].append(output_obj)
# Step3: send fwd
send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
# add output to send_fwd_buffer
self.send_forward_buffer[model_chunk_id].append(output_obj)
# send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj)
def schedule_b(
self,
@ -545,17 +569,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# output_obj_grad: Optional[dict],
):
# Step1: recv bwd
# not first stage and chunk 1
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
output_tensor_grad, recv_bwd_handles = None, []
# print(f"recv output_tensor_grad {output_tensor_grad}")
else:
output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
# print(f"recv output_tensor_grad {output_tensor_grad}")
# # not first stage and chunk 1
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# output_tensor_grad, recv_bwd_handles = None, []
# # print(f"recv output_tensor_grad {output_tensor_grad}")
# else:
# output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id)
# # print(f"recv output_tensor_grad {output_tensor_grad}")
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n")
# get input and output object from buffer;
input_obj = self.input_tensors[model_chunk_id].pop()
output_obj = self.output_tensors[model_chunk_id].pop()
input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0)
# save output_tensor_grad for dw
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
@ -565,9 +592,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# we save output_tensor_grad here
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
_wait_p2p(recv_bwd_handles)
# _wait_p2p(recv_bwd_handles)
# print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}")
# Step2: bwd step
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}")
input_object_grad = self.backward_b_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
@ -576,23 +606,23 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj=output_obj,
output_obj_grad=output_tensor_grad,
)
print(f"input_object_grad {input_object_grad}")
# print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}")
# Step3: send bwd
send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad)
# send_bwd_handles = self.send_backward(model_chunk_id=model_chunk_id, input_tensor_grad=input_object_grad)
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
def schedule_w(
self,
scheduled_node,
non_w_pending,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
# optimizer: OptimizerWrapper,
):
# get y & dy from buffer
output_obj = self.output_tensors_dw[model_chunk_id].pop()
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop()
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
self.backward_w_step(
model_chunk=model_chunk,
@ -605,6 +635,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
def run_forward_backward(
self,
model_chunk: Union[ModuleList, Module],
input_obj: Optional[dict],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
@ -615,19 +646,37 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# while we still have schedules_node in self.schedules
while it < len(self.schedules):
scheduled_node = self.schedules[it]
print(f"it {it}; scheduled_node {scheduled_node};")
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
if scheduled_node.type == "RECV_FORWARD":
self.recv_forward()
self.recv_forward(scheduled_node.chunk)
elif scheduled_node.type == "RECV_BACKWARD":
self.recv_backward()
self.recv_backward(scheduled_node.chunk)
elif scheduled_node.type == "SEND_FORWARD":
self.send_forward()
self.send_forward(scheduled_node.chunk)
elif scheduled_node.type == "SEND_BACKWARD":
self.send_backward()
elif scheduled_node.type == "F":
self.schedule_f()
self.send_backward(scheduled_node.chunk)
if scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
input_obj=input_obj,
criterion=criterion,
accum_loss=return_loss,
outputs=return_outputs,
)
elif scheduled_node.type == "B":
self.schedule_b()
self.schedule_b(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
)
elif scheduled_node.type == "W":
self.schedule_w()
self.schedule_w(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
)
it += 1

View File

@ -1176,17 +1176,8 @@ def model_chunk_dx_dw_comm_interleaved(
print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
def run_fwd_bwd(
rank: int,
world_size: int,
port: int,
):
pass
@rerun_if_address_is_in_use()
def test_dx_dw_dist():
spawn(
model_chunk_dx_dw_comm_interleaved,
nprocs=4,

View File

@ -8,6 +8,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
@ -34,6 +35,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
return num_params, num_params_trainable
# Test baseline; An 8 layer MLP do Zerobubble Pipeline on 4 node pp group;
def test_zerobubble_pipeline_base(
rank: int,
world_size: int,
@ -427,18 +429,187 @@ def test_zerobubble_pipeline_base(
assert_close(chunk_3[1].weight.grad, model_base.layers[4].weight.grad)
# Test run_forward_backward with baseline;
def test_run_fwd_bwd_base(
rank: int,
world_size: int,
port: int,
):
# init dist
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
rank = dist.get_rank()
pp_size = world_size
pg_mesh = ProcessGroupMesh(pp_size)
# stage_manager
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size)
# schedule list
zbv_schedule = [
# stage 0
[
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=0, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=0, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0),
],
# stage 1
[
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=1, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=1, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=0),
],
# stage 2
[
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=2, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=2, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0), # Send nothing
],
# stage 3
[
# chunk 0 fwd
ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="F", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=0),
# chunk 1 fwd
ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="F", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=0),
# chunk 1 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="B", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="W", chunk=1, stage=3, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=0),
# chunk 0 bwd
ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="B", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="W", chunk=0, stage=3, minibatch=0),
ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0),
],
]
scheduler = ZeroBubbleVPipeScheduler(
schedule=zbv_schedule[rank],
stage_manager=stage_manager,
num_model_chunks=pp_size,
num_microbatch=1,
overlap_p2p=False,
)
# loss func
def criterion(x, *args, **kwargs):
return (x * x).mean()
# init model and input
num_layers = 8
in_dim = out_dim = 8
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
input0.clone()
deepcopy(model)
if rank == 0:
# layer 0 & 7 to chunk 0 on rank0
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 0 or idx == 7:
local_chunk.append(sub_model)
elif rank == 1:
# layer 1 & 6 to chunk 1 on rank1
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 1 or idx == 6:
local_chunk.append(sub_model)
elif rank == 2:
# layer 2 & 5 to chunk 2 on rank2
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 2 or idx == 5:
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
local_chunk = torch.nn.Sequential().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
local_chunk.append(sub_model)
print(
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
)
torch.cuda.synchronize()
scheduler.run_forward_backward(
model_chunk=local_chunk,
input_obj=input0,
data_iter=None,
criterion=criterion,
optimizer=None,
return_loss=None,
return_outputs=None,
)
# @pytest.mark.dist
# @pytest.mark.parametrize("num_microbatch", [4])
# @pytest.mark.parametrize("batch_size", [4])
# @pytest.mark.parametrize("num_model_chunk", [2])
@rerun_if_address_is_in_use()
def test_pp():
# spawn(
# test_zerobubble_pipeline_base,
# nprocs=4,
# )
spawn(
test_zerobubble_pipeline_base,
test_run_fwd_bwd_base,
nprocs=4,
)
if __name__ == "__main__":
test_pp()