mirror of https://github.com/hpcaitech/ColossalAI
[fix] use tree_flatten replace dict traverse;
parent
26783776f1
commit
c6d6ee39bd
|
@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
from torch.nn import Module, ModuleList
|
from torch.nn import Module, ModuleList
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_flatten, tree_map
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
|
@ -489,26 +489,38 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
|
|
||||||
# For chunk 0 stage 0, use micro_batch as input_obj_
|
# For chunk 0 stage 0, use micro_batch as input_obj_
|
||||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
for k, v in micro_batch.items():
|
# for k, v in micro_batch.items():
|
||||||
if v.requires_grad:
|
# if v.requires_grad:
|
||||||
input_obj_.append(micro_batch[k])
|
# input_obj_.append(micro_batch[k])
|
||||||
output_obj_.append(output_obj[k]) # y
|
# output_obj_.append(output_obj[k]) # y
|
||||||
output_obj_grad_.append(output_obj_grad[k]) # dy
|
# output_obj_grad_.append(output_obj_grad[k]) # dy
|
||||||
|
|
||||||
|
input_obj_, _ = tree_flatten(micro_batch)
|
||||||
|
output_obj_, _ = tree_flatten(output_obj) # y
|
||||||
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||||
|
|
||||||
# For loss backward; output_obj is loss; output_obj_grad should be None
|
# For loss backward; output_obj is loss; output_obj_grad should be None
|
||||||
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
assert output_obj_grad is None
|
assert output_obj_grad is None
|
||||||
for k, v in input_obj.items():
|
# for k, v in input_obj.items():
|
||||||
if v.requires_grad:
|
# if v.requires_grad:
|
||||||
input_obj_.append(input_obj[k])
|
# input_obj_.append(input_obj[k])
|
||||||
output_obj_.append(output_obj) # LOSS
|
input_obj_, _ = tree_flatten(input_obj)
|
||||||
output_obj_grad_.append(output_obj_grad) # None
|
# output_obj_.append(output_obj) # LOSS
|
||||||
|
# output_obj_grad_.append(output_obj_grad) # None
|
||||||
|
output_obj_, _ = tree_flatten(output_obj) # LOSS
|
||||||
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # None
|
||||||
|
|
||||||
# For other chunk stage, use input_obj as input_obj_;
|
# For other chunk stage, use input_obj as input_obj_;
|
||||||
else:
|
else:
|
||||||
for k, v in input_obj.items():
|
# for k, v in input_obj.items():
|
||||||
if v.requires_grad:
|
# if v.requires_grad:
|
||||||
input_obj_.append(input_obj[k])
|
# input_obj_.append(input_obj[k])
|
||||||
output_obj_.append(output_obj[k]) # y
|
# output_obj_.append(output_obj[k]) # y
|
||||||
output_obj_grad_.append(output_obj_grad[k]) # dy
|
# output_obj_grad_.append(output_obj_grad[k]) # dy
|
||||||
|
input_obj_, _ = tree_flatten(input_obj)
|
||||||
|
output_obj_, _ = tree_flatten(output_obj) # y
|
||||||
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||||
|
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj_,
|
tensor=output_obj_,
|
||||||
|
@ -560,10 +572,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
output_obj_.append(output_obj) # LOSS
|
output_obj_.append(output_obj) # LOSS
|
||||||
output_obj_grad_.append(None) # None
|
output_obj_grad_.append(None) # None
|
||||||
else:
|
else:
|
||||||
for k, v in output_obj.items():
|
# for k, v in output_obj.items():
|
||||||
if v.requires_grad:
|
# if v.requires_grad:
|
||||||
output_obj_.append(output_obj[k])
|
# output_obj_.append(output_obj[k])
|
||||||
output_obj_grad_.append(output_obj_grad[k])
|
# output_obj_grad_.append(output_obj_grad[k])
|
||||||
|
output_obj_, _ = tree_flatten(output_obj) # y
|
||||||
|
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
|
||||||
|
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj_,
|
tensor=output_obj_,
|
||||||
|
|
Loading…
Reference in New Issue