mirror of https://github.com/hpcaitech/ColossalAI
164 lines
5.7 KiB
Python
164 lines
5.7 KiB
Python
|
from dataclasses import dataclass
|
||
|
from torch.fx.node import Node
|
||
|
from torch.fx.graph import Graph
|
||
|
from torch.fx.graph_module import GraphModule
|
||
|
from collections import OrderedDict as ODict
|
||
|
from typing import List, OrderedDict, Union, Any
|
||
|
from colossalai.fx.passes.utils import get_node_module
|
||
|
|
||
|
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class LiveVariable:
|
||
|
"""
|
||
|
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
|
||
|
"""
|
||
|
name: str
|
||
|
node: Node
|
||
|
is_inplace: bool
|
||
|
|
||
|
|
||
|
class LiveVariableVector(list):
|
||
|
"""
|
||
|
LiveVariableVector is a data structure to store the list of LiveVariable objects.
|
||
|
"""
|
||
|
|
||
|
def exists(self, name) -> bool:
|
||
|
"""
|
||
|
Check if a variable has already existed in the current list by name.
|
||
|
"""
|
||
|
for var in self:
|
||
|
if name == var.name:
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
def get(self, name) -> LiveVariable:
|
||
|
for var in self:
|
||
|
if name == var.name:
|
||
|
return var
|
||
|
raise KeyError(f"Variable {name} is not found")
|
||
|
|
||
|
def copy(self) -> "LiveVariableVector":
|
||
|
"""
|
||
|
Create a copy of this vector
|
||
|
"""
|
||
|
vector = LiveVariableVector()
|
||
|
for var in self:
|
||
|
vector.append(var)
|
||
|
return vector
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class LiveStage:
|
||
|
"""
|
||
|
LiveStage is a data structure to record the living variables at this current node.
|
||
|
"""
|
||
|
name: str
|
||
|
node: Node
|
||
|
all_live_vars: LiveVariableVector
|
||
|
unique_live_vars: LiveVariableVector
|
||
|
|
||
|
|
||
|
class GraphAnalyser:
|
||
|
|
||
|
def __init__(self, gm: GraphModule):
|
||
|
self._gm = gm
|
||
|
self._graph = gm.graph
|
||
|
|
||
|
@property
|
||
|
def gm(self) -> GraphModule:
|
||
|
"""
|
||
|
Return the GraphModule object associated with this analyser.
|
||
|
"""
|
||
|
return self._gm
|
||
|
|
||
|
@property
|
||
|
def graph(self) -> Graph:
|
||
|
"""
|
||
|
Return the Graph object associated with this analyser.
|
||
|
"""
|
||
|
return self._graph
|
||
|
|
||
|
def liveness_analysis(self) -> List[LiveStage]:
|
||
|
"""
|
||
|
Analyse the graph to obtain the variable liveness information. This function returns
|
||
|
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
|
||
|
"""
|
||
|
compute_nodes = self.graph.nodes
|
||
|
liveness_list = []
|
||
|
|
||
|
# checked: record all variables created since the first stage
|
||
|
# all: record the live variables only exist until the current stage.
|
||
|
# this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage.
|
||
|
# unique: record the unique live variables only exist until the current stage.
|
||
|
# this is different from `all list` as some variables are duplicated.
|
||
|
checked_variables = LiveVariableVector()
|
||
|
all_live_variables = LiveVariableVector()
|
||
|
unique_live_vars = LiveVariableVector()
|
||
|
|
||
|
for idx, node in enumerate(compute_nodes):
|
||
|
#############################
|
||
|
# find new living variables #
|
||
|
#############################
|
||
|
# detect whether the current op is an in-place op
|
||
|
# if it is an in-place op, we would deem it as a duplciate var
|
||
|
is_inplace = False
|
||
|
if node.op == 'call_function':
|
||
|
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
|
||
|
if node.kwargs.get('inplace', False):
|
||
|
is_inplace = True
|
||
|
elif node.op == 'call_module':
|
||
|
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
|
||
|
module = get_node_module(node)
|
||
|
if getattr(module, 'inplace', False):
|
||
|
is_inplace = True
|
||
|
|
||
|
# add the output var
|
||
|
meta = getattr(node, '_meta_data', None)
|
||
|
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
|
||
|
if not is_inplace:
|
||
|
unique_live_vars.append(live_var)
|
||
|
checked_variables.append(live_var)
|
||
|
all_live_variables.append(live_var)
|
||
|
|
||
|
# check if any input is not checked yet
|
||
|
for arg in node.args:
|
||
|
if not isinstance(arg, Node):
|
||
|
continue
|
||
|
arg_name = arg.name
|
||
|
if not checked_variables.exists(arg_name):
|
||
|
live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False)
|
||
|
all_live_variables.append(live_var_from_arg)
|
||
|
checked_variables.append(live_var_from_arg)
|
||
|
unique_live_vars.append(live_var_from_arg)
|
||
|
|
||
|
# TODO: add the logic to remove live variables
|
||
|
# this should be completed if we are able to trace the backward compute graph
|
||
|
|
||
|
# add this stage to liveness dict
|
||
|
stage = LiveStage(name=node.name,
|
||
|
node=node,
|
||
|
all_live_vars=all_live_variables.copy(),
|
||
|
unique_live_vars=unique_live_vars.copy())
|
||
|
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
|
||
|
replace = False
|
||
|
for index, prev_stage in enumerate(liveness_list):
|
||
|
all_covered = True
|
||
|
for ele in prev_stage.unique_live_vars:
|
||
|
if ele not in stage.unique_live_vars:
|
||
|
all_covered = False
|
||
|
break
|
||
|
if all_covered:
|
||
|
replace = True
|
||
|
break
|
||
|
if replace:
|
||
|
liveness_list[index] = stage
|
||
|
else:
|
||
|
liveness_list.append(stage)
|
||
|
|
||
|
return liveness_list
|
||
|
|
||
|
def get_alias_set(self):
|
||
|
pass
|