mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] Hook all meta information on ResNet nodes for auto activation checkpoint (#2248)
* [autoparallel] hook node meta on graph nodes for checkpoint solver * [autoparallel] polish code * [autoparallel] restore some node handlers * colossalai/auto_parallel/passes/meta_info_prop.py * [autoparallel] remove some unused import * [autoparallel] hook bwd_mem_outpull/2257/head
parent
c8c79102f0
commit
ab38aebace
|
@ -60,7 +60,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
|
|||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
|
||||
# store fwd_in, fwd_buffer, fwd_out
|
||||
fwd_in = [torch.zeros_like(input_op_data.data, device='meta')]
|
||||
fwd_in = [torch.zeros_like(input_op_data.data, device='meta'), torch.zeros_like(other_op_data.data, device='meta')]
|
||||
fwd_buffer = []
|
||||
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
|
@ -71,25 +70,12 @@ class MetaInfo:
|
|||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
|
||||
def compute_sharded_tensor(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||
"""
|
||||
Compute sharded meta tensor based on the given data and sharding spec.
|
||||
Compute sharded opdata based on the given data and sharding spec.
|
||||
"""
|
||||
shard_sequnce = sharding_spec.sharding_sequence
|
||||
device_mesh = sharding_spec.device_mesh
|
||||
shape = operation_data.data.shape
|
||||
|
||||
new_shape = []
|
||||
for dim, shard in zip(shape, shard_sequnce):
|
||||
if shard.is_replica:
|
||||
# replica
|
||||
new_shape.append(dim)
|
||||
else:
|
||||
# sharded according to device_mesh shape
|
||||
new_shape.append(dim // np.prod(np.array([device_mesh.mesh_shape[i] for i in shard.shard_list])))
|
||||
|
||||
return OperationData(name=operation_data.name,
|
||||
data=torch.zeros(new_shape, device="meta"),
|
||||
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
|
||||
type=operation_data.type,
|
||||
logical_shape=operation_data.logical_shape)
|
||||
|
||||
|
@ -113,7 +99,7 @@ class MetaInfo:
|
|||
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
|
||||
|
||||
# construct args for meta_func
|
||||
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
|
||||
args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()]
|
||||
|
||||
# construct kwargs
|
||||
if self.target in INPLACE_MODULE:
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
from colossalai.tensor.comm_spec import CommSpec
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||
target_sharding_spec: ShardingSpec) -> MetaInfo:
|
||||
# get comm_action_sequence and total_cost from shape_consistency_manager
|
||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||
origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
meta_info = MetaInfo()
|
||||
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
||||
# get mem cost for MetaInfo
|
||||
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
||||
# extract user that has _meta_data and extract element length
|
||||
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
|
||||
element_length = input_node._meta_data.element_size()
|
||||
|
||||
mem_cost.fwd.activation *= element_length
|
||||
mem_cost.fwd.temp *= element_length
|
||||
mem_cost.bwd.activation *= element_length
|
||||
mem_cost.bwd.temp *= element_length
|
||||
mem_cost.total.activation *= element_length
|
||||
|
||||
meta_info.memory_cost = mem_cost
|
||||
|
||||
# get computation cost for MetaInfo
|
||||
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||
total_cost['backward'] * element_length,
|
||||
total_cost['total'] * element_length)
|
||||
|
||||
# get tensor shape for MetaInfo
|
||||
origin_sharding_spec: ShardingSpec
|
||||
target_sharding_spec: ShardingSpec
|
||||
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
|
||||
output_shape = target_sharding_spec.get_sharded_shape_per_device()
|
||||
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||
meta_info.fwd_buffer = []
|
||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||
|
||||
return meta_info
|
||||
|
||||
|
||||
def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo:
|
||||
"""
|
||||
This method is used to construct `MetaInto` for shape consistency node
|
||||
"""
|
||||
|
||||
# extract node index and user node index
|
||||
args = node.args
|
||||
node_index, user_node_index = args[3], args[4]
|
||||
origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[
|
||||
node_index][user_node_index]
|
||||
|
||||
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
|
||||
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo:
|
||||
# extract node_index and op_data_name
|
||||
node_index, op_data_name = node.args[2], node.args[3]
|
||||
|
||||
comm_action = comm_actions_dict[node_index][op_data_name]
|
||||
if isinstance(comm_action.comm_spec, CommSpec):
|
||||
# this case is for all_reduce, there will be no memory cost
|
||||
meta_info = MetaInfo()
|
||||
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
|
||||
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
|
||||
element_length = output_node._meta_data.element_size()
|
||||
|
||||
total_cost = comm_action.comm_spec.get_comm_cost()
|
||||
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||
total_cost['backward'] * element_length,
|
||||
total_cost['total'] * element_length)
|
||||
|
||||
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||
meta_info.fwd_buffer = []
|
||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||
else:
|
||||
# this case will be handled by shape consistency manager
|
||||
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
|
||||
'tgt_spec']
|
||||
meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
return meta_info
|
||||
|
||||
|
||||
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict,
|
||||
comm_actions_dict: Dict):
|
||||
"""
|
||||
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == runtime_apply:
|
||||
setattr(node, 'best_metainfo',
|
||||
_runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict))
|
||||
elif node.target == runtime_comm_spec_apply:
|
||||
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
||||
else:
|
||||
pass
|
|
@ -1,15 +1,14 @@
|
|||
import uuid
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, NamedTuple, Tuple
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import GraphInfo
|
||||
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
|
||||
|
@ -68,7 +67,7 @@ class MetaInfoProp:
|
|||
"""
|
||||
graph_info = GraphInfo()
|
||||
out = _normalize_tuple(getattr(node, '_meta_data', None))
|
||||
graph_info.fwd_out = list(out)
|
||||
graph_info.fwd_out = list(out) if out[0] is not None else []
|
||||
node.meta = {**asdict(graph_info)}
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
|
@ -97,7 +96,7 @@ class MetaInfoProp:
|
|||
"""
|
||||
Handle other kind of nodes
|
||||
"""
|
||||
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}"
|
||||
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
|
||||
graph_info = GraphInfo()
|
||||
meta_info = node.best_metainfo
|
||||
meta_info: MetaInfo
|
||||
|
@ -158,5 +157,13 @@ class MetaInfoProp:
|
|||
memory_cost = meta_info.memory_cost
|
||||
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
|
||||
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
|
||||
graph_info.bwd_mem_out = memory_cost.bwd.activation
|
||||
|
||||
# fetch flop information
|
||||
# here we use fwd_time and bwd_time to deal with the case that
|
||||
# communication cost is a float
|
||||
compute_cost = meta_info.compute_cost
|
||||
graph_info.fwd_time = compute_cost.fwd
|
||||
graph_info.bwd_time = compute_cost.bwd
|
||||
|
||||
node.meta = {**asdict(graph_info)}
|
||||
|
|
|
@ -47,53 +47,6 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict:
|
|||
return rst
|
||||
|
||||
|
||||
def construct_meta_info(node: Node, user_node: Node) -> MetaInfo:
|
||||
"""
|
||||
This method is used to construct `MetaInto` for shape consistency node
|
||||
TODO: Actually we could attain the cost information from resharding cost in node
|
||||
handler, we should modify this part in the future.
|
||||
"""
|
||||
|
||||
def compute_shape(sharding_spec: ShardingSpec):
|
||||
shape = sharding_spec.entire_shape
|
||||
new_shape = []
|
||||
for dim, shard in sharding_spec.dim_partition_dict.items():
|
||||
new_shape.append(shape[dim] // len(shard))
|
||||
return new_shape
|
||||
|
||||
meta_info = MetaInfo()
|
||||
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name(
|
||||
str(node.name))
|
||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||
origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
||||
# get mem cost for MetaInfo
|
||||
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
||||
element_length = node._meta_data.element_size()
|
||||
mem_cost.fwd.activation *= element_length
|
||||
mem_cost.fwd.temp *= element_length
|
||||
mem_cost.bwd.activation *= element_length
|
||||
mem_cost.bwd.temp *= element_length
|
||||
mem_cost.total.activation *= element_length
|
||||
|
||||
meta_info.memory_cost = mem_cost
|
||||
|
||||
# get computation cost for MetaInfo
|
||||
compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total'])
|
||||
meta_info.compute_cost = compute_cost
|
||||
|
||||
# get tensor shape for MetaInfo
|
||||
input_shape = compute_shape(origin_sharding_spec)
|
||||
output_shape = compute_shape(target_sharding_spec)
|
||||
|
||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||
meta_info.fwd_buffer = []
|
||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||
|
||||
return meta_info
|
||||
|
||||
|
||||
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
|
||||
"""
|
||||
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
|
||||
|
@ -175,8 +128,6 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
|||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
meta_info = construct_meta_info(node, user_node)
|
||||
setattr(shape_consistency_node, 'best_metainfo', meta_info)
|
||||
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
|
|
|
@ -138,8 +138,7 @@ class NodeHandler(ABC):
|
|||
return None
|
||||
|
||||
if self.node.op == 'call_module':
|
||||
submod = self.node.graph.owning_module.get_submodule(self.node.target)
|
||||
target = type(submod)
|
||||
target = self.node.graph.owning_module.get_submodule(self.node.target)
|
||||
elif self.node.op == 'call_function':
|
||||
target = self.node.target
|
||||
elif self.node.op == 'call_method':
|
||||
|
|
Loading…
Reference in New Issue