mirror of https://github.com/hpcaitech/ColossalAI
[zero] add chunk size search for chunk manager (#1052)
parent
2c42b230f3
commit
e1922ea4f6
|
@ -268,3 +268,41 @@ class ChunkManager:
|
||||||
for i, chunk in enumerate(group):
|
for i, chunk in enumerate(group):
|
||||||
msg += f'[{i}] {chunk}\n'
|
msg += f'[{i}] {chunk}\n'
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float:
|
||||||
|
assert len(params_numel) > 0
|
||||||
|
total_size = 0
|
||||||
|
total_utilized_size = 0
|
||||||
|
cur_chunk_utilized_size = 0
|
||||||
|
for size in params_numel:
|
||||||
|
assert chunk_size >= size
|
||||||
|
total_utilized_size += size
|
||||||
|
if total_size == 0 or cur_chunk_utilized_size + size > chunk_size:
|
||||||
|
total_size += chunk_size
|
||||||
|
cur_chunk_utilized_size = 0
|
||||||
|
cur_chunk_utilized_size += size
|
||||||
|
return total_utilized_size / total_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def search_chunk_size(module: torch.nn.Module,
|
||||||
|
search_range: int,
|
||||||
|
n_grids: int,
|
||||||
|
min_chunk_size: Optional[int] = None) -> int:
|
||||||
|
assert search_range % n_grids == 0
|
||||||
|
# TODO(ver217): sort params and filter unused ones
|
||||||
|
params_numel = [p.numel() for p in module.parameters()]
|
||||||
|
max_param_numel = max(params_numel)
|
||||||
|
if min_chunk_size is not None:
|
||||||
|
assert min_chunk_size >= max_param_numel
|
||||||
|
else:
|
||||||
|
min_chunk_size = max_param_numel
|
||||||
|
step_size = search_range // n_grids
|
||||||
|
max_chunk_util = -1
|
||||||
|
best_chunk_size = -1
|
||||||
|
for chunk_size in range(min_chunk_size, min_chunk_size + search_range + 1, step_size):
|
||||||
|
chunk_util = ChunkManager.get_chunk_util(chunk_size, params_numel)
|
||||||
|
if chunk_util > max_chunk_util:
|
||||||
|
max_chunk_util = chunk_util
|
||||||
|
best_chunk_size = chunk_size
|
||||||
|
return best_chunk_size
|
||||||
|
|
Loading…
Reference in New Issue