mirror of https://github.com/hpcaitech/ColossalAI
[fix] rm useless assign and comments;
parent
1b4bb2beeb
commit
283c9ff5d2
|
@ -440,9 +440,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True)
|
||||
else:
|
||||
# commom bwd step
|
||||
# print(f"bwd output_obj {output_obj} output_obj_grad {output_obj_grad} input_obj {input_obj}")
|
||||
# BUG:output_obj_grad is None
|
||||
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; tensor {output_obj};\n grad_tensors {output_obj_grad};\n inputs {input_obj}\n")
|
||||
torch.autograd.backward(
|
||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||
)
|
||||
|
@ -516,7 +514,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
input_obj = input_obj
|
||||
else:
|
||||
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
|
||||
|
||||
else:
|
||||
# is last stage; recv from local
|
||||
if self.stage_manager.is_last_stage(ignore_chunk=True):
|
||||
|
@ -535,8 +532,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
outputs=outputs,
|
||||
)
|
||||
|
||||
# print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}")
|
||||
|
||||
# add input and output object for backward b
|
||||
self.input_tensors[model_chunk_id].append(input_obj)
|
||||
self.output_tensors[model_chunk_id].append(output_obj)
|
||||
|
@ -681,7 +676,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
"""
|
||||
it = 0
|
||||
# while we still have schedules_node in self.schedules
|
||||
# print(f"manger_stage {self.stage_manager.stage} schedule {self.schedules} \n")
|
||||
while it < len(self.schedules):
|
||||
scheduled_node = self.schedules[it]
|
||||
print(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
@ -139,7 +140,7 @@ def test_run_fwd_bwd_base(
|
|||
]
|
||||
|
||||
scheduler = ZeroBubbleVPipeScheduler(
|
||||
schedule=zbv_schedule[rank],
|
||||
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
|
||||
stage_manager=stage_manager,
|
||||
num_model_chunks=pp_size,
|
||||
num_microbatch=1,
|
||||
|
@ -226,7 +227,6 @@ def test_run_fwd_bwd_base(
|
|||
# layer 6
|
||||
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
|
||||
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
|
||||
|
||||
if rank == 2:
|
||||
# layer 2
|
||||
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
|
||||
|
@ -234,7 +234,6 @@ def test_run_fwd_bwd_base(
|
|||
# layer 5
|
||||
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
|
||||
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
|
||||
|
||||
if rank == 3:
|
||||
# layer 3
|
||||
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
|
||||
|
@ -244,7 +243,16 @@ def test_run_fwd_bwd_base(
|
|||
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
||||
|
||||
|
||||
# @pytest.mark.dist
|
||||
# Test iter input & multiple microbatch
|
||||
def test_run_fwd_bwd_iter_input(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
# @pytest.mark.parametrize("num_microbatch", [4])
|
||||
# @pytest.mark.parametrize("batch_size", [4])
|
||||
# @pytest.mark.parametrize("num_model_chunk", [2])
|
||||
|
|
Loading…
Reference in New Issue