mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
34 lines
800 B
34 lines
800 B
5 months ago
|
from typing import List
|
||
|
|
||
|
from torch._tensor import Tensor
|
||
|
|
||
|
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||
|
|
||
|
_ALL_GATHER_HANDLE = "_all_gather_handle"
|
||
|
|
||
|
|
||
|
def wait_all_gather_handle(p):
|
||
|
if hasattr(p, _ALL_GATHER_HANDLE):
|
||
|
handle = getattr(p, _ALL_GATHER_HANDLE)
|
||
|
handle.wait()
|
||
|
delattr(p, _ALL_GATHER_HANDLE)
|
||
|
|
||
|
|
||
|
def set_all_gather_handle(p, handle):
|
||
|
setattr(p, _ALL_GATHER_HANDLE, handle)
|
||
|
|
||
|
|
||
|
class ZeroOpHook(ColoParamOpHook):
|
||
|
def pre_forward(self, params: List[Tensor]) -> None:
|
||
|
for p in params:
|
||
|
wait_all_gather_handle(p)
|
||
|
|
||
|
def post_forward(self, params: List[Tensor]) -> None:
|
||
|
pass
|
||
|
|
||
|
def pre_backward(self, params: List[Tensor]) -> None:
|
||
|
pass
|
||
|
|
||
|
def post_backward(self, params: List[Tensor]) -> None:
|
||
|
pass
|