|
|
|
@ -1,4 +1,5 @@
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch.fx.node import Node
|
|
|
|
|
|
|
|
|
@ -23,13 +24,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
|
|
|
|
|
ctx.bwd_info = bwd_info
|
|
|
|
|
d2h_rid = fwd_info.get('d2h_rid', None)
|
|
|
|
|
if d2h_rid is not None:
|
|
|
|
|
free_region = GlobalRuntimeInfo.region_list[d2h_rid]
|
|
|
|
|
free_region = GlobalRuntimeInfo().region_list[d2h_rid]
|
|
|
|
|
assert isinstance(free_region, Region)
|
|
|
|
|
free_region.free_cuda_data()
|
|
|
|
|
|
|
|
|
|
h2d_rid = fwd_info.get('h2d_rid', None)
|
|
|
|
|
if h2d_rid is not None:
|
|
|
|
|
h2d_region = GlobalRuntimeInfo.region_list[h2d_rid]
|
|
|
|
|
h2d_region = GlobalRuntimeInfo().region_list[h2d_rid]
|
|
|
|
|
assert isinstance(h2d_region, Region)
|
|
|
|
|
h2d_region.move_param_to_cuda()
|
|
|
|
|
|
|
|
|
@ -40,7 +41,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
|
|
|
|
|
if h2d_rid is not None:
|
|
|
|
|
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
|
|
|
|
|
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
|
|
|
|
|
assert isinstance(pref_region, Region)
|
|
|
|
|
pref_region.move_param_to_cuda()
|
|
|
|
|
|
|
|
|
@ -65,23 +66,22 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
sync_rid = fwd_info.get('sync_rid', None)
|
|
|
|
|
if sync_rid is not None:
|
|
|
|
|
prefetch_event = GlobalRuntimeInfo.fwd_prefetch_event_map.get(
|
|
|
|
|
sync_rid, None)
|
|
|
|
|
prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None)
|
|
|
|
|
if prefetch_event:
|
|
|
|
|
prefetch_event.wait()
|
|
|
|
|
|
|
|
|
|
h2d_rid = fwd_info.get('h2d_rid', None)
|
|
|
|
|
if h2d_rid is not None:
|
|
|
|
|
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
|
|
|
|
|
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
|
|
|
|
|
assert isinstance(pref_region, Region)
|
|
|
|
|
master_stream = torch.cuda.current_stream()
|
|
|
|
|
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream):
|
|
|
|
|
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream)
|
|
|
|
|
with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream):
|
|
|
|
|
GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream)
|
|
|
|
|
pref_region.move_param_to_cuda()
|
|
|
|
|
|
|
|
|
|
prefetch_event = torch.cuda.Event()
|
|
|
|
|
prefetch_event.record(GlobalRuntimeInfo.h2d_stream)
|
|
|
|
|
GlobalRuntimeInfo.fwd_prefetch_event_map[h2d_rid] = prefetch_event
|
|
|
|
|
prefetch_event.record(GlobalRuntimeInfo().h2d_stream)
|
|
|
|
|
GlobalRuntimeInfo().fwd_prefetch_event_map[h2d_rid] = prefetch_event
|
|
|
|
|
|
|
|
|
|
return input_
|
|
|
|
|
|
|
|
|
@ -90,10 +90,9 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
sync_rid = ctx.bwd_info.get('sync_rid', None)
|
|
|
|
|
if sync_rid is not None:
|
|
|
|
|
wait_region = GlobalRuntimeInfo.region_list[sync_rid]
|
|
|
|
|
wait_region = GlobalRuntimeInfo().region_list[sync_rid]
|
|
|
|
|
assert isinstance(wait_region, Region)
|
|
|
|
|
prefetch_event = GlobalRuntimeInfo.bwd_prefetch_event_map.get(
|
|
|
|
|
sync_rid, None)
|
|
|
|
|
prefetch_event = GlobalRuntimeInfo().bwd_prefetch_event_map.get(sync_rid, None)
|
|
|
|
|
if prefetch_event:
|
|
|
|
|
prefetch_event.wait()
|
|
|
|
|
else:
|
|
|
|
@ -101,16 +100,16 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
|
|
|
|
|
if h2d_rid is not None:
|
|
|
|
|
pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
|
|
|
|
|
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
|
|
|
|
|
assert isinstance(pref_region, Region)
|
|
|
|
|
master_stream = torch.cuda.current_stream()
|
|
|
|
|
with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream):
|
|
|
|
|
GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream)
|
|
|
|
|
with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream):
|
|
|
|
|
GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream)
|
|
|
|
|
pref_region.move_param_to_cuda()
|
|
|
|
|
|
|
|
|
|
prefetch_event = torch.cuda.Event()
|
|
|
|
|
prefetch_event.record(GlobalRuntimeInfo.h2d_stream)
|
|
|
|
|
GlobalRuntimeInfo.bwd_prefetch_event_map[h2d_rid] = prefetch_event
|
|
|
|
|
prefetch_event.record(GlobalRuntimeInfo().h2d_stream)
|
|
|
|
|
GlobalRuntimeInfo().bwd_prefetch_event_map[h2d_rid] = prefetch_event
|
|
|
|
|
return grad_output, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -129,6 +128,7 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
|
|
|
|
|
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
|
|
|
|
|
'''
|
|
|
|
|
Convert Prefetch and Offload operation into runtime action.
|
|
|
|
@ -189,7 +189,8 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
|
|
|
|
|
|
|
|
|
|
if fwd_info or bwd_info:
|
|
|
|
|
with mod_graph.inserting_after(last_inp_node):
|
|
|
|
|
new_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action,
|
|
|
|
|
new_node = mod_graph.create_node('call_function',
|
|
|
|
|
convert_fwd_upload_bwd_offload_to_action,
|
|
|
|
|
args=(last_inp_node, fwd_info, bwd_info))
|
|
|
|
|
replace_node_users(last_inp_node, new_node)
|
|
|
|
|
|
|
|
|
@ -206,11 +207,11 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
|
|
|
|
|
|
|
|
|
|
# upload parameters of the first region
|
|
|
|
|
last_inp_node = tuple(mod_graph.nodes)[0]
|
|
|
|
|
first_region_with_p = [
|
|
|
|
|
region for region in region_list if region.param_size][0]
|
|
|
|
|
first_region_with_p = [region for region in region_list if region.param_size][0]
|
|
|
|
|
fwd_info = {"h2d_rid": first_region_with_p.r_id}
|
|
|
|
|
with mod_graph.inserting_after(last_inp_node):
|
|
|
|
|
upload_apply_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action,
|
|
|
|
|
upload_apply_node = mod_graph.create_node('call_function',
|
|
|
|
|
convert_fwd_upload_bwd_offload_to_action,
|
|
|
|
|
args=(last_inp_node, fwd_info, {}))
|
|
|
|
|
replace_node_users(last_inp_node, upload_apply_node)
|
|
|
|
|
last_inp_node = upload_apply_node
|
|
|
|
@ -225,19 +226,20 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
|
|
|
|
|
fwd_info['h2d_rid'] = fwd_prefetch_region.r_id
|
|
|
|
|
|
|
|
|
|
# forward offload
|
|
|
|
|
if r_idx > 0 and region_list[r_idx-1].need_offload:
|
|
|
|
|
if r_idx > 0 and region_list[r_idx - 1].need_offload:
|
|
|
|
|
fwd_info['d2h_rid'] = r_idx - 1
|
|
|
|
|
|
|
|
|
|
bwd_info = {}
|
|
|
|
|
# backward prefetch
|
|
|
|
|
if r_idx > 0 and region_list[r_idx-1].need_offload:
|
|
|
|
|
if r_idx > 0 and region_list[r_idx - 1].need_offload:
|
|
|
|
|
bwd_info['sync_rid'] = r_idx - 1
|
|
|
|
|
if r_idx > 0 and region_list[r_idx-1].bwd_prefetch_region:
|
|
|
|
|
bwd_info['h2d_rid'] = region_list[r_idx-1].bwd_prefetch_region.r_id
|
|
|
|
|
if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region:
|
|
|
|
|
bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id
|
|
|
|
|
|
|
|
|
|
if fwd_info or bwd_info:
|
|
|
|
|
with mod_graph.inserting_after(last_inp_node):
|
|
|
|
|
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action,
|
|
|
|
|
new_node = mod_graph.create_node('call_function',
|
|
|
|
|
convert_fwd_prefetch_bwd_offload_to_action,
|
|
|
|
|
args=(last_inp_node, fwd_info, bwd_info))
|
|
|
|
|
replace_node_users(last_inp_node, new_node)
|
|
|
|
|
|
|
|
|
@ -246,7 +248,8 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
|
|
|
|
|
if region.bwd_prefetch_region:
|
|
|
|
|
bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
|
|
|
|
|
with mod_graph.inserting_after(last_inp_node):
|
|
|
|
|
new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action,
|
|
|
|
|
new_node = mod_graph.create_node('call_function',
|
|
|
|
|
convert_fwd_prefetch_bwd_offload_to_action,
|
|
|
|
|
args=(last_inp_node, {}, bwd_info))
|
|
|
|
|
replace_node_users(last_inp_node, new_node)
|
|
|
|
|
# gm.graph.print_tabular()
|
|
|
|
|