mirror of https://github.com/hpcaitech/ColossalAI
74 lines
2.2 KiB
Python
74 lines
2.2 KiB
Python
import torch
|
|
|
|
from colossalai.elixir.chunk import ChunkFetcher
|
|
|
|
from .storage import BufferStore
|
|
|
|
|
|
def prefwd_postbwd_function(fetcher: ChunkFetcher, store: BufferStore):
|
|
|
|
class PreFwdPostBwd(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, params, *args):
|
|
with torch._C.DisableTorchFunction():
|
|
ctx.params = params
|
|
chunks = fetcher.trans_to_compute(params)
|
|
fetcher.fetch_chunks(chunks)
|
|
|
|
offset = 0
|
|
for p in ctx.params:
|
|
if not fetcher.is_in_fused(p):
|
|
# we should add parameters to buffer
|
|
# because their blocks may be changed
|
|
offset = store.insert(p, offset)
|
|
|
|
return args
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grads):
|
|
with torch._C.DisableTorchFunction():
|
|
fetcher.trans_to_hold(ctx.params, phase='b')
|
|
|
|
for p in ctx.params:
|
|
if not fetcher.is_in_fused(p):
|
|
store.erase(p)
|
|
|
|
return (None, *grads)
|
|
|
|
return PreFwdPostBwd.apply
|
|
|
|
|
|
def postfwd_prebwd_function(fetcher: ChunkFetcher, store: BufferStore):
|
|
|
|
class PostFwdPreBwd(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, params, *args):
|
|
with torch._C.DisableTorchFunction():
|
|
ctx.params = params
|
|
|
|
fetcher.trans_to_hold(ctx.params, phase='f')
|
|
for p in ctx.params:
|
|
if not fetcher.is_in_fused(p):
|
|
store.erase(p)
|
|
|
|
return args
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grads):
|
|
with torch._C.DisableTorchFunction():
|
|
chunks = fetcher.trans_to_compute(ctx.params)
|
|
fetcher.fetch_chunks(chunks)
|
|
|
|
offset = 0
|
|
for p in ctx.params:
|
|
if not fetcher.is_in_fused(p):
|
|
# we should add parameters to buffer
|
|
# because their blocks may be changed
|
|
offset = store.insert(p, offset)
|
|
|
|
return (None, *grads)
|
|
|
|
return PostFwdPreBwd.apply
|