mirror of https://github.com/hpcaitech/ColossalAI
rename ambiguous variable
parent
2bde9d2b7f
commit
fd87d78a28
|
@ -126,14 +126,14 @@ class ChunkSelector(object):
|
||||||
)
|
)
|
||||||
return chunk_info
|
return chunk_info
|
||||||
|
|
||||||
def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos):
|
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
|
||||||
if l >= 16:
|
if left >= 16:
|
||||||
gap = 4
|
gap = 4
|
||||||
else:
|
else:
|
||||||
gap = 1
|
gap = 1
|
||||||
chunk_info = chunk_region_dict["reorder_chunk_info"]
|
chunk_info = chunk_region_dict["reorder_chunk_info"]
|
||||||
while r >= l + gap:
|
while right >= left + gap:
|
||||||
mid = int((l + r) / 2 + 0.5)
|
mid = int((left + right) / 2 + 0.5)
|
||||||
chunk_info["chunk_size"] = mid
|
chunk_info["chunk_size"] = mid
|
||||||
cur_chunk_infos = chunk_infos + [chunk_info]
|
cur_chunk_infos = chunk_infos + [chunk_info]
|
||||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
|
@ -143,10 +143,10 @@ class ChunkSelector(object):
|
||||||
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
||||||
)
|
)
|
||||||
if cur_chunk_max_mem >= self.max_memory:
|
if cur_chunk_max_mem >= self.max_memory:
|
||||||
r = mid - gap
|
right = mid - gap
|
||||||
else:
|
else:
|
||||||
l = mid + gap
|
left = mid + gap
|
||||||
return l
|
return left
|
||||||
|
|
||||||
def _get_compute_node_num(self, start, end):
|
def _get_compute_node_num(self, start, end):
|
||||||
count = 0
|
count = 0
|
||||||
|
|
|
@ -67,10 +67,10 @@ class OutProductMean(nn.Module):
|
||||||
left_act = self.linear_a(M)
|
left_act = self.linear_a(M)
|
||||||
right_act = self.linear_b(M)
|
right_act = self.linear_b(M)
|
||||||
|
|
||||||
O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous()
|
o = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous()
|
||||||
# O = rearrange(O, 'b i j d e -> b i j (d e)')
|
# O = rearrange(O, 'b i j d e -> b i j (d e)')
|
||||||
O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1)
|
o = o.reshape(o.shape[0], o.shape[1], o.shape[2], -1)
|
||||||
Z = self.o_linear(O)
|
Z = self.o_linear(o)
|
||||||
|
|
||||||
return Z
|
return Z
|
||||||
|
|
||||||
|
|
|
@ -157,12 +157,12 @@ def _get_minimal_slice_set(
|
||||||
# start_edges and end_edges both indicate whether, starting from any given
|
# start_edges and end_edges both indicate whether, starting from any given
|
||||||
# dimension, the start/end index is at the top/bottom edge of the
|
# dimension, the start/end index is at the top/bottom edge of the
|
||||||
# corresponding tensor, modeled as a tree
|
# corresponding tensor, modeled as a tree
|
||||||
def reduce_edge_list(l):
|
def reduce_edge_list(ll):
|
||||||
tally = 1
|
tally = 1
|
||||||
for i in range(len(l)):
|
for i in range(len(ll)):
|
||||||
reversed_idx = -1 * (i + 1)
|
reversed_idx = -1 * (i + 1)
|
||||||
l[reversed_idx] *= tally
|
ll[reversed_idx] *= tally
|
||||||
tally = l[reversed_idx]
|
tally = ll[reversed_idx]
|
||||||
|
|
||||||
if(start_edges is None):
|
if(start_edges is None):
|
||||||
start_edges = [s == 0 for s in start]
|
start_edges = [s == 0 for s in start]
|
||||||
|
|
Loading…
Reference in New Issue