[fix] rm useless assign and comments;

pull/6034/head
duanjunwen 2024-08-27 07:31:58 +00:00
parent 1b4bb2beeb
commit 283c9ff5d2
2 changed files with 12 additions and 10 deletions

View File

@ -440,9 +440,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True)
else:
# commom bwd step
# print(f"bwd output_obj {output_obj} output_obj_grad {output_obj_grad} input_obj {input_obj}")
# BUG:output_obj_grad is None
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; tensor {output_obj};\n grad_tensors {output_obj_grad};\n inputs {input_obj}\n")
torch.autograd.backward(
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
)
@ -516,7 +514,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
input_obj = input_obj
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
else:
# is last stage; recv from local
if self.stage_manager.is_last_stage(ignore_chunk=True):
@ -535,8 +532,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs=outputs,
)
# print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}")
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
self.output_tensors[model_chunk_id].append(output_obj)
@ -681,7 +676,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
"""
it = 0
# while we still have schedules_node in self.schedules
# print(f"manger_stage {self.stage_manager.stage} schedule {self.schedules} \n")
while it < len(self.schedules):
scheduled_node = self.schedules[it]
print(

View File

@ -1,6 +1,7 @@
from copy import deepcopy
from typing import Tuple
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
@ -139,7 +140,7 @@ def test_run_fwd_bwd_base(
]
scheduler = ZeroBubbleVPipeScheduler(
schedule=zbv_schedule[rank],
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
stage_manager=stage_manager,
num_model_chunks=pp_size,
num_microbatch=1,
@ -226,7 +227,6 @@ def test_run_fwd_bwd_base(
# layer 6
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
if rank == 2:
# layer 2
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
@ -234,7 +234,6 @@ def test_run_fwd_bwd_base(
# layer 5
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
if rank == 3:
# layer 3
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
@ -244,7 +243,16 @@ def test_run_fwd_bwd_base(
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
# @pytest.mark.dist
# Test iter input & multiple microbatch
def test_run_fwd_bwd_iter_input(
rank: int,
world_size: int,
port: int,
):
pass
@pytest.mark.dist
# @pytest.mark.parametrize("num_microbatch", [4])
# @pytest.mark.parametrize("batch_size", [4])
# @pytest.mark.parametrize("num_model_chunk", [2])