|
|
|
@ -54,7 +54,7 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
|
|
|
|
return meta_info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo: |
|
|
|
|
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo: |
|
|
|
|
""" |
|
|
|
|
This method is used to construct `MetaInto` for shape consistency node |
|
|
|
|
""" |
|
|
|
@ -62,8 +62,8 @@ def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_s
|
|
|
|
|
# extract node index and user node index |
|
|
|
|
args = node.args |
|
|
|
|
node_index, user_node_index = args[3], args[4] |
|
|
|
|
origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[ |
|
|
|
|
node_index][user_node_index] |
|
|
|
|
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][ |
|
|
|
|
user_node_index] |
|
|
|
|
|
|
|
|
|
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) |
|
|
|
|
|
|
|
|
@ -98,16 +98,16 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M
|
|
|
|
|
return meta_info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict, |
|
|
|
|
comm_actions_dict: Dict): |
|
|
|
|
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, |
|
|
|
|
comm_actions_dict: Dict) -> GraphModule: |
|
|
|
|
""" |
|
|
|
|
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. |
|
|
|
|
""" |
|
|
|
|
for node in gm.graph.nodes: |
|
|
|
|
if node.target == runtime_apply: |
|
|
|
|
setattr(node, 'best_metainfo', |
|
|
|
|
_runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict)) |
|
|
|
|
setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) |
|
|
|
|
elif node.target == runtime_comm_spec_apply: |
|
|
|
|
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) |
|
|
|
|
else: |
|
|
|
|
pass |
|
|
|
|
return gm |
|
|
|
|