mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] ParamOpHook -> ColoParamOpHook (#2080)
parent
4f21c9e8d9
commit
b3b89865e2
|
@ -3,7 +3,7 @@ import torch.nn
|
|||
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemTracerHook, ParamMemTracerHook
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
|
||||
__all__ = ['RuntimeMemTracer']
|
||||
|
||||
|
@ -53,12 +53,12 @@ class RuntimeMemTracer():
|
|||
args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
self._pre_forward()
|
||||
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
return outputs
|
||||
|
||||
def backward(self, loss):
|
||||
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
loss.backward()
|
||||
self._post_backward()
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
|
||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
||||
|
@ -39,7 +39,7 @@ class GradMemTracerHook():
|
|||
hook.remove()
|
||||
|
||||
|
||||
class ParamMemTracerHook(ParamOpHook):
|
||||
class ParamMemTracerHook(ColoParamOpHook):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
|
||||
|
||||
|
@ -259,7 +259,7 @@ class ZeroDDP(ColoDDP):
|
|||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
self.gemini_manager.pre_iter(*args)
|
||||
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
if self.force_outputs_fp32:
|
||||
return _cast_float(outputs, torch.float)
|
||||
|
@ -280,12 +280,12 @@ class ZeroDDP(ColoDDP):
|
|||
self.gemini_manager.post_iter()
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
loss.backward()
|
||||
self._post_backward()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
torch.autograd.backward(tensor, grad)
|
||||
self._post_backward()
|
||||
|
||||
|
|
|
@ -5,13 +5,14 @@ from .comm_spec import CollectiveCommPattern, CommSpec
|
|||
from .compute_spec import ComputePattern, ComputeSpec
|
||||
from .dist_spec_mgr import DistSpecManager
|
||||
from .distspec import ReplicaSpec, ShardSpec
|
||||
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
||||
from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager
|
||||
from .process_group import ProcessGroup
|
||||
from .tensor_spec import ColoTensorSpec
|
||||
from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
|
||||
|
||||
__all__ = [
|
||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
||||
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
|
||||
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', 'merge_same_dim_mesh_list'
|
||||
'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec',
|
||||
'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict',
|
||||
'merge_same_dim_mesh_list'
|
||||
]
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor.const import TensorType
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
|
||||
|
@ -58,18 +58,18 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
|||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
||||
if ParamOpHookManager.has_hook():
|
||||
if ColoParamOpHookManager.has_hook():
|
||||
if not func.__name__.startswith('__'):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
|
||||
if len(params) > 0:
|
||||
with torch._C.DisableTorchFunction():
|
||||
new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||
args, kwargs = replace_args(args, kwargs, new_args)
|
||||
ret = super().__torch_function__(func, types, args, kwargs)
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = ParamOpHookManager.post_op(params, ret)
|
||||
ret = ColoParamOpHookManager.post_op(params, ret)
|
||||
return ret
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.tensor.colo_tensor import ColoTensor
|
|||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||
|
||||
|
||||
class ParamOpHook(ABC):
|
||||
class ColoParamOpHook(ABC):
|
||||
"""Hook which is triggered by each operation when operands contain ColoParameter.
|
||||
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
|
||||
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
|
||||
|
@ -32,68 +32,68 @@ class ParamOpHook(ABC):
|
|||
pass
|
||||
|
||||
|
||||
class ParamOpHookManager:
|
||||
class ColoParamOpHookManager:
|
||||
"""Manage your param op hooks. It only has static methods.
|
||||
The only static method you should call is ``use_hooks(*hooks)``.
|
||||
"""
|
||||
hooks: Tuple[ParamOpHook, ...] = tuple()
|
||||
hooks: Tuple[ColoParamOpHook, ...] = tuple()
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def use_hooks(*hooks: ParamOpHook):
|
||||
def use_hooks(*hooks: ColoParamOpHook):
|
||||
"""Change the param op hooks you use. Nested calling is allowed.
|
||||
|
||||
Example:
|
||||
>>> with ParamOpHookManager.use_hooks(*hooks):
|
||||
>>> with ColoParamOpHookManager.use_hooks(*hooks):
|
||||
>>> do_something()
|
||||
>>> with ParamOpHookManager.use_hooks():
|
||||
>>> with ColoParamOpHookManager.use_hooks():
|
||||
>>> // clear hooks
|
||||
>>> do_something()
|
||||
"""
|
||||
try:
|
||||
old_param_op_hooks = ParamOpHookManager.hooks
|
||||
ParamOpHookManager.hooks = hooks
|
||||
old_param_op_hooks = ColoParamOpHookManager.hooks
|
||||
ColoParamOpHookManager.hooks = hooks
|
||||
yield
|
||||
finally:
|
||||
ParamOpHookManager.hooks = old_param_op_hooks
|
||||
ColoParamOpHookManager.hooks = old_param_op_hooks
|
||||
|
||||
@staticmethod
|
||||
def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
|
||||
for hook in ParamOpHookManager.hooks:
|
||||
for hook in ColoParamOpHookManager.hooks:
|
||||
hook.pre_forward(params)
|
||||
|
||||
@staticmethod
|
||||
def _trigger_post_forward(params: List[torch.Tensor]) -> None:
|
||||
for hook in ParamOpHookManager.hooks:
|
||||
for hook in ColoParamOpHookManager.hooks:
|
||||
hook.post_forward(params)
|
||||
|
||||
@staticmethod
|
||||
def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
|
||||
for hook in ParamOpHookManager.hooks:
|
||||
for hook in ColoParamOpHookManager.hooks:
|
||||
hook.pre_backward(params)
|
||||
|
||||
@staticmethod
|
||||
def _trigger_post_backward(params: List[torch.Tensor]) -> None:
|
||||
for hook in ParamOpHookManager.hooks:
|
||||
for hook in ColoParamOpHookManager.hooks:
|
||||
hook.post_backward(params)
|
||||
|
||||
@staticmethod
|
||||
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
|
||||
ParamOpHookManager._trigger_pre_forward(params)
|
||||
ColoParamOpHookManager._trigger_pre_forward(params)
|
||||
args_info = _get_colo_tensors_info(*args)
|
||||
rets = PreFwdPostBwd.apply(params, *args)
|
||||
return _update_colo_tensors(args_info, *rets)
|
||||
|
||||
@staticmethod
|
||||
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
||||
ParamOpHookManager._trigger_post_forward(params)
|
||||
ColoParamOpHookManager._trigger_post_forward(params)
|
||||
arg_info = _get_colo_tensors_info(arg)
|
||||
ret = PostFwdPreBwd.apply(params, arg)
|
||||
return _unpack_args(_update_colo_tensors(arg_info, ret))
|
||||
|
||||
@staticmethod
|
||||
def has_hook() -> bool:
|
||||
return len(ParamOpHookManager.hooks) > 0
|
||||
return len(ColoParamOpHookManager.hooks) > 0
|
||||
|
||||
|
||||
class PreFwdPostBwd(torch.autograd.Function):
|
||||
|
@ -105,7 +105,7 @@ class PreFwdPostBwd(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
ParamOpHookManager._trigger_post_backward(ctx.params)
|
||||
ColoParamOpHookManager._trigger_post_backward(ctx.params)
|
||||
return (None,) + grads
|
||||
|
||||
|
||||
|
@ -118,7 +118,7 @@ class PostFwdPreBwd(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
ParamOpHookManager._trigger_pre_backward(ctx.params)
|
||||
ColoParamOpHookManager._trigger_pre_backward(ctx.params)
|
||||
return (None,) + grads
|
||||
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
|
||||
from colossalai.gemini import TensorState
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
||||
|
@ -15,7 +15,7 @@ class TrainingPhase(Enum):
|
|||
BACKWARD = 1
|
||||
|
||||
|
||||
class GeminiZeROHook(ParamOpHook):
|
||||
class GeminiZeROHook(ColoParamOpHook):
|
||||
|
||||
def __init__(self, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue