mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
256 lines
9.7 KiB
256 lines
9.7 KiB
from typing import List |
|
|
|
import torch |
|
from torch.fx.node import Node |
|
|
|
from .region import Region |
|
from .util import GlobalRuntimeInfo, requires_upload_p_in_fwd |
|
|
|
|
|
class SynPreFwdPostBwdOP(torch.autograd.Function): |
|
""" |
|
A customized prefetch and offload operation. |
|
|
|
Args: |
|
input_: input tensor. |
|
fwd_info: information dict, which contains region indices |
|
that need to be uploaded or freed during forward pass. |
|
bwd_info: information dict, which contains region indices |
|
that need to be uploaded during backward pass. |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, input_, fwd_info, bwd_info): |
|
ctx.bwd_info = bwd_info |
|
d2h_rid = fwd_info.get("d2h_rid", None) |
|
if d2h_rid is not None: |
|
free_region = GlobalRuntimeInfo().region_list[d2h_rid] |
|
assert isinstance(free_region, Region) |
|
free_region.free_cuda_data() |
|
|
|
h2d_rid = fwd_info.get("h2d_rid", None) |
|
if h2d_rid is not None: |
|
h2d_region = GlobalRuntimeInfo().region_list[h2d_rid] |
|
assert isinstance(h2d_region, Region) |
|
h2d_region.move_param_to_cuda() |
|
|
|
return input_ |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
h2d_rid = ctx.bwd_info.get("h2d_rid", None) |
|
if h2d_rid is not None: |
|
pref_region = GlobalRuntimeInfo().region_list[h2d_rid] |
|
assert isinstance(pref_region, Region) |
|
pref_region.move_param_to_cuda() |
|
|
|
return grad_output, None, None |
|
|
|
|
|
class AsynPreFwdPostBwdOP(torch.autograd.Function): |
|
""" |
|
A customized prefetch and offload operation. |
|
|
|
Args: |
|
input_: input tensor. |
|
fwd_info: information dict, which contains region indices |
|
that need to be prefetched, waited, or freed during forward pass. |
|
bwd_info: information dict, which contains region indices |
|
that need to be prefetched or waited during backward pass. |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, input_, fwd_info, bwd_info): |
|
ctx.bwd_info = bwd_info |
|
|
|
sync_rid = fwd_info.get("sync_rid", None) |
|
if sync_rid is not None: |
|
prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None) |
|
if prefetch_event: |
|
prefetch_event.wait() |
|
|
|
h2d_rid = fwd_info.get("h2d_rid", None) |
|
if h2d_rid is not None: |
|
pref_region = GlobalRuntimeInfo().region_list[h2d_rid] |
|
assert isinstance(pref_region, Region) |
|
master_stream = torch.cuda.current_stream() |
|
with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream): |
|
GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream) |
|
pref_region.move_param_to_cuda() |
|
|
|
prefetch_event = torch.cuda.Event() |
|
prefetch_event.record(GlobalRuntimeInfo().h2d_stream) |
|
GlobalRuntimeInfo().fwd_prefetch_event_map[h2d_rid] = prefetch_event |
|
|
|
return input_ |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
sync_rid = ctx.bwd_info.get("sync_rid", None) |
|
if sync_rid is not None: |
|
wait_region = GlobalRuntimeInfo().region_list[sync_rid] |
|
assert isinstance(wait_region, Region) |
|
prefetch_event = GlobalRuntimeInfo().bwd_prefetch_event_map.get(sync_rid, None) |
|
if prefetch_event: |
|
prefetch_event.wait() |
|
else: |
|
wait_region.move_param_to_cuda() |
|
|
|
h2d_rid = ctx.bwd_info.get("h2d_rid", None) |
|
if h2d_rid is not None: |
|
pref_region = GlobalRuntimeInfo().region_list[h2d_rid] |
|
assert isinstance(pref_region, Region) |
|
master_stream = torch.cuda.current_stream() |
|
with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream): |
|
GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream) |
|
pref_region.move_param_to_cuda() |
|
|
|
prefetch_event = torch.cuda.Event() |
|
prefetch_event.record(GlobalRuntimeInfo().h2d_stream) |
|
GlobalRuntimeInfo().bwd_prefetch_event_map[h2d_rid] = prefetch_event |
|
return grad_output, None, None |
|
|
|
|
|
def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): |
|
""" |
|
Convert Upload and Offload operation into runtime action. |
|
|
|
Argument: |
|
tensor(torch.Tensor): input tensor. |
|
fwd_info(dict): information dict, which contains region indices |
|
that need to be uploaded, or freed during forward pass. |
|
bwd_info(dict): information dict, which contains region indices |
|
that need to be uploaded during backward pass. |
|
""" |
|
with torch._C.DisableTorchFunction(): |
|
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) |
|
return ret |
|
|
|
|
|
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): |
|
""" |
|
Convert Prefetch and Offload operation into runtime action. |
|
|
|
Argument: |
|
tensor(torch.Tensor): input tensor. |
|
fwd_info(dict): information dict, which contains region indices |
|
that need to be prefetched, waited, or freed during forward pass. |
|
bwd_info(dict): information dict, which contains region indices |
|
that need to be prefetched or waited during backward pass. |
|
""" |
|
with torch._C.DisableTorchFunction(): |
|
ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) |
|
return ret |
|
|
|
|
|
def replace_node_users(orig_node: Node, inserted_node: Node, rep_user_nodes: List[Node] = None): |
|
user_list = list(orig_node.users.keys()) |
|
if rep_user_nodes is not None: |
|
user_list = rep_user_nodes |
|
for user in user_list: |
|
if user == inserted_node: |
|
continue |
|
new_args = list(user.args) |
|
new_kwargs = dict(user.kwargs) |
|
# the origin node may be a positional argument or key word argument of user node |
|
if orig_node in new_args: |
|
# substitute the origin node with offload_apply_node |
|
new_args[new_args.index(orig_node)] = inserted_node |
|
user.args = tuple(new_args) |
|
elif str(orig_node) in new_kwargs: |
|
# substitute the origin node with offload_apply_node |
|
new_kwargs[str(orig_node)] = inserted_node |
|
user.kwargs = new_kwargs |
|
|
|
|
|
def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]): |
|
""" |
|
This pass is used to add the synchronous upload and offload spec apply node to the origin graph. |
|
""" |
|
mod_graph = gm.graph |
|
last_inp_node = tuple(mod_graph.nodes)[0] |
|
|
|
for r_idx, region in enumerate(region_list): |
|
# forward upload |
|
fwd_info = {} |
|
if requires_upload_p_in_fwd(region_list[region.shared_rid]): |
|
fwd_info["h2d_rid"] = region.r_id |
|
|
|
# forward offload |
|
if r_idx > 0 and region_list[r_idx - 1].need_offload: |
|
fwd_info["d2h_rid"] = r_idx - 1 |
|
|
|
bwd_info = {} |
|
# backward upload |
|
if r_idx > 0 and region_list[r_idx - 1].need_offload: |
|
bwd_info["h2d_rid"] = region_list[r_idx - 1].r_id |
|
|
|
if fwd_info or bwd_info: |
|
with mod_graph.inserting_after(last_inp_node): |
|
new_node = mod_graph.create_node( |
|
"call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info) |
|
) |
|
replace_node_users(last_inp_node, new_node) |
|
|
|
last_inp_node = region.nodes[-1] |
|
|
|
return gm |
|
|
|
|
|
def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]): |
|
""" |
|
This pass is used to add the asynchronous prefetch and offload spec apply node to the origin graph. |
|
""" |
|
mod_graph = gm.graph |
|
|
|
# upload parameters of the first region |
|
last_inp_node = tuple(mod_graph.nodes)[0] |
|
first_region_with_p = [region for region in region_list if region.param_size][0] |
|
fwd_info = {"h2d_rid": first_region_with_p.r_id} |
|
with mod_graph.inserting_after(last_inp_node): |
|
upload_apply_node = mod_graph.create_node( |
|
"call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, {}) |
|
) |
|
replace_node_users(last_inp_node, upload_apply_node) |
|
last_inp_node = upload_apply_node |
|
|
|
for r_idx, region in enumerate(region_list): |
|
# forward prefetch |
|
fwd_info = {} |
|
if region.param_size: |
|
fwd_info["sync_rid"] = region.r_id |
|
fwd_prefetch_region = region.fwd_prefetch_region |
|
if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]): |
|
fwd_info["h2d_rid"] = fwd_prefetch_region.r_id |
|
|
|
# forward offload |
|
if r_idx > 0 and region_list[r_idx - 1].need_offload: |
|
fwd_info["d2h_rid"] = r_idx - 1 |
|
|
|
bwd_info = {} |
|
# backward prefetch |
|
if r_idx > 0 and region_list[r_idx - 1].need_offload: |
|
bwd_info["sync_rid"] = r_idx - 1 |
|
if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region: |
|
bwd_info["h2d_rid"] = region_list[r_idx - 1].bwd_prefetch_region.r_id |
|
|
|
if fwd_info or bwd_info: |
|
with mod_graph.inserting_after(last_inp_node): |
|
new_node = mod_graph.create_node( |
|
"call_function", |
|
convert_fwd_prefetch_bwd_offload_to_action, |
|
args=(last_inp_node, fwd_info, bwd_info), |
|
) |
|
replace_node_users(last_inp_node, new_node) |
|
|
|
last_inp_node = region.nodes[-1] |
|
|
|
if region.bwd_prefetch_region: |
|
bwd_info = {"h2d_rid": region.bwd_prefetch_region.r_id} |
|
with mod_graph.inserting_after(last_inp_node): |
|
new_node = mod_graph.create_node( |
|
"call_function", convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, {}, bwd_info) |
|
) |
|
replace_node_users(last_inp_node, new_node) |
|
# gm.graph.print_tabular() |
|
return gm
|
|
|