rename ambiguous variable

pull/2364/head
oahzxl 2 years ago
parent 2bde9d2b7f
commit fd87d78a28

@ -126,14 +126,14 @@ class ChunkSelector(object):
) )
return chunk_info return chunk_info
def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos): def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
if l >= 16: if left >= 16:
gap = 4 gap = 4
else: else:
gap = 1 gap = 1
chunk_info = chunk_region_dict["reorder_chunk_info"] chunk_info = chunk_region_dict["reorder_chunk_info"]
while r >= l + gap: while right >= left + gap:
mid = int((l + r) / 2 + 0.5) mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info] cur_chunk_infos = chunk_infos + [chunk_info]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
@ -143,10 +143,10 @@ class ChunkSelector(object):
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
) )
if cur_chunk_max_mem >= self.max_memory: if cur_chunk_max_mem >= self.max_memory:
r = mid - gap right = mid - gap
else: else:
l = mid + gap left = mid + gap
return l return left
def _get_compute_node_num(self, start, end): def _get_compute_node_num(self, start, end):
count = 0 count = 0

@ -67,10 +67,10 @@ class OutProductMean(nn.Module):
left_act = self.linear_a(M) left_act = self.linear_a(M)
right_act = self.linear_b(M) right_act = self.linear_b(M)
O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() o = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous()
# O = rearrange(O, 'b i j d e -> b i j (d e)') # O = rearrange(O, 'b i j d e -> b i j (d e)')
O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1) o = o.reshape(o.shape[0], o.shape[1], o.shape[2], -1)
Z = self.o_linear(O) Z = self.o_linear(o)
return Z return Z

@ -157,12 +157,12 @@ def _get_minimal_slice_set(
# start_edges and end_edges both indicate whether, starting from any given # start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the # dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree # corresponding tensor, modeled as a tree
def reduce_edge_list(l): def reduce_edge_list(ll):
tally = 1 tally = 1
for i in range(len(l)): for i in range(len(ll)):
reversed_idx = -1 * (i + 1) reversed_idx = -1 * (i + 1)
l[reversed_idx] *= tally ll[reversed_idx] *= tally
tally = l[reversed_idx] tally = ll[reversed_idx]
if(start_edges is None): if(start_edges is None):
start_edges = [s == 0 for s in start] start_edges = [s == 0 for s in start]

Loading…
Cancel
Save