ColossalAI/colossalai/elixir/hook/functions.py

74 lines
2.2 KiB
Python

import torch
from colossalai.elixir.chunk import ChunkFetcher
from .storage import BufferStore
def prefwd_postbwd_function(fetcher: ChunkFetcher, store: BufferStore):
class PreFwdPostBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, *args):
with torch._C.DisableTorchFunction():
ctx.params = params
chunks = fetcher.trans_to_compute(params)
fetcher.fetch_chunks(chunks)
offset = 0
for p in ctx.params:
if not fetcher.is_in_fused(p):
# we should add parameters to buffer
# because their blocks may be changed
offset = store.insert(p, offset)
return args
@staticmethod
def backward(ctx, *grads):
with torch._C.DisableTorchFunction():
fetcher.trans_to_hold(ctx.params, phase='b')
for p in ctx.params:
if not fetcher.is_in_fused(p):
store.erase(p)
return (None, *grads)
return PreFwdPostBwd.apply
def postfwd_prebwd_function(fetcher: ChunkFetcher, store: BufferStore):
class PostFwdPreBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, *args):
with torch._C.DisableTorchFunction():
ctx.params = params
fetcher.trans_to_hold(ctx.params, phase='f')
for p in ctx.params:
if not fetcher.is_in_fused(p):
store.erase(p)
return args
@staticmethod
def backward(ctx, *grads):
with torch._C.DisableTorchFunction():
chunks = fetcher.trans_to_compute(ctx.params)
fetcher.fetch_chunks(chunks)
offset = 0
for p in ctx.params:
if not fetcher.is_in_fused(p):
# we should add parameters to buffer
# because their blocks may be changed
offset = store.insert(p, offset)
return (None, *grads)
return PostFwdPreBwd.apply