mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix requir grad position and detach position and input&output local buffer append position;
parent
20503cdfdf
commit
e6e1a97a6d
|
@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
@ -496,29 +495,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
retain_graph=False,
|
||||
)
|
||||
# if model_chunk_id == 0:
|
||||
# optimizer.backward_by_grad(
|
||||
# tensor=output_obj,
|
||||
# grad=output_obj_grad,
|
||||
# inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
# retain_graph=False,
|
||||
# )
|
||||
|
||||
# else:
|
||||
# if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# optimizer.backward_by_grad(
|
||||
# tensor=output_obj,
|
||||
# grad=None,
|
||||
# inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
# retain_graph=False,
|
||||
# )
|
||||
# else:
|
||||
# optimizer.backward_by_grad(
|
||||
# tensor=output_obj,
|
||||
# grad=output_obj_grad,
|
||||
# inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||
# retain_graph=False,
|
||||
# )
|
||||
|
||||
def schedule_f(
|
||||
self,
|
||||
|
@ -557,6 +533,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# not last stage; recv from next
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
input_obj.requires_grad_()
|
||||
|
||||
# Step2: fwd step
|
||||
output_obj = self.forward_step(
|
||||
|
@ -567,21 +544,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
detached_output_obj = output_obj.clone()
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# We should not detach bwd LOSS
|
||||
detached_output_obj = output_obj.clone()
|
||||
else:
|
||||
detached_output_obj = output_obj.clone().detach()
|
||||
|
||||
# Step3: send fwd
|
||||
# add output to send_fwd_buffer
|
||||
if model_chunk_id == 0:
|
||||
# is last stage; send to local_send_forward_buffer
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
detached_output_obj = detached_output_obj.detach()
|
||||
self.local_send_forward_buffer.append(detached_output_obj)
|
||||
else:
|
||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||
else:
|
||||
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.local_send_backward_buffer.append(detached_output_obj)
|
||||
pass
|
||||
else:
|
||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||
|
||||
|
@ -624,7 +605,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
else:
|
||||
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
output_tensor_grad = self.local_send_backward_buffer.pop(0)
|
||||
output_tensor_grad = None
|
||||
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
||||
else:
|
||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||
|
|
|
@ -44,7 +44,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
|||
"test_config",
|
||||
[
|
||||
{
|
||||
"batch_size": 4,
|
||||
"batch_size": 8,
|
||||
"tp_size": 1,
|
||||
"pp_size": 4,
|
||||
"num_microbatches": 4,
|
||||
|
@ -501,7 +501,7 @@ def run_fwd_bwd_iter_input(test_config):
|
|||
"test_config",
|
||||
[
|
||||
{
|
||||
"batch_size": 4,
|
||||
"batch_size": 8,
|
||||
"tp_size": 1,
|
||||
"pp_size": 4,
|
||||
"num_microbatches": 4,
|
||||
|
@ -689,13 +689,13 @@ def run_with_hybridplugin(test_config):
|
|||
"test_config",
|
||||
[
|
||||
{
|
||||
"batch_size": 4,
|
||||
"batch_size": 8,
|
||||
"tp_size": 1,
|
||||
"pp_size": 4,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 1,
|
||||
"precision": "bf16",
|
||||
"num_model_chunk": 4,
|
||||
"num_model_chunk": 2,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue