From 58d8b8a2dd9a92c1dab3a44d2a35fb30716437c5 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 18 Oct 2024 16:48:52 +0800 Subject: [PATCH] [misc] fit torch api upgradation and remove legecy import (#6093) * [amp] fit torch's new api * [amp] fix api call * [amp] fix api call * [misc] fit torch pytree api upgrade * [misc] remove legacy import * [misc] fit torch amp api * [misc] fit torch amp api --- colossalai/accelerator/cuda_accelerator.py | 2 +- colossalai/kernel/jit/option.py | 2 +- colossalai/pipeline/schedule/_utils.py | 10 ++++++++-- .../zero/gemini/memory_tracer/runtime_mem_tracer.py | 11 ++++++----- colossalai/zero/gemini/placement_policy.py | 3 ++- .../features/mixed_precision_training_with_booster.md | 2 +- .../features/mixed_precision_training_with_booster.md | 2 +- 7 files changed, 20 insertions(+), 12 deletions(-) diff --git a/colossalai/accelerator/cuda_accelerator.py b/colossalai/accelerator/cuda_accelerator.py index f1ab487d4..32e62b33f 100644 --- a/colossalai/accelerator/cuda_accelerator.py +++ b/colossalai/accelerator/cuda_accelerator.py @@ -279,4 +279,4 @@ class CudaAccelerator(BaseAccelerator): """ Return autocast function """ - return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) + return torch.amp.autocast(device_type="cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py index d392649a6..1ee93e4e0 100644 --- a/colossalai/kernel/jit/option.py +++ b/colossalai/kernel/jit/option.py @@ -1,7 +1,6 @@ import torch from colossalai.accelerator import get_accelerator -from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear from .bias_dropout_add import bias_dropout_add_fused_train from .bias_gelu import bias_gelu_impl @@ -45,6 +44,7 @@ def warmup_jit_fusion( dtype: torch.dtype = torch.float32, ): """Compile JIT functions before the main training steps""" + from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device()) linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device()) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 271b3238f..8f42a9014 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -3,8 +3,9 @@ from typing import Any, List, Optional, Tuple import torch import torch.cuda +from packaging.version import Version from torch.nn import Module -from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten +from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, tree_flatten, tree_map, tree_unflatten # this register are for torch under version 1.13.1, maybe removed in the future @@ -16,7 +17,12 @@ def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]" return OrderedDict((key, value) for key, value in zip(context, values)) -_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten) +if Version(torch.__version__) <= Version("1.13.1"): + try: + from torch.utils._pytree import register_pytree_node as _register_pytree_node + except ImportError: + from torch.utils._pytree import _register_pytree_node + _register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten) def tree_map_hf(fn: Any, pytree: Any): diff --git a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py index b0d258824..81520326f 100644 --- a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py @@ -1,10 +1,5 @@ import torch.nn -from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import ( - GradMemStats, - GradMemTracerHook, - ParamMemTracerHook, -) from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import _cast_float @@ -27,6 +22,12 @@ class RuntimeMemTracer: def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half): super().__init__() + from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import ( + GradMemStats, + GradMemTracerHook, + ParamMemTracerHook, + ) + self.module = module self.dtype = dtype self._gradstat = GradMemStats() diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 178755d03..2aa8dc3f6 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -8,7 +8,6 @@ import torch import torch.distributed as dist from colossalai.accelerator import get_accelerator -from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager @@ -172,6 +171,8 @@ class AutoPlacementPolicy(PlacementPolicy): Returns: int: the volume of memory that is evicted """ + from colossalai.legacy.utils.memory import colo_device_memory_capacity + start = time() cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) used_cuda_model_data = self.chunk_manager.total_mem["cuda"] diff --git a/docs/source/en/features/mixed_precision_training_with_booster.md b/docs/source/en/features/mixed_precision_training_with_booster.md index 65304b1f4..1e17c2bb5 100644 --- a/docs/source/en/features/mixed_precision_training_with_booster.md +++ b/docs/source/en/features/mixed_precision_training_with_booster.md @@ -16,7 +16,7 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan) AMP stands for automatic mixed precision training. In Colossal-AI, we have incorporated different implementations of mixed precision training: -1. torch.cuda.amp +1. torch.amp 2. apex.amp 3. naive amp diff --git a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md index da377ceb2..93a69830c 100644 --- a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md +++ b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md @@ -16,7 +16,7 @@ AMP 代表自动混合精度训练。 在 Colossal-AI 中, 我们结合了混合精度训练的不同实现: -1. torch.cuda.amp +1. torch.amp 2. apex.amp 3. naive amp