mirror of https://github.com/hpcaitech/ColossalAI
73 lines
2.4 KiB
Python
73 lines
2.4 KiB
Python
from functools import partial
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
|
|
from colossalai.elixir.chunk import BlockRequire, ChunkFetcher, ChunkGroup, MemoryPool, TensorState
|
|
from colossalai.elixir.chunk.scheduler import FIFOScheduler
|
|
from colossalai.elixir.hook import BufferStore, HookParam
|
|
from colossalai.elixir.tensor import OutplaceTensor
|
|
|
|
|
|
def to_divide(a: int, b: int):
|
|
return a + (-a % b)
|
|
|
|
|
|
def grad_handler(grad: torch.Tensor, param: nn.Parameter, fetcher: ChunkFetcher):
|
|
empty_grad = torch.empty_like(grad)
|
|
empty_grad.storage().resize_(0)
|
|
|
|
with torch._C.DisableTorchFunction():
|
|
chunk = fetcher.get_one_chunk(param)
|
|
if chunk.tensors_info[param].state != TensorState.HOLD_AFTER_BWD:
|
|
raise RuntimeError()
|
|
fetcher.group.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
|
|
chunk.copy_tensor_to_chunk_slice(param, grad)
|
|
fetcher.reduce_chunk(chunk)
|
|
|
|
return empty_grad
|
|
|
|
|
|
def hook_transform(model: nn.Module, process_group: dist.ProcessGroupGloo):
|
|
pg_size = dist.get_world_size(process_group)
|
|
|
|
private_list = list()
|
|
for param in model.parameters():
|
|
block_size = to_divide(param.numel(), pg_size)
|
|
private_list.append(BlockRequire(block_size, param.dtype))
|
|
|
|
mp = MemoryPool('cuda')
|
|
mp.allocate(private_block_list=private_list)
|
|
cg = ChunkGroup(rcache=mp)
|
|
# allocate chunk group
|
|
fused_config = dict(rcache_fused=True)
|
|
for param in model.parameters():
|
|
cg.allocate_chunk([param], to_divide(param.numel(), pg_size), param.dtype, process_group, fused_config)
|
|
# initialize chunk fetcher
|
|
scheduler = FIFOScheduler()
|
|
fetcher = ChunkFetcher(scheduler, cg)
|
|
buffer = BufferStore(0, torch.float32)
|
|
# register fetcher and gradient handler
|
|
HookParam.attach_fetcher(fetcher, buffer)
|
|
for param in model.parameters():
|
|
param.register_hook(partial(grad_handler, param=param, fetcher=fetcher))
|
|
param.__class__ = HookParam
|
|
# set inplace to False for all modules
|
|
for module in model.modules():
|
|
if hasattr(module, 'inplace'):
|
|
module.inplace = False
|
|
|
|
def transform_input(self_module, inputs):
|
|
fetcher.reset()
|
|
input_list = list()
|
|
for t in inputs:
|
|
if isinstance(t, torch.Tensor):
|
|
t = OutplaceTensor(t)
|
|
input_list.append(t)
|
|
return tuple(input_list)
|
|
|
|
model.register_forward_pre_hook(transform_input)
|
|
|
|
return model, cg
|