[autoparallel] Hook all meta information on ResNet nodes for auto activation checkpoint (#2248)

* [autoparallel] hook node meta on graph nodes for checkpoint solver

* [autoparallel] polish code

* [autoparallel] restore some node handlers

* colossalai/auto_parallel/passes/meta_info_prop.py

* [autoparallel] remove some unused import

* [autoparallel] hook bwd_mem_out
pull/2257/head
Boyuan Yao 2023-01-02 16:25:18 +08:00 committed by GitHub
parent c8c79102f0
commit ab38aebace
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 132 additions and 76 deletions

View File

@ -60,7 +60,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_op_data.data, device='meta')]
fwd_in = [torch.zeros_like(input_op_data.data, device='meta'), torch.zeros_like(other_op_data.data, device='meta')]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]

View File

@ -1,6 +1,5 @@
from typing import Callable, List
import numpy as np
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
@ -71,25 +70,12 @@ class MetaInfo:
if self._strategy is not None and self._target is not None:
self.compute_metainfo()
def compute_sharded_tensor(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
"""
Compute sharded meta tensor based on the given data and sharding spec.
Compute sharded opdata based on the given data and sharding spec.
"""
shard_sequnce = sharding_spec.sharding_sequence
device_mesh = sharding_spec.device_mesh
shape = operation_data.data.shape
new_shape = []
for dim, shard in zip(shape, shard_sequnce):
if shard.is_replica:
# replica
new_shape.append(dim)
else:
# sharded according to device_mesh shape
new_shape.append(dim // np.prod(np.array([device_mesh.mesh_shape[i] for i in shard.shard_list])))
return OperationData(name=operation_data.name,
data=torch.zeros(new_shape, device="meta"),
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)
@ -113,7 +99,7 @@ class MetaInfo:
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
# construct args for meta_func
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()]
# construct kwargs
if self.target in INPLACE_MODULE:

View File

@ -0,0 +1,113 @@
from typing import Dict
import torch
from torch.fx import GraphModule
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
target_sharding_spec: ShardingSpec) -> MetaInfo:
# get comm_action_sequence and total_cost from shape_consistency_manager
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)
meta_info = MetaInfo()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for MetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
# extract user that has _meta_data and extract element length
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
element_length = input_node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
mem_cost.fwd.temp *= element_length
mem_cost.bwd.activation *= element_length
mem_cost.bwd.temp *= element_length
mem_cost.total.activation *= element_length
meta_info.memory_cost = mem_cost
# get computation cost for MetaInfo
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)
# get tensor shape for MetaInfo
origin_sharding_spec: ShardingSpec
target_sharding_spec: ShardingSpec
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
output_shape = target_sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
return meta_info
def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo:
"""
This method is used to construct `MetaInto` for shape consistency node
"""
# extract node index and user node index
args = node.args
node_index, user_node_index = args[3], args[4]
origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[
node_index][user_node_index]
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo:
# extract node_index and op_data_name
node_index, op_data_name = node.args[2], node.args[3]
comm_action = comm_actions_dict[node_index][op_data_name]
if isinstance(comm_action.comm_spec, CommSpec):
# this case is for all_reduce, there will be no memory cost
meta_info = MetaInfo()
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
element_length = output_node._meta_data.element_size()
total_cost = comm_action.comm_spec.get_comm_cost()
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
total_cost['backward'] * element_length,
total_cost['total'] * element_length)
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
else:
# this case will be handled by shape consistency manager
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
'tgt_spec']
meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
return meta_info
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict,
comm_actions_dict: Dict):
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for node in gm.graph.nodes:
if node.target == runtime_apply:
setattr(node, 'best_metainfo',
_runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict))
elif node.target == runtime_comm_spec_apply:
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
else:
pass

View File

@ -1,15 +1,14 @@
import uuid
from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Tuple
from typing import List
import torch
import torch.fx
from torch.fx import GraphModule
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
@ -68,7 +67,7 @@ class MetaInfoProp:
"""
graph_info = GraphInfo()
out = _normalize_tuple(getattr(node, '_meta_data', None))
graph_info.fwd_out = list(out)
graph_info.fwd_out = list(out) if out[0] is not None else []
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
@ -97,7 +96,7 @@ class MetaInfoProp:
"""
Handle other kind of nodes
"""
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}"
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
graph_info = GraphInfo()
meta_info = node.best_metainfo
meta_info: MetaInfo
@ -158,5 +157,13 @@ class MetaInfoProp:
memory_cost = meta_info.memory_cost
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
graph_info.bwd_mem_out = memory_cost.bwd.activation
# fetch flop information
# here we use fwd_time and bwd_time to deal with the case that
# communication cost is a float
compute_cost = meta_info.compute_cost
graph_info.fwd_time = compute_cost.fwd
graph_info.bwd_time = compute_cost.bwd
node.meta = {**asdict(graph_info)}

View File

@ -47,53 +47,6 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict:
return rst
def construct_meta_info(node: Node, user_node: Node) -> MetaInfo:
"""
This method is used to construct `MetaInto` for shape consistency node
TODO: Actually we could attain the cost information from resharding cost in node
handler, we should modify this part in the future.
"""
def compute_shape(sharding_spec: ShardingSpec):
shape = sharding_spec.entire_shape
new_shape = []
for dim, shard in sharding_spec.dim_partition_dict.items():
new_shape.append(shape[dim] // len(shard))
return new_shape
meta_info = MetaInfo()
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name(
str(node.name))
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
origin_sharding_spec, target_sharding_spec)
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for MetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
element_length = node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
mem_cost.fwd.temp *= element_length
mem_cost.bwd.activation *= element_length
mem_cost.bwd.temp *= element_length
mem_cost.total.activation *= element_length
meta_info.memory_cost = mem_cost
# get computation cost for MetaInfo
compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total'])
meta_info.compute_cost = compute_cost
# get tensor shape for MetaInfo
input_shape = compute_shape(origin_sharding_spec)
output_shape = compute_shape(target_sharding_spec)
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
meta_info.fwd_buffer = []
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
return meta_info
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
"""
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
@ -175,8 +128,6 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
meta_info = construct_meta_info(node, user_node)
setattr(shape_consistency_node, 'best_metainfo', meta_info)
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)

View File

@ -138,8 +138,7 @@ class NodeHandler(ABC):
return None
if self.node.op == 'call_module':
submod = self.node.graph.owning_module.get_submodule(self.node.target)
target = type(submod)
target = self.node.graph.owning_module.get_submodule(self.node.target)
elif self.node.op == 'call_function':
target = self.node.target
elif self.node.op == 'call_method':