diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index d6aee7c1e..8fcb2aa56 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.cuda from torch.nn import Module, ModuleList -from torch.utils._pytree import tree_map +from torch.utils._pytree import tree_flatten, tree_map from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper @@ -489,26 +489,38 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # For chunk 0 stage 0, use micro_batch as input_obj_ if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - for k, v in micro_batch.items(): - if v.requires_grad: - input_obj_.append(micro_batch[k]) - output_obj_.append(output_obj[k]) # y - output_obj_grad_.append(output_obj_grad[k]) # dy + # for k, v in micro_batch.items(): + # if v.requires_grad: + # input_obj_.append(micro_batch[k]) + # output_obj_.append(output_obj[k]) # y + # output_obj_grad_.append(output_obj_grad[k]) # dy + + input_obj_, _ = tree_flatten(micro_batch) + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + # For loss backward; output_obj is loss; output_obj_grad should be None elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): assert output_obj_grad is None - for k, v in input_obj.items(): - if v.requires_grad: - input_obj_.append(input_obj[k]) - output_obj_.append(output_obj) # LOSS - output_obj_grad_.append(output_obj_grad) # None + # for k, v in input_obj.items(): + # if v.requires_grad: + # input_obj_.append(input_obj[k]) + input_obj_, _ = tree_flatten(input_obj) + # output_obj_.append(output_obj) # LOSS + # output_obj_grad_.append(output_obj_grad) # None + output_obj_, _ = tree_flatten(output_obj) # LOSS + output_obj_grad_, _ = tree_flatten(output_obj_grad) # None + # For other chunk stage, use input_obj as input_obj_; else: - for k, v in input_obj.items(): - if v.requires_grad: - input_obj_.append(input_obj[k]) - output_obj_.append(output_obj[k]) # y - output_obj_grad_.append(output_obj_grad[k]) # dy + # for k, v in input_obj.items(): + # if v.requires_grad: + # input_obj_.append(input_obj[k]) + # output_obj_.append(output_obj[k]) # y + # output_obj_grad_.append(output_obj_grad[k]) # dy + input_obj_, _ = tree_flatten(input_obj) + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy optimizer.backward_by_grad( tensor=output_obj_, @@ -560,10 +572,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_.append(output_obj) # LOSS output_obj_grad_.append(None) # None else: - for k, v in output_obj.items(): - if v.requires_grad: - output_obj_.append(output_obj[k]) - output_obj_grad_.append(output_obj_grad[k]) + # for k, v in output_obj.items(): + # if v.requires_grad: + # output_obj_.append(output_obj[k]) + # output_obj_grad_.append(output_obj_grad[k]) + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy optimizer.backward_by_grad( tensor=output_obj_,