add doc str

pull/2364/head
oahzxl 2022-11-15 10:18:00 +08:00
parent 70a98b8f56
commit f379d1a94d
1 changed files with 95 additions and 0 deletions

View File

@ -120,6 +120,13 @@ class NodeIndexTracer(object):
return self.idx_trace_list[node_idx]['compute']
def assign_index_as_input(self, node, node_idx):
"""
Assign node's trace as its input node.
Args:
node (node)
node_idx (int)
"""
input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list)
input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx']
@ -127,6 +134,13 @@ class NodeIndexTracer(object):
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
def assign_all_index(self, node, node_idx):
"""
Add new index for all node's dims.
Args:
node (node)
node_idx (int)
"""
shape = node.meta['tensor_meta'].shape
new_trace = []
for _ in shape:
@ -134,6 +148,15 @@ class NodeIndexTracer(object):
self.idx_trace_list[node_idx]['idx'] = new_trace
def assign_transpose_index(self, node, node_idx):
"""
Assign index for transpose op.
1. swap input's dim according to transpose args
2. inherit input's computation
Args:
node (node)
node_idx (int)
"""
tranpose_dim = node.args[1:]
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
@ -145,6 +168,15 @@ class NodeIndexTracer(object):
self.inherit_computation(node.args[0], node)
def assign_permute_index(self, node, node_idx):
"""
Assign index for permute op.
1. swap input's dim according to permute args
2. inherit input's computation
Args:
node (node)
node_idx (int)
"""
permute_dim = node.args[1:]
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
@ -156,6 +188,16 @@ class NodeIndexTracer(object):
self.inherit_computation(node.args[0], node)
def assign_linear_index(self, node, node_idx):
"""
Assign index for linear op.
1. copy trace from input node and change last index accroding to weight
2. mark equal for input node last index, weight first dim and bias dim.
3. inherit input's computation, mark computation for last dim.
Args:
node (node)
node_idx (int)
"""
input_node, weight, bias = node.args
input_node_idx_trace = self.find_idx_trace_from_node(input_node)
weight_idx_trace = self.find_idx_trace_from_node(weight)
@ -173,6 +215,16 @@ class NodeIndexTracer(object):
self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
def assign_matmul_index(self, node, node_idx):
"""
Assign index for matmul op.
1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
2. mark equal for input matmul_left -1 index and matmul_right -2 dim.
3. inherit matmul_left and matmul_right computation, mark computation for last dim.
Args:
node (node)
node_idx (int)
"""
matmul_left, matmul_right = node.args
matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left)
matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right)
@ -188,21 +240,63 @@ class NodeIndexTracer(object):
self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2])
def assign_layernorm_index(self, node, idx):
"""
Assign index for layernorm op.
1. assign index as input node
2. inherit computation and mark last 2 dims as computed.
Args:
node (node)
node_idx (int)
"""
self.assign_index_as_input(node, idx)
self.inherit_computation(node.args[0], node)
self.mark_computation(node, idx, [-1, -2])
def assign_elementwise_index(self, node, idx):
"""
Assign index for element-wise op (eg. relu sigmoid add mul).
1. assign index as input node
2. inherit computation from all input nodes.
Args:
node (node)
node_idx (int)
"""
self.assign_index_as_input(node, idx)
for node_in in node.args:
if type(node_in) not in (int, float):
self.inherit_computation(node_in, node)
def assign_softmax_index(self, node, idx):
"""
Assign index for softmax op.
1. assign index as input node
2. inherit computation and mark softmax dim as computed.
Args:
node (node)
node_idx (int)
"""
self.assign_index_as_input(node, idx)
self.inherit_computation(node.args[0], node)
self.mark_computation(node, idx, [node.kwargs['dim']])
def assign_view_reshape_index(self, node, node_idx):
"""
Assign index for view and reshape op.
1. get origin shape and target shape by meta info.
2. compute the real value of -1 in target shape.
3. determine changed dim, and assgin index for generated dim.
4. log changed dim and generated dim for restore
5. look into view list to see whether the view is associated with other,
if so assgin equal dim according to previous view.
6. inherit computation.
Args:
node (node)
node_idx (int)
"""
# get data, turn into number
origin_node = node.args[0]
origin_shape = origin_node.meta['tensor_meta'].shape
@ -305,6 +399,7 @@ class NodeIndexTracer(object):
else:
raise NotImplementedError(node.op, "op not implemented yet!")
def _get_meta_node_size(x):
x = x.meta['tensor_meta']
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()