ColossalAI/tests/test_elixir/test_chunk/fetcher_utils.py

73 lines
2.4 KiB
Python

from functools import partial
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.elixir.chunk import BlockRequire, ChunkFetcher, ChunkGroup, MemoryPool, TensorState
from colossalai.elixir.chunk.scheduler import FIFOScheduler
from colossalai.elixir.hook import BufferStore, HookParam
from colossalai.elixir.tensor import OutplaceTensor
def to_divide(a: int, b: int):
return a + (-a % b)
def grad_handler(grad: torch.Tensor, param: nn.Parameter, fetcher: ChunkFetcher):
empty_grad = torch.empty_like(grad)
empty_grad.storage().resize_(0)
with torch._C.DisableTorchFunction():
chunk = fetcher.get_one_chunk(param)
if chunk.tensors_info[param].state != TensorState.HOLD_AFTER_BWD:
raise RuntimeError()
fetcher.group.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(param, grad)
fetcher.reduce_chunk(chunk)
return empty_grad
def hook_transform(model: nn.Module, process_group: dist.ProcessGroupGloo):
pg_size = dist.get_world_size(process_group)
private_list = list()
for param in model.parameters():
block_size = to_divide(param.numel(), pg_size)
private_list.append(BlockRequire(block_size, param.dtype))
mp = MemoryPool('cuda')
mp.allocate(private_block_list=private_list)
cg = ChunkGroup(rcache=mp)
# allocate chunk group
fused_config = dict(rcache_fused=True)
for param in model.parameters():
cg.allocate_chunk([param], to_divide(param.numel(), pg_size), param.dtype, process_group, fused_config)
# initialize chunk fetcher
scheduler = FIFOScheduler()
fetcher = ChunkFetcher(scheduler, cg)
buffer = BufferStore(0, torch.float32)
# register fetcher and gradient handler
HookParam.attach_fetcher(fetcher, buffer)
for param in model.parameters():
param.register_hook(partial(grad_handler, param=param, fetcher=fetcher))
param.__class__ = HookParam
# set inplace to False for all modules
for module in model.modules():
if hasattr(module, 'inplace'):
module.inplace = False
def transform_input(self_module, inputs):
fetcher.reset()
input_list = list()
for t in inputs:
if isinstance(t, torch.Tensor):
t = OutplaceTensor(t)
input_list.append(t)
return tuple(input_list)
model.register_forward_pre_hook(transform_input)
return model, cg