[Gemini] ParamOpHook -> ColoParamOpHook (#2080)

pull/2081/head^2
Jiarui Fang 2022-12-05 17:11:06 +08:00 committed by GitHub
parent 4f21c9e8d9
commit b3b89865e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 37 additions and 36 deletions

View File

@ -3,7 +3,7 @@ import torch.nn
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemTracerHook, ParamMemTracerHook from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemTracerHook, ParamMemTracerHook
from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
__all__ = ['RuntimeMemTracer'] __all__ = ['RuntimeMemTracer']
@ -53,12 +53,12 @@ class RuntimeMemTracer():
args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype) args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype)
self.module.zero_grad(set_to_none=True) self.module.zero_grad(set_to_none=True)
self._pre_forward() self._pre_forward()
with ParamOpHookManager.use_hooks(self.param_op_hook): with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
return outputs return outputs
def backward(self, loss): def backward(self, loss):
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook): with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward() loss.backward()
self._post_backward() self._post_backward()

View File

@ -8,7 +8,7 @@ import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
from colossalai.gemini.tensor_utils import alloc_storage, free_storage from colossalai.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.tensor.param_op_hook import ParamOpHook from colossalai.tensor.param_op_hook import ColoParamOpHook
class TrainingPhase(Enum): class TrainingPhase(Enum):
@ -39,7 +39,7 @@ class GradMemTracerHook():
hook.remove() hook.remove()
class ParamMemTracerHook(ParamOpHook): class ParamMemTracerHook(ColoParamOpHook):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()

View File

@ -12,7 +12,7 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.utils.gemini_hook import GeminiZeROHook from colossalai.zero.utils.gemini_hook import GeminiZeROHook
@ -259,7 +259,7 @@ class ZeroDDP(ColoDDP):
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True) self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter(*args) self.gemini_manager.pre_iter(*args)
with ParamOpHookManager.use_hooks(self.param_op_hook): with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
if self.force_outputs_fp32: if self.force_outputs_fp32:
return _cast_float(outputs, torch.float) return _cast_float(outputs, torch.float)
@ -280,12 +280,12 @@ class ZeroDDP(ColoDDP):
self.gemini_manager.post_iter() self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor): def backward(self, loss: torch.Tensor):
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook): with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward() loss.backward()
self._post_backward() self._post_backward()
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad):
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook): with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
torch.autograd.backward(tensor, grad) torch.autograd.backward(tensor, grad)
self._post_backward() self._post_backward()

View File

@ -5,13 +5,14 @@ from .comm_spec import CollectiveCommPattern, CommSpec
from .compute_spec import ComputePattern, ComputeSpec from .compute_spec import ComputePattern, ComputeSpec
from .dist_spec_mgr import DistSpecManager from .dist_spec_mgr import DistSpecManager
from .distspec import ReplicaSpec, ShardSpec from .distspec import ReplicaSpec, ShardSpec
from .param_op_hook import ParamOpHook, ParamOpHookManager from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager
from .process_group import ProcessGroup from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec from .tensor_spec import ColoTensorSpec
from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
__all__ = [ __all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec', 'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec',
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', 'merge_same_dim_mesh_list' 'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict',
'merge_same_dim_mesh_list'
] ]

View File

@ -4,7 +4,7 @@ import torch
from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType from colossalai.tensor.const import TensorType
from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.tensor.tensor_spec import ColoTensorSpec from colossalai.tensor.tensor_spec import ColoTensorSpec
@ -58,18 +58,18 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
@classmethod @classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None): def __torch_function__(cls, func, types, args=..., kwargs=None):
if ParamOpHookManager.has_hook(): if ColoParamOpHookManager.has_hook():
if not func.__name__.startswith('__'): if not func.__name__.startswith('__'):
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values()) params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
if len(params) > 0: if len(params) > 0:
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values()) new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args) args, kwargs = replace_args(args, kwargs, new_args)
ret = super().__torch_function__(func, types, args, kwargs) ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
ret = ParamOpHookManager.post_op(params, ret) ret = ColoParamOpHookManager.post_op(params, ret)
return ret return ret
return super().__torch_function__(func, types, args, kwargs) return super().__torch_function__(func, types, args, kwargs)

View File

