diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 606281c..d02ef2d 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -326,10 +326,14 @@ def find_subset_with_target_sum(nums: List[int], target: int, approximate_thresh if len(part_idxs) > 0 and (-target * approximate_threshold <= tmpTarget <= target * approximate_threshold): indexs.append(part_idxs) + elif tmpTarget > sum(nums[start:]) + target * approximate_threshold: + return elif tmpTarget > 0: for i in range(start, len(nums)): num = nums[i] - _inner_helper(start + 1, tmpTarget - num, part_idxs + [i]) + if num - target * approximate_threshold > tmpTarget: + continue + _inner_helper(i + 1, tmpTarget - num, part_idxs + [i]) _inner_helper(start=0, tmpTarget=target, part_idxs=[])