ColossalAI/colossalai/gemini/chunk/utils.py

51 lines
1.5 KiB
Python
Raw Normal View History

from time import time
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.chunk.search_utils import search_chunk_configuration
from colossalai.utils import is_ddp_ignored
def safe_div(a, b):
if a == 0:
return 0
return a / b
def init_chunk_manager(model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
**kwargs) -> ChunkManager:
if hidden_dim:
search_interval_byte = hidden_dim
else:
search_interval_byte = 1024 # defaults to 1kb
kwargs["search_interval_byte"] = search_interval_byte
dist.barrier()
begin = time()
config_dict, total_size, wasted_size = search_chunk_configuration(model, **kwargs)
dist.barrier()
end = time()
span_s = end - begin
mb_size = 1024**2
total_size /= mb_size
wasted_size /= mb_size
if dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
sep='',
flush=True)
dist.barrier()
chunk_manager = ChunkManager(config_dict, init_device)
return chunk_manager