support new op

pull/2364/head
oahzxl 2022-12-06 17:34:24 +08:00
parent f24c418bb0
commit a9d64377bb
1 changed files with 61 additions and 2 deletions

View File

@ -200,8 +200,12 @@ class NodeIndexTracer(object):
Args:
node (node)
node_idx (int)
"""
input_node, weight, bias = node.args
"""
if len(node.args) == 2:
input_node, weight = node.args
bias = None
else:
input_node, weight, bias = node.args
input_node_idx_trace = self._find_idx_trace_from_node(input_node)
weight_idx_trace = self._find_idx_trace_from_node(weight)
@ -284,6 +288,53 @@ class NodeIndexTracer(object):
self._assign_index_as_input(node, idx)
self._inherit_computation(node.args[0], node)
self._mark_computation(node, idx, [node.kwargs['dim']])
def _assign_unsqueeze_index(self, node, node_idx):
"""
Assign index for unsqueeze op.
1. assign new index for unsqueeze dim
Args:
node (node)
node_idx (int)
"""
self._assign_index_as_input(node, node_idx)
self._inherit_computation(node.args[0], node)
self.idx_trace_list[node_idx]['idx'].insert(node.args[1], self._add_index())
def _assign_dropout_index(self, node, node_idx):
"""
Assign index for unsqueeze op.
1. assign new index for unsqueeze dim
Args:
node (node)
node_idx (int)
"""
self._assign_index_as_input(node, node_idx)
def _assign_ones_like_index(self, node, node_idx):
"""
Assign index for oneslike op.
1. assign new index for all dim
Args:
node (node)
node_idx (int)
"""
self._assign_all_index(node, node_idx)
def _assign_to_index(self, node, node_idx):
"""
Assign index for to op.
1. assign new index for all dim
Args:
node (node)
node_idx (int)
"""
self._assign_index_as_input(node, node_idx)
def _assign_view_reshape_index(self, node, node_idx):
"""
@ -388,6 +439,10 @@ class NodeIndexTracer(object):
self._assign_permute_index(node, idx)
elif 'view' in node.name or 'reshape' in node.name:
self._assign_view_reshape_index(node, idx)
elif 'unsqueeze' in node.name:
self._assign_unsqueeze_index(node, idx)
elif 'to' in node.name:
self._assign_to_index(node, idx)
else:
raise NotImplementedError(node.name, "method not implemented yet!")
elif node.op == 'call_function':
@ -399,6 +454,10 @@ class NodeIndexTracer(object):
self._assign_softmax_index(node, idx)
elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']):
self._assign_elementwise_index(node, idx)
elif 'ones_like' in node.name:
self._assign_ones_like_index(node, idx)
elif 'dropout' in node.name:
self._assign_dropout_index(node, idx)
elif 'getattr' in node.name:
continue # get attr like shape
elif 'getitem' in node.name: