diff --git a/colossalai/autochunk/chunk_selector.py b/colossalai/autochunk/chunk_selector.py index f84322082..aeab66572 100644 --- a/colossalai/autochunk/chunk_selector.py +++ b/colossalai/autochunk/chunk_selector.py @@ -126,14 +126,14 @@ class ChunkSelector(object): ) return chunk_info - def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos): - if l >= 16: + def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos): + if left >= 16: gap = 4 else: gap = 1 chunk_info = chunk_region_dict["reorder_chunk_info"] - while r >= l + gap: - mid = int((l + r) / 2 + 0.5) + while right >= left + gap: + mid = int((left + right) / 2 + 0.5) chunk_info["chunk_size"] = mid cur_chunk_infos = chunk_infos + [chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( @@ -143,10 +143,10 @@ class ChunkSelector(object): cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] ) if cur_chunk_max_mem >= self.max_memory: - r = mid - gap + right = mid - gap else: - l = mid + gap - return l + left = mid + gap + return left def _get_compute_node_num(self, start, end): count = 0 diff --git a/tests/test_autochunk/evoformer/ops.py b/tests/test_autochunk/evoformer/ops.py index 611b7b0fe..a56057522 100755 --- a/tests/test_autochunk/evoformer/ops.py +++ b/tests/test_autochunk/evoformer/ops.py @@ -67,10 +67,10 @@ class OutProductMean(nn.Module): left_act = self.linear_a(M) right_act = self.linear_b(M) - O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() + o = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() # O = rearrange(O, 'b i j d e -> b i j (d e)') - O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1) - Z = self.o_linear(O) + o = o.reshape(o.shape[0], o.shape[1], o.shape[2], -1) + Z = self.o_linear(o) return Z diff --git a/tests/test_autochunk/openfold/tensor_utils.py b/tests/test_autochunk/openfold/tensor_utils.py index 7e5e8e4b6..384a71fb5 100644 --- a/tests/test_autochunk/openfold/tensor_utils.py +++ b/tests/test_autochunk/openfold/tensor_utils.py @@ -157,12 +157,12 @@ def _get_minimal_slice_set( # start_edges and end_edges both indicate whether, starting from any given # dimension, the start/end index is at the top/bottom edge of the # corresponding tensor, modeled as a tree - def reduce_edge_list(l): + def reduce_edge_list(ll): tally = 1 - for i in range(len(l)): + for i in range(len(ll)): reversed_idx = -1 * (i + 1) - l[reversed_idx] *= tally - tally = l[reversed_idx] + ll[reversed_idx] *= tally + tally = ll[reversed_idx] if(start_edges is None): start_edges = [s == 0 for s in start]