ColossalAI/colossalai/zero/gemini/chunk/utils.py

57 lines
1.5 KiB
Python
Raw Normal View History

from time import time
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from .manager import ChunkManager
from .search_utils import search_chunk_configuration
def safe_div(a, b):
if a == 0:
return 0
return a / b
def init_chunk_manager(
model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
reuse_fp16_chunk: bool = True,
verbose: bool = False,
max_prefetch: int = 0,
**kwargs,
) -> ChunkManager:
if hidden_dim:
search_interval = hidden_dim
else:
search_interval = 1024 # defaults to 1024
kwargs["search_interval"] = search_interval
dist.barrier()
begin = time()
config_dict, total_size, wasted_size = search_chunk_configuration(model, **kwargs)
dist.barrier()
end = time()
span_s = end - begin
mega_unit = 1024**2
total_size /= mega_unit
wasted_size /= mega_unit
if verbose and dist.get_rank() == 0:
print(
"searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
sep="",
flush=True,
)
dist.barrier()
chunk_manager = ChunkManager(config_dict, init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch)
return chunk_manager