mirror of https://github.com/hpcaitech/ColossalAI
support new op
parent
f24c418bb0
commit
a9d64377bb
|
@ -200,8 +200,12 @@ class NodeIndexTracer(object):
|
|||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
input_node, weight, bias = node.args
|
||||
"""
|
||||
if len(node.args) == 2:
|
||||
input_node, weight = node.args
|
||||
bias = None
|
||||
else:
|
||||
input_node, weight, bias = node.args
|
||||
input_node_idx_trace = self._find_idx_trace_from_node(input_node)
|
||||
weight_idx_trace = self._find_idx_trace_from_node(weight)
|
||||
|
||||
|
@ -284,6 +288,53 @@ class NodeIndexTracer(object):
|
|||
self._assign_index_as_input(node, idx)
|
||||
self._inherit_computation(node.args[0], node)
|
||||
self._mark_computation(node, idx, [node.kwargs['dim']])
|
||||
|
||||
def _assign_unsqueeze_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for unsqueeze op.
|
||||
1. assign new index for unsqueeze dim
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_index_as_input(node, node_idx)
|
||||
self._inherit_computation(node.args[0], node)
|
||||
self.idx_trace_list[node_idx]['idx'].insert(node.args[1], self._add_index())
|
||||
|
||||
def _assign_dropout_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for unsqueeze op.
|
||||
1. assign new index for unsqueeze dim
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_index_as_input(node, node_idx)
|
||||
|
||||
|
||||
def _assign_ones_like_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for oneslike op.
|
||||
1. assign new index for all dim
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_all_index(node, node_idx)
|
||||
|
||||
def _assign_to_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for to op.
|
||||
1. assign new index for all dim
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_index_as_input(node, node_idx)
|
||||
|
||||
def _assign_view_reshape_index(self, node, node_idx):
|
||||
"""
|
||||
|
@ -388,6 +439,10 @@ class NodeIndexTracer(object):
|
|||
self._assign_permute_index(node, idx)
|
||||
elif 'view' in node.name or 'reshape' in node.name:
|
||||
self._assign_view_reshape_index(node, idx)
|
||||
elif 'unsqueeze' in node.name:
|
||||
self._assign_unsqueeze_index(node, idx)
|
||||
elif 'to' in node.name:
|
||||
self._assign_to_index(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "method not implemented yet!")
|
||||
elif node.op == 'call_function':
|
||||
|
@ -399,6 +454,10 @@ class NodeIndexTracer(object):
|
|||
self._assign_softmax_index(node, idx)
|
||||
elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']):
|
||||
self._assign_elementwise_index(node, idx)
|
||||
elif 'ones_like' in node.name:
|
||||
self._assign_ones_like_index(node, idx)
|
||||
elif 'dropout' in node.name:
|
||||
self._assign_dropout_index(node, idx)
|
||||
elif 'getattr' in node.name:
|
||||
continue # get attr like shape
|
||||
elif 'getitem' in node.name:
|
||||
|
|
Loading…
Reference in New Issue