|
|
|
@ -200,8 +200,12 @@ class NodeIndexTracer(object):
|
|
|
|
|
Args:
|
|
|
|
|
node (node)
|
|
|
|
|
node_idx (int)
|
|
|
|
|
"""
|
|
|
|
|
input_node, weight, bias = node.args
|
|
|
|
|
"""
|
|
|
|
|
if len(node.args) == 2:
|
|
|
|
|
input_node, weight = node.args
|
|
|
|
|
bias = None
|
|
|
|
|
else:
|
|
|
|
|
input_node, weight, bias = node.args
|
|
|
|
|
input_node_idx_trace = self._find_idx_trace_from_node(input_node)
|
|
|
|
|
weight_idx_trace = self._find_idx_trace_from_node(weight)
|
|
|
|
|
|
|
|
|
@ -284,6 +288,53 @@ class NodeIndexTracer(object):
|
|
|
|
|
self._assign_index_as_input(node, idx)
|
|
|
|
|
self._inherit_computation(node.args[0], node)
|
|
|
|
|
self._mark_computation(node, idx, [node.kwargs['dim']])
|
|
|
|
|
|
|
|
|
|
def _assign_unsqueeze_index(self, node, node_idx):
|
|
|
|
|
"""
|
|
|
|
|
Assign index for unsqueeze op.
|
|
|
|
|
1. assign new index for unsqueeze dim
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
node (node)
|
|
|
|
|
node_idx (int)
|
|
|
|
|
"""
|
|
|
|
|
self._assign_index_as_input(node, node_idx)
|
|
|
|
|
self._inherit_computation(node.args[0], node)
|
|
|
|
|
self.idx_trace_list[node_idx]['idx'].insert(node.args[1], self._add_index())
|
|
|
|
|
|
|
|
|
|
def _assign_dropout_index(self, node, node_idx):
|
|
|
|
|
"""
|
|
|
|
|
Assign index for unsqueeze op.
|
|
|
|
|
1. assign new index for unsqueeze dim
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
node (node)
|
|
|
|
|
node_idx (int)
|
|
|
|
|
"""
|
|
|
|
|
self._assign_index_as_input(node, node_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _assign_ones_like_index(self, node, node_idx):
|
|
|
|
|
"""
|
|
|
|
|
Assign index for oneslike op.
|
|
|
|
|
1. assign new index for all dim
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
node (node)
|
|
|
|
|
node_idx (int)
|
|
|
|
|
"""
|
|
|
|
|
self._assign_all_index(node, node_idx)
|
|
|
|
|
|
|
|
|
|
def _assign_to_index(self, node, node_idx):
|
|
|
|
|
"""
|
|
|
|
|
Assign index for to op.
|
|
|
|
|
1. assign new index for all dim
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
node (node)
|
|
|
|
|
node_idx (int)
|
|
|
|
|
"""
|
|
|
|
|
self._assign_index_as_input(node, node_idx)
|
|
|
|
|
|
|
|
|
|
def _assign_view_reshape_index(self, node, node_idx):
|
|
|
|
|
"""
|
|
|
|
@ -388,6 +439,10 @@ class NodeIndexTracer(object):
|
|
|
|
|
self._assign_permute_index(node, idx)
|
|
|
|
|
elif 'view' in node.name or 'reshape' in node.name:
|
|
|
|
|
self._assign_view_reshape_index(node, idx)
|
|
|
|
|
elif 'unsqueeze' in node.name:
|
|
|
|
|
self._assign_unsqueeze_index(node, idx)
|
|
|
|
|
elif 'to' in node.name:
|
|
|
|
|
self._assign_to_index(node, idx)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError(node.name, "method not implemented yet!")
|
|
|
|
|
elif node.op == 'call_function':
|
|
|
|
@ -399,6 +454,10 @@ class NodeIndexTracer(object):
|
|
|
|
|
self._assign_softmax_index(node, idx)
|
|
|
|
|
elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']):
|
|
|
|
|
self._assign_elementwise_index(node, idx)
|
|
|
|
|
elif 'ones_like' in node.name:
|
|
|
|
|
self._assign_ones_like_index(node, idx)
|
|
|
|
|
elif 'dropout' in node.name:
|
|
|
|
|
self._assign_dropout_index(node, idx)
|
|
|
|
|
elif 'getattr' in node.name:
|
|
|
|
|
continue # get attr like shape
|
|
|
|
|
elif 'getitem' in node.name:
|
|
|
|
|