|
|
|
@ -8,7 +8,7 @@ from torch.utils._pytree import tree_flatten, tree_map
|
|
|
|
|
|
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
|
|
|
from colossalai.interface import OptimizerWrapper
|
|
|
|
|
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
|
|
|
|
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
|
|
|
|
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
|
|
|
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
|
|
|
@ -62,11 +62,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
self.do_post_validation = False
|
|
|
|
|
|
|
|
|
|
# P2PMeta cache
|
|
|
|
|
# self.enable_metadata_cache = enable_metadata_cache
|
|
|
|
|
# self.send_tensor_metadata = True
|
|
|
|
|
# self.send_grad_metadata = True
|
|
|
|
|
# self.tensor_metadata_recv = None
|
|
|
|
|
# self.grad_metadata_recv = None
|
|
|
|
|
self.enable_metadata_cache = enable_metadata_cache
|
|
|
|
|
self.send_tensor_metadata = True
|
|
|
|
|
self.send_grad_metadata = True
|
|
|
|
|
self.tensor_metadata_recv = None
|
|
|
|
|
self.grad_metadata_recv = None
|
|
|
|
|
|
|
|
|
|
# P2P communication
|
|
|
|
|
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
|
|
|
@ -105,8 +105,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# dy buffer for local send bwd
|
|
|
|
|
self.local_send_backward_buffer = []
|
|
|
|
|
|
|
|
|
|
# wait pp buffer
|
|
|
|
|
self.send_handles = []
|
|
|
|
|
|
|
|
|
|
def assert_buffer_empty(self):
|
|
|
|
|
# assert buuffer is empty at end
|
|
|
|
|
# assert buffer is empty at end
|
|
|
|
|
assert len(self.input_tensors[0]) == 0
|
|
|
|
|
assert len(self.input_tensors[1]) == 0
|
|
|
|
|
assert len(self.output_tensors[0]) == 0
|
|
|
|
@ -125,6 +128,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
assert len(self.recv_backward_buffer[1]) == 0
|
|
|
|
|
assert len(self.local_send_forward_buffer) == 0
|
|
|
|
|
assert len(self.local_send_backward_buffer) == 0
|
|
|
|
|
# assert len(self.send_handles) == 0
|
|
|
|
|
|
|
|
|
|
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
|
|
|
|
"""Load a batch from data iterator.
|
|
|
|
@ -221,7 +225,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
|
|
|
|
|
#################
|
|
|
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
return None, []
|
|
|
|
|
# return None, []
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
################
|
|
|
|
|
# chunk = 0 & not is_first_stage
|
|
|
|
@ -229,9 +234,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
#################
|
|
|
|
|
else:
|
|
|
|
|
prev_rank = self.stage_manager.get_prev_rank()
|
|
|
|
|
input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank)
|
|
|
|
|
input_tensor, wait_handles = self.comm.recv_forward(
|
|
|
|
|
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv
|
|
|
|
|
)
|
|
|
|
|
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
|
|
|
|
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
|
|
|
|
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
|
|
|
|
return input_tensor, wait_handles
|
|
|
|
|
# return input_tensor, wait_handles
|
|
|
|
|
return wait_handles
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
################
|
|
|
|
@ -239,7 +249,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# do nothing; cause u get y from local_send_forward_buffer in schedule f
|
|
|
|
|
################
|
|
|
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
|
|
|
|
return None, []
|
|
|
|
|
# return None, []
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
################
|
|
|
|
|
# chunk = 1 & not is_last_stage
|
|
|
|
@ -247,9 +258,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
################
|
|
|
|
|
else:
|
|
|
|
|
next_rank = self.stage_manager.get_next_rank()
|
|
|
|
|
input_tensor, wait_handles = self.comm.recv_forward(next_rank)
|
|
|
|
|
input_tensor, wait_handles = self.comm.recv_forward(
|
|
|
|
|
next_rank, metadata_recv=self.tensor_metadata_recv
|
|
|
|
|
)
|
|
|
|
|
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
|
|
|
|
|
self.tensor_metadata_recv = create_send_metadata(input_tensor)
|
|
|
|
|
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
|
|
|
|
return input_tensor, wait_handles
|
|
|
|
|
# return input_tensor, wait_handles
|
|
|
|
|
return wait_handles
|
|
|
|
|
|
|
|
|
|
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]:
|
|
|
|
|
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
|
|
|
@ -271,7 +287,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# do nothing; Already get dy from local_send_backward_buffer in schedule b
|
|
|
|
|
################
|
|
|
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
|
|
|
|
return None, []
|
|
|
|
|
# return None, []
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
################
|
|
|
|
|
# chunk = 0 & not is_last_stage
|
|
|
|
@ -279,9 +296,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
################
|
|
|
|
|
else:
|
|
|
|
|
next_rank = self.stage_manager.get_next_rank()
|
|
|
|
|
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank)
|
|
|
|
|
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
|
|
|
|
next_rank, metadata_recv=self.grad_metadata_recv
|
|
|
|
|
)
|
|
|
|
|
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
|
|
|
|
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
|
|
|
|
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
|
|
|
|
return output_tensor_grad, wait_handles
|
|
|
|
|
# return output_tensor_grad, wait_handles
|
|
|
|
|
return wait_handles
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
# bwd chunk1 is left V;
|
|
|
|
@ -290,7 +312,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# do nothing; get loss from local
|
|
|
|
|
################
|
|
|
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
return None, []
|
|
|
|
|
# return None, []
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
################
|
|
|
|
|
# chunk = 1 & not first stage
|
|
|
|
@ -298,9 +321,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
################
|
|
|
|
|
else:
|
|
|
|
|
prev_rank = self.stage_manager.get_prev_rank()
|
|
|
|
|
output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank)
|
|
|
|
|
output_tensor_grad, wait_handles = self.comm.recv_backward(
|
|
|
|
|
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv
|
|
|
|
|
)
|
|
|
|
|
if self.enable_metadata_cache and self.grad_metadata_recv is None:
|
|
|
|
|
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
|
|
|
|
|
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
|
|
|
|
return output_tensor_grad, wait_handles
|
|
|
|
|
# return output_tensor_grad, wait_handles
|
|
|
|
|
return wait_handles
|
|
|
|
|
|
|
|
|
|
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
|
|
|
|
"""Sends the input tensor to the next stage in pipeline.
|
|
|
|
@ -330,7 +358,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
else:
|
|
|
|
|
next_rank = self.stage_manager.get_next_rank()
|
|
|
|
|
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
|
|
|
|
send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank)
|
|
|
|
|
send_handles = self.comm.send_forward(
|
|
|
|
|
output_object=output_tensor, next_rank=next_rank, send_metadata=self.send_tensor_metadata
|
|
|
|
|
)
|
|
|
|
|
self.send_tensor_metadata = not self.enable_metadata_cache
|
|
|
|
|
return send_handles
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
@ -348,7 +379,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
else:
|
|
|
|
|
prev_rank = self.stage_manager.get_prev_rank()
|
|
|
|
|
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
|
|
|
|
|
send_handles = self.comm.send_forward(output_tensor, prev_rank)
|
|
|
|
|
send_handles = self.comm.send_forward(
|
|
|
|
|
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata
|
|
|
|
|
)
|
|
|
|
|
self.send_tensor_metadata = not self.enable_metadata_cache
|
|
|
|
|
return send_handles
|
|
|
|
|
|
|
|
|
|
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
|
|
|
|
@ -380,7 +414,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
else:
|
|
|
|
|
prev_rank = self.stage_manager.get_prev_rank()
|
|
|
|
|
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
|
|
|
|
send_handles = self.comm.send_backward(input_tensor_grad, prev_rank)
|
|
|
|
|
send_handles = self.comm.send_backward(
|
|
|
|
|
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
|
|
|
|
|
)
|
|
|
|
|
self.send_grad_metadata = not self.enable_metadata_cache
|
|
|
|
|
return send_handles
|
|
|
|
|
|
|
|
|
|
# bwd chunk1 is left V;
|
|
|
|
@ -399,7 +436,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
else:
|
|
|
|
|
next_rank = self.stage_manager.get_next_rank()
|
|
|
|
|
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
|
|
|
|
|
send_handles = self.comm.send_backward(input_tensor_grad, next_rank)
|
|
|
|
|
send_handles = self.comm.send_backward(
|
|
|
|
|
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata
|
|
|
|
|
)
|
|
|
|
|
self.send_grad_metadata = not self.enable_metadata_cache
|
|
|
|
|
return send_handles
|
|
|
|
|
|
|
|
|
|
def forward_step(
|
|
|
|
@ -479,11 +519,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
output_obj_grad_ = []
|
|
|
|
|
|
|
|
|
|
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
|
|
|
|
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
return None
|
|
|
|
|
# if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
# return None
|
|
|
|
|
|
|
|
|
|
# For loss backward; output_obj is loss; output_obj_grad should be None
|
|
|
|
|
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
|
|
|
assert output_obj_grad is None
|
|
|
|
|
input_obj_, _ = tree_flatten(input_obj)
|
|
|
|
|
output_obj_.append(output_obj) # LOSS
|
|
|
|
@ -510,7 +550,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
tensor=output_obj_,
|
|
|
|
|
grad=output_obj_grad_,
|
|
|
|
|
# inputs=input_obj_,
|
|
|
|
|
# retain_graph=True,
|
|
|
|
|
retain_graph=False,
|
|
|
|
|
)
|
|
|
|
|
# Format output_obj_grad
|
|
|
|
|
input_obj_grad = dict()
|
|
|
|
@ -712,6 +752,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# else:
|
|
|
|
|
# # we save output_tensor_grad here
|
|
|
|
|
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
|
|
|
|
# the_output_obj_grad = []
|
|
|
|
|
# if isinstance(output_obj, dict):
|
|
|
|
|
# for (k, v) in output_obj.items():
|
|
|
|
|
# the_output_obj_grad.append(v.requires_grad)
|
|
|
|
|
# else:
|
|
|
|
|
# the_output_obj_grad.append(output_obj.requires_grad)
|
|
|
|
|
|
|
|
|
|
input_object_grad = self.backward_b_step(
|
|
|
|
|
model_chunk=model_chunk,
|
|
|
|
@ -844,7 +890,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
|
|
|
|
# communication
|
|
|
|
|
communication_func = self.communication_map[scheduled_node.type]
|
|
|
|
|
communication_func(scheduled_node.chunk)
|
|
|
|
|
wait_handle = communication_func(scheduled_node.chunk)
|
|
|
|
|
self.send_handles.append(wait_handle)
|
|
|
|
|
elif scheduled_node.type == "F":
|
|
|
|
|
self.schedule_f(
|
|
|
|
|
scheduled_node=scheduled_node,
|
|
|
|
@ -868,6 +915,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
model_chunk_id=scheduled_node.chunk,
|
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
)
|
|
|
|
|
for h in self.send_handles:
|
|
|
|
|
for hh in h:
|
|
|
|
|
hh.wait()
|
|
|
|
|
|
|
|
|
|
# return loss & output
|
|
|
|
|
if outputs is not None:
|
|
|
|
@ -907,5 +957,4 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.assert_buffer_empty()
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|