ColossalAI/colossalai/elixir/search/simulator.py

113 lines
3.4 KiB
Python

import math
from .utils import to_divide
def calc_move_times(param_per_step: list, param_to_chunk: dict, n_blocks: int):
from colossalai.elixir.simulator import move_count
chunk_per_step = list()
for param_set in param_per_step:
id_set = set()
for name in param_set:
# continue if the parameter is ignored
if name not in param_to_chunk:
continue
id_set.add(param_to_chunk[name])
if len(id_set) > 0:
chunk_per_step.append(list(id_set))
return move_count(chunk_per_step, n_blocks)
def find_optimal_chunk_size(
# pre-commit: do not rearrange
param_per_step: list,
param_names: list,
param_numels: list,
cuda_elements: int,
overlap: bool,
min_range: int,
max_range: int,
interval: int):
max_numel = 0
for numel in param_numels:
max_numel = max(max_numel, numel)
test_size = to_divide(max(max_numel, min_range), interval)
# floor rounding
cuda_elements = to_divide(cuda_elements - interval + 1, interval)
max_range = min(max_range, cuda_elements)
min_move_elements = float('+inf')
best_size = test_size
best_number_blocks = 0
best_waste = 0
def dispatch_chunks(param_to_chunk: dict, block_size: int) -> int:
chunk_id = 0
acc = 0
left = 0
for (name, numel) in zip(param_names, param_numels):
if numel > left:
acc += left
chunk_id += 1
left = block_size
left -= numel
param_to_chunk[name] = chunk_id
return (chunk_id, left + acc)
assert test_size <= max_range, 'max_numel or min_range is larger than max_range or cuda capacity'
while test_size <= max_range:
# calculate the number of blocks
number_blocks = int(cuda_elements // test_size)
# if prefetch is enabled, we pretend that two chunks are reserved
if overlap:
number_blocks -= 2
if number_blocks <= 0:
continue
# initialize the chunk id for each parameter
param_to_chunk = dict()
number_chunks, current_waste = dispatch_chunks(param_to_chunk, test_size)
number_blocks = min(number_blocks, number_chunks)
# calculate the minimum number of movements
move_times = calc_move_times(param_per_step, param_to_chunk, number_blocks)
current_move_elements = move_times * test_size
# print("test", test_size, current_move_elements)
if current_move_elements < min_move_elements:
min_move_elements = current_move_elements
best_size = test_size
best_number_blocks = number_blocks
best_waste = current_waste
test_size += interval
if min_move_elements == float('inf'):
raise RuntimeError('optimal search: can not find a valid solution')
return best_size, best_number_blocks, best_waste
def bandwidth_c2g(n: int):
return 16.3 * n + 8.7
def bandwidth_g2c(n: int):
return 15.8 * n + 2.3
def velocity_gpu(n: int):
return 50 * n
def velocity_cpu(n: int):
return 1.66 * math.log(n) + 5.15
def rcache_prioirity_check(n: int, r_os: int, e_p: int, e_o: int):
In = e_p / bandwidth_c2g(n) + e_p / bandwidth_g2c(n)
Jn = (n / r_os) * (e_o / bandwidth_c2g(n) + In + e_p / bandwidth_g2c(n) + 1.0 / velocity_cpu(n) -
1.0 / velocity_gpu(n))
return In > Jn