[Gemini] ParamOpHook -> ColoParamOpHook (#2080)

pull/2081/head^2
Jiarui Fang 2 years ago committed by GitHub
parent 4f21c9e8d9
commit b3b89865e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,7 @@ import torch.nn
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemTracerHook, ParamMemTracerHook
from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
__all__ = ['RuntimeMemTracer']
@ -53,12 +53,12 @@ class RuntimeMemTracer():
args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype)
self.module.zero_grad(set_to_none=True)
self._pre_forward()
with ParamOpHookManager.use_hooks(self.param_op_hook):
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
return outputs
def backward(self, loss):
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward()
self._post_backward()

@ -8,7 +8,7 @@ import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.tensor.param_op_hook import ParamOpHook
from colossalai.tensor.param_op_hook import ColoParamOpHook
class TrainingPhase(Enum):
@ -39,7 +39,7 @@ class GradMemTracerHook():
hook.remove()
class ParamMemTracerHook(ParamOpHook):
class ParamMemTracerHook(ColoParamOpHook):
def __init__(self) -> None:
super().__init__()

@ -12,7 +12,7 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
@ -259,7 +259,7 @@ class ZeroDDP(ColoDDP):
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter(*args)
with ParamOpHookManager.use_hooks(self.param_op_hook):
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
if self.force_outputs_fp32:
return _cast_float(outputs, torch.float)
@ -280,12 +280,12 @@ class ZeroDDP(ColoDDP):
self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor):
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward()
self._post_backward()
def backward_by_grad(self, tensor, grad):
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
torch.autograd.backward(tensor, grad)
self._post_backward()

@ -5,13 +5,14 @@ from .comm_spec import CollectiveCommPattern, CommSpec
from .compute_spec import ComputePattern, ComputeSpec
from .dist_spec_mgr import DistSpecManager
from .distspec import ReplicaSpec, ShardSpec
from .param_op_hook import ParamOpHook, ParamOpHookManager
from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager
from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec
from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
__all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', 'merge_same_dim_mesh_list'
'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec',
'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict',
'merge_same_dim_mesh_list'
]

@ -4,7 +4,7 @@ import torch
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.tensor.tensor_spec import ColoTensorSpec
@ -58,18 +58,18 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):
if ParamOpHookManager.has_hook():
if ColoParamOpHookManager.has_hook():
if not func.__name__.startswith('__'):
if kwargs is None:
kwargs = {}
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
if len(params) > 0:
with torch._C.DisableTorchFunction():
new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values())
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ParamOpHookManager.post_op(params, ret)
ret = ColoParamOpHookManager.post_op(params, ret)
return ret
return super().__torch_function__(func, types, args, kwargs)

@ -8,7 +8,7 @@ from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.tensor_spec import ColoTensorSpec
class ParamOpHook(ABC):
class ColoParamOpHook(ABC):
"""Hook which is triggered by each operation when operands contain ColoParameter.
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
@ -32,68 +32,68 @@ class ParamOpHook(ABC):
pass
class ParamOpHookManager:
class ColoParamOpHookManager:
"""Manage your param op hooks. It only has static methods.
The only static method you should call is ``use_hooks(*hooks)``.
"""
hooks: Tuple[ParamOpHook, ...] = tuple()
hooks: Tuple[ColoParamOpHook, ...] = tuple()
@staticmethod
@contextmanager
def use_hooks(*hooks: ParamOpHook):
def use_hooks(*hooks: ColoParamOpHook):
"""Change the param op hooks you use. Nested calling is allowed.
Example:
>>> with ParamOpHookManager.use_hooks(*hooks):
>>> with ColoParamOpHookManager.use_hooks(*hooks):
>>> do_something()
>>> with ParamOpHookManager.use_hooks():
>>> with ColoParamOpHookManager.use_hooks():
>>> // clear hooks
>>> do_something()
"""
try:
old_param_op_hooks = ParamOpHookManager.hooks
ParamOpHookManager.hooks = hooks
old_param_op_hooks = ColoParamOpHookManager.hooks
ColoParamOpHookManager.hooks = hooks
yield
finally:
ParamOpHookManager.hooks = old_param_op_hooks
ColoParamOpHookManager.hooks = old_param_op_hooks
@staticmethod
def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
for hook in ColoParamOpHookManager.hooks:
hook.pre_forward(params)
@staticmethod
def _trigger_post_forward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
for hook in ColoParamOpHookManager.hooks:
hook.post_forward(params)
@staticmethod
def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
for hook in ColoParamOpHookManager.hooks:
hook.pre_backward(params)
@staticmethod
def _trigger_post_backward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks:
for hook in ColoParamOpHookManager.hooks:
hook.post_backward(params)
@staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
ParamOpHookManager._trigger_pre_forward(params)
ColoParamOpHookManager._trigger_pre_forward(params)
args_info = _get_colo_tensors_info(*args)
rets = PreFwdPostBwd.apply(params, *args)
return _update_colo_tensors(args_info, *rets)
@staticmethod
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ParamOpHookManager._trigger_post_forward(params)
ColoParamOpHookManager._trigger_post_forward(params)
arg_info = _get_colo_tensors_info(arg)
ret = PostFwdPreBwd.apply(params, arg)
return _unpack_args(_update_colo_tensors(arg_info, ret))
@staticmethod
def has_hook() -> bool:
return len(ParamOpHookManager.hooks) > 0
return len(ColoParamOpHookManager.hooks) > 0
class PreFwdPostBwd(torch.autograd.Function):
@ -105,7 +105,7 @@ class PreFwdPostBwd(torch.autograd.Function):
@staticmethod
def backward(ctx, *grads):
ParamOpHookManager._trigger_post_backward(ctx.params)
ColoParamOpHookManager._trigger_post_backward(ctx.params)
return (None,) + grads
@ -118,7 +118,7 @@ class PostFwdPreBwd(torch.autograd.Function):
@staticmethod
def backward(ctx, *grads):
ParamOpHookManager._trigger_pre_backward(ctx.params)
ColoParamOpHookManager._trigger_pre_backward(ctx.params)
return (None,) + grads

@ -7,7 +7,7 @@ import torch
from colossalai.gemini import TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor.param_op_hook import ParamOpHook
from colossalai.tensor.param_op_hook import ColoParamOpHook
class TrainingPhase(Enum):
@ -15,7 +15,7 @@ class TrainingPhase(Enum):
BACKWARD = 1
class GeminiZeROHook(ParamOpHook):
class GeminiZeROHook(ColoParamOpHook):
def __init__(self, gemini_manager: GeminiManager) -> None:
super().__init__()

Loading…
Cancel
Save