mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix cross_bwdB_bwdW
parent
6977bf5365
commit
59819ae4ae
|
@ -99,7 +99,7 @@ class DualPipeGraph(object):
|
|||
else:
|
||||
stage_pipe_temp.append(node)
|
||||
stage_pipe = stage_pipe_temp[::-1] # node from last fully B to ...
|
||||
|
||||
# print(f"stage_pipe {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in stage_pipe]}")
|
||||
if chunk == 0:
|
||||
# get first d
|
||||
for node in stage_pipe:
|
||||
|
@ -1201,12 +1201,107 @@ class DualPipeGraph(object):
|
|||
|
||||
########### Pipe_Stage 3.2 ###########
|
||||
def cross_bwdB_bwdW(pipeline_schedule: List[List[ScheduledNode]]):
|
||||
for stage in range(0, self.n_stage // 2):
|
||||
first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=0)
|
||||
# print(f"stage {stage} Up first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ")
|
||||
u_queue_w, u_queue_b, d_queue_w = [], [], []
|
||||
### 1.Get W nodes, then merge up/down W nodes ###
|
||||
# get up W nodes: [first_u: mbs//2]
|
||||
# for stage in range(0, self.n_stage // 2):
|
||||
# first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=0)
|
||||
# # print(f"stage {stage} Up first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ")
|
||||
# u_queue_w, u_queue_b, d_queue_w = [], [], []
|
||||
# ### 1.Get W nodes, then merge up/down W nodes ###
|
||||
# # get up W nodes: [first_u: mbs//2]
|
||||
# for _ in range(first_u, self.n_micro // 2):
|
||||
# curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
|
||||
# u_queue_w.append(
|
||||
# ScheduledNode(
|
||||
# type="W",
|
||||
# chunk=0,
|
||||
# stage=stage,
|
||||
# minibatch=_,
|
||||
# start_time=curr_time,
|
||||
# completion_time=curr_time + self.one_time_unit,
|
||||
# )
|
||||
# )
|
||||
# curr_time += self.one_time_unit
|
||||
# # get down W nodes: [first_d: mbs//2] Bwd W to W Queue
|
||||
# for _ in range(first_d, self.n_micro // 2):
|
||||
# curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
|
||||
# d_queue_w.append(
|
||||
# ScheduledNode(
|
||||
# type="W",
|
||||
# chunk=1,
|
||||
# stage=stage,
|
||||
# minibatch=_,
|
||||
# start_time=curr_time,
|
||||
# completion_time=curr_time + self.one_time_unit,
|
||||
# )
|
||||
# )
|
||||
# curr_time += self.one_time_unit
|
||||
# ### 2.Get B nodes, then cross with W ###
|
||||
# for _ in range(last_u, self.n_micro // 2):
|
||||
# curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
|
||||
# u_queue_b.append(
|
||||
# ScheduledNode(
|
||||
# type="B",
|
||||
# chunk=0,
|
||||
# stage=stage,
|
||||
# minibatch=_ + 1,
|
||||
# start_time=curr_time,
|
||||
# completion_time=curr_time + self.one_time_unit,
|
||||
# )
|
||||
# )
|
||||
# curr_time += self.one_time_unit
|
||||
# # if stage % 2 == 0: u_queue_w first, then d_queue_w
|
||||
# if stage % 2 == 0:
|
||||
# w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w)
|
||||
# wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b)
|
||||
# # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
|
||||
# cut_idx = len(wb_nodes)
|
||||
# for _ in range(len(wb_nodes)):
|
||||
# if (
|
||||
# wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
|
||||
# and wb_nodes[_].type == "B"
|
||||
# and wb_nodes[_].chunk == 0
|
||||
# ):
|
||||
# cut_idx = _
|
||||
# break
|
||||
# wb_nodes = wb_nodes[: cut_idx + 1]
|
||||
# # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
|
||||
# # else: d_queue_w first, then u_queue_w
|
||||
# else:
|
||||
# w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w)
|
||||
# wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b)
|
||||
# # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
|
||||
# cut_idx = len(wb_nodes)
|
||||
# for _ in range(len(wb_nodes)):
|
||||
# if (
|
||||
# wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
|
||||
# and wb_nodes[_].type == "B"
|
||||
# and wb_nodes[_].chunk == 0
|
||||
# ):
|
||||
# cut_idx = _
|
||||
# break
|
||||
# wb_nodes = wb_nodes[: cut_idx + 1]
|
||||
# # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
|
||||
|
||||
for stage in range(self.n_stage // 2, self.n_stage):
|
||||
first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=1)
|
||||
print(f"stage {stage} Down first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ")
|
||||
d_queue_w, d_queue_b, u_queue_w = [], [], []
|
||||
### 1.Get W nodes, then merge down/up W nodes ###
|
||||
# get down W nodes: [last_d: mbs] chunk 1
|
||||
for _ in range(last_d, self.n_micro // 2):
|
||||
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
|
||||
d_queue_w.append(
|
||||
ScheduledNode(
|
||||
type="W",
|
||||
chunk=1,
|
||||
stage=stage,
|
||||
minibatch=_,
|
||||
start_time=curr_time,
|
||||
completion_time=curr_time + self.one_time_unit,
|
||||
)
|
||||
)
|
||||
curr_time += self.one_time_unit
|
||||
# print(f"stage {stage} d_queue_w {[_.minibatch for _ in d_queue_w]}")
|
||||
# get up W nodes: [first_u: mbs//2] chunk 0
|
||||
for _ in range(first_u, self.n_micro // 2):
|
||||
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
|
||||
u_queue_w.append(
|
||||
|
@ -1220,27 +1315,13 @@ class DualPipeGraph(object):
|
|||
)
|
||||
)
|
||||
curr_time += self.one_time_unit
|
||||
# get down W nodes: [first_d: mbs//2] Bwd W to W Queue
|
||||
for _ in range(first_d, self.n_micro // 2):
|
||||
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
|
||||
d_queue_w.append(
|
||||
ScheduledNode(
|
||||
type="W",
|
||||
chunk=1,
|
||||
stage=stage,
|
||||
minibatch=_,
|
||||
start_time=curr_time,
|
||||
completion_time=curr_time + self.one_time_unit,
|
||||
)
|
||||
)
|
||||
curr_time += self.one_time_unit
|
||||
### 2.Get B nodes, then cross with W ###
|
||||
for _ in range(last_u, self.n_micro // 2):
|
||||
for _ in range(last_d, self.n_micro // 2):
|
||||
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
|
||||
u_queue_b.append(
|
||||
d_queue_b.append(
|
||||
ScheduledNode(
|
||||
type="B",
|
||||
chunk=0,
|
||||
chunk=1,
|
||||
stage=stage,
|
||||
minibatch=_ + 1,
|
||||
start_time=curr_time,
|
||||
|
@ -1248,121 +1329,41 @@ class DualPipeGraph(object):
|
|||
)
|
||||
)
|
||||
curr_time += self.one_time_unit
|
||||
# if stage % 2 == 0: u_queue_w first, then d_queue_w
|
||||
if stage % 2 == 0:
|
||||
w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w)
|
||||
wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b)
|
||||
# clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
|
||||
cut_idx = len(wb_nodes)
|
||||
for _ in range(len(wb_nodes)):
|
||||
if (
|
||||
wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
|
||||
and wb_nodes[_].type == "B"
|
||||
and wb_nodes[_].chunk == 0
|
||||
):
|
||||
cut_idx = _
|
||||
break
|
||||
wb_nodes = wb_nodes[: cut_idx + 1]
|
||||
# print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
|
||||
# else: d_queue_w first, then u_queue_w
|
||||
else:
|
||||
w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w)
|
||||
wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b)
|
||||
# clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
|
||||
cut_idx = len(wb_nodes)
|
||||
for _ in range(len(wb_nodes)):
|
||||
if (
|
||||
wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
|
||||
and wb_nodes[_].type == "B"
|
||||
and wb_nodes[_].chunk == 0
|
||||
):
|
||||
cut_idx = _
|
||||
break
|
||||
wb_nodes = wb_nodes[: cut_idx + 1]
|
||||
# print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
|
||||
|
||||
for stage in range(self.n_stage // 2, self.n_stage):
|
||||
first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=1)
|
||||
print(f"stage {stage} Down first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ")
|
||||
d_queue_w, d_queue_b, u_queue_w = [], [], []
|
||||
### 1.Get W nodes, then merge down/up W nodes ###
|
||||
# get down W nodes: [first_d: mbs//2] chunk 1
|
||||
for _ in range(self.n_micro // 2, first_d):
|
||||
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
|
||||
d_queue_w.append(
|
||||
ScheduledNode(
|
||||
type="W",
|
||||
chunk=1,
|
||||
stage=stage,
|
||||
minibatch=_,
|
||||
start_time=curr_time,
|
||||
completion_time=curr_time + self.one_time_unit,
|
||||
)
|
||||
)
|
||||
curr_time += self.one_time_unit
|
||||
# get up W nodes: [first_u: mbs//2] chunk 0
|
||||
for _ in range(self.n_micro // 2, first_u):
|
||||
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
|
||||
d_queue_w.append(
|
||||
ScheduledNode(
|
||||
type="W",
|
||||
chunk=0,
|
||||
stage=stage,
|
||||
minibatch=_,
|
||||
start_time=curr_time,
|
||||
completion_time=curr_time + self.one_time_unit,
|
||||
)
|
||||
)
|
||||
curr_time += self.one_time_unit
|
||||
### 2.Get B nodes, then cross with W ###
|
||||
for _ in range(self.n_micro // 2, last_d):
|
||||
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
|
||||
u_queue_b.append(
|
||||
ScheduledNode(
|
||||
type="B",
|
||||
chunk=1,
|
||||
stage=stage,
|
||||
minibatch=_ + 1,
|
||||
start_time=curr_time,
|
||||
completion_time=curr_time + self.one_time_unit,
|
||||
)
|
||||
)
|
||||
curr_time += self.one_time_unit
|
||||
print(f"stage {stage} d_queue_w {[_.minibatch for _ in d_queue_w]}")
|
||||
# print(
|
||||
# f"stage {stage} d_queue_w {[_.minibatch for _ in d_queue_w]} d_queue_b {[_.minibatch for _ in d_queue_b]} u_queue_w {[_.minibatch for _ in u_queue_w]}"
|
||||
# )
|
||||
if stage % 2 == 0:
|
||||
w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w)
|
||||
wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b)
|
||||
# clean w nodes, let it stop at mbs // 2 - 1, chunk 1, type 'B'
|
||||
cut_idx = len(wb_nodes)
|
||||
for _ in range(len(wb_nodes)):
|
||||
if (
|
||||
wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
|
||||
and wb_nodes[_].type == "B"
|
||||
and wb_nodes[_].chunk == 1
|
||||
):
|
||||
cut_idx = _
|
||||
break
|
||||
wb_nodes = wb_nodes[: cut_idx + 1]
|
||||
# print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
|
||||
# else: d_queue_w first, then u_queue_w
|
||||
else:
|
||||
w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w)
|
||||
wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b)
|
||||
# clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
|
||||
cut_idx = len(wb_nodes)
|
||||
for _ in range(len(wb_nodes)):
|
||||
if (
|
||||
wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
|
||||
and wb_nodes[_].type == "B"
|
||||
and wb_nodes[_].chunk == 1
|
||||
):
|
||||
cut_idx = _
|
||||
break
|
||||
wb_nodes = wb_nodes[: cut_idx + 1]
|
||||
# print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
|
||||
# if stage % 2 == 0:
|
||||
# w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w)
|
||||
# wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b)
|
||||
# # clean w nodes, let it stop at mbs // 2 - 1, chunk 1, type 'B'
|
||||
# cut_idx = len(wb_nodes)
|
||||
# for _ in range(len(wb_nodes)):
|
||||
# if (
|
||||
# wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
|
||||
# and wb_nodes[_].type == "B"
|
||||
# and wb_nodes[_].chunk == 1
|
||||
# ):
|
||||
# cut_idx = _
|
||||
# break
|
||||
# wb_nodes = wb_nodes[: cut_idx + 1]
|
||||
# # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
|
||||
# # else: d_queue_w first, then u_queue_w
|
||||
# else:
|
||||
# w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w)
|
||||
# wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b)
|
||||
# # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
|
||||
# cut_idx = len(wb_nodes)
|
||||
# for _ in range(len(wb_nodes)):
|
||||
# if (
|
||||
# wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
|
||||
# and wb_nodes[_].type == "B"
|
||||
# and wb_nodes[_].chunk == 1
|
||||
# ):
|
||||
# cut_idx = _
|
||||
# break
|
||||
# wb_nodes = wb_nodes[: cut_idx + 1]
|
||||
# # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
|
||||
|
||||
########### Pipe_Stage 3.3 ###########
|
||||
def bwdW_step(pipeline_schedule: List[List[ScheduledNode]]):
|
||||
|
|
Loading…
Reference in New Issue