Browse Source

[fix] fix requir grad position and detach position and input&output local buffer append position;

pull/6034/head
duanjunwen 3 months ago
parent
commit
e6e1a97a6d
  1. 37
      colossalai/pipeline/schedule/zero_bubble_pp.py
  2. 8
      tests/test_pipeline/test_schedule/test_zerobubble_pp.py

37
colossalai/pipeline/schedule/zero_bubble_pp.py

@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
@ -496,29 +495,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
# if model_chunk_id == 0:
# optimizer.backward_by_grad(
# tensor=output_obj,
# grad=output_obj_grad,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# retain_graph=False,
# )
# else:
# if self.stage_manager.is_first_stage(ignore_chunk=True):
# optimizer.backward_by_grad(
# tensor=output_obj,
# grad=None,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# retain_graph=False,
# )
# else:
# optimizer.backward_by_grad(
# tensor=output_obj,
# grad=output_obj_grad,
# inputs=list(model_chunk[model_chunk_id].parameters()),
# retain_graph=False,
# )
def schedule_f(
self,
@ -557,6 +533,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# not last stage; recv from next
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
input_obj.requires_grad_()
# Step2: fwd step
output_obj = self.forward_step(
@ -567,21 +544,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss=accum_loss,
outputs=outputs,
)
detached_output_obj = output_obj.clone()
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not detach bwd LOSS
detached_output_obj = output_obj.clone()
else:
detached_output_obj = output_obj.clone().detach()
# Step3: send fwd
# add output to send_fwd_buffer
if model_chunk_id == 0:
# is last stage; send to local_send_forward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
detached_output_obj = detached_output_obj.detach()
self.local_send_forward_buffer.append(detached_output_obj)
else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
else:
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
if self.stage_manager.is_first_stage(ignore_chunk=True):
self.local_send_backward_buffer.append(detached_output_obj)
pass
else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
@ -624,7 +605,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
else:
# chunk1, is first stage; recv LOSS from local send bwd buffer
if self.stage_manager.is_first_stage(ignore_chunk=True):
output_tensor_grad = self.local_send_backward_buffer.pop(0)
output_tensor_grad = None
# chunk1, not first stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)

8
tests/test_pipeline/test_schedule/test_zerobubble_pp.py

@ -44,7 +44,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
"test_config",
[
{
"batch_size": 4,
"batch_size": 8,
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 4,
@ -501,7 +501,7 @@ def run_fwd_bwd_iter_input(test_config):
"test_config",
[
{
"batch_size": 4,
"batch_size": 8,
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 4,
@ -689,13 +689,13 @@ def run_with_hybridplugin(test_config):
"test_config",
[
{
"batch_size": 4,
"batch_size": 8,
"tp_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunk": 4,
"num_model_chunk": 2,
},
],
)

Loading…
Cancel
Save