mirror of https://github.com/hpcaitech/ColossalAI
add doc str
parent
70a98b8f56
commit
f379d1a94d
|
@ -120,6 +120,13 @@ class NodeIndexTracer(object):
|
|||
return self.idx_trace_list[node_idx]['compute']
|
||||
|
||||
def assign_index_as_input(self, node, node_idx):
|
||||
"""
|
||||
Assign node's trace as its input node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list)
|
||||
input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx']
|
||||
|
||||
|
@ -127,6 +134,13 @@ class NodeIndexTracer(object):
|
|||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||
|
||||
def assign_all_index(self, node, node_idx):
|
||||
"""
|
||||
Add new index for all node's dims.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
shape = node.meta['tensor_meta'].shape
|
||||
new_trace = []
|
||||
for _ in shape:
|
||||
|
@ -134,6 +148,15 @@ class NodeIndexTracer(object):
|
|||
self.idx_trace_list[node_idx]['idx'] = new_trace
|
||||
|
||||
def assign_transpose_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for transpose op.
|
||||
1. swap input's dim according to transpose args
|
||||
2. inherit input's computation
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
tranpose_dim = node.args[1:]
|
||||
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
|
||||
|
||||
|
@ -145,6 +168,15 @@ class NodeIndexTracer(object):
|
|||
self.inherit_computation(node.args[0], node)
|
||||
|
||||
def assign_permute_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for permute op.
|
||||
1. swap input's dim according to permute args
|
||||
2. inherit input's computation
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
permute_dim = node.args[1:]
|
||||
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
|
||||
|
||||
|
@ -156,6 +188,16 @@ class NodeIndexTracer(object):
|
|||
self.inherit_computation(node.args[0], node)
|
||||
|
||||
def assign_linear_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for linear op.
|
||||
1. copy trace from input node and change last index accroding to weight
|
||||
2. mark equal for input node last index, weight first dim and bias dim.
|
||||
3. inherit input's computation, mark computation for last dim.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
input_node, weight, bias = node.args
|
||||
input_node_idx_trace = self.find_idx_trace_from_node(input_node)
|
||||
weight_idx_trace = self.find_idx_trace_from_node(weight)
|
||||
|
@ -173,6 +215,16 @@ class NodeIndexTracer(object):
|
|||
self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
|
||||
|
||||
def assign_matmul_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for matmul op.
|
||||
1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
|
||||
2. mark equal for input matmul_left -1 index and matmul_right -2 dim.
|
||||
3. inherit matmul_left and matmul_right computation, mark computation for last dim.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
matmul_left, matmul_right = node.args
|
||||
matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left)
|
||||
matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right)
|
||||
|
@ -188,21 +240,63 @@ class NodeIndexTracer(object):
|
|||
self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2])
|
||||
|
||||
def assign_layernorm_index(self, node, idx):
|
||||
"""
|
||||
Assign index for layernorm op.
|
||||
1. assign index as input node
|
||||
2. inherit computation and mark last 2 dims as computed.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self.assign_index_as_input(node, idx)
|
||||
self.inherit_computation(node.args[0], node)
|
||||
self.mark_computation(node, idx, [-1, -2])
|
||||
|
||||
def assign_elementwise_index(self, node, idx):
|
||||
"""
|
||||
Assign index for element-wise op (eg. relu sigmoid add mul).
|
||||
1. assign index as input node
|
||||
2. inherit computation from all input nodes.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self.assign_index_as_input(node, idx)
|
||||
for node_in in node.args:
|
||||
if type(node_in) not in (int, float):
|
||||
self.inherit_computation(node_in, node)
|
||||
|
||||
def assign_softmax_index(self, node, idx):
|
||||
"""
|
||||
Assign index for softmax op.
|
||||
1. assign index as input node
|
||||
2. inherit computation and mark softmax dim as computed.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self.assign_index_as_input(node, idx)
|
||||
self.inherit_computation(node.args[0], node)
|
||||
self.mark_computation(node, idx, [node.kwargs['dim']])
|
||||
|
||||
def assign_view_reshape_index(self, node, node_idx):
|
||||
"""
|
||||
Assign index for view and reshape op.
|
||||
1. get origin shape and target shape by meta info.
|
||||
2. compute the real value of -1 in target shape.
|
||||
3. determine changed dim, and assgin index for generated dim.
|
||||
4. log changed dim and generated dim for restore
|
||||
5. look into view list to see whether the view is associated with other,
|
||||
if so assgin equal dim according to previous view.
|
||||
6. inherit computation.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
# get data, turn into number
|
||||
origin_node = node.args[0]
|
||||
origin_shape = origin_node.meta['tensor_meta'].shape
|
||||
|
@ -305,6 +399,7 @@ class NodeIndexTracer(object):
|
|||
else:
|
||||
raise NotImplementedError(node.op, "op not implemented yet!")
|
||||
|
||||
|
||||
def _get_meta_node_size(x):
|
||||
x = x.meta['tensor_meta']
|
||||
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
|
||||
|
|
Loading…
Reference in New Issue