[fix] fix cross_bwdB_bwdW

pull/6229/head
duanjunwen 2025-02-28 17:32:53 +08:00
parent 6977bf5365
commit 59819ae4ae
1 changed files with 136 additions and 135 deletions

View File

@ -99,7 +99,7 @@ class DualPipeGraph(object):
else:
stage_pipe_temp.append(node)
stage_pipe = stage_pipe_temp[::-1] # node from last fully B to ...
# print(f"stage_pipe {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in stage_pipe]}")
if chunk == 0:
# get first d
for node in stage_pipe:
@ -1201,12 +1201,107 @@ class DualPipeGraph(object):
########### Pipe_Stage 3.2 ###########
def cross_bwdB_bwdW(pipeline_schedule: List[List[ScheduledNode]]):
for stage in range(0, self.n_stage // 2):
first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=0)
# print(f"stage {stage} Up first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ")
u_queue_w, u_queue_b, d_queue_w = [], [], []
### 1.Get W nodes, then merge up/down W nodes ###
# get up W nodes: [first_u: mbs//2]
# for stage in range(0, self.n_stage // 2):
# first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=0)
# # print(f"stage {stage} Up first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ")
# u_queue_w, u_queue_b, d_queue_w = [], [], []
# ### 1.Get W nodes, then merge up/down W nodes ###
# # get up W nodes: [first_u: mbs//2]
# for _ in range(first_u, self.n_micro // 2):
# curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
# u_queue_w.append(
# ScheduledNode(
# type="W",
# chunk=0,
# stage=stage,
# minibatch=_,
# start_time=curr_time,
# completion_time=curr_time + self.one_time_unit,
# )
# )
# curr_time += self.one_time_unit
# # get down W nodes: [first_d: mbs//2] Bwd W to W Queue
# for _ in range(first_d, self.n_micro // 2):
# curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
# d_queue_w.append(
# ScheduledNode(
# type="W",
# chunk=1,
# stage=stage,
# minibatch=_,
# start_time=curr_time,
# completion_time=curr_time + self.one_time_unit,
# )
# )
# curr_time += self.one_time_unit
# ### 2.Get B nodes, then cross with W ###
# for _ in range(last_u, self.n_micro // 2):
# curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
# u_queue_b.append(
# ScheduledNode(
# type="B",
# chunk=0,
# stage=stage,
# minibatch=_ + 1,
# start_time=curr_time,
# completion_time=curr_time + self.one_time_unit,
# )
# )
# curr_time += self.one_time_unit
# # if stage % 2 == 0: u_queue_w first, then d_queue_w
# if stage % 2 == 0:
# w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w)
# wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b)
# # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
# cut_idx = len(wb_nodes)
# for _ in range(len(wb_nodes)):
# if (
# wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
# and wb_nodes[_].type == "B"
# and wb_nodes[_].chunk == 0
# ):
# cut_idx = _
# break
# wb_nodes = wb_nodes[: cut_idx + 1]
# # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
# # else: d_queue_w first, then u_queue_w
# else:
# w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w)
# wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b)
# # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
# cut_idx = len(wb_nodes)
# for _ in range(len(wb_nodes)):
# if (
# wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
# and wb_nodes[_].type == "B"
# and wb_nodes[_].chunk == 0
# ):
# cut_idx = _
# break
# wb_nodes = wb_nodes[: cut_idx + 1]
# # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
for stage in range(self.n_stage // 2, self.n_stage):
first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=1)
print(f"stage {stage} Down first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ")
d_queue_w, d_queue_b, u_queue_w = [], [], []
### 1.Get W nodes, then merge down/up W nodes ###
# get down W nodes: [last_d: mbs] chunk 1
for _ in range(last_d, self.n_micro // 2):
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
d_queue_w.append(
ScheduledNode(
type="W",
chunk=1,
stage=stage,
minibatch=_,
start_time=curr_time,
completion_time=curr_time + self.one_time_unit,
)
)
curr_time += self.one_time_unit
# print(f"stage {stage} d_queue_w {[_.minibatch for _ in d_queue_w]}")
# get up W nodes: [first_u: mbs//2] chunk 0
for _ in range(first_u, self.n_micro // 2):
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
u_queue_w.append(
@ -1220,27 +1315,13 @@ class DualPipeGraph(object):
)
)
curr_time += self.one_time_unit
# get down W nodes: [first_d: mbs//2] Bwd W to W Queue
for _ in range(first_d, self.n_micro // 2):
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
d_queue_w.append(
ScheduledNode(
type="W",
chunk=1,
stage=stage,
minibatch=_,
start_time=curr_time,
completion_time=curr_time + self.one_time_unit,
)
)
curr_time += self.one_time_unit
### 2.Get B nodes, then cross with W ###
for _ in range(last_u, self.n_micro // 2):
for _ in range(last_d, self.n_micro // 2):
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
u_queue_b.append(
d_queue_b.append(
ScheduledNode(
type="B",
chunk=0,
chunk=1,
stage=stage,
minibatch=_ + 1,
start_time=curr_time,
@ -1248,121 +1329,41 @@ class DualPipeGraph(object):
)
)
curr_time += self.one_time_unit
# if stage % 2 == 0: u_queue_w first, then d_queue_w
if stage % 2 == 0:
w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w)
wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b)
# clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
cut_idx = len(wb_nodes)
for _ in range(len(wb_nodes)):
if (
wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
and wb_nodes[_].type == "B"
and wb_nodes[_].chunk == 0
):
cut_idx = _
break
wb_nodes = wb_nodes[: cut_idx + 1]
# print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
# else: d_queue_w first, then u_queue_w
else:
w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w)
wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b)
# clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
cut_idx = len(wb_nodes)
for _ in range(len(wb_nodes)):
if (
wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
and wb_nodes[_].type == "B"
and wb_nodes[_].chunk == 0
):
cut_idx = _
break
wb_nodes = wb_nodes[: cut_idx + 1]
# print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
for stage in range(self.n_stage // 2, self.n_stage):
first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=1)
print(f"stage {stage} Down first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ")
d_queue_w, d_queue_b, u_queue_w = [], [], []
### 1.Get W nodes, then merge down/up W nodes ###
# get down W nodes: [first_d: mbs//2] chunk 1
for _ in range(self.n_micro // 2, first_d):
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
d_queue_w.append(
ScheduledNode(
type="W",
chunk=1,
stage=stage,
minibatch=_,
start_time=curr_time,
completion_time=curr_time + self.one_time_unit,
)
)
curr_time += self.one_time_unit
# get up W nodes: [first_u: mbs//2] chunk 0
for _ in range(self.n_micro // 2, first_u):
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
d_queue_w.append(
ScheduledNode(
type="W",
chunk=0,
stage=stage,
minibatch=_,
start_time=curr_time,
completion_time=curr_time + self.one_time_unit,
)
)
curr_time += self.one_time_unit
### 2.Get B nodes, then cross with W ###
for _ in range(self.n_micro // 2, last_d):
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
u_queue_b.append(
ScheduledNode(
type="B",
chunk=1,
stage=stage,
minibatch=_ + 1,
start_time=curr_time,
completion_time=curr_time + self.one_time_unit,
)
)
curr_time += self.one_time_unit
print(f"stage {stage} d_queue_w {[_.minibatch for _ in d_queue_w]}")
# print(
# f"stage {stage} d_queue_w {[_.minibatch for _ in d_queue_w]} d_queue_b {[_.minibatch for _ in d_queue_b]} u_queue_w {[_.minibatch for _ in u_queue_w]}"
# )
if stage % 2 == 0:
w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w)
wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b)
# clean w nodes, let it stop at mbs // 2 - 1, chunk 1, type 'B'
cut_idx = len(wb_nodes)
for _ in range(len(wb_nodes)):
if (
wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
and wb_nodes[_].type == "B"
and wb_nodes[_].chunk == 1
):
cut_idx = _
break
wb_nodes = wb_nodes[: cut_idx + 1]
# print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
# else: d_queue_w first, then u_queue_w
else:
w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w)
wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b)
# clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
cut_idx = len(wb_nodes)
for _ in range(len(wb_nodes)):
if (
wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
and wb_nodes[_].type == "B"
and wb_nodes[_].chunk == 1
):
cut_idx = _
break
wb_nodes = wb_nodes[: cut_idx + 1]
# print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
# if stage % 2 == 0:
# w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w)
# wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b)
# # clean w nodes, let it stop at mbs // 2 - 1, chunk 1, type 'B'
# cut_idx = len(wb_nodes)
# for _ in range(len(wb_nodes)):
# if (
# wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
# and wb_nodes[_].type == "B"
# and wb_nodes[_].chunk == 1
# ):
# cut_idx = _
# break
# wb_nodes = wb_nodes[: cut_idx + 1]
# # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
# # else: d_queue_w first, then u_queue_w
# else:
# w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w)
# wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b)
# # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
# cut_idx = len(wb_nodes)
# for _ in range(len(wb_nodes)):
# if (
# wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
# and wb_nodes[_].type == "B"
# and wb_nodes[_].chunk == 1
# ):
# cut_idx = _
# break
# wb_nodes = wb_nodes[: cut_idx + 1]
# # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
########### Pipe_Stage 3.3 ###########
def bwdW_step(pipeline_schedule: List[List[ScheduledNode]]):