@ -8,7 +8,7 @@ from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.tensor_spec import ColoTensorSpec from colossalai.tensor.tensor_spec import ColoTensorSpec
class ParamOpHook(ABC): class ColoParamOpHook(ABC):
"""Hook which is triggered by each operation when operands contain ColoParameter. """Hook which is triggered by each operation when operands contain ColoParameter.
To customize it, you must inherit this abstract class, and implement ``pre_forward``, To customize it, you must inherit this abstract class, and implement ``pre_forward``,
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list ``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
@ -32,68 +32,68 @@ class ParamOpHook(ABC):
pass pass
class ParamOpHookManager: class ColoParamOpHookManager:
"""Manage your param op hooks. It only has static methods. """Manage your param op hooks. It only has static methods.
The only static method you should call is ``use_hooks(*hooks)``. The only static method you should call is ``use_hooks(*hooks)``.
""" """
hooks: Tuple[ParamOpHook, ...] = tuple() hooks: Tuple[ColoParamOpHook, ...] = tuple()
@staticmethod @staticmethod
@contextmanager @contextmanager
def use_hooks(*hooks: ParamOpHook): def use_hooks(*hooks: ColoParamOpHook):
"""Change the param op hooks you use. Nested calling is allowed. """Change the param op hooks you use. Nested calling is allowed.
Example: Example:
>>> with ParamOpHookManager.use_hooks(*hooks): >>> with ColoParamOpHookManager.use_hooks(*hooks):
>>> do_something() >>> do_something()
>>> with ParamOpHookManager.use_hooks(): >>> with ColoParamOpHookManager.use_hooks():
>>> // clear hooks >>> // clear hooks
>>> do_something() >>> do_something()
""" """
try: try:
old_param_op_hooks = ParamOpHookManager.hooks old_param_op_hooks = ColoParamOpHookManager.hooks
ParamOpHookManager.hooks = hooks ColoParamOpHookManager.hooks = hooks
yield yield
finally: finally:
ParamOpHookManager.hooks = old_param_op_hooks ColoParamOpHookManager.hooks = old_param_op_hooks
@staticmethod @staticmethod
def _trigger_pre_forward(params: List[torch.Tensor]) -> None: def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks: for hook in ColoParamOpHookManager.hooks:
hook.pre_forward(params) hook.pre_forward(params)
@staticmethod @staticmethod
def _trigger_post_forward(params: List[torch.Tensor]) -> None: def _trigger_post_forward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks: for hook in ColoParamOpHookManager.hooks:
hook.post_forward(params) hook.post_forward(params)
@staticmethod @staticmethod
def _trigger_pre_backward(params: List[torch.Tensor]) -> None: def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks: for hook in ColoParamOpHookManager.hooks:
hook.pre_backward(params) hook.pre_backward(params)
@staticmethod @staticmethod
def _trigger_post_backward(params: List[torch.Tensor]) -> None: def _trigger_post_backward(params: List[torch.Tensor]) -> None:
for hook in ParamOpHookManager.hooks: for hook in ColoParamOpHookManager.hooks:
hook.post_backward(params) hook.post_backward(params)
@staticmethod @staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> list: def pre_op(params: List[torch.Tensor], *args: Any) -> list:
ParamOpHookManager._trigger_pre_forward(params) ColoParamOpHookManager._trigger_pre_forward(params)
args_info = _get_colo_tensors_info(*args) args_info = _get_colo_tensors_info(*args)
rets = PreFwdPostBwd.apply(params, *args) rets = PreFwdPostBwd.apply(params, *args)
return _update_colo_tensors(args_info, *rets) return _update_colo_tensors(args_info, *rets)
@staticmethod @staticmethod
def post_op(params: List[torch.Tensor], arg: Any) -> Any: def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ParamOpHookManager._trigger_post_forward(params) ColoParamOpHookManager._trigger_post_forward(params)
arg_info = _get_colo_tensors_info(arg) arg_info = _get_colo_tensors_info(arg)
ret = PostFwdPreBwd.apply(params, arg) ret = PostFwdPreBwd.apply(params, arg)
return _unpack_args(_update_colo_tensors(arg_info, ret)) return _unpack_args(_update_colo_tensors(arg_info, ret))
@staticmethod @staticmethod
def has_hook() -> bool: def has_hook() -> bool:
return len(ParamOpHookManager.hooks) > 0 return len(ColoParamOpHookManager.hooks) > 0
class PreFwdPostBwd(torch.autograd.Function): class PreFwdPostBwd(torch.autograd.Function):
@ -105,7 +105,7 @@ class PreFwdPostBwd(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *grads): def backward(ctx, *grads):
ParamOpHookManager._trigger_post_backward(ctx.params) ColoParamOpHookManager._trigger_post_backward(ctx.params)
return (None,) + grads return (None,) + grads
@ -118,7 +118,7 @@ class PostFwdPreBwd(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *grads): def backward(ctx, *grads):
ParamOpHookManager._trigger_pre_backward(ctx.params) ColoParamOpHookManager._trigger_pre_backward(ctx.params)
return (None,) + grads return (None,) + grads

View File

@ -7,7 +7,7 @@ import torch
from colossalai.gemini import TensorState from colossalai.gemini import TensorState
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor.param_op_hook import ParamOpHook from colossalai.tensor.param_op_hook import ColoParamOpHook
class TrainingPhase(Enum): class TrainingPhase(Enum):
@ -15,7 +15,7 @@ class TrainingPhase(Enum):
BACKWARD = 1 BACKWARD = 1
class GeminiZeROHook(ParamOpHook): class GeminiZeROHook(ColoParamOpHook):
def __init__(self, gemini_manager: GeminiManager) -> None: def __init__(self, gemini_manager: GeminiManager) -> None:
super().__init__() super().__init__()