[zero] add chunk size search for chunk manager (#1052)

pull/1063/head
ver217 2022-06-02 13:20:20 +08:00 committed by GitHub
parent 2c42b230f3
commit e1922ea4f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 38 additions and 0 deletions

View File

@ -268,3 +268,41 @@ class ChunkManager:
for i, chunk in enumerate(group):
msg += f'[{i}] {chunk}\n'
return msg
@staticmethod
def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float:
assert len(params_numel) > 0
total_size = 0
total_utilized_size = 0
cur_chunk_utilized_size = 0
for size in params_numel:
assert chunk_size >= size
total_utilized_size += size
if total_size == 0 or cur_chunk_utilized_size + size > chunk_size:
total_size += chunk_size
cur_chunk_utilized_size = 0
cur_chunk_utilized_size += size
return total_utilized_size / total_size
@staticmethod
def search_chunk_size(module: torch.nn.Module,
search_range: int,
n_grids: int,
min_chunk_size: Optional[int] = None) -> int:
assert search_range % n_grids == 0
# TODO(ver217): sort params and filter unused ones
params_numel = [p.numel() for p in module.parameters()]
max_param_numel = max(params_numel)
if min_chunk_size is not None:
assert min_chunk_size >= max_param_numel
else:
min_chunk_size = max_param_numel
step_size = search_range // n_grids
max_chunk_util = -1
best_chunk_size = -1
for chunk_size in range(min_chunk_size, min_chunk_size + search_range + 1, step_size):
chunk_util = ChunkManager.get_chunk_util(chunk_size, params_numel)
if chunk_util > max_chunk_util:
max_chunk_util = chunk_util
best_chunk_size = chunk_size
return best_chunk_size