From d202cc28c0e7707762bb5f94d944575b327ba903 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 9 Jan 2024 10:20:05 +0800 Subject: [PATCH] [npu] change device to accelerator api (#5239) * update accelerator * fix timer * fix amp * update * fix * update bug * add error raise * fix autocast * fix set device * remove doc accelerator * update doc * update doc * update doc * use nullcontext * update cpu * update null context * change time limit for example * udpate * update * update * update * [npu] polish accelerator code --------- Co-authored-by: Xuanlei Zhao Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com> --- .../workflows/example_check_on_dispatch.yml | 2 +- .github/workflows/example_check_on_pr.yml | 2 +- .../workflows/example_check_on_schedule.yml | 2 +- applications/Chat/coati/trainer/ppo.py | 4 +- .../coati/trainer/strategies/colossalai.py | 13 +- applications/Colossal-LLaMA-2/train.py | 56 ++-- colossalai/accelerator/__init__.py | 2 + colossalai/accelerator/api.py | 13 +- colossalai/accelerator/base_accelerator.py | 236 ++++++++++++++- colossalai/accelerator/cpu_accelerator.py | 277 ++++++++++++++++++ colossalai/accelerator/cuda_accelerator.py | 224 +++++++++++++- colossalai/accelerator/npu_accelerator.py | 226 +++++++++++++- .../naive_amp/grad_scaler/base_grad_scaler.py | 4 +- .../grad_scaler/dynamic_grad_scaler.py | 14 +- .../naive_amp/mixed_precision_mixin/fp16.py | 4 +- .../auto_parallel/offload/amp_optimizer.py | 6 +- colossalai/auto_parallel/offload/solver.py | 7 +- .../booster/mixed_precision/fp16_torch.py | 4 +- colossalai/booster/plugin/gemini_plugin.py | 5 +- .../booster/plugin/hybrid_parallel_plugin.py | 28 +- .../booster/plugin/low_level_zero_plugin.py | 4 +- colossalai/initialize.py | 15 +- .../extensions/flash_attention/utils.py | 6 +- colossalai/kernel/jit/option.py | 22 +- colossalai/legacy/amp/torch_amp/torch_amp.py | 5 +- colossalai/legacy/communication/p2p.py | 10 +- colossalai/legacy/communication/ring.py | 6 +- colossalai/legacy/communication/utils.py | 6 +- .../legacy/engine/schedule/_base_schedule.py | 6 +- .../engine/schedule/_pipeline_schedule.py | 6 +- .../engine/schedule/_pipeline_schedule_v2.py | 4 +- colossalai/legacy/initialize.py | 6 +- .../nn/layer/colossalai_layer/embedding.py | 4 +- .../layer/colossalai_layer/normalization.py | 4 +- .../legacy/nn/layer/parallel_1d/layers.py | 20 +- .../legacy/nn/layer/parallel_2d/_operation.py | 8 +- .../legacy/nn/layer/parallel_2d/layers.py | 39 ++- .../nn/layer/parallel_2p5d/_operation.py | 14 +- .../legacy/nn/layer/parallel_2p5d/layers.py | 39 ++- .../legacy/nn/layer/parallel_3d/layers.py | 53 +++- .../nn/layer/parallel_sequence/_operation.py | 12 +- colossalai/legacy/nn/layer/vanilla/layers.py | 28 +- colossalai/legacy/nn/loss/loss_2d.py | 4 +- colossalai/legacy/nn/loss/loss_2p5d.py | 4 +- colossalai/legacy/nn/loss/loss_3d.py | 6 +- .../legacy/trainer/hooks/_metric_hook.py | 22 +- .../legacy/utils/activation_checkpoint.py | 10 +- colossalai/legacy/utils/memory.py | 9 +- .../utils/profiler/legacy/comm_profiler.py | 4 +- .../legacy/zero/gemini/stateful_tensor_mgr.py | 4 +- .../zero/gemini/tensor_placement_policy.py | 6 +- .../bucket_tensor_shard_strategy.py | 8 +- .../zero/shard_utils/tensor_shard_strategy.py | 10 +- .../zero/sharded_model/sharded_model_v2.py | 14 +- .../legacy/zero/sharded_model/zero_hook.py | 4 +- colossalai/moe/routers.py | 116 ++++---- colossalai/moe/utils.py | 27 +- colossalai/pipeline/schedule/generate.py | 4 +- .../pipeline/schedule/interleaved_pp.py | 6 +- colossalai/pipeline/schedule/one_f_one_b.py | 6 +- colossalai/shardformer/layer/utils.py | 18 +- colossalai/testing/utils.py | 15 +- colossalai/utils/__init__.py | 7 - colossalai/utils/device.py | 223 -------------- colossalai/utils/timer.py | 8 +- colossalai/zero/gemini/chunk/chunk.py | 26 +- colossalai/zero/gemini/chunk/manager.py | 12 +- colossalai/zero/gemini/gemini_ddp.py | 7 +- colossalai/zero/gemini/gemini_optimizer.py | 15 +- .../memory_tracer/chunk_memstats_collector.py | 4 +- .../gemini/memory_tracer/memory_monitor.py | 6 +- colossalai/zero/gemini/placement_policy.py | 8 +- colossalai/zero/gemini/utils.py | 6 +- colossalai/zero/low_level/low_level_optim.py | 33 +-- .../train_gpt_using_hybrid_parallelism.md | 5 +- .../train_gpt_using_hybrid_parallelism.md | 3 +- .../roberta/pretraining/run_pretraining.py | 11 +- .../dreambooth/train_dreambooth_colossalai.py | 12 +- .../train_dreambooth_colossalai_lora.py | 12 +- examples/images/resnet/train.py | 6 +- examples/images/vit/vit_benchmark.py | 5 +- examples/inference/benchmark_llama.py | 11 +- examples/inference/run_llama_inference.py | 4 +- examples/language/bert/finetune.py | 10 +- .../auto_offload/train_gpt_offload.py | 4 +- .../language/gpt/gemini/train_gpt_demo.py | 8 +- .../gpt/hybridparallelism/finetune.py | 10 +- examples/language/gpt/titans/model/embed.py | 14 +- examples/language/llama2/benchmark.py | 11 +- examples/language/llama2/data_utils.py | 6 +- examples/language/llama2/finetune.py | 6 +- .../language/llama2/performance_evaluator.py | 9 +- examples/language/llama2/pretrain.py | 6 +- .../openmoe/benchmark/benchmark_cai.py | 10 +- examples/language/openmoe/train.py | 6 +- examples/language/palm/train.py | 8 +- .../tutorial/new_api/cifar_resnet/train.py | 6 +- examples/tutorial/new_api/cifar_vit/train.py | 6 +- .../tutorial/new_api/glue_bert/finetune.py | 4 +- examples/tutorial/opt/opt/run_clm.py | 16 +- .../test_offload/test_perf.py | 4 +- .../test_compatibility_with_gemini.py | 8 +- .../test_plugin/test_low_level_zero_plugin.py | 6 +- tests/test_legacy/test_comm/test_comm.py | 8 +- .../test_1d/checks_1d/check_layer_1d.py | 26 +- .../test_2d/checks_2d/check_layer_2d.py | 36 +-- .../test_2d/checks_2d/check_operation_2d.py | 12 +- .../test_2p5d/checks_2p5d/check_layer_2p5d.py | 36 +-- .../checks_2p5d/check_operation_2p5d.py | 12 +- .../test_3d/checks_3d/check_layer_3d.py | 30 +- .../checks_seq/check_layer_seq.py | 8 +- .../test_trainer/test_pipeline/test_p2p.py | 4 +- tests/test_legacy/test_utils/test_memory.py | 6 +- .../test_utils/test_norm_gradient_clipping.py | 4 +- tests/test_moe/test_grad_handler.py | 6 +- tests/test_moe/test_kernel.py | 10 +- tests/test_moe/test_moe_checkpoint.py | 4 +- tests/test_moe/test_moe_ep_tp.py | 62 ++-- tests/test_moe/test_moe_group.py | 4 +- tests/test_optimizer/test_adam_kernel.py | 7 +- tests/test_pipeline/test_p2p_communication.py | 4 +- tests/test_zero/test_gemini/test_chunkv2.py | 6 +- tests/test_zero/test_gemini/test_fwd_bwd.py | 4 +- .../test_zero/test_gemini/test_grad_accum.py | 4 +- tests/test_zero/test_gemini/test_inference.py | 8 +- tests/test_zero/test_gemini/test_optim.py | 4 +- tests/test_zero/test_gemini/test_search.py | 4 +- .../test_zero/test_low_level/test_grad_acc.py | 7 +- 128 files changed, 1773 insertions(+), 868 deletions(-) create mode 100644 colossalai/accelerator/cpu_accelerator.py delete mode 100644 colossalai/utils/device.py diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml index 9d3bd9a48..011a0ae03 100644 --- a/.github/workflows/example_check_on_dispatch.yml +++ b/.github/workflows/example_check_on_dispatch.yml @@ -47,7 +47,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ - timeout-minutes: 10 + timeout-minutes: 15 steps: - name: 📚 Checkout uses: actions/checkout@v3 diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 5934704f4..608ae863f 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -79,7 +79,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ - timeout-minutes: 10 + timeout-minutes: 15 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }} cancel-in-progress: true diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml index 5ed128c3e..4fcd1e3a9 100644 --- a/.github/workflows/example_check_on_schedule.yml +++ b/.github/workflows/example_check_on_schedule.yml @@ -35,7 +35,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 - timeout-minutes: 10 + timeout-minutes: 15 steps: - name: 📚 Checkout uses: actions/checkout@v3 diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index d69666898..330e4e0e3 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, DistributedSampler from tqdm import tqdm from transformers import PreTrainedTokenizerBase -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from .base import OnPolicyTrainer from .callbacks import Callback @@ -105,7 +105,7 @@ class PPOTrainer(OnPolicyTrainer): self.critic_optim = critic_optim self.offload_inference_models = offload_inference_models - self.device = get_current_device() + self.device = get_accelerator().get_current_device() def _before_fit( self, diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index 7129edb06..95f016786 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -6,7 +6,6 @@ import torch.nn as nn import colossalai from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.utils import get_current_device from colossalai.zero.gemini.gemini_ddp import GeminiDDP from .ddp import DDPStrategy @@ -158,9 +157,19 @@ class GeminiStrategy(DDPStrategy): warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.") + # colossalai has changed api for get_current_device in 0.3.4 version or newer + try: + from colossalai.accelerator import get_accelerator + + chunk_init_device = get_accelerator().get_current_device() + except: + from colossalai.utils import get_current_device + + chunk_init_device = get_current_device() + # NOTE: dist should be initialized before calling get_current_device() plugin_initializer = lambda: GeminiPlugin( - chunk_init_device=get_current_device(), + chunk_init_device=chunk_init_device, placement_policy=placement_policy, shard_param_frac=shard_param_frac, offload_optim_frac=offload_optim_frac, diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index 41b4ef031..92863e8e4 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -1,44 +1,37 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Continual Pre-training of LLaMA-2 developed by Colossal-AI Team +Continual Pre-training of LLaMA-2 developed by Colossal-AI Team """ -import json import argparse +import json import os import resource from contextlib import nullcontext -from tqdm import tqdm import torch import torch.distributed as dist +from colossal_llama2.dataset.loader import ( + DataCollatorForSupervisedDataset, + StatefulDistributedSampler, + load_tokenized_dataset, + setup_distributed_dataloader, +) +from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint +from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention +from colossal_llama2.utils.froze import freeze_non_embeds_parameters from torch.utils.tensorboard import SummaryWriter -from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig +from tqdm import tqdm +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import ( - GeminiPlugin, - LowLevelZeroPlugin, - HybridParallelPlugin, -) +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device - -from colossal_llama2.dataset.loader import ( - load_tokenized_dataset, - setup_distributed_dataloader, - DataCollatorForSupervisedDataset, - StatefulDistributedSampler, -) - -from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention -from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint -from colossal_llama2.utils.froze import freeze_non_embeds_parameters def get_model_numel(model: torch.nn.Module) -> int: @@ -215,9 +208,18 @@ def main() -> None: # ====================================================== # Initialize Model, Objective, Optimizer and LR Scheduler # ====================================================== - init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() - ) + + # colossalai has changed api for get_current_device in 0.3.4 version or newer + try: + from colossalai.accelerator import get_accelerator + + current_device = get_accelerator().get_current_device() + except: + from colossalai.utils import get_current_device + + current_device = get_current_device() + + init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() with init_ctx: model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) # Freeze part of parameters. @@ -320,7 +322,7 @@ def main() -> None: initial=start_step, ) as pbar: for step, batch in pbar: - batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} + batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)} batch_output = model(**batch) @@ -372,9 +374,7 @@ def main() -> None: # Final save. coordinator.print_on_master("Start saving final model checkpoint") booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) - coordinator.print_on_master( - f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}" - ) + coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") diff --git a/colossalai/accelerator/__init__.py b/colossalai/accelerator/__init__.py index d144235d3..1405133af 100644 --- a/colossalai/accelerator/__init__.py +++ b/colossalai/accelerator/__init__.py @@ -1,5 +1,6 @@ from .api import auto_set_accelerator, get_accelerator, set_accelerator from .base_accelerator import BaseAccelerator +from .cpu_accelerator import CpuAccelerator from .cuda_accelerator import CudaAccelerator from .npu_accelerator import NpuAccelerator @@ -10,4 +11,5 @@ __all__ = [ "BaseAccelerator", "CudaAccelerator", "NpuAccelerator", + "CpuAccelerator", ] diff --git a/colossalai/accelerator/api.py b/colossalai/accelerator/api.py index 393340b71..02b3055d7 100644 --- a/colossalai/accelerator/api.py +++ b/colossalai/accelerator/api.py @@ -3,6 +3,7 @@ from collections import OrderedDict from typing import Union from .base_accelerator import BaseAccelerator +from .cpu_accelerator import CpuAccelerator from .cuda_accelerator import CudaAccelerator from .npu_accelerator import NpuAccelerator @@ -15,7 +16,7 @@ _ACCELERATOR = None # we use ordered dictionary here to associate the # order with device check priority # i.e. auto_set_accelerator will check cuda first -_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator) +_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator, cpu=CpuAccelerator) def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None: @@ -43,19 +44,17 @@ def auto_set_accelerator() -> None: """ global _ACCELERATOR - for _, accelerator_cls in _ACCELERATOR_MAPPING.items(): + for accelerator_name, accelerator_cls in _ACCELERATOR_MAPPING.items(): try: accelerator = accelerator_cls() - if accelerator.is_available(): + if accelerator_name == "cpu" or accelerator.is_available(): _ACCELERATOR = accelerator - break + break except: pass if _ACCELERATOR is None: - raise RuntimeError( - f"No accelerator is available. Please check your environment. The list of accelerators we support is {list(_ACCELERATOR_MAPPING.keys())}" - ) + raise RuntimeError("No accelerator is available.") def get_accelerator() -> BaseAccelerator: diff --git a/colossalai/accelerator/base_accelerator.py b/colossalai/accelerator/base_accelerator.py index 71d03b8d6..a550cd7a2 100644 --- a/colossalai/accelerator/base_accelerator.py +++ b/colossalai/accelerator/base_accelerator.py @@ -1,6 +1,7 @@ #!/usr/bin/env python + from abc import ABC, abstractmethod -from typing import Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -8,6 +9,8 @@ __all__ = ["BaseAccelerator"] class BaseAccelerator(ABC): + support_set_device: bool = True + def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None: self._name = name self._communication_backend = communication_backend @@ -45,6 +48,12 @@ class BaseAccelerator(ABC): # ======================= # device APIs # ======================= + @abstractmethod + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + @abstractmethod def current_device(self) -> int: """ @@ -52,7 +61,7 @@ class BaseAccelerator(ABC): """ @abstractmethod - def set_device(self, device: Union[torch.device, int]) -> None: + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: """ Bind the current process to a device. """ @@ -79,3 +88,226 @@ class BaseAccelerator(ABC): """ Return the number of devices on the machine. """ + + def set_to_device(self, models: Any) -> Any: + """ + Send model to device. + + :param models: nn.module or a list of module + """ + if isinstance(models, list) and len(models) > 1: + ret = [] + for model in models: + ret.append(model.to(self.get_current_device())) + return ret + elif isinstance(models, list): + return models[0].to(self.get_current_device()) + else: + return models.to(self.get_current_device()) + + @abstractmethod + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the capability of a device. + """ + + @abstractmethod + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + + @abstractmethod + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + + @abstractmethod + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the device as given by nvidia-smi or npu-smi, etc. + """ + + # ======================= + # random number generator APIs + # ======================= + @abstractmethod + def get_rng_state(self, device="cuda") -> torch.Tensor: + """ + Returns the random number generator state of the specified device as a ByteTensor. + """ + + @abstractmethod + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + + @abstractmethod + def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None: + """ + Sets the random number generator state of the specified device. + """ + + @abstractmethod + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + + @abstractmethod + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current device. + """ + + @abstractmethod + def manual_seed_all(self, seed: int) -> None: + """ + Sets the seed for generating random numbers on all devices. + """ + + @abstractmethod + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current device. + """ + + @abstractmethod + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all devices. + """ + + @abstractmethod + def initial_seed(self) -> int: + """ + Returns the current random seed of the current device. + """ + + # ======================= + # memory management APIs + # ======================= + @abstractmethod + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other device application and visible in nvidia-smi. + """ + + @abstractmethod + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of CUDA memory allocator statistics for a given device. + """ + + @abstractmethod + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + + @abstractmethod + def memory_snapshot(self): + """ + Returns a snapshot of the CUDA memory allocator state across all devices. + """ + + @abstractmethod + def memory_allocated(self, device=None) -> int: + """ + Returns the current device memory occupied by tensors in bytes for a given device. + """ + + @abstractmethod + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum device memory occupied by tensors in bytes for a given device. + """ + + @abstractmethod + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum device memory occupied by tensors for a given device. + """ + + @abstractmethod + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum device memory managed by the caching allocator for a given device. + """ + + @abstractmethod + def memory_reserved(self, device=None) -> int: + """ + Returns the current device memory managed by the caching allocator in bytes for a given device. + """ + + @abstractmethod + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum device memory managed by the caching allocator in bytes for a given device. + """ + + @abstractmethod + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + + @abstractmethod + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the device memory allocator. + """ + + # ======================= + # streams and events APIs + # ======================= + + @abstractmethod + def Stream(self, device=None, priority=0, **kwargs): + """ + A device stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details. + """ + + @abstractmethod + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + device events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams. + """ + + @abstractmethod + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + + @abstractmethod + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + + @abstractmethod + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + + @abstractmethod + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + + # ======================= + # amp APIs + # ======================= + @abstractmethod + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ diff --git a/colossalai/accelerator/cpu_accelerator.py b/colossalai/accelerator/cpu_accelerator.py new file mode 100644 index 000000000..c1f01b4f7 --- /dev/null +++ b/colossalai/accelerator/cpu_accelerator.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python + +import resource +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import psutil +import torch + +from .base_accelerator import BaseAccelerator + +__all__ = ["CpuAccelerator"] + + +class CpuAccelerator(BaseAccelerator): + support_set_device: bool = False + """ + Accelerator class for cpu. + """ + + def __init__(self): + super().__init__(name="cpu", communication_backend="gloo", is_synchronous=False) + + # ======================= + # device APIs + # ======================= + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + return torch.device("cpu") + + def current_device(self) -> int: + """ + Return the current device index. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: + """ + Bind the current process to a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def is_available(self): + """ + Check if the accelerator is available. + """ + return True + + def device_count(self): + """ + Return the number of devices on the machine. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the cuda capability of a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # random number generator APIs + # ======================= + def get_rng_state(self, device=None) -> torch.Tensor: + """ + Returns the random number generator state of the specified GPU as a ByteTensor. + """ + return torch.get_rng_state(device) + + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def set_rng_state(self, new_state: torch.ByteTensor, device: str = None) -> None: + """ + Sets the random number generator state of the specified GPU. + """ + torch.set_rng_state(new_state) + + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current GPU. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def manual_seed_all(self, seed: int) -> None: + """ + Set the random seed for the all processes. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current GPU. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all GPUs. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def initial_seed(self) -> int: + """ + Returns the current random seed of the current GPU. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # memory management APIs + # ======================= + + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of CUDA memory allocator statistics for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_snapshot(self): + """ + Returns a snapshot of the CUDA memory allocator state across all devices. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_allocated(self, device=None) -> int: + """ + Returns the current GPU memory occupied by tensors in bytes for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum GPU memory occupied by tensors in bytes for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_reserved(self, device=None) -> int: + """ + Returns the current GPU memory managed by the caching allocator in bytes for a given device. + """ + return psutil.Process().memory_info().rss + + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum GPU memory managed by the caching allocator in bytes for a given device. + """ + return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + max_memory = int(psutil.virtual_memory().total * fraction) + _, hard = resource.getrlimit(resource.RLIMIT_AS) + resource.setrlimit(resource.RLIMIT_AS, (max_memory, hard)) + + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the CUDA memory allocator. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # streams and events APIs + # ======================= + + def Stream(self, device=None, priority=0, **kwargs): + """ + A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # amp APIs + # ======================= + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ + return nullcontext diff --git a/colossalai/accelerator/cuda_accelerator.py b/colossalai/accelerator/cuda_accelerator.py index 72152834a..bdaf53bd5 100644 --- a/colossalai/accelerator/cuda_accelerator.py +++ b/colossalai/accelerator/cuda_accelerator.py @@ -1,7 +1,9 @@ #!/usr/bin/env python -from typing import Union + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist from .base_accelerator import BaseAccelerator @@ -19,16 +21,26 @@ class CudaAccelerator(BaseAccelerator): # ======================= # device APIs # ======================= + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + return torch.device(f"cuda:{torch.cuda.current_device()}") + def current_device(self) -> int: """ Return the current device index. """ return torch.cuda.current_device() - def set_device(self, device: Union[torch.device, int]) -> None: + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: """ Bind the current process to a device. """ + if device is None: + if not dist.is_initialized(): + raise RuntimeError("Cannot get current device when distributed is not initialized.") + device = dist.get_rank() % self.device_count() torch.cuda.set_device(device) def get_device_name(self, device: Union[torch.device, int]) -> str: @@ -54,3 +66,211 @@ class CudaAccelerator(BaseAccelerator): Return the number of devices on the machine. """ return torch.cuda.device_count() + + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the cuda capability of a device. + """ + return torch.cuda.get_device_capability(device) + + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + return torch.cuda.get_device_name(device) + + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + return torch.cuda.get_device_properties(device) + + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi + """ + return torch.cuda.utilization(device) + + # ======================= + # random number generator APIs + # ======================= + def get_rng_state(self, device="cuda") -> torch.Tensor: + """ + Returns the random number generator state of the specified GPU as a ByteTensor. + """ + return torch.cuda.get_rng_state(device) + + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + return torch.cuda.get_rng_state_all() + + def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None: + """ + Sets the random number generator state of the specified GPU. + """ + torch.cuda.set_rng_state(new_state, device) + + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + torch.cuda.set_rng_state_all(new_states) + + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current GPU. + """ + torch.cuda.manual_seed(seed) + + def manual_seed_all(self, seed: int) -> None: + """ + Set the random seed for the all processes. + """ + torch.cuda.manual_seed_all(seed) + + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current GPU. + """ + torch.cuda.seed() + + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all GPUs. + """ + torch.cuda.seed_all() + + def initial_seed(self) -> int: + """ + Returns the current random seed of the current GPU. + """ + return torch.cuda.initial_seed() + + # ======================= + # memory management APIs + # ======================= + + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi. + """ + torch.cuda.empty_cache() + + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of CUDA memory allocator statistics for a given device. + """ + return torch.cuda.memory_stats(device=device) + + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + return torch.cuda.memory_summary(device=device, abbreviated=abbreviated) + + def memory_snapshot(self): + """ + Returns a snapshot of the CUDA memory allocator state across all devices. + """ + return torch.cuda.memory_snapshot() + + def memory_allocated(self, device=None) -> int: + """ + Returns the current GPU memory occupied by tensors in bytes for a given device. + """ + return torch.cuda.memory_allocated(device=device) + + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum GPU memory occupied by tensors in bytes for a given device. + """ + return torch.cuda.max_memory_allocated(device=device) + + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device. + """ + torch.cuda.reset_max_memory_allocated(device=device) + + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + """ + torch.cuda.reset_max_memory_cached(device=device) + + def memory_reserved(self, device=None) -> int: + """ + Returns the current GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.cuda.memory_reserved(device=device) + + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.cuda.max_memory_reserved(device=device) + + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + torch.cuda.set_per_process_memory_fraction(fraction, device=device) + + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the CUDA memory allocator. + """ + torch.cuda.reset_peak_memory_stats(device=device) + + # ======================= + # streams and events APIs + # ======================= + + def Stream(self, device=None, priority=0, **kwargs): + """ + A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details. + """ + return torch.cuda.Stream(device, priority, **kwargs) + + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams. + """ + return torch.cuda.Event(enable_timing, blocking, interprocess) + + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + return torch.cuda.current_stream(device) + + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + return torch.cuda.default_stream(device) + + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + torch.cuda.set_stream(stream_) + + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + return torch.cuda.stream(stream_) + + # ======================= + # amp APIs + # ======================= + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ + return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) diff --git a/colossalai/accelerator/npu_accelerator.py b/colossalai/accelerator/npu_accelerator.py index a8bba6eaf..b3575dbfe 100644 --- a/colossalai/accelerator/npu_accelerator.py +++ b/colossalai/accelerator/npu_accelerator.py @@ -1,13 +1,17 @@ #!/usr/bin/env python -from typing import Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist from .base_accelerator import BaseAccelerator +IS_NPU_AVAILABLE = False try: import torch_npu # noqa + + IS_NPU_AVAILABLE = True except ImportError: pass @@ -26,16 +30,26 @@ class NpuAccelerator(BaseAccelerator): # ======================= # device APIs # ======================= + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + return torch.device(f"npu:{torch.npu.current_device()}") + def current_device(self) -> int: """ Return the current device index. """ return torch.npu.current_device() - def set_device(self, device: Union[torch.device, int]) -> None: + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: """ Bind the current process to a device. """ + if device is None: + if not dist.is_initialized(): + raise RuntimeError("Cannot get current device when distributed is not initialized.") + device = dist.get_rank() % self.device_count() torch.npu.set_device(device) def get_device_name(self, device: Union[torch.device, int]) -> str: @@ -61,3 +75,211 @@ class NpuAccelerator(BaseAccelerator): Return the number of devices on the machine. """ return torch.npu.device_count() + + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the npu capability of a device. + """ + return torch.npu.get_device_capability(device) + + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + return torch.npu.get_device_name(device) + + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + return torch.npu.get_device_properties(device) + + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi + """ + return torch.npu.utilization(device) + + # ======================= + # random number generator APIs + # ======================= + def get_rng_state(self, device="npu") -> torch.Tensor: + """ + Returns the random number generator state of the specified GPU as a ByteTensor. + """ + return torch.npu.get_rng_state(device) + + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + return torch.npu.get_rng_state_all() + + def set_rng_state(self, new_state: torch.ByteTensor, device: str = "npu") -> None: + """ + Sets the random number generator state of the specified GPU. + """ + torch.npu.set_rng_state(new_state, device) + + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + torch.npu.set_rng_state_all(new_states) + + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current GPU. + """ + torch.npu.manual_seed(seed) + + def manual_seed_all(self, seed: int) -> None: + """ + Set the random seed for the all processes. + """ + torch.npu.manual_seed_all(seed) + + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current GPU. + """ + torch.npu.seed() + + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all GPUs. + """ + torch.npu.seed_all() + + def initial_seed(self) -> int: + """ + Returns the current random seed of the current GPU. + """ + return torch.npu.initial_seed() + + # ======================= + # memory management APIs + # ======================= + + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi. + """ + torch.npu.empty_cache() + + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of npu memory allocator statistics for a given device. + """ + return torch.npu.memory_stats(device=device) + + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + return torch.npu.memory_summary(device=device, abbreviated=abbreviated) + + def memory_snapshot(self): + """ + Returns a snapshot of the npu memory allocator state across all devices. + """ + return torch.npu.memory_snapshot() + + def memory_allocated(self, device=None) -> int: + """ + Returns the current GPU memory occupied by tensors in bytes for a given device. + """ + return torch.npu.memory_allocated(device=device) + + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum GPU memory occupied by tensors in bytes for a given device. + """ + return torch.npu.max_memory_allocated(device=device) + + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device. + """ + torch.npu.reset_max_memory_allocated(device=device) + + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + """ + torch.npu.reset_max_memory_cached(device=device) + + def memory_reserved(self, device=None) -> int: + """ + Returns the current GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.npu.memory_reserved(device=device) + + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.npu.max_memory_reserved(device=device) + + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + torch.npu.set_per_process_memory_fraction(fraction, device=device) + + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the npu memory allocator. + """ + torch.npu.reset_peak_memory_stats(device=device) + + # ======================= + # streams and events APIs + # ======================= + + def Stream(self, device=None, priority=0, **kwargs): + """ + A npu stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See npu-semantics for details. + """ + return torch.npu.Stream(device, priority, **kwargs) + + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + npu events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize npu streams. + """ + return torch.npu.Event(enable_timing, blocking, interprocess) + + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + return torch.npu.current_stream(device) + + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + return torch.npu.default_stream(device) + + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + torch.npu.set_stream(stream_) + + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + return torch.npu.stream(stream_) + + # ======================= + # amp APIs + # ======================= + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ + return torch.npu.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py index 439d13dcf..fc4c884d4 100644 --- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -7,8 +7,8 @@ from typing import Dict import torch from torch import Tensor +from colossalai.accelerator import get_accelerator from colossalai.logging import get_dist_logger -from colossalai.utils.device import get_current_device __all__ = ["BaseGradScaler"] @@ -23,7 +23,7 @@ class BaseGradScaler(ABC): def __init__(self, initial_scale: float, verbose: bool): assert initial_scale > 0 - self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float) + self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float) self._verbose = verbose if self._verbose: diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py index 86ba919ee..5cd8035d7 100644 --- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -5,7 +5,7 @@ from typing import Optional import torch -from colossalai.utils.device import get_current_device +from colossalai.accelerator import get_accelerator from .base_grad_scaler import BaseGradScaler @@ -37,14 +37,20 @@ class DynamicGradScaler(BaseGradScaler): hysteresis: int = 2, verbose: bool = False, ): + a = get_accelerator() + a.device_count() super().__init__(initial_scale, verbose) if min_scale: - self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float) + self._min_scale = torch.tensor( + [min_scale], device=get_accelerator().get_current_device(), dtype=torch.float + ) else: self._min_scale = None if max_scale: - self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float) + self._max_scale = torch.tensor( + [max_scale], device=get_accelerator().get_current_device(), dtype=torch.float + ) else: self._max_scale = None @@ -117,7 +123,7 @@ class DynamicGradScaler(BaseGradScaler): return state_dict def load_state_dict(self, state_dict): - self._scale = state_dict["scale"].to(get_current_device()) + self._scale = state_dict["scale"].to(get_accelerator().get_current_device()) self._growth_factor = state_dict["growth_factor"] self._backoff_factor = state_dict["backoff_factor"] self._hysteresis = state_dict["hysteresis"] diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py index 9ce272356..2e7c8a281 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py @@ -5,8 +5,8 @@ import torch import torch.distributed as dist from torch import Tensor +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler -from colossalai.utils import get_current_device from .base import MixedPrecisionMixin @@ -40,7 +40,7 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin): max_scale=max_scale, ) self.optim_state = OptimState.UNSCALED - self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device()) + self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device()) @property def loss_scale(self) -> float: diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py index 601bf2926..fe8439269 100644 --- a/colossalai/auto_parallel/offload/amp_optimizer.py +++ b/colossalai/auto_parallel/offload/amp_optimizer.py @@ -4,10 +4,10 @@ from typing import Dict, Tuple import torch from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device from .base_offload_module import BaseOffloadModule from .region import Region @@ -79,7 +79,9 @@ class AMPOptimizer(OptimizerWrapper): hysteresis=hysteresis, max_scale=max_scale, ) - self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + self._found_overflow: torch.Tensor = torch.zeros( + 1, dtype=torch.int64, device=get_accelerator().get_current_device() + ) self._logger = get_dist_logger() def _set_grad_ptr(self): diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py index a6628e29c..3ad210de9 100644 --- a/colossalai/auto_parallel/offload/solver.py +++ b/colossalai/auto_parallel/offload/solver.py @@ -11,7 +11,7 @@ except: import torch from torch.fx.node import Node -from colossalai.utils.device import get_current_device +from colossalai.accelerator import get_accelerator from .region import Region from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator @@ -57,7 +57,10 @@ class Solver(ABC): if memory_budget > 0: self.memory_budget = memory_budget * self.error_factor else: - self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor + self.memory_budget = ( + torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory + * self.error_factor + ) self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() self.comp_power: float = self._extract_computing_power() diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index 443c4094c..c757a878d 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -5,8 +5,8 @@ import torch.nn as nn from torch import Tensor from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.utils.device import autocast from .mixed_precision_base import MixedPrecision @@ -89,7 +89,7 @@ class TorchAMPModule(ModelWrapper): super().__init__(module) def forward(self, *args, **kwargs): - with autocast(): + with get_accelerator().autocast(): return self.module(*args, **kwargs) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 261080dc9..d6610a3e1 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -12,6 +12,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +from colossalai.accelerator import IS_NPU_AVAILABLE, get_accelerator from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io.utils import ( get_model_base_filenames, @@ -24,8 +25,6 @@ from colossalai.checkpoint_io.utils import ( from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.utils import get_current_device -from colossalai.utils.device import IS_NPU_AVAILABLE from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats @@ -367,7 +366,7 @@ class GeminiPlugin(DPPluginBase): assert placement_policy == "static", "NPU only supports static placement policy" self.gemini_config = dict( chunk_config_dict=chunk_config_dict, - chunk_init_device=(chunk_init_device or get_current_device()), + chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()), placement_policy=placement_policy, enable_gradient_accumulation=enable_gradient_accumulation, shard_param_frac=shard_param_frac, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bbc36ceab..2cc9e19bf 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -18,6 +18,7 @@ from torch.utils._pytree import tree_map from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh @@ -29,7 +30,6 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.zero.low_level import LowLevelZeroOptimizer -from colossalai.utils.device import get_current_device from .pp_plugin_base import PipelinePluginBase @@ -82,7 +82,7 @@ class HybridParallelModule(ModelWrapper): self.mixed_precision = torch.bfloat16 if self.mixed_precision is not None: module = module.to(self.mixed_precision) - module = module.to(get_current_device()) + module = module.to(get_accelerator().get_current_device()) # setting input type cast when using mixed precision self.convert_fn = None @@ -346,7 +346,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if self.tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) if self.pp_size > 1: @@ -385,7 +387,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): total_norm_exponentiated += grad_norm_exponentiated - total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32) + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if self.tp_size > 1: # compute norm in tp process group dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) @@ -543,7 +547,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): # so we need to calculate the norm of 'tp' and 'pp' gradients. total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if self.tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) @@ -586,7 +592,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): total_norm_exponentiated += grad_norm_exponentiated - total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32) + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if self.tp_size > 1: # compute norm in tp process group dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) @@ -798,7 +806,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # so we only need to calculate the norm 'tp' of 'pp' gradients. total_norm = super()._compute_grad_norm(gradients, norm_type) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) @@ -837,7 +847,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): total_norm_exponentiated += grad_norm_exponentiated - total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32) + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if dp_size > 1: # compute norm in dp process group dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 89102820c..d21496f0b 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -12,6 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map from torch.utils.data import DataLoader +from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO from colossalai.checkpoint_io.utils import ( get_optimizer_base_filenames, @@ -24,7 +25,6 @@ from colossalai.checkpoint_io.utils import ( sharded_optimizer_loading_epilogue, ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper -from colossalai.utils import get_current_device from colossalai.zero import LowLevelZeroOptimizer from .dp_plugin_base import DPPluginBase @@ -52,7 +52,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): self.dtype = torch.bfloat16 if self.dtype is not None: module = module.to(self.dtype) - module = module.to(get_current_device()) + module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None if self.dtype is not None: diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 25076b742..aaeaad382 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -6,12 +6,12 @@ import warnings from pathlib import Path from typing import Dict, Union -import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.context import Config from colossalai.logging import get_dist_logger -from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed +from colossalai.utils import set_seed def launch( @@ -47,17 +47,18 @@ def launch( if rank == 0: warnings.warn("`config` is deprecated and will be removed soon.") - if IS_NPU_AVAILABLE and backend == "nccl": - backend = "hccl" + cur_accelerator = get_accelerator() + + backend = cur_accelerator.communication_backend # init default process group init_method = f"tcp://[{host}]:{port}" dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # set cuda device - if torch.cuda.is_available() or IS_NPU_AVAILABLE: - # if local rank is not given, calculate automatically - set_device(local_rank) + # if local rank is not given, calculate automatically + if cur_accelerator.support_set_device: + cur_accelerator.set_device(local_rank) set_seed(seed) diff --git a/colossalai/kernel/extensions/flash_attention/utils.py b/colossalai/kernel/extensions/flash_attention/utils.py index 0eab9e89f..06fef491f 100644 --- a/colossalai/kernel/extensions/flash_attention/utils.py +++ b/colossalai/kernel/extensions/flash_attention/utils.py @@ -6,7 +6,7 @@ import torch import torch.nn.functional as F from einops import rearrange -from colossalai.utils.device import get_current_device +from colossalai.accelerator import get_accelerator class Unpad(torch.autograd.Function): @@ -70,7 +70,9 @@ class SeqLenInfo: cu_seqlens: torch.Tensor = None @staticmethod - def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()): + def materialize( + attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device() + ): if attn_mask is not None: indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py index 8bebad894..d392649a6 100644 --- a/colossalai/kernel/jit/option.py +++ b/colossalai/kernel/jit/option.py @@ -1,7 +1,7 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear -from colossalai.utils import get_current_device from .bias_dropout_add import bias_dropout_add_fused_train from .bias_gelu import bias_gelu_impl @@ -46,11 +46,13 @@ def warmup_jit_fusion( ): """Compile JIT functions before the main training steps""" - embed = Embedding(vocab_size, hidden_size).to(get_current_device()) - linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device()) - linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device()) + embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device()) + linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device()) + linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_accelerator().get_current_device()) - x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device()) + x = torch.randint( + vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_accelerator().get_current_device() + ) x = embed(x) y, y_bias = linear_1(x) z, z_bias = linear_2(y) @@ -58,8 +60,8 @@ def warmup_jit_fusion( # prop and recomputation for bias_grad, input_grad in zip([True, True], [False, True]): for _ in range(10): - bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device()) - input_ = torch.rand_like(y, dtype=dtype, device=get_current_device()) + bias = torch.rand_like(y_bias, dtype=dtype, device=get_accelerator().get_current_device()) + input_ = torch.rand_like(y, dtype=dtype, device=get_accelerator().get_current_device()) bias.requires_grad, input_.requires_grad = bias_grad, input_grad bias_gelu_impl(input_, bias) @@ -69,9 +71,9 @@ def warmup_jit_fusion( # prop and recomputation for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): for _ in range(10): - input_ = torch.rand_like(z, dtype=dtype, device=get_current_device()) - residual = torch.rand_like(x, dtype=dtype, device=get_current_device()) - bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device()) + input_ = torch.rand_like(z, dtype=dtype, device=get_accelerator().get_current_device()) + residual = torch.rand_like(x, dtype=dtype, device=get_accelerator().get_current_device()) + bias = torch.rand_like(z_bias, dtype=dtype, device=get_accelerator().get_current_device()) input_.requires_grad = input_grad bias.requires_grad = bias_grad residual.requires_grad = residual_grad diff --git a/colossalai/legacy/amp/torch_amp/torch_amp.py b/colossalai/legacy/amp/torch_amp/torch_amp.py index 0a8d09be2..08f867eee 100644 --- a/colossalai/legacy/amp/torch_amp/torch_amp.py +++ b/colossalai/legacy/amp/torch_amp/torch_amp.py @@ -1,18 +1,19 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from colossalai.utils.device import autocast - import torch.nn as nn from torch import Tensor from torch.nn.modules.loss import _Loss from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper from colossalai.legacy.utils import clip_grad_norm_fp32 from ._grad_scaler import GradScaler +autocast = get_accelerator().autocast + class TorchAMPOptimizer(OptimizerWrapper): """A wrapper class which integrate Pytorch AMP with an optimizer diff --git a/colossalai/legacy/communication/p2p.py b/colossalai/legacy/communication/p2p.py index 19c3919b6..cf0bd4ba2 100644 --- a/colossalai/legacy/communication/p2p.py +++ b/colossalai/legacy/communication/p2p.py @@ -8,9 +8,9 @@ from typing import List, Tuple, Union import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks @@ -43,12 +43,16 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): if isinstance(recv_shapes, torch.Size): recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors) - buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) + buffer_recv = torch.empty( + recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype + ) return buffer_recv, recv_split buffer_recv = [] for recv_shape in recv_shapes: recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors) - tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) + tensor_recv = torch.empty( + recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype + ) buffer_recv.append(tensor_recv) return buffer_recv, recv_split diff --git a/colossalai/legacy/communication/ring.py b/colossalai/legacy/communication/ring.py index a61dae56c..792a15abd 100644 --- a/colossalai/legacy/communication/ring.py +++ b/colossalai/legacy/communication/ring.py @@ -3,9 +3,9 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device, synchronize def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor: @@ -29,7 +29,7 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> current_rank = gpc.get_global_rank() tensor_recv_prev = torch.empty( - buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype + buffer_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=tensor_send_next.dtype ) # send to next rank @@ -52,6 +52,6 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> req.wait() # To protect against race condition when using batch_isend_irecv(). - synchronize() + get_accelerator().synchronize() return tensor_recv_prev diff --git a/colossalai/legacy/communication/utils.py b/colossalai/legacy/communication/utils.py index 6d77f3753..0b7c0eb74 100644 --- a/colossalai/legacy/communication/utils.py +++ b/colossalai/legacy/communication/utils.py @@ -3,9 +3,9 @@ from typing import List, Tuple, Union import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device TensorShape = Union[torch.Size, List[int], Tuple[int]] @@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool: if next_rank is None: next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} + tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()} if isinstance(obj, torch.Tensor): send_obj_nums = torch.tensor(1, **tensor_kwargs) dist.send(send_obj_nums, next_rank) @@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size: if prev_rank is None: prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) - tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} + tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()} recv_obj_nums = torch.empty((), **tensor_kwargs) dist.recv(recv_obj_nums, prev_rank) if recv_obj_nums.item() == 1: diff --git a/colossalai/legacy/engine/schedule/_base_schedule.py b/colossalai/legacy/engine/schedule/_base_schedule.py index 4a3ccfda1..9b2913442 100644 --- a/colossalai/legacy/engine/schedule/_base_schedule.py +++ b/colossalai/legacy/engine/schedule/_base_schedule.py @@ -6,8 +6,8 @@ from typing import Callable, Iterable import torch +from colossalai.accelerator import get_accelerator from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device class BaseSchedule(ABC): @@ -29,12 +29,12 @@ class BaseSchedule(ABC): def _move_tensor(element): if torch.is_tensor(element): if not element.is_cuda: - return element.to(get_current_device()).detach() + return element.to(get_accelerator().get_current_device()).detach() return element def _move_to_device(self, data): if isinstance(data, torch.Tensor): - data = data.to(get_current_device()) + data = data.to(get_accelerator().get_current_device()) elif isinstance(data, (list, tuple)): data_to_return = [] for element in data: diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py index 5fd5602e7..4a23853c1 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py @@ -7,12 +7,12 @@ from typing import Callable, List, Tuple, Union import torch.cuda import colossalai.legacy.communication as comm +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp.naive_amp import NaiveAMPModel from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank from colossalai.logging import get_dist_logger -from colossalai.utils.device import get_current_device from ._base_schedule import BaseSchedule @@ -352,7 +352,7 @@ class PipelineSchedule(BaseSchedule): output_objs = [] return_tensors = [] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) else: accum_loss = None # Used for tensor meta information communication @@ -584,7 +584,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): if not forward_only: output_obj_grads = [[] for _ in range(len(model))] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) else: accum_loss = None diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py index 4cd7e47c3..6e7760218 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py @@ -6,10 +6,10 @@ from typing import Iterable, Tuple import torch.cuda import colossalai.legacy.communication.p2p_v2 as comm +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.engine import Engine -from colossalai.utils.device import get_current_device from ._pipeline_schedule import PipelineSchedule @@ -99,7 +99,7 @@ class PipelineScheduleV2(PipelineSchedule): output_objs = [] return_tensors = [] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) else: accum_loss = None diff --git a/colossalai/legacy/initialize.py b/colossalai/legacy/initialize.py index 4035bd6b5..d99a7d3f0 100644 --- a/colossalai/legacy/initialize.py +++ b/colossalai/legacy/initialize.py @@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader +from colossalai.accelerator import get_accelerator from colossalai.context import Config, ConfigException from colossalai.interface import OptimizerWrapper from colossalai.legacy.amp import AMP_TYPE, convert_to_amp @@ -34,7 +35,6 @@ from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2 from colossalai.legacy.zero.gemini.ophooks import BaseOpHook from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device def get_default_parser(): @@ -309,9 +309,9 @@ def initialize( else: if isinstance(model, nn.Module): # first sync model across dp ranks - model.to(get_current_device()) + model.to(get_accelerator().get_current_device()) elif isinstance(model, Callable): - model = model().to(get_current_device()) + model = model().to(get_accelerator().get_current_device()) # optimizer maybe a optimizer_cls if isinstance(optimizer, Callable): diff --git a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py index e1db0fe98..aa661664f 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py @@ -3,8 +3,8 @@ from typing import Callable from torch import dtype, nn +from colossalai.accelerator import get_accelerator from colossalai.nn import init -from colossalai.utils import get_current_device from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D @@ -83,7 +83,7 @@ class Embedding(ColossalaiModule): embed = ( nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs) .to(dtype) - .to(get_current_device()) + .to(get_accelerator().get_current_device()) ) weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) elif num_embeddings <= vocab_parallel_limit: diff --git a/colossalai/legacy/nn/layer/colossalai_layer/normalization.py b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py index f8e317e72..58842f481 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/normalization.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py @@ -1,6 +1,6 @@ from torch import nn -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from ..parallel_1d import LayerNorm1D from ..parallel_2d import LayerNorm2D @@ -36,7 +36,7 @@ class LayerNorm(ColossalaiModule): def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None: tensor_parallel = get_tensor_parallel_mode() if tensor_parallel is None: - norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) + norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_accelerator().get_current_device()) else: norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) super().__init__(norm) diff --git a/colossalai/legacy/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py index b6ec5347f..36cb09d32 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py @@ -10,6 +10,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn.parameter import Parameter +from colossalai.accelerator import get_accelerator from colossalai.kernel import LayerNorm from colossalai.legacy.communication import broadcast from colossalai.legacy.context import ParallelMode, seed @@ -22,7 +23,6 @@ from colossalai.legacy.utils.checkpointing import ( partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..colossalai_layer._utils import ColossalaiModule @@ -221,7 +221,7 @@ class Classifier1D(ParallelLayer): # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False @@ -357,7 +357,7 @@ class VocabParallelClassifier1D(ParallelLayer): # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False @@ -499,7 +499,7 @@ class Linear1D_Col(ParallelLayer): # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) if bias: @@ -638,7 +638,7 @@ class Linear1D_Row(ParallelLayer): # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) if self.stream_chunk_num > 1: @@ -802,7 +802,9 @@ class Embedding1D(ParallelLayer): self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -912,7 +914,11 @@ class VocabParallelEmbedding1D(ParallelLayer): self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype) + torch.empty( + (self.num_embeddings_per_partition, self.embed_dim), + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.reset_parameters(weight_initializer) diff --git a/colossalai/legacy/nn/layer/parallel_2d/_operation.py b/colossalai/legacy/nn/layer/parallel_2d/_operation.py index f1eff7128..f67ee2e60 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_2d/_operation.py @@ -5,10 +5,10 @@ import torch.distributed as dist from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device def matmul_2d( @@ -250,7 +250,7 @@ class Matmul_AB_2D(torch.autograd.Function): B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[-1]) - C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -399,7 +399,7 @@ class Matmul_ABT_2D(torch.autograd.Function): B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[0]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -556,7 +556,7 @@ class Matmul_ATB_2D(torch.autograd.Function): B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[-1], B.shape[-1]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases diff --git a/colossalai/legacy/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py index f81c5334a..4987afa18 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import broadcast from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.core import global_context as gpc @@ -18,7 +19,6 @@ from colossalai.legacy.utils.checkpointing import ( partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple @@ -82,7 +82,7 @@ class Linear2D(ParallelLayer): self.hidden_size_per_partition = divide(self.out_features, self.summa_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter( torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) ) @@ -259,7 +259,7 @@ class LayerNorm2D(ParallelLayer): self.partitioned_partition = divide(normalized_shape, self.summa_dim**2) # create parameters - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) if bias: @@ -438,18 +438,24 @@ class PatchEmbedding2D(ParallelLayer): self.weight = Parameter( torch.empty( (self.embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) - self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) + ) self.cls_token = Parameter( - torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.pos_embed = Parameter( torch.zeros( - (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype + (1, self.num_patches + 1, self.embed_size_per_partition), + device=get_accelerator().get_current_device(), + dtype=dtype, ) ) @@ -619,7 +625,9 @@ class Embedding2D(ParallelLayer): self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -758,7 +766,7 @@ class VocabParallelEmbedding2D(ParallelLayer): self.weight = Parameter( torch.empty( (self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) @@ -895,11 +903,18 @@ class Classifier2D(ParallelLayer): self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, + self.input_size_per_partition, + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.has_weight = True if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -1052,7 +1067,7 @@ class VocabParallelClassifier2D(ParallelLayer): self.output_size_per_partition = divide(num_classes, self.summa_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py index 50900c135..43328bd03 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py @@ -5,10 +5,10 @@ import torch.distributed as dist from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device def get_parallel_group(parallel_mode: ParallelMode): @@ -205,7 +205,7 @@ class Matmul_AB_2p5D(torch.autograd.Function): B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[-1]) - C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -362,7 +362,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function): B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[0]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -527,7 +527,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function): B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[-1], B.shape[-1]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -661,7 +661,9 @@ class _Add_Bias_2p5D(torch.autograd.Function): if row_rank == 0: bias_temp = bias.clone() else: - bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) + bias_temp = torch.zeros( + output_size_per_partition, dtype=bias.dtype, device=get_accelerator().get_current_device() + ) src_rank = ( col_rank + dep_rank * tesseract_dim**2 @@ -984,7 +986,7 @@ class SplitFirst(torch.autograd.Function): @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: grad_shape = (ctx.batch_size,) + output_grad.shape[1:] - grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) + grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_accelerator().get_current_device()) dist.all_gather( list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode) ) diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py index b451a4031..d9410f1cb 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import broadcast from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.core import global_context as gpc @@ -19,7 +20,6 @@ from colossalai.legacy.utils.checkpointing import ( partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple @@ -84,7 +84,7 @@ class Linear2p5D(ParallelLayer): self.hidden_size_per_partition = divide(out_features, self.tesseract_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter( torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) ) @@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer): self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * # create parameters - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) if bias: @@ -451,18 +451,24 @@ class PatchEmbedding2p5D(ParallelLayer): self.weight = Parameter( torch.empty( (self.embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) - self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) + ) self.cls_token = Parameter( - torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.pos_embed = Parameter( torch.zeros( - (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype + (1, self.num_patches + 1, self.embed_size_per_partition), + device=get_accelerator().get_current_device(), + dtype=dtype, ) ) @@ -632,7 +638,9 @@ class Embedding2p5D(ParallelLayer): self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -772,7 +780,7 @@ class VocabParallelEmbedding2p5D(ParallelLayer): self.weight = Parameter( torch.empty( (self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) @@ -910,11 +918,18 @@ class Classifier2p5D(ParallelLayer): self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, + self.input_size_per_partition, + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.has_weight = True if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -1068,7 +1083,7 @@ class VocabParallelClassifier2p5D(ParallelLayer): self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False diff --git a/colossalai/legacy/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py index 16e515f87..bb01ec851 100644 --- a/colossalai/legacy/nn/layer/parallel_3d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import all_reduce, broadcast from colossalai.legacy.constants import ( INPUT_GROUP_3D, @@ -27,7 +28,6 @@ from colossalai.legacy.utils.checkpointing import ( partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ._operation import ( @@ -69,11 +69,13 @@ class LayerNorm3D(ParallelLayer): self.normalized_shape_per_partition = divide(normalized_shape, self.depth) self.weight = Parameter( - torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) + torch.ones(self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) ) if bias: self.bias = Parameter( - torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) + torch.zeros( + self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype + ) ) else: self.bias = None @@ -202,13 +204,15 @@ class Linear3D(ParallelLayer): torch.empty( self.in_features_per_partition, self.out_features_per_partition, - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) if bias: self.bias = Parameter( - torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) + torch.zeros( + self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype + ) ) else: self.bias = None @@ -380,11 +384,18 @@ class Classifier3D(ParallelLayer): self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, + self.in_features_per_partition, + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.has_weight = True if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -523,14 +534,16 @@ class VocabParallelClassifier3D(ParallelLayer): torch.empty( self.out_features_per_partition, self.in_features_per_partition, - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) self.has_weight = True if bias: self.bias = Parameter( - torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) + torch.zeros( + self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype + ) ) else: self.bias = None @@ -705,16 +718,24 @@ class PatchEmbedding3D(ParallelLayer): self.weight = nn.Parameter( torch.empty( - (embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype + (embed_size_per_partition, in_chans, *self.patch_size), + device=get_accelerator().get_current_device(), + dtype=dtype, ) ) - self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype)) + self.bias = nn.Parameter( + torch.empty(embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) + ) self.cls_token = nn.Parameter( - torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros((1, 1, embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype) ) self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, self.num_patches + 1, embed_size_per_partition), + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) @@ -880,7 +901,9 @@ class Embedding3D(ParallelLayer): self.embed_kwargs = kwargs self.weight = nn.Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -1019,7 +1042,7 @@ class VocabParallelEmbedding3D(ParallelLayer): self.weight = Parameter( torch.empty( (self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) diff --git a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py index 24d5499e3..4e9bf364d 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py @@ -5,11 +5,11 @@ import torch from torch import distributed as dist from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import ring_forward from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range -from colossalai.utils import get_current_device class RingQK(torch.autograd.Function): @@ -30,7 +30,7 @@ class RingQK(torch.autograd.Function): sub_seq_length, sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), dtype=sub_q.dtype, - device=get_current_device(), + device=get_accelerator().get_current_device(), ) # compute local QK^T @@ -71,7 +71,7 @@ class RingQK(torch.autograd.Function): grad_q = torch.zeros_like( sub_q, dtype=sub_q.dtype, - device=get_current_device(), + device=get_accelerator().get_current_device(), ) # compute with local sub_k @@ -105,7 +105,7 @@ class RingAV(torch.autograd.Function): batch_size * num_attention_heads, sub_seq_length, attention_head_size, - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=attention_score.dtype, ) @@ -142,7 +142,9 @@ class RingAV(torch.autograd.Function): grad_v /= local_world_size # calculate gradient for attention score - grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device()) + grad_attention_score = torch.zeros_like( + attention_scores, dtype=grad_output.dtype, device=get_accelerator().get_current_device() + ) # compute with local sub_k grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) diff --git a/colossalai/legacy/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py index 590ad5ff6..3a1c2e57b 100644 --- a/colossalai/legacy/nn/layer/vanilla/layers.py +++ b/colossalai/legacy/nn/layer/vanilla/layers.py @@ -7,10 +7,10 @@ from torch import Tensor from torch import nn as nn from torch.nn.parameter import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import seed from colossalai.legacy.registry import LAYERS from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..utils import to_2tuple @@ -173,12 +173,18 @@ class VanillaPatchEmbedding(nn.Module): self.flatten = flatten self.weight = nn.Parameter( - torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype) + torch.empty( + (embed_size, in_chans, *self.patch_size), device=get_accelerator().get_current_device(), dtype=dtype + ) + ) + self.bias = nn.Parameter(torch.empty(embed_size, device=get_accelerator().get_current_device(), dtype=dtype)) + self.cls_token = nn.Parameter( + torch.zeros((1, 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype) ) - self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) - self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, self.num_patches + 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) @@ -242,11 +248,15 @@ class VanillaClassifier(nn.Module): self.has_weight = False else: self.weight = nn.Parameter( - torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, self.in_features, device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.has_weight = True if bias: - self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = nn.Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -287,7 +297,7 @@ class VanillaLayerNorm(nn.Module): self.normalized_shape = (normalized_shape,) self.variance_epsilon = eps - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) if bias: @@ -333,7 +343,7 @@ class VanillaLinear(nn.Module): self.in_features = in_features self.out_features = out_features self.skip_bias_add = skip_bias_add - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) if bias: self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) diff --git a/colossalai/legacy/nn/loss/loss_2d.py b/colossalai/legacy/nn/loss/loss_2d.py index 44f39a6db..474fd4a2c 100644 --- a/colossalai/legacy/nn/loss/loss_2d.py +++ b/colossalai/legacy/nn/loss/loss_2d.py @@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.legacy.registry import LOSSES -from colossalai.utils import get_current_device @LOSSES.register_module @@ -118,7 +118,7 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function): grad_2d = grad_input.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device()) grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. diff --git a/colossalai/legacy/nn/loss/loss_2p5d.py b/colossalai/legacy/nn/loss/loss_2p5d.py index c57bf26e9..b423ab3d8 100644 --- a/colossalai/legacy/nn/loss/loss_2p5d.py +++ b/colossalai/legacy/nn/loss/loss_2p5d.py @@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.legacy.registry import LOSSES -from colossalai.utils import get_current_device @LOSSES.register_module @@ -112,7 +112,7 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function): grad_2d = grad_input.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device()) grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. diff --git a/colossalai/legacy/nn/loss/loss_3d.py b/colossalai/legacy/nn/loss/loss_3d.py index 988317cae..de6a674d6 100644 --- a/colossalai/legacy/nn/loss/loss_3d.py +++ b/colossalai/legacy/nn/loss/loss_3d.py @@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.accelerator import get_accelerator from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.legacy.registry import LOSSES -from colossalai.utils import get_current_device @LOSSES.register_module @@ -80,7 +80,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function): target_mask = (targets < vocab_start) | (targets > vocab_end) masked_target = targets.clone() - vocab_start masked_target[target_mask] = 0 - arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_accelerator().get_current_device()) predicted_logits = logits[arange_1d, masked_target] predicted_logits = predicted_logits.clone().contiguous().view_as(targets) predicted_logits[target_mask] = 0.0 @@ -110,7 +110,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function): grad_2d = input_grad.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device()) grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() input_grad.mul_(output_grad.unsqueeze(dim=-1)) diff --git a/colossalai/legacy/trainer/hooks/_metric_hook.py b/colossalai/legacy/trainer/hooks/_metric_hook.py index 35a7f0a15..0e6731db5 100644 --- a/colossalai/legacy/trainer/hooks/_metric_hook.py +++ b/colossalai/legacy/trainer/hooks/_metric_hook.py @@ -7,12 +7,12 @@ from typing import Callable import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import all_reduce from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.registry import HOOKS from colossalai.legacy.utils import is_no_pp_or_last_stage -from colossalai.utils import get_current_device from ._base_hook import BaseHook from ._commons_ import _format_number @@ -82,8 +82,8 @@ class LossMetric(Metric): def __init__(self, epoch_only): super().__init__(epoch_only=epoch_only) - self.last_step_loss = torch.zeros(1, device=get_current_device()) - self.accum_loss = torch.zeros(1, device=get_current_device()) + self.last_step_loss = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) self.count = 0 def reset(self) -> None: @@ -164,10 +164,10 @@ class AccuracyMetric(Metric): def __init__(self, epoch_only: bool, accuracy_func: Callable): super().__init__(epoch_only=epoch_only) self.acc = accuracy_func - self.last_step_sum = torch.zeros(1, device=get_current_device()) - self.last_step_correct = torch.zeros(1, device=get_current_device()) - self.accumulated_sum = torch.zeros(1, device=get_current_device()) - self.accumulated_correct = torch.zeros(1, device=get_current_device()) + self.last_step_sum = torch.zeros(1, device=get_accelerator().get_current_device()) + self.last_step_correct = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accumulated_sum = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accumulated_correct = torch.zeros(1, device=get_accelerator().get_current_device()) def reset(self) -> None: self.last_step_sum.zero_() @@ -320,10 +320,10 @@ class ThroughputMetric(Metric): super().__init__(epoch_only=epoch_only) self.ignored_steps = ignored_steps self.cur_steps = 0 - self.accumulated_num_samples = torch.zeros(1, device=get_current_device()) - self.accumulated_used_time = torch.zeros(1, device=get_current_device()) - self.last_step_num_samples = torch.zeros(1, device=get_current_device()) - self.last_step_used_time = torch.zeros(1, device=get_current_device()) + self.accumulated_num_samples = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accumulated_used_time = torch.zeros(1, device=get_accelerator().get_current_device()) + self.last_step_num_samples = torch.zeros(1, device=get_accelerator().get_current_device()) + self.last_step_used_time = torch.zeros(1, device=get_accelerator().get_current_device()) self._tflop_per_step = tflop_per_step self._use_local = use_local diff --git a/colossalai/legacy/utils/activation_checkpoint.py b/colossalai/legacy/utils/activation_checkpoint.py index 9a8051ae9..d1382cb1e 100644 --- a/colossalai/legacy/utils/activation_checkpoint.py +++ b/colossalai/legacy/utils/activation_checkpoint.py @@ -6,8 +6,8 @@ import weakref import torch from torch.utils.checkpoint import check_backward_validity, detach_variable +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states -from colossalai.utils.device import autocast, get_current_device def copy_to_device(obj, device): @@ -33,7 +33,7 @@ class CheckpointFunction(torch.autograd.Function): check_backward_validity(args) ctx.run_function = run_function ctx.activation_offload = activation_offload - ctx.device = get_current_device() + ctx.device = get_accelerator().get_current_device() # preserve rng states ctx.fwd_cpu_rng_state = torch.get_rng_state() @@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function): inputs[idx] = tensors[i] detached_inputs = detach_variable(tuple(inputs)) if ctx.had_autocast_in_fwd: - with torch.enable_grad(), autocast(): + with torch.enable_grad(), get_accelerator().autocast()(): outputs = ctx.run_function(*detached_inputs) else: with torch.enable_grad(): @@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args): # rerun forward, the inner_pack will store all the activations in storage if has_autocast_in_fwd: - with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks( + with torch.enable_grad(), get_accelerator().autocast()(), torch.autograd.graph.saved_tensors_hooks( inner_pack, inner_unpack ): _unused = function(*args) @@ -245,7 +245,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args): # get device if we need to offload the activation if activation_offload: - device = get_current_device() + device = get_accelerator().get_current_device() # run function with pack and unpack as saved_tensors_hooks with torch.autograd.graph.saved_tensors_hooks(pack, unpack): diff --git a/colossalai/legacy/utils/memory.py b/colossalai/legacy/utils/memory.py index 2f99a7d2f..cfb22d315 100644 --- a/colossalai/legacy/utils/memory.py +++ b/colossalai/legacy/utils/memory.py @@ -6,9 +6,9 @@ import torch import torch.distributed as dist from packaging import version +from colossalai.accelerator import get_accelerator from colossalai.legacy.core import global_context as gpc from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device _GLOBAL_CUDA_MEM_FRACTION = 1.0 _GLOBAL_CPU_MEM_CAPACITY = -1 @@ -112,7 +112,10 @@ def colo_device_memory_capacity(device: torch.device) -> int: # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory. return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node if device.type == "cuda": - return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION + return ( + torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory + * _GLOBAL_CUDA_MEM_FRACTION + ) def colo_device_memory_used(device: torch.device) -> int: @@ -153,7 +156,7 @@ def colo_set_process_memory_fraction(ratio: float) -> None: return global _GLOBAL_CUDA_MEM_FRACTION _GLOBAL_CUDA_MEM_FRACTION = ratio - torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device()) + torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_accelerator().get_current_device()) def colo_set_cpu_memory_capacity(size: int) -> None: diff --git a/colossalai/legacy/utils/profiler/legacy/comm_profiler.py b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py index ad54b989f..a9e3ffe1a 100644 --- a/colossalai/legacy/utils/profiler/legacy/comm_profiler.py +++ b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py @@ -8,7 +8,7 @@ import torch.distributed as dist from torch.autograd.profiler import profile from torch.distributed import ReduceOp -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time @@ -177,7 +177,7 @@ class CommProfiler(BaseProfiler): assert current_comm_event is not None, "dist op has not been found" - buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) + buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_accelerator().get_current_device()) torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) current_comm_event.self_cuda_time = buffer.item() diff --git a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py index e336717f4..b0360880e 100644 --- a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py +++ b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py @@ -3,7 +3,7 @@ import types from time import time from typing import List -from colossalai.utils.device import get_current_device +from colossalai.accelerator import get_accelerator from .stateful_tensor import StatefulTensor, TensorState from .tensor_placement_policy import TensorPlacementPolicy @@ -69,7 +69,7 @@ class StatefulTensorMgr(object): # move COMPUTE tensors to CUDA self._cpu_gpu_move_volume += cuda_demand for t in move_to_cuda_tensor_list: - colo_model_data_tensor_move_inline(t, get_current_device()) + colo_model_data_tensor_move_inline(t, get_accelerator().get_current_device()) @property def cpu_gpu_move_volume(self): diff --git a/colossalai/legacy/zero/gemini/tensor_placement_policy.py b/colossalai/legacy/zero/gemini/tensor_placement_policy.py index 3aca80cfe..6fde91d4a 100644 --- a/colossalai/legacy/zero/gemini/tensor_placement_policy.py +++ b/colossalai/legacy/zero/gemini/tensor_placement_policy.py @@ -5,8 +5,8 @@ from typing import List, Optional, Type import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.utils.memory import colo_device_memory_capacity -from colossalai.utils import get_current_device from colossalai.zero.gemini.memory_tracer import MemStatsCollector from .stateful_tensor import StatefulTensor @@ -38,7 +38,7 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy): class CUDATensorPlacementPolicy(TensorPlacementPolicy): def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available" - super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector) + super().__init__(get_accelerator().get_current_device(), mem_stats_collector=mem_stats_collector) def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: return 0, 0 @@ -78,7 +78,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy): int: the volume of memory that is evicted """ start = time() - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. diff --git a/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py index b9d3071a8..e5a35dea1 100644 --- a/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -4,8 +4,8 @@ import torch import torch.distributed as dist from torch._utils import _flatten_dense_tensors as flatten +from colossalai.accelerator import get_accelerator from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.utils import get_current_device from .tensor_shard_strategy import TensorShardStrategy @@ -30,9 +30,11 @@ class BucketTensorShardStrategy(TensorShardStrategy): rank = dist.get_rank(process_group) for i in range(world_size): if i == rank: - buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) + buffer_list.append( + flatten([t.payload for t in tensor_list]).cuda(get_accelerator().get_current_device()) + ) else: - buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) + buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_accelerator().get_current_device())) dist.all_gather(buffer_list, buffer_list[rank], group=process_group) # Move to target device before splitting buffer # Ensure we utilize maximum PCIE bandwidth diff --git a/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py index ebaef774b..fb6ef534b 100644 --- a/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py @@ -3,11 +3,11 @@ from typing import List, Optional import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.legacy.zero.shard_utils.commons import get_shard from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.utils import get_current_device class TensorShardStrategy(BaseShardStrategy): @@ -34,9 +34,9 @@ class TensorShardStrategy(BaseShardStrategy): if t.is_sharded: return if t.payload.device.type == "cuda": - assert t.payload.device == get_current_device(), ( + assert t.payload.device == get_accelerator().get_current_device(), ( f"shard tensor on cuda device index {t.payload.device.index}," - f" but current cuda device is {get_current_device()}" + f" but current cuda device is {get_accelerator().get_current_device()}" ) sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) t.payload_reset(sharded_payload) @@ -50,7 +50,9 @@ class TensorShardStrategy(BaseShardStrategy): world_size = dist.get_world_size(process_group) rank = dist.get_rank(process_group) - buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device()) + buffer = torch.empty( + payload_numel * world_size, dtype=t.payload.dtype, device=get_accelerator().get_current_device() + ) buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0)) buffer_list[rank].copy_(t.payload) diff --git a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py index 85f2ac215..bb7744a80 100644 --- a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py @@ -10,6 +10,7 @@ import torch.nn as nn from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils.memory import colo_device_memory_capacity @@ -22,7 +23,7 @@ from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_move_to_c from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.logging import get_dist_logger -from colossalai.utils import disposable, get_current_device +from colossalai.utils import disposable from colossalai.zero.gemini.memory_tracer import MemStatsCollector from ._utils import ( @@ -212,8 +213,12 @@ class ShardedModelV2(nn.Module): self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0]) if gpc.get_global_rank() == 0: with open(filename, "w+") as f: - f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n") - f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n") + f.write( + f"cuda reserved {torch.cuda.memory_reserved(get_accelerator().get_current_device()) / 1e9} GB\n" + ) + f.write( + f"cuda max allocated {torch.cuda.max_memory_allocated(get_accelerator().get_current_device()) / 1e9} GB\n" + ) f.write("CUDA model data (GB)\n") f.write("\n") f.write("CUDA non model data (GB)\n") @@ -266,7 +271,8 @@ class ShardedModelV2(nn.Module): # model data is fixed in cuda during training. # cuda margin space can be used to store OS. self._cuda_margin_space = ( - colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda + colo_device_memory_capacity(get_accelerator().get_current_device()) + - self._memstats_collector._memstats.max_overall_cuda ) @torch.no_grad() diff --git a/colossalai/legacy/zero/sharded_model/zero_hook.py b/colossalai/legacy/zero/sharded_model/zero_hook.py index 892e9f31d..332f44d53 100644 --- a/colossalai/legacy/zero/sharded_model/zero_hook.py +++ b/colossalai/legacy/zero/sharded_model/zero_hook.py @@ -3,13 +3,13 @@ from typing import Optional import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.registry import OPHOOKS from colossalai.legacy.zero.gemini.ophooks import BaseOpHook from colossalai.legacy.zero.gemini.stateful_tensor import TensorState from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device from colossalai.zero.gemini.memory_tracer import MemStatsCollector @@ -33,7 +33,7 @@ class ZeroHook(BaseOpHook): self.process_group = process_group # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU - self.computing_device = get_current_device() + self.computing_device = get_accelerator().get_current_device() self._memstarts_collector = memstarts_collector self._stateful_tensor_mgr = stateful_tensor_mgr diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index c5bb50862..f5815d05d 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -8,9 +8,9 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributed import ProcessGroup +from colossalai.accelerator import get_accelerator from colossalai.moe._operation import moe_cumsum from colossalai.moe.manager import MOE_MANAGER -from colossalai.utils import get_current_device class MoeRouter(nn.Module, ABC): @@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC): drop_tks (bool, optional): Whether drops tokens in evaluation """ - def __init__(self, - k_value: int, - capacity_factor_train: float, - capacity_factor_eval: float, - min_capacity: int, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True, - use_kernel: bool = False): + def __init__( + self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + use_kernel: bool = False, + ): super().__init__() self.k_value = k_value self.capacity_factor_train = capacity_factor_train @@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC): if router_probs.dim() == expert_indices.dim() == 2: router_probs = router_probs.unsqueeze(0) expert_indices = expert_indices.unsqueeze(0) - assert router_probs.dim() == expert_indices.dim() == 3, \ - "router_probs must be 3D tensor and expert_indices must be 4D tensor" + assert ( + router_probs.dim() == expert_indices.dim() == 3 + ), "router_probs must be 3D tensor and expert_indices must be 4D tensor" # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. expert_mask = F.one_hot(expert_indices, num_experts) @@ -122,25 +125,29 @@ class Top1Router(MoeRouter): drop_tks (bool, optional): Whether drops tokens in evaluation """ - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(k_value=1, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) self.select_policy = select_policy assert select_policy in {"first", "random"} if select_policy == "random": self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, device=get_current_device()) + low=torch.tensor(0.0, device=get_accelerator().get_current_device()), + high=torch.tensor(1.0, device=get_accelerator().get_current_device()), ).rsample def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: @@ -216,18 +223,22 @@ class Top2Router(MoeRouter): drop_tks (bool, optional): Whether drops tokens in evaluation. """ - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(k_value=2, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: """ @@ -255,8 +266,8 @@ class Top2Router(MoeRouter): top2_idx = torch.argmax(logits_except1, dim=-1) mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - cmask = (mask1 + mask2) # loss: [s, e] - cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 + cmask = mask1 + mask2 # loss: [s, e] + cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 # calculate loss expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) @@ -269,7 +280,7 @@ class Top2Router(MoeRouter): dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() - rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] + rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel) rank2 += torch.sum(mask1, dim=-2, keepdim=True) @@ -336,15 +347,18 @@ class TopKRouter(MoeRouter): oversubscribed / reach capacity. """ - def __init__(self, - num_selected_experts: int, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, - drop_tks) + def __init__( + self, + num_selected_experts: int, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks + ) def forward( self, @@ -410,7 +424,7 @@ class TopKRouter(MoeRouter): # The combine array will be used for combining expert outputs, scaled by the # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # expert_capacity]. - combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask) + combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) return combine_array, dispatch_mask diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index 5a17a6e0d..e25e7dd48 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -7,13 +7,12 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from colossalai.accelerator import get_accelerator from colossalai.moe.manager import MOE_MANAGER from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor -from colossalai.utils import get_current_device class ForceFP32Parameter(torch.nn.Parameter): - def half(self, memory_format=None): return self.data.clone() @@ -30,8 +29,8 @@ class NormalNoiseGenerator: def __init__(self, num_experts: int): self.normal = torch.distributions.normal.Normal( - loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()), + loc=torch.tensor(0.0, device=get_accelerator().get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, device=get_accelerator().get_current_device()), ).rsample def __call__(self, inputs: torch.Tensor): @@ -52,8 +51,8 @@ class UniformNoiseGenerator: def __init__(self, eps: float = 1e-2): self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(1.0 - eps, device=get_current_device()), - high=torch.tensor(1.0 + eps, device=get_current_device()), + low=torch.tensor(1.0 - eps, device=get_accelerator().get_current_device()), + high=torch.tensor(1.0 + eps, device=get_accelerator().get_current_device()), ).rsample def __call__(self, inputs: torch.Tensor): @@ -142,7 +141,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]] epsize_param_dict = dict() for param in model.parameters(): if not is_moe_tensor(param): - ep_size = 1 # set ep_size to 1 for dp parameters + ep_size = 1 # set ep_size to 1 for dp parameters else: ep_size = get_ep_size(param) if ep_size not in epsize_param_dict: @@ -193,18 +192,13 @@ def create_ep_hierarchical_group( assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually." nproc_per_node = int(nproc_per_node) else: - assert dist.get_world_size() % nproc_per_node == 0, \ - "nproc_per_node should be a divisor of world_size." + assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size." num_node = dist.get_world_size() // nproc_per_node intra_src_rank = None ep_intra_node_group = None for i in range(num_node): - ep_intra_ranks = [ - i * nproc_per_node + j - for j in range(nproc_per_node) - if j in ep_group_ranks - ] + ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks] group = dist.new_group(ep_intra_ranks) if rank in ep_intra_ranks: assert ep_intra_node_group is None @@ -212,10 +206,7 @@ def create_ep_hierarchical_group( intra_src_rank = ep_intra_ranks[0] ep_inter_node_group = None - ep_inter_ranks = [ - ep_group_ranks[0] + i * nproc_per_node - for i in range(num_node) - ] + ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)] if len(ep_inter_ranks) > 1: group = dist.new_group(ep_inter_ranks) if rank in ep_inter_ranks: diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 72480526b..20f316c2a 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -7,10 +7,10 @@ import torch.cuda from torch.nn import Module from torch.utils._pytree import tree_map +from colossalai.accelerator import get_accelerator from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.device import get_current_device from ._utils import get_batch_size, get_micro_batch, model_forward, to_device from .base import PipelineSchedule @@ -86,7 +86,7 @@ class GenerateSchedule(PipelineSchedule): """ micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) self.microbatch_offset += self.microbatch_size - return tree_map(partial(to_device, device=get_current_device()), micro_batch) + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) def _prepare_inputs_for_interval_stage(self): """ diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index cbf6dd80f..91d936bfd 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -6,10 +6,10 @@ import torch.cuda from torch.nn import Module from torch.utils._pytree import tree_map +from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.device import get_current_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule @@ -56,7 +56,7 @@ class InterleavedSchedule(PipelineSchedule): """ micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) self.microbatch_offset[model_chunk_id] += self.microbatch_size - return tree_map(partial(to_device, device=get_current_device()), micro_batch) + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: """Helper method to get the model chunk ID given the iteration number. @@ -292,7 +292,7 @@ class InterleavedSchedule(PipelineSchedule): outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None if return_loss and self.stage_manager.is_last_stage(): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) else: accum_loss = None diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index fd918cf19..606bf8797 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -6,10 +6,10 @@ import torch.cuda from torch.nn import Module from torch.utils._pytree import tree_map +from colossalai.accelerator import get_accelerator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.device import get_current_device from ._utils import ( detach, @@ -80,7 +80,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): """ micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) self.microbatch_offset += self.microbatch_size - return tree_map(partial(to_device, device=get_current_device()), micro_batch) + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) def recv_forward(self, prev_rank: int = None) -> Any: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. @@ -297,7 +297,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None if return_loss and self.stage_manager.is_last_stage(): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) else: accum_loss = None diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 96fd3bd7b..0d2cc1b33 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -7,7 +7,7 @@ from torch import nn from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup, get_world_size -from colossalai.utils.device import get_current_device, get_rng_state, manual_seed, set_rng_state +from colossalai.accelerator import get_accelerator class SeqParallelUtils: @@ -110,10 +110,10 @@ class Randomizer: # 1. get the current rng state # 2. set the seed and store the rng state # 3. recover the original rng state - device_original_rng_state = get_rng_state() - manual_seed(seed) - self.device_rng_state = get_rng_state() - set_rng_state(device_original_rng_state) + device_original_rng_state = get_accelerator().get_rng_state() + get_accelerator().manual_seed(seed) + self.device_rng_state = get_accelerator().get_rng_state() + get_accelerator().set_rng_state(device_original_rng_state) # to the same for cpu rng state cpu_original_rng_state = torch.get_rng_state() @@ -122,10 +122,10 @@ class Randomizer: torch.set_rng_state(cpu_original_rng_state) def _set_device_rng_state(self, rng_state): - set_rng_state(rng_state) + get_accelerator().set_rng_state(rng_state) def _get_device_rng_state(self): - current_state = get_rng_state() + current_state = get_accelerator().get_rng_state() return current_state def _set_cpu_rng_state(self, rng_state): @@ -210,7 +210,7 @@ class Randomizer: index = Randomizer.index() if dist.is_initialized(): # convert the index to tensor - index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device()) + index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device()) # all gather the index gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] @@ -232,7 +232,7 @@ class Randomizer: if dist.is_initialized(): # convert the index to tensor - index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device()) + index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device()) # all gather the index gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index 7cd24b0ad..5f6864ff0 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -9,7 +9,8 @@ from typing import Any, Callable, List import torch import torch.multiprocessing as mp from packaging import version -from colossalai.utils.device import empty_cache, reset_max_memory_allocated, reset_peak_memory_stats, synchronize, reset_max_memory_cached, device_count + +from colossalai.accelerator import get_accelerator def parameterize(argument: str, values: List[Any]) -> Callable: @@ -199,7 +200,7 @@ def skip_if_not_enough_gpus(min_gpus: int): def _wrap_func(f): def _execute_by_gpu_num(*args, **kwargs): - num_avail_gpu = device_count() + num_avail_gpu = get_accelerator().device_count() if num_avail_gpu >= min_gpus: f(*args, **kwargs) @@ -263,11 +264,11 @@ def clear_cache_before_run(): def _wrap_func(f): def _clear_cache(*args, **kwargs): - empty_cache() - reset_peak_memory_stats() - reset_max_memory_allocated() - reset_max_memory_cached() - synchronize() + get_accelerator().empty_cache() + get_accelerator().reset_peak_memory_stats() + get_accelerator().reset_max_memory_allocated() + get_accelerator().reset_max_memory_cached() + get_accelerator().synchronize() gc.collect() f(*args, **kwargs) diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 0246a35e2..9d33e4668 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -7,17 +7,12 @@ from .common import ( is_ddp_ignored, set_seed, ) -from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize from .multi_tensor_apply import multi_tensor_applier from .tensor_detector import TensorDetector from .timer import MultiTimer, Timer __all__ = [ "conditional_context", - "get_current_device", - "synchronize", - "empty_cache", - "set_to_cuda", "Timer", "MultiTimer", "multi_tensor_applier", @@ -28,6 +23,4 @@ __all__ = [ "free_storage", "set_seed", "is_ddp_ignored", - "set_device", - "IS_NPU_AVAILABLE", ] diff --git a/colossalai/utils/device.py b/colossalai/utils/device.py deleted file mode 100644 index c70dbdaa5..000000000 --- a/colossalai/utils/device.py +++ /dev/null @@ -1,223 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from typing import Any, Dict, List, Optional, Tuple, Callable - -import torch -import torch.distributed as dist - -IS_NPU_AVAILABLE: bool = False -try: - import torch_npu # noqa - - IS_NPU_AVAILABLE = torch.npu.is_available() -except ImportError: - pass - - -def set_to_cuda(models): - """Send model to gpu. - - :param models: nn.module or a list of module - """ - if isinstance(models, list) and len(models) > 1: - ret = [] - for model in models: - ret.append(model.to(get_current_device())) - return ret - elif isinstance(models, list): - return models[0].to(get_current_device()) - else: - return models.to(get_current_device()) - - -def get_current_device() -> torch.device: - """ - Returns currently selected device (gpu/cpu). - If cuda available, return gpu, otherwise return cpu. - """ - if torch.cuda.is_available(): - return torch.device(f"cuda:{torch.cuda.current_device()}") - elif IS_NPU_AVAILABLE: - return torch.device(f"npu:{torch.npu.current_device()}") - else: - return torch.device("cpu") - - -def _dispatch_device_func(fn_name: str, *args, **kwargs): - if torch.cuda.is_available(): - return getattr(torch.cuda, fn_name)(*args, **kwargs) - elif IS_NPU_AVAILABLE: - return getattr(torch.npu, fn_name)(*args, **kwargs) - else: - raise RuntimeError("No device available") - - -# device semantics - - -def can_device_access_peer(device, peer_device) -> bool: - return _dispatch_device_func("can_device_access_peer", device, peer_device) - - -def current_device() -> int: - return _dispatch_device_func("current_device") - - -def current_stream(device=None): - return _dispatch_device_func("current_stream", device) - - -def default_stream(device=None): - return _dispatch_device_func("default_stream", device) - - -def device_count() -> int: - return _dispatch_device_func("device_count") - - -def get_device_capability(device=None) -> Tuple[int, int]: - return _dispatch_device_func("get_device_capability", device) - - -def get_device_name(device=None) -> str: - return _dispatch_device_func("get_device_name", device) - - -def get_device_properties(device): - return _dispatch_device_func("get_device_properties", device) - - -def set_device(index: Optional[int] = None) -> None: - if index is None: - index = dist.get_rank() % device_count() - _dispatch_device_func("set_device", index) - - -def set_stream(stream_): - return _dispatch_device_func("set_stream", stream_) - - -def stream(stream_): - return _dispatch_device_func("stream", stream_) - - -def synchronize(): - return _dispatch_device_func("synchronize") - - -def utilization(device=None) -> int: - return _dispatch_device_func("utilization", device) - - -# random number generator - - -def get_rng_state(device="cuda") -> torch.Tensor: - return _dispatch_device_func("get_rng_state", device) - - -def get_rng_state_all() -> List[torch.Tensor]: - return _dispatch_device_func("get_rng_state_all") - - -def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None: - return _dispatch_device_func("set_rng_state", new_state, device) - - -def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None: - return _dispatch_device_func("set_rng_state_all", new_states) - - -def manual_seed(seed: int) -> None: - return _dispatch_device_func("manual_seed", seed) - - -def manual_seed_all(seed: int) -> None: - return _dispatch_device_func("manual_seed_all", seed) - - -def seed() -> None: - return _dispatch_device_func("seed") - - -def seed_all() -> None: - return _dispatch_device_func("seed_all") - - -def initial_seed() -> int: - return _dispatch_device_func("initial_seed") - - -# streams and events - - -def Stream(device=None, priority=0, **kwargs): - return _dispatch_device_func("Stream", device, priority, **kwargs) - - -def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): - return _dispatch_device_func("Event", enable_timing, blocking, interprocess) - - -# memory management - - -def empty_cache() -> None: - return _dispatch_device_func("empty_cache") - - -def memory_stats(device=None) -> Dict[str, Any]: - return _dispatch_device_func("memory_stats", device) - - -def memory_summary(device=None, abbreviated=False) -> str: - return _dispatch_device_func("memory_summary", device, abbreviated) - - -def memory_snapshot(): - return _dispatch_device_func("memory_snapshot") - - -def memory_allocated(device=None) -> int: - return _dispatch_device_func("memory_allocated", device) - - -def max_memory_allocated(device=None) -> int: - return _dispatch_device_func("max_memory_allocated", device) - - -def reset_max_memory_allocated(device=None) -> None: - return _dispatch_device_func("reset_max_memory_allocated", device) - - -def reset_max_memory_cached(device=None) -> None: - return _dispatch_device_func("reset_max_memory_cached", device) - - -def memory_reserved(device=None) -> int: - return _dispatch_device_func("memory_reserved", device) - - -def max_memory_reserved(device=None) -> int: - return _dispatch_device_func("max_memory_reserved", device) - - -def set_per_process_memory_fraction(fraction: float, device=None) -> None: - return _dispatch_device_func("set_per_process_memory_fraction", fraction, device) - - -def reset_peak_memory_stats(device=None) -> None: - return _dispatch_device_func("reset_peak_memory_stats", device) - - -# amp - - -def autocast() -> Callable: - if torch.cuda.is_available(): - return torch.cuda.amp.autocast() - elif IS_NPU_AVAILABLE: - return torch.npu.amp.autocast() - else: - raise RuntimeError("No device available") diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py index 8ab6b46f2..0fbdd0932 100644 --- a/colossalai/utils/timer.py +++ b/colossalai/utils/timer.py @@ -3,7 +3,7 @@ import time from typing import Tuple -from .device import synchronize +from colossalai.accelerator import get_accelerator class Timer: @@ -21,13 +21,13 @@ class Timer: @property def current_time(self) -> float: - synchronize() + get_accelerator().synchronize() return time.time() def start(self): """Firstly synchronize cuda, reset the clock and then start the timer.""" self._elapsed = 0 - synchronize() + get_accelerator().synchronize() self._start_time = time.time() self._started = True @@ -44,7 +44,7 @@ class Timer: Returns: int: Start-stop interval. """ - synchronize() + get_accelerator().synchronize() end_time = time.time() elapsed = end_time - self._start_time if keep_in_history: diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index defc6c4cb..7a9f58701 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -6,8 +6,7 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from colossalai.utils import get_current_device -from colossalai.utils.device import IS_NPU_AVAILABLE +from colossalai.accelerator import get_accelerator class TensorState(Enum): @@ -107,7 +106,7 @@ class Chunk: self.valid_end = self.shard_size self.dtype = dtype - device = init_device or get_current_device() + device = init_device or get_accelerator().get_current_device() # chunk_temp is a global chunk, which only exists during building the chunks. self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero @@ -125,7 +124,7 @@ class Chunk: # configure the init device of the shard # no-offload default: fp16, fp32 -> CUDA # offload default: fp16, fp32 -> CPU - self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device() + self.shard_device = torch.device("cpu") if cpu_shard_init else get_accelerator().get_current_device() self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() self.shard_mem = self.chunk_mem // self.pg_size @@ -192,10 +191,7 @@ class Chunk: if self.chunk_temp is not None: return self.chunk_temp.device.type else: - if self.is_gathered or self.cuda_shard is not None: - return "npu" if IS_NPU_AVAILABLE else "cuda" - else: - return "cpu" + return get_accelerator().name @property def payload(self) -> torch.Tensor: @@ -297,7 +293,7 @@ class Chunk: self.valid_end = self.utilized_size - self.shard_begin if self.chunk_temp.device.type == "cpu": - self.cuda_global_chunk = self.chunk_temp.to(get_current_device()) + self.cuda_global_chunk = self.chunk_temp.to(get_accelerator().get_current_device()) self.__update_tensors_ptr() else: self.cuda_global_chunk = self.chunk_temp @@ -334,12 +330,12 @@ class Chunk: return if device.type == "cuda" or device.type == "npu": - assert device == get_current_device(), "can't move chunk to another device" + assert device == get_accelerator().get_current_device(), "can't move chunk to another device" if self.cuda_shard: return - self.cuda_shard = self.cpu_shard.to(get_current_device()) + self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device()) if not self.pin_memory: self.cpu_shard = None @@ -394,7 +390,9 @@ class Chunk: if self.extra_dp_group is not None: dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group) else: - self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) + self.cuda_shard = torch.empty( + self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device() + ) input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) @@ -533,7 +531,7 @@ class Chunk: # only be called when optimizer state is in CPU memory # the grad and param should be in the same device assert self.cuda_shard is None - temp = optim_chunk.cpu_shard.to(get_current_device()) + temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device()) # avoid to transform FP32 in CPU self.cuda_shard = temp.to(self.dtype) @@ -631,7 +629,7 @@ class Chunk: grad_chunk.valid_end = self.valid_end if grad_chunk.chunk_temp.device.type == "cpu": - grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device()) + grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_accelerator().get_current_device()) else: grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp grad_chunk.chunk_temp = None diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 5f4f37c26..5bc662a61 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -5,7 +5,8 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from colossalai.utils import free_storage, get_current_device +from colossalai.accelerator import get_accelerator +from colossalai.utils import free_storage from .chunk import Chunk, ChunkFullError, TensorState @@ -20,7 +21,7 @@ class ChunkManager: """ def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: - self.device = init_device or get_current_device() + self.device = init_device or get_accelerator().get_current_device() self.dp_degree_chunk_size_dict: Dict[int, int] = dict() self.kwargs_config = chunk_configuration for k, v in self.kwargs_config.items(): @@ -107,7 +108,7 @@ class ChunkManager: return self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": - chunk.shard_move(get_current_device()) + chunk.shard_move(get_accelerator().get_current_device()) self.__add_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) @@ -276,7 +277,10 @@ class ChunkManager: accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size) else: accumulated_grad = ( - chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size) + chunk.grad_chunk.cpu_shard.to(get_accelerator().get_current_device()) + .clone() + .detach() + .mul_(chunk.pg_size) ) accumulated_grad_gathered = False diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 5217b8036..79831cf33 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -10,6 +10,7 @@ import torch.nn as nn from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group +from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.interface import ModelWrapper from colossalai.lazy import LazyTensor @@ -27,7 +28,7 @@ from colossalai.tensor.d_tensor import ( is_distributed_tensor, ) from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored +from colossalai.utils import _cast_float, free_storage, is_ddp_ignored from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .gemini_hook import GeminiZeROHook @@ -766,7 +767,7 @@ class GeminiDDP(ModelWrapper): # move ignored parameters to CUDA if is_ddp_ignored(p): - p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision) + p.data = p.data.to(device=get_accelerator().get_current_device(), dtype=self.mixed_precision) continue # create a fp16 parameter @@ -815,7 +816,7 @@ class GeminiDDP(ModelWrapper): for buffer in self.module.buffers(): if isinstance(buffer, LazyTensor): buffer.materialize() - buffer.data = buffer.to(get_current_device()) + buffer.data = buffer.to(get_accelerator().get_current_device()) if torch.is_floating_point(buffer): buffer.data = buffer.to(self.mixed_precision) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 8f828bd6c..09fad1e77 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -11,6 +11,7 @@ from torch.distributed import ProcessGroup from torch.nn import Parameter from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.interface import OptimizerWrapper @@ -26,7 +27,7 @@ from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, is_distributed_tensor, ) -from colossalai.utils import disposable, get_current_device, is_ddp_ignored +from colossalai.utils import disposable, is_ddp_ignored from .chunk import Chunk, ChunkManager from .gemini_ddp import GeminiDDP @@ -233,7 +234,7 @@ class GeminiOptimizer(OptimizerWrapper): grad_chunk.l2_norm = None # clear l2 norm - comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) + comm_buffer = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device()) for group, part_norm in group_to_norm.items(): comm_buffer.fill_(part_norm) dist.all_reduce(comm_buffer, group=group) @@ -314,10 +315,10 @@ class GeminiOptimizer(OptimizerWrapper): continue if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: - self.chunk_manager.move_chunk(chunk32, get_current_device()) + self.chunk_manager.move_chunk(chunk32, get_accelerator().get_current_device()) # stores grad now - self.chunk_manager.move_chunk(chunk16, get_current_device()) - self.module.set_chunk_grad_device(chunk16, get_current_device()) + self.chunk_manager.move_chunk(chunk16, get_accelerator().get_current_device()) + self.module.set_chunk_grad_device(chunk16, get_accelerator().get_current_device()) fp32_params_used_cuda_margin_mem += chunk32.payload_mem for group in self.param_groups: @@ -328,7 +329,7 @@ class GeminiOptimizer(OptimizerWrapper): state = self.optim.state[fake_param] for k, v in state.items(): if isinstance(v, torch.Tensor): - state[k] = v.to(get_current_device()) + state[k] = v.to(get_accelerator().get_current_device()) def _register_states_(self): for group in self.optim.param_groups: @@ -551,7 +552,7 @@ class GeminiOptimizer(OptimizerWrapper): self, param_id: int, state_names: list, - device: torch.device = get_current_device(), + device: torch.device = get_accelerator().get_current_device(), dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ diff --git a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py index b5e40a817..e302805df 100644 --- a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py @@ -1,6 +1,6 @@ from typing import Optional -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from colossalai.zero.gemini.chunk import ChunkManager from .memory_stats import MemStats @@ -33,4 +33,4 @@ class ChunkMemStatsCollector(MemStatsCollector): def cuda_margin_mem(self) -> float: from colossalai.legacy.utils.memory import colo_device_memory_capacity - return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda + return colo_device_memory_capacity(get_accelerator().get_current_device()) - self._memstats.max_overall_cuda diff --git a/colossalai/zero/gemini/memory_tracer/memory_monitor.py b/colossalai/zero/gemini/memory_tracer/memory_monitor.py index 513a6326d..82c8e9dab 100644 --- a/colossalai/zero/gemini/memory_tracer/memory_monitor.py +++ b/colossalai/zero/gemini/memory_tracer/memory_monitor.py @@ -5,7 +5,7 @@ from time import sleep, time import torch -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator class MemoryMonitor: @@ -77,7 +77,7 @@ class AsyncMemoryMonitor(MemoryMonitor): super().__init__() self.keep_measuring = False - current_device = get_current_device() + current_device = get_accelerator().get_current_device() def _set_cuda_device(): torch.cuda.set_device(current_device) @@ -116,7 +116,7 @@ class AsyncMemoryMonitor(MemoryMonitor): while self.keep_measuring: max_usage = max( max_usage, - colo_device_memory_used(get_current_device()), + colo_device_memory_used(get_accelerator().get_current_device()), ) sleep(self.interval) return max_usage diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 8a74eb587..388999549 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.utils.memory import colo_device_memory_capacity -from colossalai.utils import get_current_device from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager @@ -85,7 +85,7 @@ class StaticPlacementPolicy(PlacementPolicy): # init offload optim settings # keep gathered chunks are in CUDA if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem: - device = get_current_device() + device = get_accelerator().get_current_device() else: device = torch.device("cpu") # real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here @@ -140,7 +140,7 @@ class AutoPlacementPolicy(PlacementPolicy): int: the volume of memory that is evicted """ start = time() - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) used_cuda_model_data = self.chunk_manager.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. @@ -194,7 +194,7 @@ class AutoPlacementPolicy(PlacementPolicy): # init offload optim settings # keep gathered chunks are in CUDA if chunk.keep_gathered: - grads_device_map[p] = get_current_device() + grads_device_map[p] = get_accelerator().get_current_device() else: grads_device_map[p] = torch.device("cpu") diff --git a/colossalai/zero/gemini/utils.py b/colossalai/zero/gemini/utils.py index 5305953fe..b563ea5b2 100644 --- a/colossalai/zero/gemini/utils.py +++ b/colossalai/zero/gemini/utils.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from .chunk import Chunk @@ -18,11 +18,11 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype): if chunk.cuda_shard is not None: shard_temp = chunk.cuda_shard else: - shard_temp = chunk.cpu_shard.to(get_current_device()) + shard_temp = chunk.cpu_shard.to(get_accelerator().get_current_device()) shard_temp = shard_temp.to(dtype) - total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_current_device()) + total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_accelerator().get_current_device()) gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index c1b35ee17..81eba6fe5 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -12,7 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup from torch.optim import Optimizer -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.mixed_precision_mixin import ( BF16MixedPrecisionMixin, FP16MixedPrecisionMixin, @@ -22,9 +22,6 @@ from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.tensor.moe_tensor.api import is_moe_tensor -# from colossalai.tensor import ColoParameter, ProcessGroup -from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device - from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor from .bookkeeping import BucketStore, GradientStore, ParameterStore @@ -183,7 +180,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # intialize communication stream for # communication-compuation overlapping if self._overlap_communication: - self._comm_stream = device_utils.Stream() + self._comm_stream = get_accelerator().Stream() # reduction hook is only used if overlapping communication # or stage 2 is used @@ -217,7 +214,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): return len(self._working_param_groups) def _sanity_checks(self): - assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required" + assert get_accelerator().name in ["cuda", "npu"], "device is required" for param_group in self.optim.param_groups: group_params = param_group["params"] for param in group_params: @@ -228,7 +225,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): def _create_master_param_current_rank(self, param_list): # split each param evenly by world size params_current_rank = [] - device = "cpu" if self._cpu_offload else get_current_device() + device = "cpu" if self._cpu_offload else get_accelerator().get_current_device() for param in param_list: padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size @@ -340,11 +337,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if len(moe_grad_list) > 0: moe_flat_grads.record_stream(stream) # waiting for ops in the default stream finishing - stream.wait_stream(device_utils.current_stream()) + stream.wait_stream(get_accelerator().current_stream()) else: - stream = device_utils.current_stream() + stream = get_accelerator().current_stream() - with device_utils.stream(stream): + with get_accelerator().stream(stream): group_id = self._bucket_store.current_group_id if self.moe_extra_dp_pg is None: @@ -486,7 +483,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # clear reduced grads if self._overlap_communication: - device_utils.synchronize() + get_accelerator().synchronize() self.zero_grad() @@ -505,7 +502,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # clear reduced grads if self._overlap_communication: - device_utils.synchronize() + get_accelerator().synchronize() self.zero_grad() @@ -621,7 +618,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): release_param_grad(self._master_param_groups_of_current_rank[group_id]) # update working partition updated by the current rank - device = get_current_device() + device = get_accelerator().get_current_device() for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): @@ -661,7 +658,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): norm_type = float(norm_type) if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float + ) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) total_norm = total_norm_cuda.item() @@ -673,7 +672,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # Sum across all model parallel GPUs. total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float ) torch.distributed.all_reduce( total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg @@ -765,7 +764,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): Dict: the pytorch form state_dict """ zero_state = dict() - device = get_current_device() + device = get_accelerator().get_current_device() for param, state in self.optim.state.items(): zero_state[param] = copy.deepcopy(state) for k, v in state.items(): @@ -827,7 +826,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ret_block = dict() ret_block_size = 0 - device = get_current_device() + device = get_accelerator().get_current_device() local_states = self.optim.state_dict()["state"] for param_idx, states in local_states.items(): current_block_size = 0 diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 7a0e3b1a0..e87eafb6e 100644 --- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -45,7 +45,6 @@ from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device ``` ## Define Plugin Create a `HybridParallelPlugin` object and specify the desired parallelism strategies to be used. In this example, both pipeline parallelism and ZeRO-1 are used simultaneously. @@ -149,7 +148,7 @@ model, optimizer, _criterion, _, lr_scheduler = booster.boost( ## Training GPT-2 using hybrid parallelism -In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training. +In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training. Define a training function. When pipeline parallelism is used, you need to call `booster.execute_pipeline` to schedule the stages of model training. ```python def train_epoch( @@ -204,4 +203,4 @@ Training the gpt-2 model for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` - \ No newline at end of file + diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 117406980..ae941b489 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -43,7 +43,6 @@ from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device ``` ### 定义plugin 定义一个[`HybridParallelPlugin`](../basics/booster_plugins.md)对象,指定所需要使用的并行策略,在该例子中,同时使用了流水线并行和zero1. @@ -201,4 +200,4 @@ def train_epoch( for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` - \ No newline at end of file + diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py index 5396de693..40b11d649 100644 --- a/examples/community/roberta/pretraining/run_pretraining.py +++ b/examples/community/roberta/pretraining/run_pretraining.py @@ -16,10 +16,10 @@ from utils.global_vars import get_tensorboard_writer, get_timers, set_global_var from utils.logger import Logger import colossalai +from colossalai.accelerator import get_accelerator from colossalai.context import ParallelMode from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.tensor import ProcessGroup, ShardSpec -from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext @@ -53,7 +53,7 @@ def main(): set_global_variables(launch_time, args.tensorboard_path) world_size = torch.distributed.get_world_size() - get_current_device() + get_accelerator().get_current_device() # build model, optimizer and criterion if args.distplan.startswith("CAI"): @@ -67,7 +67,10 @@ def main(): # build GPT model with ColoInitContext( - device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg + device=get_accelerator().get_current_device(), + dtype=torch.half, + default_dist_spec=default_dist_spec, + default_pg=shard_pg, ): config, model, numel = get_model(args, logger) @@ -78,7 +81,7 @@ def main(): elif args.distplan == "CAI_Gemini": gemini_config = dict( strict_ddp_mode=args.tp_degree == 1, - device=get_current_device(), + device=get_accelerator().get_current_device(), placement_policy=args.placement, pin_memory=True, hidden_dim=model.config.hidden_size, diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index 1a7f8da7f..cc2b2ebc7 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -20,11 +20,11 @@ from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device disable_existing_loggers() logger = get_dist_logger() @@ -386,7 +386,7 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 + torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -401,7 +401,7 @@ def main(args): sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - pipeline.to(get_current_device()) + pipeline.to(get_accelerator().get_current_device()) for example in tqdm( sample_dataloader, @@ -578,8 +578,8 @@ def main(args): # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. - vae.to(get_current_device(), dtype=weight_dtype) - text_encoder.to(get_current_device(), dtype=weight_dtype) + vae.to(get_accelerator().get_current_device(), dtype=weight_dtype) + text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader)) @@ -613,7 +613,7 @@ def main(args): torch.cuda.reset_peak_memory_stats() # Move batch to gpu for key, value in batch.items(): - batch[key] = value.to(get_current_device(), non_blocking=True) + batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True) # Convert images to latent space optimizer.zero_grad() diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index ea6dde8bb..227488abe 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -21,13 +21,13 @@ from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device disable_existing_loggers() logger = get_dist_logger() @@ -385,7 +385,7 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 + torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -400,7 +400,7 @@ def main(args): sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - pipeline.to(get_current_device()) + pipeline.to(get_accelerator().get_current_device()) for example in tqdm( sample_dataloader, @@ -598,8 +598,8 @@ def main(args): # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. - vae.to(get_current_device(), dtype=weight_dtype) - text_encoder.to(get_current_device(), dtype=weight_dtype) + vae.to(get_accelerator().get_current_device(), dtype=weight_dtype) + text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader)) @@ -633,7 +633,7 @@ def main(args): torch.cuda.reset_peak_memory_stats() # Move batch to gpu for key, value in batch.items(): - batch[key] = value.to(get_current_device(), non_blocking=True) + batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True) # Convert images to latent space optimizer.zero_grad() diff --git a/examples/images/resnet/train.py b/examples/images/resnet/train.py index 13df516d4..5871bbf87 100644 --- a/examples/images/resnet/train.py +++ b/examples/images/resnet/train.py @@ -13,12 +13,12 @@ from torch.utils.data import DataLoader from tqdm import tqdm import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl @torch.no_grad() def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: model.eval() - correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) - total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) for images, labels in test_dataloader: images = images.cuda() labels = labels.cuda() diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index b770bc9cf..078017324 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -33,9 +33,10 @@ def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224 def colo_memory_cap(size_in_GB): - from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + from colossalai.accelerator import get_accelerator + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) print(f"Limiting GPU memory usage to {size_in_GB} GB") diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index 9a26098b3..26cac977a 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -6,10 +6,9 @@ import torch.distributed as dist import transformers import colossalai -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.inference import InferenceEngine from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -from colossalai.utils.device import get_current_device GIGABYTE = 1024**3 MEGABYTE = 1024 * 1024 @@ -52,7 +51,7 @@ CONFIG_MAP = { def data_gen(batch_size: int = 4, seq_len: int = 512): - input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device()) + input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_accelerator().get_current_device()) attention_mask = torch.ones_like(input_ids) data = dict(input_ids=input_ids, attention_mask=attention_mask) return data @@ -97,9 +96,9 @@ def print_details_info(outputs, model_config, args, whole_end2end): msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n" if torch.cuda.is_available(): - msg += f"-------Memory Summary Device:{device_utils.current_device()}-------\n" - msg += f"Max memory allocated: {device_utils.max_memory_allocated() / GIGABYTE:.2f} GB\n" - msg += f"Max memory reserved: {device_utils.max_memory_reserved() / GIGABYTE:.2f} GB\n" + msg += f"-------Memory Summary Device:{get_accelerator().current_device()}-------\n" + msg += f"Max memory allocated: {get_accelerator().max_memory_allocated() / GIGABYTE:.2f} GB\n" + msg += f"Max memory reserved: {get_accelerator().max_memory_reserved() / GIGABYTE:.2f} GB\n" print(msg) diff --git a/examples/inference/run_llama_inference.py b/examples/inference/run_llama_inference.py index 8f85a9363..b5228c64e 100644 --- a/examples/inference/run_llama_inference.py +++ b/examples/inference/run_llama_inference.py @@ -5,9 +5,9 @@ import torch.distributed as dist from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.inference import InferenceEngine from colossalai.testing import spawn -from colossalai.utils.device import get_current_device INPUT_TEXTS = [ "What is the longest river in the world?", @@ -57,7 +57,7 @@ def run_inference(args): ) inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True) - inputs = {k: v.to(get_current_device()) for k, v in inputs.items()} + inputs = {k: v.to(get_accelerator().get_current_device()) for k, v in inputs.items()} outputs = engine.generate(inputs) if rank == 0: diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 563cfa58d..dc6768e58 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -18,11 +18,11 @@ from transformers import ( ) import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -59,7 +59,7 @@ def evaluate_model( use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) labels = batch["labels"] @@ -88,8 +88,10 @@ def evaluate_model( object_list = [None, None] dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) - metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) - accum_loss.add_(object_list[1].to(get_current_device())) + metric.add_batch( + predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels + ) + accum_loss.add_(object_list[1].to(get_accelerator().get_current_device())) else: batch = move_to_cuda(batch) diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py index e811e1acb..b35112498 100644 --- a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py +++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py @@ -7,13 +7,13 @@ from model_zoo import GPTLMLoss, get_gpt2_components from torch.utils._pytree import tree_map import colossalai +from colossalai.accelerator import get_accelerator from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.fx.profiler import parameter_size from colossalai.nn.optimizer import HybridAdam from colossalai.testing import spawn -from colossalai.utils import get_current_device def parse_args(): @@ -41,7 +41,7 @@ def train_gpt(args): 64, 8, ), - device=get_current_device(), + device=get_accelerator().get_current_device(), ) criterion = GPTLMLoss() diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 88b76c654..78d090ba2 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -12,12 +12,12 @@ from commons.utils import get_data, get_profile_context, get_tflops, get_time_st from packaging import version import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device CAI_VERSION = colossalai.__version__ @@ -141,7 +141,11 @@ def main(): criterion = GPTLMLoss() torch.manual_seed(123) if args.distplan.startswith("CAI"): - ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext() + ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if args.distplan == "CAI_Gemini" + else nullcontext() + ) # build GPT model with ctx: model = model_builder(args.model_type)(checkpoint=True) diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index 62804eff8..eb56ee530 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -13,11 +13,11 @@ from tqdm import tqdm from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -54,7 +54,7 @@ def evaluate_model( use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) labels = batch["labels"] @@ -83,8 +83,10 @@ def evaluate_model( object_list = [None, None] dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) - metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) - accum_loss.add_(object_list[1].to(get_current_device())) + metric.add_batch( + predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels + ) + accum_loss.add_(object_list[1].to(get_accelerator().get_current_device())) else: batch = move_to_cuda(batch) diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py index b2e3f71a5..ec3df50c4 100644 --- a/examples/language/gpt/titans/model/embed.py +++ b/examples/language/gpt/titans/model/embed.py @@ -5,6 +5,7 @@ from torch import nn as nn from torch.nn import functional as F from torch.nn.parameter import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.base_layer import ParallelLayer @@ -12,7 +13,6 @@ from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_b from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row from colossalai.legacy.nn.layer.utils import divide from colossalai.legacy.registry import LAYERS, LOSSES -from colossalai.utils import get_current_device class VocabParallelEmbedding(torch.nn.Module): @@ -96,7 +96,9 @@ class VocabParallelEmbedding(torch.nn.Module): if position_ids is not None: position_ids = position_ids.view(-1, input_shape[-1]) if position_ids is None: - position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device()) + position_ids = torch.arange( + 0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device() + ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_embeddings = self.position_embeddings(position_ids) @@ -194,7 +196,7 @@ class VocabParallelEmbedding1D(torch.nn.Module): self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index # Allocate weights and initialize. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs)) init.uniform_(self.weight, -1, 1) @@ -439,7 +441,9 @@ class HiddenParallelEmbedding(torch.nn.Module): if position_ids is not None: position_ids = position_ids.view(-1, input_shape[-1]) if position_ids is None: - position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device()) + position_ids = torch.arange( + 0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device() + ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_embeddings = self.position_embeddings(position_ids) @@ -532,7 +536,7 @@ class HiddenParallelEmbedding1D(torch.nn.Module): self._weight = None # Allocate weights and initialize. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs)) init.uniform_(self.weight, -1, 1) diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index d7a79a022..2f8a76044 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -13,13 +13,12 @@ from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaForCausalLM import colossalai -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Constants @@ -166,7 +165,7 @@ def main(): # Initialize Model and Optimizer # ============================== init_ctx = ( - LazyInitContext(default_device=get_current_device()) + LazyInitContext(default_device=get_accelerator().get_current_device()) if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) else nullcontext() ) @@ -197,7 +196,9 @@ def main(): torch.set_default_dtype(torch.bfloat16) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) torch.set_default_dtype(torch.float) - coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" + ) coordinator.print_on_master( f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" ) @@ -223,7 +224,7 @@ def main(): performance_evaluator.on_step_end(**batch) performance_evaluator.on_fit_end() - coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") if __name__ == "__main__": diff --git a/examples/language/llama2/data_utils.py b/examples/language/llama2/data_utils.py index a438833e1..6b9e8ef28 100644 --- a/examples/language/llama2/data_utils.py +++ b/examples/language/llama2/data_utils.py @@ -8,7 +8,7 @@ from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group from torch.utils.data import DataLoader, Dataset, DistributedSampler -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator class StatefulDistributedSampler(DistributedSampler): @@ -108,7 +108,9 @@ class RandomDataset(Dataset): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): self.num_samples = num_samples self.max_length = max_length - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py index f7708b1a3..66b540076 100644 --- a/examples/language/llama2/finetune.py +++ b/examples/language/llama2/finetune.py @@ -21,13 +21,13 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.models.llama.tokenization_llama import LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device def get_model_numel(model: nn.Module) -> int: @@ -191,7 +191,9 @@ def main(): config = LlamaConfig.from_pretrained(args.model_path) # use lazy init when using GeminiPlugin init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, GeminiPlugin) + else nullcontext() ) with init_ctx: diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/llama2/performance_evaluator.py index 6b1c92711..c2169a730 100644 --- a/examples/language/llama2/performance_evaluator.py +++ b/examples/language/llama2/performance_evaluator.py @@ -5,9 +5,8 @@ import torch import torch.distributed as dist from torch import Tensor -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.cluster import DistCoordinator -from colossalai.utils.device import get_current_device def divide(x: float, y: float) -> float: @@ -22,7 +21,7 @@ def divide(x: float, y: float) -> float: def all_reduce_mean(x: float, world_size: int) -> float: if world_size == 1: return x - tensor = torch.tensor([x], device=get_current_device()) + tensor = torch.tensor([x], device=get_accelerator().get_current_device()) dist.all_reduce(tensor) tensor = tensor / world_size return tensor.item() @@ -86,13 +85,13 @@ class PerformanceEvaluator: self.disable = self.ignore_steps > 0 and step < self.ignore_steps if self.disable: return - device_utils.synchronize() + get_accelerator().synchronize() self.timer.start() def on_step_end(self, input_ids: Tensor, **kwargs) -> None: if self.disable: return - device_utils.synchronize() + get_accelerator().synchronize() self.timer.end() batch_size, seq_len = input_ids.shape diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py index bb10f7a00..d32cec2a2 100644 --- a/examples/language/llama2/pretrain.py +++ b/examples/language/llama2/pretrain.py @@ -20,13 +20,13 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.models.llama.tokenization_llama import LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device MODEL_CONFIGS = { "7b": LlamaConfig(max_position_embeddings=4096), @@ -227,7 +227,9 @@ def main(): config = MODEL_CONFIGS[args.config] # use lazy init when using GeminiPlugin init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, GeminiPlugin) + else nullcontext() ) with init_ctx: diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 65562b386..03b660ecf 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -14,6 +14,7 @@ from transformers.models.llama import LlamaConfig from utils import PerformanceEvaluator, get_model_numel import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator @@ -21,7 +22,6 @@ from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import skip_init from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device def move_to_cuda(batch, device): @@ -64,13 +64,15 @@ class RandomDataset(Dataset): ) self.input_ids.append(encode["input_ids"]) self.attention_mask.append(encode["attention_mask"]) - self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device()) - self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device()) + self.input_ids = torch.cat(self.input_ids, dim=0).to(get_accelerator().get_current_device()) + self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_accelerator().get_current_device()) repeat_times = num_samples // self.input_ids.shape[0] + 1 self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples] self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples] else: - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index b08436166..1ae661f54 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -15,6 +15,7 @@ from transformers import T5Tokenizer from transformers.models.llama import LlamaConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator @@ -22,7 +23,6 @@ from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import skip_init from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device def move_to_cuda(batch, device): @@ -61,7 +61,9 @@ class RandomDataset(Dataset): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None): self.num_samples = num_samples self.max_length = max_length - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 7af02e24e..4fac7b507 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -14,12 +14,12 @@ from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper from torch.utils.data import DataLoader, Dataset import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn import HybridAdam -from colossalai.utils import get_current_device # constants @@ -159,7 +159,11 @@ if args.distplan == "colossalai": logger.info(f"plugin: {plugin}") booster = Booster(plugin=plugin, **booster_kwargs) - ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == "gemini" else nullcontext() + ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if args.plugin == "gemini" + else nullcontext() + ) with ctx: model = PaLM(num_tokens=50304, dim=4096, depth=64) diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py index 4407a51c3..a4733126f 100644 --- a/examples/tutorial/new_api/cifar_resnet/train.py +++ b/examples/tutorial/new_api/cifar_resnet/train.py @@ -13,12 +13,12 @@ from torch.utils.data import DataLoader from tqdm import tqdm import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl @torch.no_grad() def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: model.eval() - correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) - total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) for images, labels in test_dataloader: images = images.cuda() labels = labels.cuda() diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py index 700e4d2e0..ec6c852b5 100644 --- a/examples/tutorial/new_api/cifar_vit/train.py +++ b/examples/tutorial/new_api/cifar_vit/train.py @@ -13,13 +13,13 @@ from torch.utils.data import DataLoader from tqdm import tqdm import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.cluster import DistCoordinator from colossalai.nn.lr_scheduler import LinearWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -73,8 +73,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl @torch.no_grad() def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: model.eval() - correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) - total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) for images, labels in test_dataloader: images = images.cuda() labels = labels.cuda() diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py index 990822c9f..e97c9017f 100644 --- a/examples/tutorial/new_api/glue_bert/finetune.py +++ b/examples/tutorial/new_api/glue_bert/finetune.py @@ -12,11 +12,11 @@ from tqdm import tqdm from transformers import AutoConfig, BertForSequenceClassification, get_linear_schedule_with_warmup import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -45,7 +45,7 @@ def evaluate( model.eval() def evaluate_subset(dataloader: DataLoader): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) outputs = model(**batch) diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index 9bd23ffc8..3f0d04879 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -51,13 +51,13 @@ from transformers import ( from transformers.utils.versions import require_version import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.tensor import ProcessGroup from colossalai.legacy.utils import get_dataloader from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device from colossalai.zero import GeminiOptimizer require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") @@ -249,9 +249,9 @@ def parse_args(): def colo_memory_cap(size_in_GB): - from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) print("Using {} GB of GPU memory".format(size_in_GB)) @@ -265,7 +265,9 @@ class DummyDataloader: self.vocab_size = vocab_size def generate(self): - input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=get_current_device()) + input_ids = torch.randint( + 0, self.vocab_size, (self.batch_size, self.seq_len), device=get_accelerator().get_current_device() + ) attention_mask = torch.ones_like(input_ids) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids} @@ -390,7 +392,7 @@ def main(): if args.init_in_cpu: init_dev = torch.device("cpu") else: - init_dev = get_current_device() + init_dev = get_accelerator().get_current_device() cai_version = colossalai.__version__ logger.info(f"using Colossal-AI version {cai_version}") @@ -439,7 +441,9 @@ def main(): except ImportError: # this works for unreleased main branch, and this may be released on 0.2.9 from colossalai.zero import GeminiDDP - model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) + model = GeminiDDP( + model, device=get_accelerator().get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True + ) elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): from colossalai.gemini import ChunkManager, GeminiManager diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py index 2c8b260e6..373ba28b8 100644 --- a/tests/test_auto_parallel/test_offload/test_perf.py +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -5,13 +5,13 @@ import torch from torch.utils._pytree import tree_map import colossalai +from colossalai.accelerator import get_accelerator from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.fx.profiler import parameter_size from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_tensor.common_utils import set_seed @@ -31,7 +31,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): 64, 8, ), - device=get_current_device(), + device=get_accelerator().get_current_device(), ) criterion = LMLoss() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index aba746f19..d57717326 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -10,12 +10,12 @@ try: except: NO_CODEGEN = True +from colossalai.accelerator import get_accelerator from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer import HybridAdam from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn -from colossalai.utils import get_current_device from colossalai.zero import zero_model_wrapper, zero_optim_wrapper @@ -72,7 +72,11 @@ def check_auto_parallel_with_gemini(rank, world_size, port): print("=" * msg_length) gemini_config = dict( - strict_ddp_mode=False, device=get_current_device(), placement_policy="cpu", pin_memory=True, search_range_m=128 + strict_ddp_mode=False, + device=get_accelerator().get_current_device(), + placement_policy="cpu", + pin_memory=True, + search_range_m=128, ) gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 3eaaf882c..490c015a8 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -5,7 +5,7 @@ import torch.distributed as dist from torch.optim import Adam import colossalai -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin @@ -22,7 +22,7 @@ _STUCK_MODELS = ["transformers_albert_for_multiple_choice"] def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: - device = device_utils.get_current_device() + device = get_accelerator().get_current_device() try: plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5) booster = Booster(plugin=plugin) @@ -69,7 +69,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): continue err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) - device_utils.empty_cache() + get_accelerator().empty_cache() if err is None: passed_models.append(name) diff --git a/tests/test_legacy/test_comm/test_comm.py b/tests/test_legacy/test_comm/test_comm.py index 7d2c81972..079022e93 100644 --- a/tests/test_legacy/test_comm/test_comm.py +++ b/tests/test_legacy/test_comm/test_comm.py @@ -2,12 +2,12 @@ import pytest import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import all_gather, all_reduce, reduce_scatter from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.initialize import launch from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) @@ -16,7 +16,7 @@ SIZE = 8 def check_all_gather(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) + tensor = tensor.to(get_accelerator().get_current_device()) print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True) print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) @@ -27,7 +27,7 @@ def check_all_gather(): def check_reduce_scatter(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) + tensor = tensor.to(get_accelerator().get_current_device()) print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True) print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) @@ -38,7 +38,7 @@ def check_reduce_scatter(): def check_all_reduce(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) + tensor = tensor.to(get_accelerator().get_current_device()) print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True) print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) diff --git a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py index 8a9a73d65..f09df9253 100644 --- a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py +++ b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py @@ -2,6 +2,7 @@ import torch import torch.distributed as dist from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.global_variables import tensor_parallel_env as env @@ -16,13 +17,12 @@ from colossalai.legacy.nn import ( VocabParallelEmbedding1D, ) from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal def check_linear_col(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -68,7 +68,7 @@ def check_linear_col(): print_rank_0("linear_col forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) dist.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] grad = grad.clone() @@ -91,7 +91,7 @@ def check_linear_col(): def check_linear_row(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -137,7 +137,7 @@ def check_linear_row(): print_rank_0("linear_row forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) dist.broadcast(grad_master, src=0) grad = grad_master.clone() out.backward(grad) @@ -159,7 +159,7 @@ def check_linear_row(): def check_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -201,7 +201,7 @@ def check_embed(): def check_vocab_parallel_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -243,7 +243,7 @@ def check_vocab_parallel_embed(): def check_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -309,7 +309,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -369,7 +369,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -420,7 +420,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -472,7 +472,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_vocab_parallel_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -508,7 +508,7 @@ def check_vocab_parallel_loss(): @torch.no_grad() def check_linear_row_stream_inference(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py index 0bbc72eca..78bd407b9 100644 --- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py @@ -1,5 +1,6 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import ( @@ -16,13 +17,12 @@ from colossalai.legacy.nn import ( VocabParallelEmbedding2D, ) from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal def check_linear(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = HIDDEN_SIZE @@ -74,7 +74,7 @@ def check_linear(): print_rank_0("linear forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -103,7 +103,7 @@ def check_linear(): def check_layernorm(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE EPS = 1e-12 @@ -139,7 +139,7 @@ def check_layernorm(): print_rank_0("layer norm forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -154,7 +154,7 @@ def check_layernorm(): def check_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -201,7 +201,7 @@ def check_embed(): def check_patch_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -274,7 +274,7 @@ def check_patch_embed(): def check_vocab_parallel_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -321,7 +321,7 @@ def check_vocab_parallel_embed(): def check_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = NUM_CLASSES @@ -371,7 +371,7 @@ def check_classifier_no_given_weight(): print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] # grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -399,7 +399,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -467,7 +467,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -519,7 +519,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -573,7 +573,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -608,7 +608,7 @@ def check_loss(): def check_vocab_parallel_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -645,7 +645,7 @@ def check_vocab_parallel_loss(): # def check_attention(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 @@ -683,7 +683,7 @@ def check_vocab_parallel_loss(): # print_rank_0('self attention backward: pass') # def check_mlp(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE @@ -716,7 +716,7 @@ def check_vocab_parallel_loss(): # print_rank_0('mlp backward: pass') # def check_transformerlayer(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py index 9c126cefe..4506cfee6 100644 --- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py @@ -3,11 +3,11 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, SEQ_LENGTH, check_equal @@ -27,7 +27,7 @@ def check_AB(): i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) + A_master = torch.randn(A_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, DEPTH, dim=0)[i] A = torch.chunk(A, DEPTH, dim=-1)[j] @@ -35,7 +35,7 @@ def check_AB(): A.requires_grad = True B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) + B_master = torch.randn(B_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(B_master, src=0) B = torch.chunk(B_master, DEPTH, dim=0)[i] B = torch.chunk(B, DEPTH, dim=-1)[j] @@ -72,7 +72,7 @@ def check_AB(): print_rank_0("AB forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -105,7 +105,7 @@ def check_ABT(): tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float - device = get_current_device() + device = get_accelerator().get_current_device() j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -184,7 +184,7 @@ def check_ATB(): ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py index 283e7f683..914607614 100644 --- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py @@ -1,6 +1,7 @@ import torch from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import ( @@ -17,13 +18,12 @@ from colossalai.legacy.nn import ( VocabParallelEmbedding2p5D, ) from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import * def check_linear(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -76,7 +76,7 @@ def check_linear(): print_rank_0("linear forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -104,7 +104,7 @@ def check_linear(): def check_layernorm(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE EPS = 1e-12 @@ -141,7 +141,7 @@ def check_layernorm(): print_rank_0("layer norm forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -156,7 +156,7 @@ def check_layernorm(): def check_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -204,7 +204,7 @@ def check_embed(): def check_patch_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -278,7 +278,7 @@ def check_patch_embed(): def check_vocab_parallel_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -326,7 +326,7 @@ def check_vocab_parallel_embed(): def check_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = NUM_CLASSES @@ -377,7 +377,7 @@ def check_classifier_no_given_weight(): print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] # grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -405,7 +405,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -472,7 +472,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -524,7 +524,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -578,7 +578,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -613,7 +613,7 @@ def check_loss(): def check_vocab_parallel_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -650,7 +650,7 @@ def check_vocab_parallel_loss(): # def check_attention(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 @@ -689,7 +689,7 @@ def check_vocab_parallel_loss(): # print_rank_0('self attention backward: pass') # def check_mlp(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE @@ -725,7 +725,7 @@ def check_vocab_parallel_loss(): # print_rank_0('mlp backward: pass') # def check_transformerlayer(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py index 992bd6107..91a15c81d 100644 --- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py @@ -1,10 +1,10 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import * @@ -25,7 +25,7 @@ def check_AB(): k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) + A_master = torch.randn(A_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] @@ -33,7 +33,7 @@ def check_AB(): A.requires_grad = True B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) + B_master = torch.randn(B_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(B_master, src=0) B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] @@ -70,7 +70,7 @@ def check_AB(): print_rank_0("AB forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -103,7 +103,7 @@ def check_ABT(): tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float - device = get_current_device() + device = get_accelerator().get_current_device() i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -184,7 +184,7 @@ def check_ATB(): ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) diff --git a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py index a4a4ae9a5..f9f19a17b 100644 --- a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py +++ b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -5,6 +5,7 @@ import time import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.legacy.core import global_context from colossalai.legacy.nn import ( @@ -23,7 +24,6 @@ from colossalai.legacy.nn import ( from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.legacy.utils import print_rank_0 from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal @@ -31,7 +31,7 @@ from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_L def check_linear(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -84,7 +84,7 @@ def check_linear(): logger.info("Rank {} linear forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, device=get_current_device()) + grad_master = torch.randn(grad_shape, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -119,7 +119,7 @@ def check_linear(): def check_layernorm(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -206,7 +206,7 @@ def check_layernorm(): def check_classifier_no_given_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -258,7 +258,7 @@ def check_classifier_no_given_weight(): logger.info("Rank {} classifier (no given weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, device=get_current_device()) + grad_master = torch.randn(grad_shape, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=0)[j] @@ -306,7 +306,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -413,7 +413,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -463,7 +463,7 @@ def check_classifier_given_embed_weight(): logger.info("Rank {} classifier (given embed weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=0)[j] @@ -497,7 +497,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -580,7 +580,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_patch_embed(): rank = torch.distributed.get_rank() - device = get_current_device() + device = get_accelerator().get_current_device() logger = get_dist_logger() torch.float32 @@ -678,7 +678,7 @@ def check_patch_embed(): def check_embed(): rank = torch.distributed.get_rank() - device = get_current_device() + device = get_accelerator().get_current_device() logger = get_dist_logger() torch.float32 @@ -746,7 +746,7 @@ def check_embed(): def check_vocab_parallel_embed(): rank = torch.distributed.get_rank() - device = get_current_device() + device = get_accelerator().get_current_device() logger = get_dist_logger() torch.float32 @@ -823,7 +823,7 @@ def check_vocab_parallel_embed(): def check_loss(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -876,7 +876,7 @@ def check_loss(): def check_vocab_parallel_loss(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) diff --git a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py index aa4d5d6ce..f4ad0d6d1 100644 --- a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py +++ b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py @@ -1,9 +1,9 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import TransformerSelfAttentionRing -from colossalai.utils import get_current_device def check_selfattention(): @@ -13,10 +13,10 @@ def check_selfattention(): HIDDEN_SIZE = 16 layer = TransformerSelfAttentionRing(16, 8, 8, 0.1) - layer = layer.to(get_current_device()) + layer = layer.to(get_accelerator().get_current_device()) - hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device()) + hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_accelerator().get_current_device()) attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to( - get_current_device() + get_accelerator().get_current_device() ) layer(hidden_states, attention_mask) diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py index a5a2d3857..cab111358 100644 --- a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py +++ b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py @@ -5,6 +5,7 @@ import pytest import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import ( recv_backward, recv_forward, @@ -18,7 +19,6 @@ from colossalai.legacy.core import global_context as gpc from colossalai.legacy.initialize import launch from colossalai.logging import get_dist_logger from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device BATCH_SIZE = 4 SEQ_LENGTH = 2 @@ -73,7 +73,7 @@ def check_forward_backward(output_tensor, output_grad, rank, logger): def check_comm(size, rank, prev_rank, next_rank, logger): dtype = torch.float32 - device = get_current_device() + device = get_accelerator().get_current_device() tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) tensor = torch.randn(tensor_shape, dtype=dtype, device=device) diff --git a/tests/test_legacy/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py index 9df7cf75a..4993df4f3 100644 --- a/tests/test_legacy/test_utils/test_memory.py +++ b/tests/test_legacy/test_utils/test_memory.py @@ -1,15 +1,15 @@ import pytest import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction from colossalai.testing import spawn -from colossalai.utils.device import get_current_device def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): - frac1 = colo_device_memory_capacity(get_current_device()) + frac1 = colo_device_memory_capacity(get_accelerator().get_current_device()) colo_set_process_memory_fraction(0.5) - frac2 = colo_device_memory_capacity(get_current_device()) + frac2 = colo_device_memory_capacity(get_accelerator().get_current_device()) assert frac2 * 2 == frac1 diff --git a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py index b5f2be705..9975cc04f 100644 --- a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py +++ b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py @@ -4,12 +4,12 @@ from torch.nn.parameter import Parameter from torch.nn.utils import clip_grad_norm_ import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.tensor import ColoTensorSpec, ProcessGroup, distspec from colossalai.legacy.utils.common import clip_grad_norm from colossalai.logging import disable_existing_loggers from colossalai.tensor.colo_parameter import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): @@ -36,7 +36,7 @@ def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None: @parameterize("norm_type", [2.0, 3.0, float("inf")]) def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float): print(f"{world_size}, {dtype}, {device}, {norm_type}") - cuda_device = get_current_device() + cuda_device = get_accelerator().get_current_device() devices = [cuda_device] * 4 if device == "cpu": devices = [torch.device("cpu")] * 4 diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 3fac62472..a349bc5a9 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -4,10 +4,10 @@ import torch.distributed as dist import torch.nn as nn import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from tests.test_moe.moe_utils import MoeGradientHandler BATCH_SIZE = 4 @@ -38,7 +38,7 @@ def run_test(rank, world_size, port): layer_list.append(moe_layer) model = nn.ModuleList(layer_list) - model = model.to(get_current_device()) + model = model.to(get_accelerator().get_current_device()) dist_dict = MOE_MANAGER.parallel_info_dict assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) @@ -52,7 +52,7 @@ def run_test(rank, world_size, port): rank = dist.get_rank() torch.cuda.manual_seed(78 + rank) - data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) + data = torch.randn(BATCH_SIZE, DIM, device=get_accelerator().get_current_device()) grad = torch.randn_like(data) MOE_MANAGER.reset_loss() diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 255ec7444..62d61a3d4 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -3,10 +3,10 @@ import torch import torch.distributed as dist import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device BATCH_SIZE = 4 NUM_EXPERTS = 4 @@ -28,7 +28,9 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f torch.manual_seed(rs + local_rank) # set each process has different random seed # get randomized data - tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) + tokens = torch.randn( + BATCH_SIZE, hidden_size, dtype=data_type, device=get_accelerator().get_current_device(), requires_grad=True + ) layer = SparseMLP( hidden_size=hidden_size, @@ -37,7 +39,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f router_top_k=topk, router_capacity_factor_train=1.0, ) - layer = layer.to(get_current_device()) + layer = layer.to(get_accelerator().get_current_device()) if data_type == torch.float16: layer = layer.half() @@ -45,7 +47,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f layer.enable_kernel = False old_out = layer(tokens) ech = old_out.shape - grad = torch.randn(ech, device=get_current_device()) + grad = torch.randn(ech, device=get_accelerator().get_current_device()) old_out.backward(grad) # get gradient # save all results diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index bd1103df3..8f51e1663 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -9,11 +9,11 @@ import torch.distributed as dist from transformers.models.llama import LlamaConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device sys.path.append( os.path.join( @@ -28,7 +28,7 @@ OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenM def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20): - input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device()) + input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_accelerator().get_current_device()) attention_mask = torch.ones_like(input_ids) return { "input_ids": input_ids, diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index f87d4c792..74feeeb59 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -7,12 +7,12 @@ import torch import torch.distributed as dist import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from tests.test_moe.moe_utils import MoeGradientHandler @@ -23,8 +23,9 @@ def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_ tp_model (MoeModule) local_model (MoeModule) """ - for (tp_name, tp_param), (local_name, local_param) in \ - zip(tp_model.named_parameters(), local_model.named_parameters()): + for (tp_name, tp_param), (local_name, local_param) in zip( + tp_model.named_parameters(), local_model.named_parameters() + ): assert tp_name == local_name if not is_moe_tensor(tp_param): if assert_grad_flag: @@ -54,8 +55,7 @@ def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: tp_model (MoeModule) ep_model (MoeModule) """ - for (tp_name, tp_param), (ep_name, ep_param) in \ - zip(tp_model.named_parameters(), ep_model.named_parameters()): + for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()): assert tp_name == ep_name if not is_moe_tensor(tp_param): if assert_grad_flag: @@ -97,8 +97,9 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_ local_model (MoeModule) ep_model (MoeModule) """ - for (local_name, local_param), (ep_name, ep_param) in \ - zip(local_model.named_parameters(), ep_model.named_parameters()): + for (local_name, local_param), (ep_name, ep_param) in zip( + local_model.named_parameters(), ep_model.named_parameters() + ): assert local_name == ep_name if "experts" not in local_name: if assert_grad_flag: @@ -141,14 +142,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2, - enable_hierarchical_comm=enable_hierarchical_comm + enable_hierarchical_comm=enable_hierarchical_comm, ) MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel="TP") tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) - ep_model = ep_model.to(get_current_device()) - tp_model = tp_model.to(get_current_device()) - local_model = local_model.to(get_current_device()) + ep_model = ep_model.to(get_accelerator().get_current_device()) + tp_model = tp_model.to(get_accelerator().get_current_device()) + local_model = local_model.to(get_accelerator().get_current_device()) # sync ep param sync_moe_model_param(ep_model) @@ -163,11 +164,11 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size tp_grad_handler = MoeGradientHandler(tp_model) rank = dist.get_rank() - input_data = torch.randn(batch_size, dim, device=get_current_device()) + input_data = torch.randn(batch_size, dim, device=get_accelerator().get_current_device()) micro_batch_size = batch_size // world_size index = rank * micro_batch_size # NOTE: ep & tp takes in sharded data for each process - shard_data = input_data.detach()[index:index + micro_batch_size] + shard_data = input_data.detach()[index : index + micro_batch_size] out_local = local_model(input_data) MOE_MANAGER.reset_loss() @@ -176,13 +177,15 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size out_ep = ep_model(shard_data) MOE_MANAGER.reset_loss() - assert torch.allclose(out_tp, out_ep, atol=1e-6), \ - f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}" + assert torch.allclose( + out_tp, out_ep, atol=1e-6 + ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}" try: - out_local_slice = out_local[index:index + micro_batch_size] - assert torch.allclose(out_ep, out_local_slice, atol=1e-6), \ - f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}" - except AssertionError as e: + out_local_slice = out_local[index : index + micro_batch_size] + assert torch.allclose( + out_ep, out_local_slice, atol=1e-6 + ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}" + except AssertionError: """ e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1 router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2 @@ -193,8 +196,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature. """ warnings.warn( - "EP & TP may result in different behavior from local model. " - "Please check the comments for details." + "EP & TP may result in different behavior from local model. " "Please check the comments for details." ) out_local.mean().backward() @@ -208,10 +210,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) try: sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) - except AssertionError as e: + except AssertionError: warnings.warn( - "EP & TP may result in different behavior from local model. " - "Please check the comments for details." + "EP & TP may result in different behavior from local model. " "Please check the comments for details." ) @@ -219,14 +220,17 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size @pytest.mark.parametrize("num_experts", [4, 64]) @pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("dim", [64]) -@pytest.mark.parametrize("config", [ - {"enable_hierarchical_comm": False}, - {"enable_hierarchical_comm": True}, -]) +@pytest.mark.parametrize( + "config", + [ + {"enable_hierarchical_comm": False}, + {"enable_hierarchical_comm": True}, + ], +) @rerun_if_address_is_in_use() def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict): spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_ep_tp(num_experts=8, batch_size=32, dim=32) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 95c0e715d..2f08a335d 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -3,11 +3,11 @@ import torch.distributed as dist import torch.nn as nn import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe.experts import MLPExperts from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device HIDDEN_SIZE = 4 INTERMEDIATE_SIZE = 8 @@ -46,7 +46,7 @@ def run_moe_init(expert_parallel): assert dist.get_rank(parallel_info_dict[1].dp_group) == rank model = nn.ModuleList([exp0, exp1, exp2]) - model = model.to(get_current_device()) + model = model.to(get_accelerator().get_current_device()) sync_moe_model_param(model) # MOE experts layout success when ep_size = 1 diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index c136f78a1..2ff4b3016 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -8,7 +8,8 @@ import pytest import torch from torch import Tensor -from colossalai.utils import get_current_device, multi_tensor_applier +from colossalai.accelerator import get_accelerator +from colossalai.utils import multi_tensor_applier _FUSED_ALLOWED_P_G_TYPES = [ (torch.float, torch.half), @@ -155,7 +156,9 @@ def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): rtol, atol = 1e-3, 1e-3 if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: rtol, atol = 4e-3, 4e-3 - check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol) + check_adam_kernel( + FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_accelerator().get_current_device(), 3, rtol, atol + ) @pytest.mark.parametrize("adamw", [False, True]) diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index 1665711ce..5ebe2a128 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -3,11 +3,11 @@ import torch import torch.distributed as dist import colossalai +from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device def check_p2p_communication(): @@ -17,7 +17,7 @@ def check_p2p_communication(): rank = dist.get_rank() - tensor = torch.ones(1, device=get_current_device()) + tensor = torch.ones(1, device=get_accelerator().get_current_device()) if rank == 0: p2p.send_forward(tensor) diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index 5977c706f..e4dc569b8 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -4,15 +4,15 @@ import torch.distributed as dist from torch.distributed.distributed_c10d import _get_default_group import colossalai +from colossalai.accelerator import get_accelerator from colossalai.tensor import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.chunk import Chunk def dist_sum(x): - temp = torch.tensor([x], device=get_current_device()) + temp = torch.tensor([x], device=get_accelerator().get_current_device()) dist.all_reduce(temp) return temp.item() @@ -66,7 +66,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): assert my_chunk.cpu_shard.size(0) == 1024 // world_size assert my_chunk.device_type == "cpu" assert my_chunk.can_move - my_chunk.shard_move(get_current_device()) + my_chunk.shard_move(get_accelerator().get_current_device()) else: assert my_chunk.cuda_global_chunk.size(0) == 1024 assert my_chunk.device_type == "cuda" diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 21afff753..3a9742e01 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -5,11 +5,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd_bwd @@ -47,7 +47,7 @@ def exam_gpt_fwd_bwd( use_grad_checkpoint: bool = False, master_weights: bool = True, ): - init_device = get_current_device() + init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) ) diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index 35323e516..36a803492 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -6,10 +6,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd @@ -53,7 +53,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): def exam_gemini_grad_acc( placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool ): - init_device = get_current_device() + init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) ) diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 152bf2895..7f3c7176e 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -7,11 +7,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd @@ -47,7 +47,9 @@ def multi_chunk_init(model: torch.nn.Module, placement_config: dict): def single_chunk_init(model: torch.nn.Module, placement_config: dict): - model = GeminiDDP(model, chunk_init_device=get_current_device(), pin_memory=True, **placement_config) + model = GeminiDDP( + model, chunk_init_device=get_accelerator().get_current_device(), pin_memory=True, **placement_config + ) return model @@ -63,7 +65,7 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - init_dev = get_current_device() + init_dev = get_accelerator().get_current_device() model = model_builder().to(init_dev) for torch_p, p in zip(torch_model.parameters(), model.parameters()): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 405d7d789..71bb27b4a 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -5,11 +5,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd_bwd @@ -150,7 +150,7 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. model = GeminiDDP( model, - chunk_init_device=get_current_device(), + chunk_init_device=get_accelerator().get_current_device(), search_range_m=1, pin_memory=True, mixed_precision=mixed_precision, diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py index e99f6d59b..cf3658bf9 100644 --- a/tests/test_zero/test_gemini/test_search.py +++ b/tests/test_zero/test_gemini/test_search.py @@ -2,8 +2,8 @@ import pytest import torch import colossalai +from colossalai.accelerator import get_accelerator from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration from tests.kit.model_zoo import model_zoo @@ -34,7 +34,7 @@ def exam_chunk_manager(): sharded_ddp_model = model_builder() chunk_manager = init_chunk_manager( sharded_ddp_model, - get_current_device(), + get_accelerator().get_current_device(), hidden_dim=128, search_range_m=1, min_chunk_size_m=0, diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index 351ae5f67..11f738615 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -7,9 +7,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.testing import spawn from colossalai.testing.random import seed_all -from colossalai.utils import conditional_context, get_current_device +from colossalai.utils import conditional_context from colossalai.zero import LowLevelZeroOptimizer @@ -28,7 +29,7 @@ class MlpModel(nn.Module): def exam_zero_1_2_grad_acc(): local_rank = torch.distributed.get_rank() seed_all(2009) - device = get_current_device() + device = get_accelerator().get_current_device() # create model zero1_model = MlpModel().to(device) zero2_model = copy.deepcopy(zero1_model) @@ -71,7 +72,7 @@ def exam_zero_1_2_grad_acc(): def exam_zero_1_grad_acc(sync): local_rank = torch.distributed.get_rank() seed_all(2008) - device = get_current_device() + device = get_accelerator().get_current_device() # create models zero_model = MlpModel()