mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] Hook all meta information on ResNet nodes for auto activation checkpoint (#2248)
* [autoparallel] hook node meta on graph nodes for checkpoint solver * [autoparallel] polish code * [autoparallel] restore some node handlers * colossalai/auto_parallel/passes/meta_info_prop.py * [autoparallel] remove some unused import * [autoparallel] hook bwd_mem_outpull/2257/head
parent
c8c79102f0
commit
ab38aebace
|
@ -60,7 +60,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
|
||||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||||
|
|
||||||
# store fwd_in, fwd_buffer, fwd_out
|
# store fwd_in, fwd_buffer, fwd_out
|
||||||
fwd_in = [torch.zeros_like(input_op_data.data, device='meta')]
|
fwd_in = [torch.zeros_like(input_op_data.data, device='meta'), torch.zeros_like(other_op_data.data, device='meta')]
|
||||||
fwd_buffer = []
|
fwd_buffer = []
|
||||||
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
|
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
|
@ -71,25 +70,12 @@ class MetaInfo:
|
||||||
if self._strategy is not None and self._target is not None:
|
if self._strategy is not None and self._target is not None:
|
||||||
self.compute_metainfo()
|
self.compute_metainfo()
|
||||||
|
|
||||||
def compute_sharded_tensor(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
|
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute sharded meta tensor based on the given data and sharding spec.
|
Compute sharded opdata based on the given data and sharding spec.
|
||||||
"""
|
"""
|
||||||
shard_sequnce = sharding_spec.sharding_sequence
|
|
||||||
device_mesh = sharding_spec.device_mesh
|
|
||||||
shape = operation_data.data.shape
|
|
||||||
|
|
||||||
new_shape = []
|
|
||||||
for dim, shard in zip(shape, shard_sequnce):
|
|
||||||
if shard.is_replica:
|
|
||||||
# replica
|
|
||||||
new_shape.append(dim)
|
|
||||||
else:
|
|
||||||
# sharded according to device_mesh shape
|
|
||||||
new_shape.append(dim // np.prod(np.array([device_mesh.mesh_shape[i] for i in shard.shard_list])))
|
|
||||||
|
|
||||||
return OperationData(name=operation_data.name,
|
return OperationData(name=operation_data.name,
|
||||||
data=torch.zeros(new_shape, device="meta"),
|
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
|
||||||
type=operation_data.type,
|
type=operation_data.type,
|
||||||
logical_shape=operation_data.logical_shape)
|
logical_shape=operation_data.logical_shape)
|
||||||
|
|
||||||
|
@ -113,7 +99,7 @@ class MetaInfo:
|
||||||
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
|
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
|
||||||
|
|
||||||
# construct args for meta_func
|
# construct args for meta_func
|
||||||
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]
|
args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()]
|
||||||
|
|
||||||
# construct kwargs
|
# construct kwargs
|
||||||
if self.target in INPLACE_MODULE:
|
if self.target in INPLACE_MODULE:
|
||||||
|
|
|
@ -0,0 +1,113 @@
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||||
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||||
|
from colossalai.tensor.comm_spec import CommSpec
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
|
||||||
|
def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||||
|
target_sharding_spec: ShardingSpec) -> MetaInfo:
|
||||||
|
# get comm_action_sequence and total_cost from shape_consistency_manager
|
||||||
|
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
||||||
|
origin_sharding_spec, target_sharding_spec)
|
||||||
|
|
||||||
|
meta_info = MetaInfo()
|
||||||
|
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
||||||
|
# get mem cost for MetaInfo
|
||||||
|
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
||||||
|
# extract user that has _meta_data and extract element length
|
||||||
|
input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
|
||||||
|
element_length = input_node._meta_data.element_size()
|
||||||
|
|
||||||
|
mem_cost.fwd.activation *= element_length
|
||||||
|
mem_cost.fwd.temp *= element_length
|
||||||
|
mem_cost.bwd.activation *= element_length
|
||||||
|
mem_cost.bwd.temp *= element_length
|
||||||
|
mem_cost.total.activation *= element_length
|
||||||
|
|
||||||
|
meta_info.memory_cost = mem_cost
|
||||||
|
|
||||||
|
# get computation cost for MetaInfo
|
||||||
|
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||||
|
total_cost['backward'] * element_length,
|
||||||
|
total_cost['total'] * element_length)
|
||||||
|
|
||||||
|
# get tensor shape for MetaInfo
|
||||||
|
origin_sharding_spec: ShardingSpec
|
||||||
|
target_sharding_spec: ShardingSpec
|
||||||
|
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
|
||||||
|
output_shape = target_sharding_spec.get_sharded_shape_per_device()
|
||||||
|
|
||||||
|
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||||
|
meta_info.fwd_buffer = []
|
||||||
|
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||||
|
|
||||||
|
return meta_info
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo:
|
||||||
|
"""
|
||||||
|
This method is used to construct `MetaInto` for shape consistency node
|
||||||
|
"""
|
||||||
|
|
||||||
|
# extract node index and user node index
|
||||||
|
args = node.args
|
||||||
|
node_index, user_node_index = args[3], args[4]
|
||||||
|
origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[
|
||||||
|
node_index][user_node_index]
|
||||||
|
|
||||||
|
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo:
|
||||||
|
# extract node_index and op_data_name
|
||||||
|
node_index, op_data_name = node.args[2], node.args[3]
|
||||||
|
|
||||||
|
comm_action = comm_actions_dict[node_index][op_data_name]
|
||||||
|
if isinstance(comm_action.comm_spec, CommSpec):
|
||||||
|
# this case is for all_reduce, there will be no memory cost
|
||||||
|
meta_info = MetaInfo()
|
||||||
|
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
|
||||||
|
output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
|
||||||
|
element_length = output_node._meta_data.element_size()
|
||||||
|
|
||||||
|
total_cost = comm_action.comm_spec.get_comm_cost()
|
||||||
|
meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
|
||||||
|
total_cost['backward'] * element_length,
|
||||||
|
total_cost['total'] * element_length)
|
||||||
|
|
||||||
|
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
|
||||||
|
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
||||||
|
meta_info.fwd_buffer = []
|
||||||
|
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
||||||
|
else:
|
||||||
|
# this case will be handled by shape consistency manager
|
||||||
|
origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
|
||||||
|
'tgt_spec']
|
||||||
|
meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||||
|
|
||||||
|
return meta_info
|
||||||
|
|
||||||
|
|
||||||
|
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict,
|
||||||
|
comm_actions_dict: Dict):
|
||||||
|
"""
|
||||||
|
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
|
||||||
|
"""
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
if node.target == runtime_apply:
|
||||||
|
setattr(node, 'best_metainfo',
|
||||||
|
_runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict))
|
||||||
|
elif node.target == runtime_comm_spec_apply:
|
||||||
|
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
||||||
|
else:
|
||||||
|
pass
|
|
@ -1,15 +1,14 @@
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Any, Dict, List, NamedTuple, Tuple
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from torch.fx.node import Argument, Node, Target
|
from torch.fx.node import Node
|
||||||
from torch.utils._pytree import tree_map
|
|
||||||
|
|
||||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||||
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
|
from colossalai.fx._compatibility import compatibility
|
||||||
from colossalai.fx.profiler import GraphInfo
|
from colossalai.fx.profiler import GraphInfo
|
||||||
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||||
|
|
||||||
|
@ -68,7 +67,7 @@ class MetaInfoProp:
|
||||||
"""
|
"""
|
||||||
graph_info = GraphInfo()
|
graph_info = GraphInfo()
|
||||||
out = _normalize_tuple(getattr(node, '_meta_data', None))
|
out = _normalize_tuple(getattr(node, '_meta_data', None))
|
||||||
graph_info.fwd_out = list(out)
|
graph_info.fwd_out = list(out) if out[0] is not None else []
|
||||||
node.meta = {**asdict(graph_info)}
|
node.meta = {**asdict(graph_info)}
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
|
@ -97,7 +96,7 @@ class MetaInfoProp:
|
||||||
"""
|
"""
|
||||||
Handle other kind of nodes
|
Handle other kind of nodes
|
||||||
"""
|
"""
|
||||||
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}"
|
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
|
||||||
graph_info = GraphInfo()
|
graph_info = GraphInfo()
|
||||||
meta_info = node.best_metainfo
|
meta_info = node.best_metainfo
|
||||||
meta_info: MetaInfo
|
meta_info: MetaInfo
|
||||||
|
@ -158,5 +157,13 @@ class MetaInfoProp:
|
||||||
memory_cost = meta_info.memory_cost
|
memory_cost = meta_info.memory_cost
|
||||||
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
|
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
|
||||||
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
|
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
|
||||||
|
graph_info.bwd_mem_out = memory_cost.bwd.activation
|
||||||
|
|
||||||
|
# fetch flop information
|
||||||
|
# here we use fwd_time and bwd_time to deal with the case that
|
||||||
|
# communication cost is a float
|
||||||
|
compute_cost = meta_info.compute_cost
|
||||||
|
graph_info.fwd_time = compute_cost.fwd
|
||||||
|
graph_info.bwd_time = compute_cost.bwd
|
||||||
|
|
||||||
node.meta = {**asdict(graph_info)}
|
node.meta = {**asdict(graph_info)}
|
||||||
|
|
|
@ -47,53 +47,6 @@ def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict:
|
||||||
return rst
|
return rst
|
||||||
|
|
||||||
|
|
||||||
def construct_meta_info(node: Node, user_node: Node) -> MetaInfo:
|
|
||||||
"""
|
|
||||||
This method is used to construct `MetaInto` for shape consistency node
|
|
||||||
TODO: Actually we could attain the cost information from resharding cost in node
|
|
||||||
handler, we should modify this part in the future.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def compute_shape(sharding_spec: ShardingSpec):
|
|
||||||
shape = sharding_spec.entire_shape
|
|
||||||
new_shape = []
|
|
||||||
for dim, shard in sharding_spec.dim_partition_dict.items():
|
|
||||||
new_shape.append(shape[dim] // len(shard))
|
|
||||||
return new_shape
|
|
||||||
|
|
||||||
meta_info = MetaInfo()
|
|
||||||
origin_sharding_spec, target_sharding_spec = node.sharding_spec, user_node.best_strategy.get_sharding_spec_by_name(
|
|
||||||
str(node.name))
|
|
||||||
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
|
|
||||||
origin_sharding_spec, target_sharding_spec)
|
|
||||||
|
|
||||||
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
|
|
||||||
# get mem cost for MetaInfo
|
|
||||||
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
|
|
||||||
element_length = node._meta_data.element_size()
|
|
||||||
mem_cost.fwd.activation *= element_length
|
|
||||||
mem_cost.fwd.temp *= element_length
|
|
||||||
mem_cost.bwd.activation *= element_length
|
|
||||||
mem_cost.bwd.temp *= element_length
|
|
||||||
mem_cost.total.activation *= element_length
|
|
||||||
|
|
||||||
meta_info.memory_cost = mem_cost
|
|
||||||
|
|
||||||
# get computation cost for MetaInfo
|
|
||||||
compute_cost = TrainCycleItem(total_cost['forward'], total_cost['backward'], total_cost['total'])
|
|
||||||
meta_info.compute_cost = compute_cost
|
|
||||||
|
|
||||||
# get tensor shape for MetaInfo
|
|
||||||
input_shape = compute_shape(origin_sharding_spec)
|
|
||||||
output_shape = compute_shape(target_sharding_spec)
|
|
||||||
|
|
||||||
meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
|
|
||||||
meta_info.fwd_buffer = []
|
|
||||||
meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
|
|
||||||
|
|
||||||
return meta_info
|
|
||||||
|
|
||||||
|
|
||||||
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
|
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
|
||||||
"""
|
"""
|
||||||
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
|
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
|
||||||
|
@ -175,8 +128,6 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||||
runtime_apply,
|
runtime_apply,
|
||||||
args=(node, origin_dict_node, input_dict_node,
|
args=(node, origin_dict_node, input_dict_node,
|
||||||
node_to_index_dict[node], user_node_index))
|
node_to_index_dict[node], user_node_index))
|
||||||
meta_info = construct_meta_info(node, user_node)
|
|
||||||
setattr(shape_consistency_node, 'best_metainfo', meta_info)
|
|
||||||
|
|
||||||
new_args = list(user_node.args)
|
new_args = list(user_node.args)
|
||||||
new_kwargs = dict(user_node.kwargs)
|
new_kwargs = dict(user_node.kwargs)
|
||||||
|
|
|
@ -138,8 +138,7 @@ class NodeHandler(ABC):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.node.op == 'call_module':
|
if self.node.op == 'call_module':
|
||||||
submod = self.node.graph.owning_module.get_submodule(self.node.target)
|
target = self.node.graph.owning_module.get_submodule(self.node.target)
|
||||||
target = type(submod)
|
|
||||||
elif self.node.op == 'call_function':
|
elif self.node.op == 'call_function':
|
||||||
target = self.node.target
|
target = self.node.target
|
||||||
elif self.node.op == 'call_method':
|
elif self.node.op == 'call_method':
|
||||||
|
|
Loading…
Reference in New Issue