[fix] fix requir grad position and detach position and input&output local buffer append position;

pull/6034/head
duanjunwen 2024-09-04 03:31:08 +00:00
parent 20503cdfdf
commit e6e1a97a6d
2 changed files with 13 additions and 32 deletions

View File

@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.cuda import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
@ -496,29 +495,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
inputs=list(model_chunk[model_chunk_id].parameters()), inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False, retain_graph=False,
) )
# if model_chunk_id == 0:
# optimizer.backward_by_grad(
# tensor=output_obj,
# grad=output_obj_grad,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# retain_graph=False,
# )
# else:
# if self.stage_manager.is_first_stage(ignore_chunk=True):
# optimizer.backward_by_grad(
# tensor=output_obj,
# grad=None,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# retain_graph=False,
# )
# else:
# optimizer.backward_by_grad(
# tensor=output_obj,
# grad=output_obj_grad,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# retain_graph=False,
# )
def schedule_f( def schedule_f(
self, self,
@ -557,6 +533,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# not last stage; recv from next # not last stage; recv from next
else: else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
input_obj.requires_grad_()
# Step2: fwd step # Step2: fwd step
output_obj = self.forward_step( output_obj = self.forward_step(
@ -567,21 +544,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss=accum_loss, accum_loss=accum_loss,
outputs=outputs, outputs=outputs,
) )
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
detached_output_obj = output_obj.clone() # We should not detach bwd LOSS
detached_output_obj = output_obj.clone()
else:
detached_output_obj = output_obj.clone().detach()
# Step3: send fwd # Step3: send fwd
# add output to send_fwd_buffer # add output to send_fwd_buffer
if model_chunk_id == 0: if model_chunk_id == 0:
# is last stage; send to local_send_forward_buffer # is last stage; send to local_send_forward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True): if self.stage_manager.is_last_stage(ignore_chunk=True):
detached_output_obj = detached_output_obj.detach()
self.local_send_forward_buffer.append(detached_output_obj) self.local_send_forward_buffer.append(detached_output_obj)
else: else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj) self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
else: else:
# is first stage; end of fwd; append LOSS to local_send_backward_buffer # is first stage; end of fwd; append LOSS to local_send_backward_buffer
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
self.local_send_backward_buffer.append(detached_output_obj) pass
else: else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj) self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
@ -624,7 +605,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else: else:
# chunk1, is first stage; recv LOSS from local send bwd buffer # chunk1, is first stage; recv LOSS from local send bwd buffer
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
output_tensor_grad = self.local_send_backward_buffer.pop(0) output_tensor_grad = None
# chunk1, not first stage; recv output_grad from recv_backward_buffer # chunk1, not first stage; recv output_grad from recv_backward_buffer
else: else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)

View File

@ -44,7 +44,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
"test_config", "test_config",
[ [
{ {
"batch_size": 4, "batch_size": 8,
"tp_size": 1, "tp_size": 1,
"pp_size": 4, "pp_size": 4,
"num_microbatches": 4, "num_microbatches": 4,
@ -501,7 +501,7 @@ def run_fwd_bwd_iter_input(test_config):
"test_config", "test_config",
[ [
{ {
"batch_size": 4, "batch_size": 8,
"tp_size": 1, "tp_size": 1,
"pp_size": 4, "pp_size": 4,
"num_microbatches": 4, "num_microbatches": 4,
@ -689,13 +689,13 @@ def run_with_hybridplugin(test_config):
"test_config", "test_config",
[ [
{ {
"batch_size": 4, "batch_size": 8,
"tp_size": 1, "tp_size": 1,
"pp_size": 4, "pp_size": 4,
"num_microbatches": 4, "num_microbatches": 4,
"zero_stage": 1, "zero_stage": 1,
"precision": "bf16", "precision": "bf16",
"num_model_chunk": 4, "num_model_chunk": 2,
}, },
], ],
) )