mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] ParamOpHook -> ColoParamOpHook (#2080)
parent
4f21c9e8d9
commit
b3b89865e2
|
@ -3,7 +3,7 @@ import torch.nn
|
||||||
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||||
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemTracerHook, ParamMemTracerHook
|
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemTracerHook, ParamMemTracerHook
|
||||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||||
|
|
||||||
__all__ = ['RuntimeMemTracer']
|
__all__ = ['RuntimeMemTracer']
|
||||||
|
|
||||||
|
@ -53,12 +53,12 @@ class RuntimeMemTracer():
|
||||||
args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype)
|
args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype)
|
||||||
self.module.zero_grad(set_to_none=True)
|
self.module.zero_grad(set_to_none=True)
|
||||||
self._pre_forward()
|
self._pre_forward()
|
||||||
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
outputs = self.module(*args, **kwargs)
|
outputs = self.module(*args, **kwargs)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def backward(self, loss):
|
def backward(self, loss):
|
||||||
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self._post_backward()
|
self._post_backward()
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
||||||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||||
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||||
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
|
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||||
|
|
||||||
|
|
||||||
class TrainingPhase(Enum):
|
class TrainingPhase(Enum):
|
||||||
|
@ -39,7 +39,7 @@ class GradMemTracerHook():
|
||||||
hook.remove()
|
hook.remove()
|
||||||
|
|
||||||
|
|
||||||
class ParamMemTracerHook(ParamOpHook):
|
class ParamMemTracerHook(ColoParamOpHook):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
|
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
|
||||||
|
|
||||||
|
@ -259,7 +259,7 @@ class ZeroDDP(ColoDDP):
|
||||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||||
self.module.zero_grad(set_to_none=True)
|
self.module.zero_grad(set_to_none=True)
|
||||||
self.gemini_manager.pre_iter(*args)
|
self.gemini_manager.pre_iter(*args)
|
||||||
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
outputs = self.module(*args, **kwargs)
|
outputs = self.module(*args, **kwargs)
|
||||||
if self.force_outputs_fp32:
|
if self.force_outputs_fp32:
|
||||||
return _cast_float(outputs, torch.float)
|
return _cast_float(outputs, torch.float)
|
||||||
|
@ -280,12 +280,12 @@ class ZeroDDP(ColoDDP):
|
||||||
self.gemini_manager.post_iter()
|
self.gemini_manager.post_iter()
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor):
|
def backward(self, loss: torch.Tensor):
|
||||||
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self._post_backward()
|
self._post_backward()
|
||||||
|
|
||||||
def backward_by_grad(self, tensor, grad):
|
def backward_by_grad(self, tensor, grad):
|
||||||
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
torch.autograd.backward(tensor, grad)
|
torch.autograd.backward(tensor, grad)
|
||||||
self._post_backward()
|
self._post_backward()
|
||||||
|
|
||||||
|
|
|
@ -5,13 +5,14 @@ from .comm_spec import CollectiveCommPattern, CommSpec
|
||||||
from .compute_spec import ComputePattern, ComputeSpec
|
from .compute_spec import ComputePattern, ComputeSpec
|
||||||
from .dist_spec_mgr import DistSpecManager
|
from .dist_spec_mgr import DistSpecManager
|
||||||
from .distspec import ReplicaSpec, ShardSpec
|
from .distspec import ReplicaSpec, ShardSpec
|
||||||
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager
|
||||||
from .process_group import ProcessGroup
|
from .process_group import ProcessGroup
|
||||||
from .tensor_spec import ColoTensorSpec
|
from .tensor_spec import ColoTensorSpec
|
||||||
from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
|
from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
||||||
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
|
'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec',
|
||||||
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', 'merge_same_dim_mesh_list'
|
'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict',
|
||||||
|
'merge_same_dim_mesh_list'
|
||||||
]
|
]
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
|
|
||||||
from colossalai.tensor.colo_tensor import ColoTensor
|
from colossalai.tensor.colo_tensor import ColoTensor
|
||||||
from colossalai.tensor.const import TensorType
|
from colossalai.tensor.const import TensorType
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,18 +58,18 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
def __torch_function__(cls, func, types, args=..., kwargs=None):
|
||||||
if ParamOpHookManager.has_hook():
|
if ColoParamOpHookManager.has_hook():
|
||||||
if not func.__name__.startswith('__'):
|
if not func.__name__.startswith('__'):
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
|
params = filter_args(lambda arg: isinstance(arg, ColoParameter), *args, *kwargs.values())
|
||||||
if len(params) > 0:
|
if len(params) > 0:
|
||||||
with torch._C.DisableTorchFunction():
|
with torch._C.DisableTorchFunction():
|
||||||
new_args = ParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||||
args, kwargs = replace_args(args, kwargs, new_args)
|
args, kwargs = replace_args(args, kwargs, new_args)
|
||||||
ret = super().__torch_function__(func, types, args, kwargs)
|
ret = super().__torch_function__(func, types, args, kwargs)
|
||||||
with torch._C.DisableTorchFunction():
|
with torch._C.DisableTorchFunction():
|
||||||
ret = ParamOpHookManager.post_op(params, ret)
|
ret = ColoParamOpHookManager.post_op(params, ret)
|
||||||
return ret
|
return ret
|
||||||
return super().__torch_function__(func, types, args, kwargs)
|
return super().__torch_function__(func, types, args, kwargs)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.tensor.colo_tensor import ColoTensor
|
||||||
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||||
|
|
||||||
|
|
||||||
class ParamOpHook(ABC):
|
class ColoParamOpHook(ABC):
|
||||||
"""Hook which is triggered by each operation when operands contain ColoParameter.
|
"""Hook which is triggered by each operation when operands contain ColoParameter.
|
||||||
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
|
To customize it, you must inherit this abstract class, and implement ``pre_forward``,
|
||||||
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
|
``post_forward``, ``pre_backward`` and ``post_backward``. These four methods take a list
|
||||||
|
@ -32,68 +32,68 @@ class ParamOpHook(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ParamOpHookManager:
|
class ColoParamOpHookManager:
|
||||||
"""Manage your param op hooks. It only has static methods.
|
"""Manage your param op hooks. It only has static methods.
|
||||||
The only static method you should call is ``use_hooks(*hooks)``.
|
The only static method you should call is ``use_hooks(*hooks)``.
|
||||||
"""
|
"""
|
||||||
hooks: Tuple[ParamOpHook, ...] = tuple()
|
hooks: Tuple[ColoParamOpHook, ...] = tuple()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def use_hooks(*hooks: ParamOpHook):
|
def use_hooks(*hooks: ColoParamOpHook):
|
||||||
"""Change the param op hooks you use. Nested calling is allowed.
|
"""Change the param op hooks you use. Nested calling is allowed.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> with ParamOpHookManager.use_hooks(*hooks):
|
>>> with ColoParamOpHookManager.use_hooks(*hooks):
|
||||||
>>> do_something()
|
>>> do_something()
|
||||||
>>> with ParamOpHookManager.use_hooks():
|
>>> with ColoParamOpHookManager.use_hooks():
|
||||||
>>> // clear hooks
|
>>> // clear hooks
|
||||||
>>> do_something()
|
>>> do_something()
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
old_param_op_hooks = ParamOpHookManager.hooks
|
old_param_op_hooks = ColoParamOpHookManager.hooks
|
||||||
ParamOpHookManager.hooks = hooks
|
ColoParamOpHookManager.hooks = hooks
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
ParamOpHookManager.hooks = old_param_op_hooks
|
ColoParamOpHookManager.hooks = old_param_op_hooks
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
|
def _trigger_pre_forward(params: List[torch.Tensor]) -> None:
|
||||||
for hook in ParamOpHookManager.hooks:
|
for hook in ColoParamOpHookManager.hooks:
|
||||||
hook.pre_forward(params)
|
hook.pre_forward(params)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _trigger_post_forward(params: List[torch.Tensor]) -> None:
|
def _trigger_post_forward(params: List[torch.Tensor]) -> None:
|
||||||
for hook in ParamOpHookManager.hooks:
|
for hook in ColoParamOpHookManager.hooks:
|
||||||
hook.post_forward(params)
|
hook.post_forward(params)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
|
def _trigger_pre_backward(params: List[torch.Tensor]) -> None:
|
||||||
for hook in ParamOpHookManager.hooks:
|
for hook in ColoParamOpHookManager.hooks:
|
||||||
hook.pre_backward(params)
|
hook.pre_backward(params)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _trigger_post_backward(params: List[torch.Tensor]) -> None:
|
def _trigger_post_backward(params: List[torch.Tensor]) -> None:
|
||||||
for hook in ParamOpHookManager.hooks:
|
for hook in ColoParamOpHookManager.hooks:
|
||||||
hook.post_backward(params)
|
hook.post_backward(params)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
|
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
|
||||||
ParamOpHookManager._trigger_pre_forward(params)
|
ColoParamOpHookManager._trigger_pre_forward(params)
|
||||||
args_info = _get_colo_tensors_info(*args)
|
args_info = _get_colo_tensors_info(*args)
|
||||||
rets = PreFwdPostBwd.apply(params, *args)
|
rets = PreFwdPostBwd.apply(params, *args)
|
||||||
return _update_colo_tensors(args_info, *rets)
|
return _update_colo_tensors(args_info, *rets)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
|
||||||
ParamOpHookManager._trigger_post_forward(params)
|
ColoParamOpHookManager._trigger_post_forward(params)
|
||||||
arg_info = _get_colo_tensors_info(arg)
|
arg_info = _get_colo_tensors_info(arg)
|
||||||
ret = PostFwdPreBwd.apply(params, arg)
|
ret = PostFwdPreBwd.apply(params, arg)
|
||||||
return _unpack_args(_update_colo_tensors(arg_info, ret))
|
return _unpack_args(_update_colo_tensors(arg_info, ret))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def has_hook() -> bool:
|
def has_hook() -> bool:
|
||||||
return len(ParamOpHookManager.hooks) > 0
|
return len(ColoParamOpHookManager.hooks) > 0
|
||||||
|
|
||||||
|
|
||||||
class PreFwdPostBwd(torch.autograd.Function):
|
class PreFwdPostBwd(torch.autograd.Function):
|
||||||
|
@ -105,7 +105,7 @@ class PreFwdPostBwd(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, *grads):
|
def backward(ctx, *grads):
|
||||||
ParamOpHookManager._trigger_post_backward(ctx.params)
|
ColoParamOpHookManager._trigger_post_backward(ctx.params)
|
||||||
return (None,) + grads
|
return (None,) + grads
|
||||||
|
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ class PostFwdPreBwd(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, *grads):
|
def backward(ctx, *grads):
|
||||||
ParamOpHookManager._trigger_pre_backward(ctx.params)
|
ColoParamOpHookManager._trigger_pre_backward(ctx.params)
|
||||||
return (None,) + grads
|
return (None,) + grads
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
||||||
|
|
||||||
from colossalai.gemini import TensorState
|
from colossalai.gemini import TensorState
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||||
|
|
||||||
|
|
||||||
class TrainingPhase(Enum):
|
class TrainingPhase(Enum):
|
||||||
|
@ -15,7 +15,7 @@ class TrainingPhase(Enum):
|
||||||
BACKWARD = 1
|
BACKWARD = 1
|
||||||
|
|
||||||
|
|
||||||
class GeminiZeROHook(ParamOpHook):
|
class GeminiZeROHook(ColoParamOpHook):
|
||||||
|
|
||||||
def __init__(self, gemini_manager: GeminiManager) -> None:
|
def __init__(self, gemini_manager: GeminiManager) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue