|
|
|
@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import torch.cuda |
|
|
|
|
import torch.distributed |
|
|
|
|
from torch.nn import Module, ModuleList |
|
|
|
|
from torch.utils._pytree import tree_map |
|
|
|
|
|
|
|
|
@ -496,29 +495,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
inputs=list(model_chunk[model_chunk_id].parameters()), |
|
|
|
|
retain_graph=False, |
|
|
|
|
) |
|
|
|
|
# if model_chunk_id == 0: |
|
|
|
|
# optimizer.backward_by_grad( |
|
|
|
|
# tensor=output_obj, |
|
|
|
|
# grad=output_obj_grad, |
|
|
|
|
# inputs=list(model_chunk[model_chunk_id].parameters()), |
|
|
|
|
# retain_graph=False, |
|
|
|
|
# ) |
|
|
|
|
|
|
|
|
|
# else: |
|
|
|
|
# if self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
# optimizer.backward_by_grad( |
|
|
|
|
# tensor=output_obj, |
|
|
|
|
# grad=None, |
|
|
|
|
# inputs=list(model_chunk[model_chunk_id].parameters()), |
|
|
|
|
# retain_graph=False, |
|
|
|
|
# ) |
|
|
|
|
# else: |
|
|
|
|
# optimizer.backward_by_grad( |
|
|
|
|
# tensor=output_obj, |
|
|
|
|
# grad=output_obj_grad, |
|
|
|
|
# inputs=list(model_chunk[model_chunk_id].parameters()), |
|
|
|
|
# retain_graph=False, |
|
|
|
|
# ) |
|
|
|
|
|
|
|
|
|
def schedule_f( |
|
|
|
|
self, |
|
|
|
@ -557,6 +533,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
# not last stage; recv from next |
|
|
|
|
else: |
|
|
|
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) |
|
|
|
|
input_obj.requires_grad_() |
|
|
|
|
|
|
|
|
|
# Step2: fwd step |
|
|
|
|
output_obj = self.forward_step( |
|
|
|
@ -567,21 +544,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
accum_loss=accum_loss, |
|
|
|
|
outputs=outputs, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
detached_output_obj = output_obj.clone() |
|
|
|
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
# We should not detach bwd LOSS |
|
|
|
|
detached_output_obj = output_obj.clone() |
|
|
|
|
else: |
|
|
|
|
detached_output_obj = output_obj.clone().detach() |
|
|
|
|
|
|
|
|
|
# Step3: send fwd |
|
|
|
|
# add output to send_fwd_buffer |
|
|
|
|
if model_chunk_id == 0: |
|
|
|
|
# is last stage; send to local_send_forward_buffer |
|
|
|
|
if self.stage_manager.is_last_stage(ignore_chunk=True): |
|
|
|
|
detached_output_obj = detached_output_obj.detach() |
|
|
|
|
self.local_send_forward_buffer.append(detached_output_obj) |
|
|
|
|
else: |
|
|
|
|
self.send_forward_buffer[model_chunk_id].append(detached_output_obj) |
|
|
|
|
else: |
|
|
|
|
# is first stage; end of fwd; append LOSS to local_send_backward_buffer |
|
|
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
self.local_send_backward_buffer.append(detached_output_obj) |
|
|
|
|
pass |
|
|
|
|
else: |
|
|
|
|
self.send_forward_buffer[model_chunk_id].append(detached_output_obj) |
|
|
|
|
|
|
|
|
@ -624,7 +605,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
else: |
|
|
|
|
# chunk1, is first stage; recv LOSS from local send bwd buffer |
|
|
|
|
if self.stage_manager.is_first_stage(ignore_chunk=True): |
|
|
|
|
output_tensor_grad = self.local_send_backward_buffer.pop(0) |
|
|
|
|
output_tensor_grad = None |
|
|
|
|
# chunk1, not first stage; recv output_grad from recv_backward_buffer |
|
|
|
|
else: |
|
|
|
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) |
|
|
|
|