mirror of https://github.com/hpcaitech/ColossalAI
[zero] add chunk size search for chunk manager (#1052)
parent
2c42b230f3
commit
e1922ea4f6
|
@ -268,3 +268,41 @@ class ChunkManager:
|
|||
for i, chunk in enumerate(group):
|
||||
msg += f'[{i}] {chunk}\n'
|
||||
return msg
|
||||
|
||||
@staticmethod
|
||||
def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float:
|
||||
assert len(params_numel) > 0
|
||||
total_size = 0
|
||||
total_utilized_size = 0
|
||||
cur_chunk_utilized_size = 0
|
||||
for size in params_numel:
|
||||
assert chunk_size >= size
|
||||
total_utilized_size += size
|
||||
if total_size == 0 or cur_chunk_utilized_size + size > chunk_size:
|
||||
total_size += chunk_size
|
||||
cur_chunk_utilized_size = 0
|
||||
cur_chunk_utilized_size += size
|
||||
return total_utilized_size / total_size
|
||||
|
||||
@staticmethod
|
||||
def search_chunk_size(module: torch.nn.Module,
|
||||
search_range: int,
|
||||
n_grids: int,
|
||||
min_chunk_size: Optional[int] = None) -> int:
|
||||
assert search_range % n_grids == 0
|
||||
# TODO(ver217): sort params and filter unused ones
|
||||
params_numel = [p.numel() for p in module.parameters()]
|
||||
max_param_numel = max(params_numel)
|
||||
if min_chunk_size is not None:
|
||||
assert min_chunk_size >= max_param_numel
|
||||
else:
|
||||
min_chunk_size = max_param_numel
|
||||
step_size = search_range // n_grids
|
||||
max_chunk_util = -1
|
||||
best_chunk_size = -1
|
||||
for chunk_size in range(min_chunk_size, min_chunk_size + search_range + 1, step_size):
|
||||
chunk_util = ChunkManager.get_chunk_util(chunk_size, params_numel)
|
||||
if chunk_util > max_chunk_util:
|
||||
max_chunk_util = chunk_util
|
||||
best_chunk_size = chunk_size
|
||||
return best_chunk_size
|
||||
|
|
Loading…
Reference in New Issue