mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] new metainfoprop based on metainfo class (#2179)
* [autoparallel] new metainfoprop to combine SPMD solver and checkpoint solver * [autoparallel] new metainfoprop to combine SPMD solver and checkpoint solver * [autoparallel] modify placeholder handler * [autoparallel] modify metainfoprop * [autoparallel] fix function typo * [autoparallel] fix placeholder handlerpull/2212/head^2
parent
78509124d3
commit
d0bc5a1b34
|
@ -0,0 +1,162 @@
|
||||||
|
import uuid
|
||||||
|
from dataclasses import asdict
|
||||||
|
from typing import Any, Dict, List, NamedTuple, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.fx
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from torch.fx.node import Argument, Node, Target
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||||
|
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
|
||||||
|
from colossalai.fx.profiler import GraphInfo
|
||||||
|
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_tuple(x):
|
||||||
|
if not isinstance(x, tuple):
|
||||||
|
return (x,)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
|
class MetaInfoProp:
|
||||||
|
|
||||||
|
def __init__(self, module: GraphModule) -> None:
|
||||||
|
self.module = module
|
||||||
|
self.func_dict = {
|
||||||
|
'placeholder': self.placeholder_handler,
|
||||||
|
'get_attr': self.get_attr_handler,
|
||||||
|
'output': self.output_handler,
|
||||||
|
'call_function': self.node_handler,
|
||||||
|
'call_module': self.node_handler,
|
||||||
|
'call_method': self.node_handler,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _set_data_ptr(self, x):
|
||||||
|
"""
|
||||||
|
Set uuid to tensor
|
||||||
|
"""
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
if not x.data_ptr():
|
||||||
|
data_ptr = uuid.uuid4()
|
||||||
|
x.data_ptr = lambda: data_ptr
|
||||||
|
|
||||||
|
def _is_inplace(self, node: Node):
|
||||||
|
"""
|
||||||
|
Check if the node is inplace operation.
|
||||||
|
"""
|
||||||
|
if node.op == 'call_method':
|
||||||
|
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
|
||||||
|
elif node.op == "call_function":
|
||||||
|
return node.target in OUTPUT_SAVED_OPS
|
||||||
|
return False
|
||||||
|
|
||||||
|
def run(self) -> GraphModule:
|
||||||
|
"""
|
||||||
|
Run the meta information propagation pass on the module.
|
||||||
|
"""
|
||||||
|
for node in self.module.graph.nodes:
|
||||||
|
node: Node
|
||||||
|
self.func_dict[node.op](node)
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
|
def placeholder_handler(self, node: Node) -> None:
|
||||||
|
"""
|
||||||
|
Handle the placeholder node.
|
||||||
|
"""
|
||||||
|
graph_info = GraphInfo()
|
||||||
|
out = _normalize_tuple(getattr(node, '_meta_data', None))
|
||||||
|
graph_info.fwd_out = list(out)
|
||||||
|
node.meta = {**asdict(graph_info)}
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
|
def get_attr_handler(self, node: Node) -> None:
|
||||||
|
"""
|
||||||
|
Handle the get_attr node.
|
||||||
|
"""
|
||||||
|
graph_info = GraphInfo()
|
||||||
|
node.meta = {**asdict(graph_info)}
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
|
def output_handler(self, node: Node) -> None:
|
||||||
|
"""
|
||||||
|
Handle the output node.
|
||||||
|
"""
|
||||||
|
graph_info = GraphInfo()
|
||||||
|
output_tensors = []
|
||||||
|
for par in node._input_nodes:
|
||||||
|
if par.meta:
|
||||||
|
output_tensors += par.meta["fwd_out"]
|
||||||
|
graph_info.fwd_in = output_tensors
|
||||||
|
node.meta = {**asdict(graph_info)}
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
|
def node_handler(self, node: Node) -> None:
|
||||||
|
"""
|
||||||
|
Handle other kind of nodes
|
||||||
|
"""
|
||||||
|
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}"
|
||||||
|
graph_info = GraphInfo()
|
||||||
|
meta_info = node.best_metainfo
|
||||||
|
meta_info: MetaInfo
|
||||||
|
|
||||||
|
# set data_ptr for input_tensor in MetaInfo class
|
||||||
|
input_tensor: List[torch.Tensor] = meta_info.fwd_in
|
||||||
|
buffer_tensor: List[torch.Tensor] = meta_info.fwd_buffer
|
||||||
|
output_tensor: List[torch.Tensor] = meta_info.fwd_out
|
||||||
|
|
||||||
|
if len(input_tensor) > 0:
|
||||||
|
for par in node._input_nodes:
|
||||||
|
if par.meta:
|
||||||
|
if len(par.meta["fwd_out"]) > 0:
|
||||||
|
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
|
||||||
|
for tensor in par.meta["fwd_out"]:
|
||||||
|
tensor: torch.Tensor
|
||||||
|
target_tensor = next(
|
||||||
|
(x for x in input_tensor if not x.data_ptr() and x.shape == tensor.shape), None)
|
||||||
|
target_tensor.data_ptr = tensor.data_ptr
|
||||||
|
|
||||||
|
# set data_ptr for tensor in input_tensor that is not set
|
||||||
|
for tensor in input_tensor:
|
||||||
|
if not tensor.data_ptr():
|
||||||
|
self._set_data_ptr(tensor)
|
||||||
|
|
||||||
|
# attach it to graph_info
|
||||||
|
graph_info.fwd_in = input_tensor
|
||||||
|
|
||||||
|
if self._is_inplace(node):
|
||||||
|
# inplace operation will not create new tensor
|
||||||
|
# set data_ptr for buffer_tensor and output_tensor of current node
|
||||||
|
for tensor in input_tensor:
|
||||||
|
tensor: torch.Tensor
|
||||||
|
target_buffer_tensor = next((x for x in buffer_tensor if not x.data_ptr() and x.shape == tensor.shape),
|
||||||
|
None)
|
||||||
|
target_output_tensor = next((x for x in output_tensor if not x.data_ptr() and x.shape == tensor.shape),
|
||||||
|
None)
|
||||||
|
target_buffer_tensor.data_ptr = tensor.data_ptr
|
||||||
|
target_output_tensor.data_ptr = tensor.data_ptr
|
||||||
|
# attach them to graph_info
|
||||||
|
graph_info.fwd_tmp = buffer_tensor
|
||||||
|
graph_info.fwd_out = output_tensor
|
||||||
|
|
||||||
|
else:
|
||||||
|
# set data_ptr for buffer_tensor
|
||||||
|
for tensor in buffer_tensor:
|
||||||
|
self._set_data_ptr(tensor)
|
||||||
|
# attach it to graph_info
|
||||||
|
graph_info.fwd_tmp = buffer_tensor
|
||||||
|
|
||||||
|
# set data_ptr for output_tensor
|
||||||
|
for tensor in output_tensor:
|
||||||
|
self._set_data_ptr(tensor)
|
||||||
|
# attach it to graph_info
|
||||||
|
graph_info.fwd_out = output_tensor
|
||||||
|
|
||||||
|
# fetch other memory informations
|
||||||
|
memory_cost = meta_info.memory_cost
|
||||||
|
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
|
||||||
|
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
|
||||||
|
|
||||||
|
node.meta = {**asdict(graph_info)}
|
|
@ -79,6 +79,10 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
|
||||||
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
||||||
str(node))
|
str(node))
|
||||||
|
|
||||||
|
# attach the corresponding metainfo if node has the attribute `metainfo_vector`
|
||||||
|
if hasattr(node, 'metainfo_vector'):
|
||||||
|
setattr(node, 'best_metainfo', node.metainfo_vector[strategy_index])
|
||||||
|
|
||||||
# the dict to get input sharding specs of user node
|
# the dict to get input sharding specs of user node
|
||||||
sharding_spec_convert_dict = {}
|
sharding_spec_convert_dict = {}
|
||||||
# the dict to record comm actions of nodes
|
# the dict to record comm actions of nodes
|
||||||
|
|
|
@ -235,10 +235,15 @@ class MetaInfoNodeHandler(NodeHandler):
|
||||||
"""
|
"""
|
||||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||||
target = self.get_target_function()
|
target = self.get_target_function()
|
||||||
|
metainfo_vector = []
|
||||||
for strategy in self.strategies_vector:
|
for strategy in self.strategies_vector:
|
||||||
metainfo = MetaInfo(strategy, target)
|
metainfo = MetaInfo(strategy, target)
|
||||||
strategy.compute_cost = metainfo.compute_cost
|
strategy.compute_cost = metainfo.compute_cost
|
||||||
strategy.memory_cost = metainfo.memory_cost
|
strategy.memory_cost = metainfo.memory_cost
|
||||||
|
metainfo_vector.append(metainfo)
|
||||||
|
|
||||||
|
# attach metainfos to the handler
|
||||||
|
setattr(self, "metainfo_vector", metainfo_vector)
|
||||||
|
|
||||||
return self.strategies_vector
|
return self.strategies_vector
|
||||||
|
|
||||||
|
@ -277,9 +282,14 @@ class MetaInfoModuleHandler(ModuleHandler):
|
||||||
"""
|
"""
|
||||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||||
target = self.get_target_function()
|
target = self.get_target_function()
|
||||||
|
metainfo_vector = []
|
||||||
for strategy in self.strategies_vector:
|
for strategy in self.strategies_vector:
|
||||||
metainfo = MetaInfo(strategy, target)
|
metainfo = MetaInfo(strategy, target)
|
||||||
strategy.compute_cost = metainfo.compute_cost
|
strategy.compute_cost = metainfo.compute_cost
|
||||||
strategy.memory_cost = metainfo.memory_cost
|
strategy.memory_cost = metainfo.memory_cost
|
||||||
|
metainfo_vector.append(metainfo)
|
||||||
|
|
||||||
|
# attach metainfos to the handler
|
||||||
|
setattr(self, "metainfo_vector", metainfo_vector)
|
||||||
|
|
||||||
return self.strategies_vector
|
return self.strategies_vector
|
||||||
|
|
|
@ -111,18 +111,27 @@ class StrategiesConstructor:
|
||||||
submod_type = type(submod)
|
submod_type = type(submod)
|
||||||
handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector)
|
handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector)
|
||||||
handler.register_strategy()
|
handler.register_strategy()
|
||||||
|
# attach metainfo_vector to node
|
||||||
|
if hasattr(handler, 'metainfo_vector'):
|
||||||
|
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||||
|
|
||||||
# call_function node
|
# call_function node
|
||||||
elif node.op == 'call_function':
|
elif node.op == 'call_function':
|
||||||
target = node.target
|
target = node.target
|
||||||
handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector)
|
handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector)
|
||||||
handler.register_strategy()
|
handler.register_strategy()
|
||||||
|
# attach metainfo_vector to node
|
||||||
|
if hasattr(handler, 'metainfo_vector'):
|
||||||
|
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||||
|
|
||||||
# call_method node
|
# call_method node
|
||||||
elif node.op == 'call_method':
|
elif node.op == 'call_method':
|
||||||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||||
handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector)
|
handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector)
|
||||||
handler.register_strategy()
|
handler.register_strategy()
|
||||||
|
# attach metainfo_vector to node
|
||||||
|
if hasattr(handler, 'metainfo_vector'):
|
||||||
|
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||||
|
|
||||||
# output node
|
# output node
|
||||||
elif node.op == 'output':
|
elif node.op == 'output':
|
||||||
|
|
Loading…
Reference in New Issue