[fix] use tree_flatten replace dict traverse;

pull/6065/head
duanjunwen 2024-09-20 07:18:49 +00:00
parent 26783776f1
commit c6d6ee39bd
1 changed files with 34 additions and 20 deletions

View File

@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.cuda import torch.cuda
from torch.nn import Module, ModuleList from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_flatten, tree_map
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
@ -489,26 +489,38 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# For chunk 0 stage 0, use micro_batch as input_obj_ # For chunk 0 stage 0, use micro_batch as input_obj_
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
for k, v in micro_batch.items(): # for k, v in micro_batch.items():
if v.requires_grad: # if v.requires_grad:
input_obj_.append(micro_batch[k]) # input_obj_.append(micro_batch[k])
output_obj_.append(output_obj[k]) # y # output_obj_.append(output_obj[k]) # y
output_obj_grad_.append(output_obj_grad[k]) # dy # output_obj_grad_.append(output_obj_grad[k]) # dy
input_obj_, _ = tree_flatten(micro_batch)
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
# For loss backward; output_obj is loss; output_obj_grad should be None # For loss backward; output_obj is loss; output_obj_grad should be None
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
assert output_obj_grad is None assert output_obj_grad is None
for k, v in input_obj.items(): # for k, v in input_obj.items():
if v.requires_grad: # if v.requires_grad:
input_obj_.append(input_obj[k]) # input_obj_.append(input_obj[k])
output_obj_.append(output_obj) # LOSS input_obj_, _ = tree_flatten(input_obj)
output_obj_grad_.append(output_obj_grad) # None # output_obj_.append(output_obj) # LOSS
# output_obj_grad_.append(output_obj_grad) # None
output_obj_, _ = tree_flatten(output_obj) # LOSS
output_obj_grad_, _ = tree_flatten(output_obj_grad) # None
# For other chunk stage, use input_obj as input_obj_; # For other chunk stage, use input_obj as input_obj_;
else: else:
for k, v in input_obj.items(): # for k, v in input_obj.items():
if v.requires_grad: # if v.requires_grad:
input_obj_.append(input_obj[k]) # input_obj_.append(input_obj[k])
output_obj_.append(output_obj[k]) # y # output_obj_.append(output_obj[k]) # y
output_obj_grad_.append(output_obj_grad[k]) # dy # output_obj_grad_.append(output_obj_grad[k]) # dy
input_obj_, _ = tree_flatten(input_obj)
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
optimizer.backward_by_grad( optimizer.backward_by_grad(
tensor=output_obj_, tensor=output_obj_,
@ -560,10 +572,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_.append(output_obj) # LOSS output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(None) # None output_obj_grad_.append(None) # None
else: else:
for k, v in output_obj.items(): # for k, v in output_obj.items():
if v.requires_grad: # if v.requires_grad:
output_obj_.append(output_obj[k]) # output_obj_.append(output_obj[k])
output_obj_grad_.append(output_obj_grad[k]) # output_obj_grad_.append(output_obj_grad[k])
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
optimizer.backward_by_grad( optimizer.backward_by_grad(
tensor=output_obj_, tensor=output_obj_,