mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish code colossalai/gemini/update/search_utils.py (#1557)
parent
413f9c19f4
commit
46931e3c32
|
@ -48,12 +48,11 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
|
||||||
def search_chunk_configuration(
|
def search_chunk_configuration(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
search_range_mb: int,
|
search_range_mb: int,
|
||||||
search_interval_byte: int, # hidden size is the best value for the interval
|
search_interval_byte: int, # hidden size is the best value for the interval
|
||||||
min_chunk_size_mb: int = 32,
|
min_chunk_size_mb: int = 32,
|
||||||
filter_exlarge_params: bool = True
|
filter_exlarge_params: bool = True):
|
||||||
):
|
search_range_byte = search_range_mb * 1024**2
|
||||||
search_range_byte = search_range_mb * 1024 ** 2
|
min_chunk_size_byte = min_chunk_size_mb * 1024**2
|
||||||
min_chunk_size_byte = min_chunk_size_mb * 1024 ** 2
|
|
||||||
assert search_range_byte % search_interval_byte == 0
|
assert search_range_byte % search_interval_byte == 0
|
||||||
|
|
||||||
params_dict = clasify_params(model)
|
params_dict = clasify_params(model)
|
||||||
|
|
Loading…
Reference in New Issue