mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* [autoparallel] hook node meta on graph nodes for checkpoint solver * [autoparallel] polish code * [autoparallel] restore some node handlers * colossalai/auto_parallel/passes/meta_info_prop.py * [autoparallel] remove some unused import * [autoparallel] hook bwd_mem_outpull/2257/head
Boyuan Yao
2 years ago
committed by
GitHub
6 changed files with 132 additions and 76 deletions
@ -0,0 +1,113 @@
|
||||
from typing import Dict |
||||
|
||||
import torch |
||||
from torch.fx import GraphModule |
||||
from torch.fx.node import Node |
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo |
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply |
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem |
||||
from colossalai.tensor.comm_spec import CommSpec |
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager |
||||
from colossalai.tensor.sharding_spec import ShardingSpec |
||||
|
||||
shape_consistency_manager = ShapeConsistencyManager() |
||||
|
||||
|
||||
def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec, |
||||
target_sharding_spec: ShardingSpec) -> MetaInfo: |
||||
# get comm_action_sequence and total_cost from shape_consistency_manager |
||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( |
||||
origin_sharding_spec, target_sharding_spec) |
||||
|
||||
meta_info = MetaInfo() |
||||
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel |
||||
# get mem cost for MetaInfo |
||||
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) |
||||
# extract user that has _meta_data and extract element length |
||||
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data')) |
||||
element_length = input_node._meta_data.element_size() |
||||
|
||||
mem_cost.fwd.activation *= element_length |
||||
mem_cost.fwd.temp *= element_length |
||||
mem_cost.bwd.activation *= element_length |
||||
mem_cost.bwd.temp *= element_length |
||||
mem_cost.total.activation *= element_length |
||||
|
||||
meta_info.memory_cost = mem_cost |
||||
|
||||
# get computation cost for MetaInfo |
||||
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, |
||||
total_cost['backward'] * element_length, |
||||
total_cost['total'] * element_length) |
||||
|
||||
# get tensor shape for MetaInfo |
||||
origin_sharding_spec: ShardingSpec |
||||
target_sharding_spec: ShardingSpec |
||||
input_shape = origin_sharding_spec.get_sharded_shape_per_device() |
||||
output_shape = target_sharding_spec.get_sharded_shape_per_device() |
||||
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')] |
||||
meta_info.fwd_buffer = [] |
||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')] |
||||
|
||||
return meta_info |
||||
|
||||
|
||||
def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo: |
||||
""" |
||||
This method is used to construct `MetaInto` for shape consistency node |
||||
""" |
||||
|
||||
# extract node index and user node index |
||||
args = node.args |
||||
node_index, user_node_index = args[3], args[4] |
||||
origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[ |
||||
node_index][user_node_index] |
||||
|
||||
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) |
||||
|
||||
|
||||
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo: |
||||
# extract node_index and op_data_name |
||||
node_index, op_data_name = node.args[2], node.args[3] |
||||
|
||||
comm_action = comm_actions_dict[node_index][op_data_name] |
||||
if isinstance(comm_action.comm_spec, CommSpec): |
||||
# this case is for all_reduce, there will be no memory cost |
||||
meta_info = MetaInfo() |
||||
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost) |
||||
output_node = next(n for n in node.users if hasattr(n, '_meta_data')) |
||||
element_length = output_node._meta_data.element_size() |
||||
|
||||
total_cost = comm_action.comm_spec.get_comm_cost() |
||||
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, |
||||
total_cost['backward'] * element_length, |
||||
total_cost['total'] * element_length) |
||||
|
||||
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device() |
||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')] |
||||
meta_info.fwd_buffer = [] |
||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')] |
||||
else: |
||||
# this case will be handled by shape consistency manager |
||||
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[ |
||||
'tgt_spec'] |
||||
meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) |
||||
|
||||
return meta_info |
||||
|
||||
|
||||
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict, |
||||
comm_actions_dict: Dict): |
||||
""" |
||||
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. |
||||
""" |
||||
for node in gm.graph.nodes: |
||||
if node.target == runtime_apply: |
||||
setattr(node, 'best_metainfo', |
||||
_runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict)) |
||||
elif node.target == runtime_comm_spec_apply: |
||||
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) |
||||
else: |
||||
pass |
Loading…
Reference in new issue