mirror of https://github.com/hpcaitech/ColossalAI
90 lines
3.1 KiB
Python
90 lines
3.1 KiB
Python
|
from typing import OrderedDict
|
||
|
from torch.fx import GraphModule
|
||
|
from collections import OrderedDict
|
||
|
import pdb
|
||
|
|
||
|
|
||
|
def linearize(gm: GraphModule) -> dict:
|
||
|
status_dict = {}
|
||
|
node_dict = OrderedDict()
|
||
|
node_idx = 0
|
||
|
for node in gm.graph.nodes:
|
||
|
last_dict_len = len(status_dict)
|
||
|
# remove node from users list in status_dict
|
||
|
for item in status_dict.values():
|
||
|
if node in item:
|
||
|
item.remove(node)
|
||
|
|
||
|
# pop node from status_dict if it is fully used
|
||
|
for key in list(status_dict):
|
||
|
if len(status_dict[key]) == 0:
|
||
|
status_dict.pop(key)
|
||
|
|
||
|
# first node in graph, it should be in n0-n1 type,
|
||
|
# where n0 contains only input op, i.e. placeholder
|
||
|
if last_dict_len == 0:
|
||
|
node_dict[node_idx] = [node]
|
||
|
status_dict[node.name] = list(node.users)
|
||
|
node_idx += 1
|
||
|
node_dict[node_idx] = []
|
||
|
|
||
|
continue
|
||
|
|
||
|
# boundary case
|
||
|
if len(status_dict) == 0:
|
||
|
# current node region end point = next node region start point
|
||
|
# i.e. n1-n2-n3-... type node, each node contains only one op
|
||
|
if last_dict_len == 1:
|
||
|
if len(node_dict[node_idx]) > 0:
|
||
|
node_idx += 1
|
||
|
node_dict[node_idx] = []
|
||
|
node_dict[node_idx].append(node)
|
||
|
status_dict[node.name] = list(node.users)
|
||
|
|
||
|
continue
|
||
|
|
||
|
# n1-n2-n3, if n1 has multiple ops, the last op in n1 will be
|
||
|
# the one who is able to clean all others in status_dict
|
||
|
# and as the last_dict_len > 1, there are multiple ops are used
|
||
|
# by this node, we view it as the end of one node and start a new node
|
||
|
else:
|
||
|
|
||
|
node_dict[node_idx].append(node)
|
||
|
status_dict[node.name] = list(node.users)
|
||
|
node_idx += 1
|
||
|
node_dict[node_idx] = []
|
||
|
|
||
|
continue
|
||
|
|
||
|
else:
|
||
|
# currently I will use bigger node structure
|
||
|
# if the following region is activated, the node will be smaller
|
||
|
#################################################
|
||
|
# if last_dict_len == 1:
|
||
|
# if len(node_dict[node_idx]) > 0:
|
||
|
# node_idx += 1
|
||
|
# node_dict[node_idx] = [node]
|
||
|
# status_dict[node.name] = list(node.users)
|
||
|
#
|
||
|
# continue
|
||
|
#################################################
|
||
|
|
||
|
# in-node case, as the current node can not clean status_dict
|
||
|
# we view it as in-node status, the node will be appended to the
|
||
|
# current node_idx
|
||
|
node_dict[node_idx].append(node)
|
||
|
status_dict[node.name] = list(node.users)
|
||
|
|
||
|
continue
|
||
|
|
||
|
# If the output node use multiple nodes, there might be an
|
||
|
# empty node after the output node
|
||
|
if len(node_dict[node_idx]) == 0:
|
||
|
node_dict.pop[node_idx]
|
||
|
node_idx -= 1
|
||
|
|
||
|
# pop the last two nodes
|
||
|
node_dict.pop(0)
|
||
|
node_dict.pop(node_idx)
|
||
|
return node_dict
|