mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix requir grad position and detach position and input&output local buffer append position;
parent
20503cdfdf
commit
e6e1a97a6d
|
@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
import torch.distributed
|
|
||||||
from torch.nn import Module, ModuleList
|
from torch.nn import Module, ModuleList
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
@ -496,29 +495,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
inputs=list(model_chunk[model_chunk_id].parameters()),
|
inputs=list(model_chunk[model_chunk_id].parameters()),
|
||||||
retain_graph=False,
|
retain_graph=False,
|
||||||
)
|
)
|
||||||
# if model_chunk_id == 0:
|
|
||||||
# optimizer.backward_by_grad(
|
|
||||||
# tensor=output_obj,
|
|
||||||
# grad=output_obj_grad,
|
|
||||||
# inputs=list(model_chunk[model_chunk_id].parameters()),
|
|
||||||
# retain_graph=False,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# else:
|
|
||||||
# if self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
||||||
# optimizer.backward_by_grad(
|
|
||||||
# tensor=output_obj,
|
|
||||||
# grad=None,
|
|
||||||
# inputs=list(model_chunk[model_chunk_id].parameters()),
|
|
||||||
# retain_graph=False,
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# optimizer.backward_by_grad(
|
|
||||||
# tensor=output_obj,
|
|
||||||
# grad=output_obj_grad,
|
|
||||||
# inputs=list(model_chunk[model_chunk_id].parameters()),
|
|
||||||
# retain_graph=False,
|
|
||||||
# )
|
|
||||||
|
|
||||||
def schedule_f(
|
def schedule_f(
|
||||||
self,
|
self,
|
||||||
|
@ -557,6 +533,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# not last stage; recv from next
|
# not last stage; recv from next
|
||||||
else:
|
else:
|
||||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||||
|
input_obj.requires_grad_()
|
||||||
|
|
||||||
# Step2: fwd step
|
# Step2: fwd step
|
||||||
output_obj = self.forward_step(
|
output_obj = self.forward_step(
|
||||||
|
@ -567,21 +544,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
accum_loss=accum_loss,
|
accum_loss=accum_loss,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
|
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
detached_output_obj = output_obj.clone()
|
# We should not detach bwd LOSS
|
||||||
|
detached_output_obj = output_obj.clone()
|
||||||
|
else:
|
||||||
|
detached_output_obj = output_obj.clone().detach()
|
||||||
|
|
||||||
# Step3: send fwd
|
# Step3: send fwd
|
||||||
# add output to send_fwd_buffer
|
# add output to send_fwd_buffer
|
||||||
if model_chunk_id == 0:
|
if model_chunk_id == 0:
|
||||||
# is last stage; send to local_send_forward_buffer
|
# is last stage; send to local_send_forward_buffer
|
||||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||||
|
detached_output_obj = detached_output_obj.detach()
|
||||||
self.local_send_forward_buffer.append(detached_output_obj)
|
self.local_send_forward_buffer.append(detached_output_obj)
|
||||||
else:
|
else:
|
||||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||||
else:
|
else:
|
||||||
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
|
# is first stage; end of fwd; append LOSS to local_send_backward_buffer
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
self.local_send_backward_buffer.append(detached_output_obj)
|
pass
|
||||||
else:
|
else:
|
||||||
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
|
||||||
|
|
||||||
|
@ -624,7 +605,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
else:
|
else:
|
||||||
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
# chunk1, is first stage; recv LOSS from local send bwd buffer
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
output_tensor_grad = self.local_send_backward_buffer.pop(0)
|
output_tensor_grad = None
|
||||||
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
# chunk1, not first stage; recv output_grad from recv_backward_buffer
|
||||||
else:
|
else:
|
||||||
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
|
||||||
|
|
|
@ -44,7 +44,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"batch_size": 4,
|
"batch_size": 8,
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 4,
|
"pp_size": 4,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
|
@ -501,7 +501,7 @@ def run_fwd_bwd_iter_input(test_config):
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"batch_size": 4,
|
"batch_size": 8,
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 4,
|
"pp_size": 4,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
|
@ -689,13 +689,13 @@ def run_with_hybridplugin(test_config):
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"batch_size": 4,
|
"batch_size": 8,
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 4,
|
"pp_size": 4,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"num_model_chunk": 4,
|
"num_model_chunk": 2,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue