from typing import List import torch from torch.fx.node import Node from .region import Region from .util import GlobalRuntimeInfo, requires_upload_p_in_fwd class SynPreFwdPostBwdOP(torch.autograd.Function): """ A customized prefetch and offload operation. Args: input_: input tensor. fwd_info: information dict, which contains region indices that need to be uploaded or freed during forward pass. bwd_info: information dict, which contains region indices that need to be uploaded during backward pass. """ @staticmethod def forward(ctx, input_, fwd_info, bwd_info): ctx.bwd_info = bwd_info d2h_rid = fwd_info.get("d2h_rid", None) if d2h_rid is not None: free_region = GlobalRuntimeInfo().region_list[d2h_rid] assert isinstance(free_region, Region) free_region.free_cuda_data() h2d_rid = fwd_info.get("h2d_rid", None) if h2d_rid is not None: h2d_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(h2d_region, Region) h2d_region.move_param_to_cuda() return input_ @staticmethod def backward(ctx, grad_output): h2d_rid = ctx.bwd_info.get("h2d_rid", None) if h2d_rid is not None: pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) pref_region.move_param_to_cuda() return grad_output, None, None class AsynPreFwdPostBwdOP(torch.autograd.Function): """ A customized prefetch and offload operation. Args: input_: input tensor. fwd_info: information dict, which contains region indices that need to be prefetched, waited, or freed during forward pass. bwd_info: information dict, which contains region indices that need to be prefetched or waited during backward pass. """ @staticmethod def forward(ctx, input_, fwd_info, bwd_info): ctx.bwd_info = bwd_info sync_rid = fwd_info.get("sync_rid", None) if sync_rid is not None: prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None) if prefetch_event: prefetch_event.wait() h2d_rid = fwd_info.get("h2d_rid", None) if h2d_rid is not None: pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) master_stream = torch.cuda.current_stream() with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream): GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream) pref_region.move_param_to_cuda() prefetch_event = torch.cuda.Event() prefetch_event.record(GlobalRuntimeInfo().h2d_stream) GlobalRuntimeInfo().fwd_prefetch_event_map[h2d_rid] = prefetch_event return input_ @staticmethod def backward(ctx, grad_output): sync_rid = ctx.bwd_info.get("sync_rid", None) if sync_rid is not None: wait_region = GlobalRuntimeInfo().region_list[sync_rid] assert isinstance(wait_region, Region) prefetch_event = GlobalRuntimeInfo().bwd_prefetch_event_map.get(sync_rid, None) if prefetch_event: prefetch_event.wait() else: wait_region.move_param_to_cuda() h2d_rid = ctx.bwd_info.get("h2d_rid", None) if h2d_rid is not None: pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) master_stream = torch.cuda.current_stream() with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream): GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream) pref_region.move_param_to_cuda() prefetch_event = torch.cuda.Event() prefetch_event.record(GlobalRuntimeInfo().h2d_stream) GlobalRuntimeInfo().bwd_prefetch_event_map[h2d_rid] = prefetch_event return grad_output, None, None def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): """ Convert Upload and Offload operation into runtime action. Argument: tensor(torch.Tensor): input tensor. fwd_info(dict): information dict, which contains region indices that need to be uploaded, or freed during forward pass. bwd_info(dict): information dict, which contains region indices that need to be uploaded during backward pass. """ with torch._C.DisableTorchFunction(): ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) return ret def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): """ Convert Prefetch and Offload operation into runtime action. Argument: tensor(torch.Tensor): input tensor. fwd_info(dict): information dict, which contains region indices that need to be prefetched, waited, or freed during forward pass. bwd_info(dict): information dict, which contains region indices that need to be prefetched or waited during backward pass. """ with torch._C.DisableTorchFunction(): ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) return ret def replace_node_users(orig_node: Node, inserted_node: Node, rep_user_nodes: List[Node] = None): user_list = list(orig_node.users.keys()) if rep_user_nodes is not None: user_list = rep_user_nodes for user in user_list: if user == inserted_node: continue new_args = list(user.args) new_kwargs = dict(user.kwargs) # the origin node may be a positional argument or key word argument of user node if orig_node in new_args: # substitute the origin node with offload_apply_node new_args[new_args.index(orig_node)] = inserted_node user.args = tuple(new_args) elif str(orig_node) in new_kwargs: # substitute the origin node with offload_apply_node new_kwargs[str(orig_node)] = inserted_node user.kwargs = new_kwargs def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]): """ This pass is used to add the synchronous upload and offload spec apply node to the origin graph. """ mod_graph = gm.graph last_inp_node = tuple(mod_graph.nodes)[0] for r_idx, region in enumerate(region_list): # forward upload fwd_info = {} if requires_upload_p_in_fwd(region_list[region.shared_rid]): fwd_info["h2d_rid"] = region.r_id # forward offload if r_idx > 0 and region_list[r_idx - 1].need_offload: fwd_info["d2h_rid"] = r_idx - 1 bwd_info = {} # backward upload if r_idx > 0 and region_list[r_idx - 1].need_offload: bwd_info["h2d_rid"] = region_list[r_idx - 1].r_id if fwd_info or bwd_info: with mod_graph.inserting_after(last_inp_node): new_node = mod_graph.create_node( "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info) ) replace_node_users(last_inp_node, new_node) last_inp_node = region.nodes[-1] return gm def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]): """ This pass is used to add the asynchronous prefetch and offload spec apply node to the origin graph. """ mod_graph = gm.graph # upload parameters of the first region last_inp_node = tuple(mod_graph.nodes)[0] first_region_with_p = [region for region in region_list if region.param_size][0] fwd_info = {"h2d_rid": first_region_with_p.r_id} with mod_graph.inserting_after(last_inp_node): upload_apply_node = mod_graph.create_node( "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, {}) ) replace_node_users(last_inp_node, upload_apply_node) last_inp_node = upload_apply_node for r_idx, region in enumerate(region_list): # forward prefetch fwd_info = {} if region.param_size: fwd_info["sync_rid"] = region.r_id fwd_prefetch_region = region.fwd_prefetch_region if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]): fwd_info["h2d_rid"] = fwd_prefetch_region.r_id # forward offload if r_idx > 0 and region_list[r_idx - 1].need_offload: fwd_info["d2h_rid"] = r_idx - 1 bwd_info = {} # backward prefetch if r_idx > 0 and region_list[r_idx - 1].need_offload: bwd_info["sync_rid"] = r_idx - 1 if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region: bwd_info["h2d_rid"] = region_list[r_idx - 1].bwd_prefetch_region.r_id if fwd_info or bwd_info: with mod_graph.inserting_after(last_inp_node): new_node = mod_graph.create_node( "call_function", convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info), ) replace_node_users(last_inp_node, new_node) last_inp_node = region.nodes[-1] if region.bwd_prefetch_region: bwd_info = {"h2d_rid": region.bwd_prefetch_region.r_id} with mod_graph.inserting_after(last_inp_node): new_node = mod_graph.create_node( "call_function", convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, {}, bwd_info) ) replace_node_users(last_inp_node, new_node) # gm.graph.print_tabular() return gm