[npu] change device to accelerator api (#5239)

* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------

Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
pull/5242/head
Hongxin Liu 2024-01-09 10:20:05 +08:00 committed by GitHub
parent dd2c28a323
commit d202cc28c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
128 changed files with 1773 additions and 868 deletions

View File

@ -47,7 +47,7 @@ jobs:
container: container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/ options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10 timeout-minutes: 15
steps: steps:
- name: 📚 Checkout - name: 📚 Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3

View File

@ -79,7 +79,7 @@ jobs:
container: container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/ options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10 timeout-minutes: 15
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }}
cancel-in-progress: true cancel-in-progress: true

View File

@ -35,7 +35,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container: container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
timeout-minutes: 10 timeout-minutes: 15
steps: steps:
- name: 📚 Checkout - name: 📚 Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3

View File

@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
from .base import OnPolicyTrainer from .base import OnPolicyTrainer
from .callbacks import Callback from .callbacks import Callback
@ -105,7 +105,7 @@ class PPOTrainer(OnPolicyTrainer):
self.critic_optim = critic_optim self.critic_optim = critic_optim
self.offload_inference_models = offload_inference_models self.offload_inference_models = offload_inference_models
self.device = get_current_device() self.device = get_accelerator().get_current_device()
def _before_fit( def _before_fit(
self, self,

View File

@ -6,7 +6,6 @@ import torch.nn as nn
import colossalai import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.utils import get_current_device
from colossalai.zero.gemini.gemini_ddp import GeminiDDP from colossalai.zero.gemini.gemini_ddp import GeminiDDP
from .ddp import DDPStrategy from .ddp import DDPStrategy
@ -158,9 +157,19 @@ class GeminiStrategy(DDPStrategy):
warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.") warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
# colossalai has changed api for get_current_device in 0.3.4 version or newer
try:
from colossalai.accelerator import get_accelerator
chunk_init_device = get_accelerator().get_current_device()
except:
from colossalai.utils import get_current_device
chunk_init_device = get_current_device()
# NOTE: dist should be initialized before calling get_current_device() # NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin( plugin_initializer = lambda: GeminiPlugin(
chunk_init_device=get_current_device(), chunk_init_device=chunk_init_device,
placement_policy=placement_policy, placement_policy=placement_policy,
shard_param_frac=shard_param_frac, shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac, offload_optim_frac=offload_optim_frac,

View File

@ -4,41 +4,34 @@
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
""" """
import json
import argparse import argparse
import json
import os import os
import resource import resource
from contextlib import nullcontext from contextlib import nullcontext
from tqdm import tqdm
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossal_llama2.dataset.loader import (
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_distributed_dataloader,
)
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import ( from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
GeminiPlugin,
LowLevelZeroPlugin,
HybridParallelPlugin,
)
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossal_llama2.dataset.loader import (
load_tokenized_dataset,
setup_distributed_dataloader,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
)
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
def get_model_numel(model: torch.nn.Module) -> int: def get_model_numel(model: torch.nn.Module) -> int:
@ -215,9 +208,18 @@ def main() -> None:
# ====================================================== # ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler # Initialize Model, Objective, Optimizer and LR Scheduler
# ====================================================== # ======================================================
init_ctx = (
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() # colossalai has changed api for get_current_device in 0.3.4 version or newer
) try:
from colossalai.accelerator import get_accelerator
current_device = get_accelerator().get_current_device()
except:
from colossalai.utils import get_current_device
current_device = get_current_device()
init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
with init_ctx: with init_ctx:
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
# Freeze part of parameters. # Freeze part of parameters.
@ -320,7 +322,7 @@ def main() -> None:
initial=start_step, initial=start_step,
) as pbar: ) as pbar:
for step, batch in pbar: for step, batch in pbar:
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch_output = model(**batch) batch_output = model(**batch)
@ -372,9 +374,7 @@ def main() -> None:
# Final save. # Final save.
coordinator.print_on_master("Start saving final model checkpoint") coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master( coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")

View File

@ -1,5 +1,6 @@
from .api import auto_set_accelerator, get_accelerator, set_accelerator from .api import auto_set_accelerator, get_accelerator, set_accelerator
from .base_accelerator import BaseAccelerator from .base_accelerator import BaseAccelerator
from .cpu_accelerator import CpuAccelerator
from .cuda_accelerator import CudaAccelerator from .cuda_accelerator import CudaAccelerator
from .npu_accelerator import NpuAccelerator from .npu_accelerator import NpuAccelerator
@ -10,4 +11,5 @@ __all__ = [
"BaseAccelerator", "BaseAccelerator",
"CudaAccelerator", "CudaAccelerator",
"NpuAccelerator", "NpuAccelerator",
"CpuAccelerator",
] ]

View File

@ -3,6 +3,7 @@ from collections import OrderedDict
from typing import Union from typing import Union
from .base_accelerator import BaseAccelerator from .base_accelerator import BaseAccelerator
from .cpu_accelerator import CpuAccelerator
from .cuda_accelerator import CudaAccelerator from .cuda_accelerator import CudaAccelerator
from .npu_accelerator import NpuAccelerator from .npu_accelerator import NpuAccelerator
@ -15,7 +16,7 @@ _ACCELERATOR = None
# we use ordered dictionary here to associate the # we use ordered dictionary here to associate the
# order with device check priority # order with device check priority
# i.e. auto_set_accelerator will check cuda first # i.e. auto_set_accelerator will check cuda first
_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator) _ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator, cpu=CpuAccelerator)
def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None: def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
@ -43,19 +44,17 @@ def auto_set_accelerator() -> None:
""" """
global _ACCELERATOR global _ACCELERATOR
for _, accelerator_cls in _ACCELERATOR_MAPPING.items(): for accelerator_name, accelerator_cls in _ACCELERATOR_MAPPING.items():
try: try:
accelerator = accelerator_cls() accelerator = accelerator_cls()
if accelerator.is_available(): if accelerator_name == "cpu" or accelerator.is_available():
_ACCELERATOR = accelerator _ACCELERATOR = accelerator
break break
except: except:
pass pass
if _ACCELERATOR is None: if _ACCELERATOR is None:
raise RuntimeError( raise RuntimeError("No accelerator is available.")
f"No accelerator is available. Please check your environment. The list of accelerators we support is {list(_ACCELERATOR_MAPPING.keys())}"
)
def get_accelerator() -> BaseAccelerator: def get_accelerator() -> BaseAccelerator:

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
@ -8,6 +9,8 @@ __all__ = ["BaseAccelerator"]
class BaseAccelerator(ABC): class BaseAccelerator(ABC):
support_set_device: bool = True
def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None: def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None:
self._name = name self._name = name
self._communication_backend = communication_backend self._communication_backend = communication_backend
@ -45,6 +48,12 @@ class BaseAccelerator(ABC):
# ======================= # =======================
# device APIs # device APIs
# ======================= # =======================
@abstractmethod
def get_current_device(self) -> torch.device:
"""
Return the current device.
"""
@abstractmethod @abstractmethod
def current_device(self) -> int: def current_device(self) -> int:
""" """
@ -52,7 +61,7 @@ class BaseAccelerator(ABC):
""" """
@abstractmethod @abstractmethod
def set_device(self, device: Union[torch.device, int]) -> None: def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
""" """
Bind the current process to a device. Bind the current process to a device.
""" """
@ -79,3 +88,226 @@ class BaseAccelerator(ABC):
""" """
Return the number of devices on the machine. Return the number of devices on the machine.
""" """
def set_to_device(self, models: Any) -> Any:
"""
Send model to device.
:param models: nn.module or a list of module
"""
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(self.get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(self.get_current_device())
else:
return models.to(self.get_current_device())
@abstractmethod
def get_device_capability(self, device=None) -> Tuple[int, int]:
"""
Gets the capability of a device.
"""
@abstractmethod
def get_device_name(self, device=None) -> str:
"""
Gets the name of a device.
"""
@abstractmethod
def get_device_properties(self, device):
"""
Gets the properties of a device.
"""
@abstractmethod
def utilization(self, device=None) -> int:
"""
Returns the percent of time over the past sample period during which one or more kernels was executing on the device as given by nvidia-smi or npu-smi, etc.
"""
# =======================
# random number generator APIs
# =======================
@abstractmethod
def get_rng_state(self, device="cuda") -> torch.Tensor:
"""
Returns the random number generator state of the specified device as a ByteTensor.
"""
@abstractmethod
def get_rng_state_all(self) -> List[torch.Tensor]:
"""
Returns a list of ByteTensor representing the random number states of all devices.
"""
@abstractmethod
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None:
"""
Sets the random number generator state of the specified device.
"""
@abstractmethod
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
"""
Sets the random number generator state of all devices.
"""
@abstractmethod
def manual_seed(self, seed: int) -> None:
"""
Sets the seed for generating random numbers for the current device.
"""
@abstractmethod
def manual_seed_all(self, seed: int) -> None:
"""
Sets the seed for generating random numbers on all devices.
"""
@abstractmethod
def seed(self) -> None:
"""
Sets the seed for generating random numbers to a random number for the current device.
"""
@abstractmethod
def seed_all(self) -> None:
"""
Sets the seed for generating random numbers to a random number on all devices.
"""
@abstractmethod
def initial_seed(self) -> int:
"""
Returns the current random seed of the current device.
"""
# =======================
# memory management APIs
# =======================
@abstractmethod
def empty_cache(self) -> None:
"""
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other device application and visible in nvidia-smi.
"""
@abstractmethod
def memory_stats(self, device=None) -> Dict[str, Any]:
"""
Returns a dictionary of CUDA memory allocator statistics for a given device.
"""
@abstractmethod
def memory_summary(self, device=None, abbreviated=False) -> str:
"""
Returns a human-readable printout of the current memory allocator statistics for a given device.
"""
@abstractmethod
def memory_snapshot(self):
"""
Returns a snapshot of the CUDA memory allocator state across all devices.
"""
@abstractmethod
def memory_allocated(self, device=None) -> int:
"""
Returns the current device memory occupied by tensors in bytes for a given device.
"""
@abstractmethod
def max_memory_allocated(self, device=None) -> int:
"""
Returns the maximum device memory occupied by tensors in bytes for a given device.
"""
@abstractmethod
def reset_max_memory_allocated(self, device=None) -> None:
"""
Resets the starting point in tracking maximum device memory occupied by tensors for a given device.
"""
@abstractmethod
def reset_max_memory_cached(self, device=None) -> None:
"""
Resets the starting point in tracking maximum device memory managed by the caching allocator for a given device.
"""
@abstractmethod
def memory_reserved(self, device=None) -> int:
"""
Returns the current device memory managed by the caching allocator in bytes for a given device.
"""
@abstractmethod
def max_memory_reserved(self, device=None) -> int:
"""
Returns the maximum device memory managed by the caching allocator in bytes for a given device.
"""
@abstractmethod
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
"""
Set memory fraction for a process.
"""
@abstractmethod
def reset_peak_memory_stats(self, device=None) -> None:
"""
Resets the "peak" stats tracked by the device memory allocator.
"""
# =======================
# streams and events APIs
# =======================
@abstractmethod
def Stream(self, device=None, priority=0, **kwargs):
"""
A device stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
"""
@abstractmethod
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
"""
device events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
"""
@abstractmethod
def current_stream(self, device=None):
"""
Returns the currently selected Stream for a given device.
"""
@abstractmethod
def default_stream(self, device=None):
"""
Returns the default Stream for a given device.
"""
@abstractmethod
def set_stream(self, stream_):
"""
Sets the current stream.This is a wrapper API to set the stream.
"""
@abstractmethod
def stream(self, stream_):
"""
Wrapper around the Context-manager StreamContext that selects a given stream.
"""
# =======================
# amp APIs
# =======================
@abstractmethod
def autocast(
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
) -> Callable:
"""
Return autocast function
"""

View File

@ -0,0 +1,277 @@
#!/usr/bin/env python
import resource
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import psutil
import torch
from .base_accelerator import BaseAccelerator
__all__ = ["CpuAccelerator"]
class CpuAccelerator(BaseAccelerator):
support_set_device: bool = False
"""
Accelerator class for cpu.
"""
def __init__(self):
super().__init__(name="cpu", communication_backend="gloo", is_synchronous=False)
# =======================
# device APIs
# =======================
def get_current_device(self) -> torch.device:
"""
Return the current device.
"""
return torch.device("cpu")
def current_device(self) -> int:
"""
Return the current device index.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
"""
Bind the current process to a device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def get_device_name(self, device: Union[torch.device, int]) -> str:
"""
Return the name of the device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def synchronize(self, device: Union[torch.device, int] = None):
"""
Synchronize the current process.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def is_available(self):
"""
Check if the accelerator is available.
"""
return True
def device_count(self):
"""
Return the number of devices on the machine.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def get_device_capability(self, device=None) -> Tuple[int, int]:
"""
Gets the cuda capability of a device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def get_device_name(self, device=None) -> str:
"""
Gets the name of a device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def get_device_properties(self, device):
"""
Gets the properties of a device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def utilization(self, device=None) -> int:
"""
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
"""
raise RuntimeError("this method is not supported for cpu accelerator")
# =======================
# random number generator APIs
# =======================
def get_rng_state(self, device=None) -> torch.Tensor:
"""
Returns the random number generator state of the specified GPU as a ByteTensor.
"""
return torch.get_rng_state(device)
def get_rng_state_all(self) -> List[torch.Tensor]:
"""
Returns a list of ByteTensor representing the random number states of all devices.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def set_rng_state(self, new_state: torch.ByteTensor, device: str = None) -> None:
"""
Sets the random number generator state of the specified GPU.
"""
torch.set_rng_state(new_state)
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
"""
Sets the random number generator state of all devices.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def manual_seed(self, seed: int) -> None:
"""
Sets the seed for generating random numbers for the current GPU.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def manual_seed_all(self, seed: int) -> None:
"""
Set the random seed for the all processes.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def seed(self) -> None:
"""
Sets the seed for generating random numbers to a random number for the current GPU.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def seed_all(self) -> None:
"""
Sets the seed for generating random numbers to a random number on all GPUs.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def initial_seed(self) -> int:
"""
Returns the current random seed of the current GPU.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
# =======================
# memory management APIs
# =======================
def empty_cache(self) -> None:
"""
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_stats(self, device=None) -> Dict[str, Any]:
"""
Returns a dictionary of CUDA memory allocator statistics for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_summary(self, device=None, abbreviated=False) -> str:
"""
Returns a human-readable printout of the current memory allocator statistics for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_snapshot(self):
"""
Returns a snapshot of the CUDA memory allocator state across all devices.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_allocated(self, device=None) -> int:
"""
Returns the current GPU memory occupied by tensors in bytes for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def max_memory_allocated(self, device=None) -> int:
"""
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def reset_max_memory_allocated(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def reset_max_memory_cached(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_reserved(self, device=None) -> int:
"""
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
"""
return psutil.Process().memory_info().rss
def max_memory_reserved(self, device=None) -> int:
"""
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
"""
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
"""
Set memory fraction for a process.
"""
max_memory = int(psutil.virtual_memory().total * fraction)
_, hard = resource.getrlimit(resource.RLIMIT_AS)
resource.setrlimit(resource.RLIMIT_AS, (max_memory, hard))
def reset_peak_memory_stats(self, device=None) -> None:
"""
Resets the "peak" stats tracked by the CUDA memory allocator.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
# =======================
# streams and events APIs
# =======================
def Stream(self, device=None, priority=0, **kwargs):
"""
A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
"""
CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def current_stream(self, device=None):
"""
Returns the currently selected Stream for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def default_stream(self, device=None):
"""
Returns the default Stream for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def set_stream(self, stream_):
"""
Sets the current stream.This is a wrapper API to set the stream.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def stream(self, stream_):
"""
Wrapper around the Context-manager StreamContext that selects a given stream.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
# =======================
# amp APIs
# =======================
def autocast(
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
) -> Callable:
"""
Return autocast function
"""
return nullcontext

View File

@ -1,7 +1,9 @@
#!/usr/bin/env python #!/usr/bin/env python
from typing import Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist
from .base_accelerator import BaseAccelerator from .base_accelerator import BaseAccelerator
@ -19,16 +21,26 @@ class CudaAccelerator(BaseAccelerator):
# ======================= # =======================
# device APIs # device APIs
# ======================= # =======================
def get_current_device(self) -> torch.device:
"""
Return the current device.
"""
return torch.device(f"cuda:{torch.cuda.current_device()}")
def current_device(self) -> int: def current_device(self) -> int:
""" """
Return the current device index. Return the current device index.
""" """
return torch.cuda.current_device() return torch.cuda.current_device()
def set_device(self, device: Union[torch.device, int]) -> None: def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
""" """
Bind the current process to a device. Bind the current process to a device.
""" """
if device is None:
if not dist.is_initialized():
raise RuntimeError("Cannot get current device when distributed is not initialized.")
device = dist.get_rank() % self.device_count()
torch.cuda.set_device(device) torch.cuda.set_device(device)
def get_device_name(self, device: Union[torch.device, int]) -> str: def get_device_name(self, device: Union[torch.device, int]) -> str:
@ -54,3 +66,211 @@ class CudaAccelerator(BaseAccelerator):
Return the number of devices on the machine. Return the number of devices on the machine.
""" """
return torch.cuda.device_count() return torch.cuda.device_count()
def get_device_capability(self, device=None) -> Tuple[int, int]:
"""
Gets the cuda capability of a device.
"""
return torch.cuda.get_device_capability(device)
def get_device_name(self, device=None) -> str:
"""
Gets the name of a device.
"""
return torch.cuda.get_device_name(device)
def get_device_properties(self, device):
"""
Gets the properties of a device.
"""
return torch.cuda.get_device_properties(device)
def utilization(self, device=None) -> int:
"""
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
"""
return torch.cuda.utilization(device)
# =======================
# random number generator APIs
# =======================
def get_rng_state(self, device="cuda") -> torch.Tensor:
"""
Returns the random number generator state of the specified GPU as a ByteTensor.
"""
return torch.cuda.get_rng_state(device)
def get_rng_state_all(self) -> List[torch.Tensor]:
"""
Returns a list of ByteTensor representing the random number states of all devices.
"""
return torch.cuda.get_rng_state_all()
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None:
"""
Sets the random number generator state of the specified GPU.
"""
torch.cuda.set_rng_state(new_state, device)
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
"""
Sets the random number generator state of all devices.
"""
torch.cuda.set_rng_state_all(new_states)
def manual_seed(self, seed: int) -> None:
"""
Sets the seed for generating random numbers for the current GPU.
"""
torch.cuda.manual_seed(seed)
def manual_seed_all(self, seed: int) -> None:
"""
Set the random seed for the all processes.
"""
torch.cuda.manual_seed_all(seed)
def seed(self) -> None:
"""
Sets the seed for generating random numbers to a random number for the current GPU.
"""
torch.cuda.seed()
def seed_all(self) -> None:
"""
Sets the seed for generating random numbers to a random number on all GPUs.
"""
torch.cuda.seed_all()
def initial_seed(self) -> int:
"""
Returns the current random seed of the current GPU.
"""
return torch.cuda.initial_seed()
# =======================
# memory management APIs
# =======================
def empty_cache(self) -> None:
"""
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
"""
torch.cuda.empty_cache()
def memory_stats(self, device=None) -> Dict[str, Any]:
"""
Returns a dictionary of CUDA memory allocator statistics for a given device.
"""
return torch.cuda.memory_stats(device=device)
def memory_summary(self, device=None, abbreviated=False) -> str:
"""
Returns a human-readable printout of the current memory allocator statistics for a given device.
"""
return torch.cuda.memory_summary(device=device, abbreviated=abbreviated)
def memory_snapshot(self):
"""
Returns a snapshot of the CUDA memory allocator state across all devices.
"""
return torch.cuda.memory_snapshot()
def memory_allocated(self, device=None) -> int:
"""
Returns the current GPU memory occupied by tensors in bytes for a given device.
"""
return torch.cuda.memory_allocated(device=device)
def max_memory_allocated(self, device=None) -> int:
"""
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
"""
return torch.cuda.max_memory_allocated(device=device)
def reset_max_memory_allocated(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
"""
torch.cuda.reset_max_memory_allocated(device=device)
def reset_max_memory_cached(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
"""
torch.cuda.reset_max_memory_cached(device=device)
def memory_reserved(self, device=None) -> int:
"""
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
"""
return torch.cuda.memory_reserved(device=device)
def max_memory_reserved(self, device=None) -> int:
"""
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
"""
return torch.cuda.max_memory_reserved(device=device)
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
"""
Set memory fraction for a process.
"""
torch.cuda.set_per_process_memory_fraction(fraction, device=device)
def reset_peak_memory_stats(self, device=None) -> None:
"""
Resets the "peak" stats tracked by the CUDA memory allocator.
"""
torch.cuda.reset_peak_memory_stats(device=device)
# =======================
# streams and events APIs
# =======================
def Stream(self, device=None, priority=0, **kwargs):
"""
A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
"""
return torch.cuda.Stream(device, priority, **kwargs)
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
"""
CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
"""
return torch.cuda.Event(enable_timing, blocking, interprocess)
def current_stream(self, device=None):
"""
Returns the currently selected Stream for a given device.
"""
return torch.cuda.current_stream(device)
def default_stream(self, device=None):
"""
Returns the default Stream for a given device.
"""
return torch.cuda.default_stream(device)
def set_stream(self, stream_):
"""
Sets the current stream.This is a wrapper API to set the stream.
"""
torch.cuda.set_stream(stream_)
def stream(self, stream_):
"""
Wrapper around the Context-manager StreamContext that selects a given stream.
"""
return torch.cuda.stream(stream_)
# =======================
# amp APIs
# =======================
def autocast(
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
) -> Callable:
"""
Return autocast function
"""
return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)

View File

@ -1,13 +1,17 @@
#!/usr/bin/env python #!/usr/bin/env python
from typing import Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist
from .base_accelerator import BaseAccelerator from .base_accelerator import BaseAccelerator
IS_NPU_AVAILABLE = False
try: try:
import torch_npu # noqa import torch_npu # noqa
IS_NPU_AVAILABLE = True
except ImportError: except ImportError:
pass pass
@ -26,16 +30,26 @@ class NpuAccelerator(BaseAccelerator):
# ======================= # =======================
# device APIs # device APIs
# ======================= # =======================
def get_current_device(self) -> torch.device:
"""
Return the current device.
"""
return torch.device(f"npu:{torch.npu.current_device()}")
def current_device(self) -> int: def current_device(self) -> int:
""" """
Return the current device index. Return the current device index.
""" """
return torch.npu.current_device() return torch.npu.current_device()
def set_device(self, device: Union[torch.device, int]) -> None: def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
""" """
Bind the current process to a device. Bind the current process to a device.
""" """
if device is None:
if not dist.is_initialized():
raise RuntimeError("Cannot get current device when distributed is not initialized.")
device = dist.get_rank() % self.device_count()
torch.npu.set_device(device) torch.npu.set_device(device)
def get_device_name(self, device: Union[torch.device, int]) -> str: def get_device_name(self, device: Union[torch.device, int]) -> str:
@ -61,3 +75,211 @@ class NpuAccelerator(BaseAccelerator):
Return the number of devices on the machine. Return the number of devices on the machine.
""" """
return torch.npu.device_count() return torch.npu.device_count()
def get_device_capability(self, device=None) -> Tuple[int, int]:
"""
Gets the npu capability of a device.
"""
return torch.npu.get_device_capability(device)
def get_device_name(self, device=None) -> str:
"""
Gets the name of a device.
"""
return torch.npu.get_device_name(device)
def get_device_properties(self, device):
"""
Gets the properties of a device.
"""
return torch.npu.get_device_properties(device)
def utilization(self, device=None) -> int:
"""
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
"""
return torch.npu.utilization(device)
# =======================
# random number generator APIs
# =======================
def get_rng_state(self, device="npu") -> torch.Tensor:
"""
Returns the random number generator state of the specified GPU as a ByteTensor.
"""
return torch.npu.get_rng_state(device)
def get_rng_state_all(self) -> List[torch.Tensor]:
"""
Returns a list of ByteTensor representing the random number states of all devices.
"""
return torch.npu.get_rng_state_all()
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "npu") -> None:
"""
Sets the random number generator state of the specified GPU.
"""
torch.npu.set_rng_state(new_state, device)
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
"""
Sets the random number generator state of all devices.
"""
torch.npu.set_rng_state_all(new_states)
def manual_seed(self, seed: int) -> None:
"""
Sets the seed for generating random numbers for the current GPU.
"""
torch.npu.manual_seed(seed)
def manual_seed_all(self, seed: int) -> None:
"""
Set the random seed for the all processes.
"""
torch.npu.manual_seed_all(seed)
def seed(self) -> None:
"""
Sets the seed for generating random numbers to a random number for the current GPU.
"""
torch.npu.seed()
def seed_all(self) -> None:
"""
Sets the seed for generating random numbers to a random number on all GPUs.
"""
torch.npu.seed_all()
def initial_seed(self) -> int:
"""
Returns the current random seed of the current GPU.
"""
return torch.npu.initial_seed()
# =======================
# memory management APIs
# =======================
def empty_cache(self) -> None:
"""
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
"""
torch.npu.empty_cache()
def memory_stats(self, device=None) -> Dict[str, Any]:
"""
Returns a dictionary of npu memory allocator statistics for a given device.
"""
return torch.npu.memory_stats(device=device)
def memory_summary(self, device=None, abbreviated=False) -> str:
"""
Returns a human-readable printout of the current memory allocator statistics for a given device.
"""
return torch.npu.memory_summary(device=device, abbreviated=abbreviated)
def memory_snapshot(self):
"""
Returns a snapshot of the npu memory allocator state across all devices.
"""
return torch.npu.memory_snapshot()
def memory_allocated(self, device=None) -> int:
"""
Returns the current GPU memory occupied by tensors in bytes for a given device.
"""
return torch.npu.memory_allocated(device=device)
def max_memory_allocated(self, device=None) -> int:
"""
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
"""
return torch.npu.max_memory_allocated(device=device)
def reset_max_memory_allocated(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
"""
torch.npu.reset_max_memory_allocated(device=device)
def reset_max_memory_cached(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
"""
torch.npu.reset_max_memory_cached(device=device)
def memory_reserved(self, device=None) -> int:
"""
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
"""
return torch.npu.memory_reserved(device=device)
def max_memory_reserved(self, device=None) -> int:
"""
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
"""
return torch.npu.max_memory_reserved(device=device)
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
"""
Set memory fraction for a process.
"""
torch.npu.set_per_process_memory_fraction(fraction, device=device)
def reset_peak_memory_stats(self, device=None) -> None:
"""
Resets the "peak" stats tracked by the npu memory allocator.
"""
torch.npu.reset_peak_memory_stats(device=device)
# =======================
# streams and events APIs
# =======================
def Stream(self, device=None, priority=0, **kwargs):
"""
A npu stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See npu-semantics for details.
"""
return torch.npu.Stream(device, priority, **kwargs)
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
"""
npu events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize npu streams.
"""
return torch.npu.Event(enable_timing, blocking, interprocess)
def current_stream(self, device=None):
"""
Returns the currently selected Stream for a given device.
"""
return torch.npu.current_stream(device)
def default_stream(self, device=None):
"""
Returns the default Stream for a given device.
"""
return torch.npu.default_stream(device)
def set_stream(self, stream_):
"""
Sets the current stream.This is a wrapper API to set the stream.
"""
torch.npu.set_stream(stream_)
def stream(self, stream_):
"""
Wrapper around the Context-manager StreamContext that selects a given stream.
"""
return torch.npu.stream(stream_)
# =======================
# amp APIs
# =======================
def autocast(
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
) -> Callable:
"""
Return autocast function
"""
return torch.npu.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)

View File

@ -7,8 +7,8 @@ from typing import Dict
import torch import torch
from torch import Tensor from torch import Tensor
from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.device import get_current_device
__all__ = ["BaseGradScaler"] __all__ = ["BaseGradScaler"]
@ -23,7 +23,7 @@ class BaseGradScaler(ABC):
def __init__(self, initial_scale: float, verbose: bool): def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0 assert initial_scale > 0
self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float) self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float)
self._verbose = verbose self._verbose = verbose
if self._verbose: if self._verbose:

View File

@ -5,7 +5,7 @@ from typing import Optional
import torch import torch
from colossalai.utils.device import get_current_device from colossalai.accelerator import get_accelerator
from .base_grad_scaler import BaseGradScaler from .base_grad_scaler import BaseGradScaler
@ -37,14 +37,20 @@ class DynamicGradScaler(BaseGradScaler):
hysteresis: int = 2, hysteresis: int = 2,
verbose: bool = False, verbose: bool = False,
): ):
a = get_accelerator()
a.device_count()
super().__init__(initial_scale, verbose) super().__init__(initial_scale, verbose)
if min_scale: if min_scale:
self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float) self._min_scale = torch.tensor(
[min_scale], device=get_accelerator().get_current_device(), dtype=torch.float
)
else: else:
self._min_scale = None self._min_scale = None
if max_scale: if max_scale:
self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float) self._max_scale = torch.tensor(
[max_scale], device=get_accelerator().get_current_device(), dtype=torch.float
)
else: else:
self._max_scale = None self._max_scale = None
@ -117,7 +123,7 @@ class DynamicGradScaler(BaseGradScaler):
return state_dict return state_dict
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self._scale = state_dict["scale"].to(get_current_device()) self._scale = state_dict["scale"].to(get_accelerator().get_current_device())
self._growth_factor = state_dict["growth_factor"] self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"] self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict["hysteresis"] self._hysteresis = state_dict["hysteresis"]

View File

@ -5,8 +5,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor from torch import Tensor
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device
from .base import MixedPrecisionMixin from .base import MixedPrecisionMixin
@ -40,7 +40,7 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin):
max_scale=max_scale, max_scale=max_scale,
) )
self.optim_state = OptimState.UNSCALED self.optim_state = OptimState.UNSCALED
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device()) self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
@property @property
def loss_scale(self) -> float: def loss_scale(self) -> float:

View File

@ -4,10 +4,10 @@ from typing import Dict, Tuple
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule from .base_offload_module import BaseOffloadModule
from .region import Region from .region import Region
@ -79,7 +79,9 @@ class AMPOptimizer(OptimizerWrapper):
hysteresis=hysteresis, hysteresis=hysteresis,
max_scale=max_scale, max_scale=max_scale,
) )
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) self._found_overflow: torch.Tensor = torch.zeros(
1, dtype=torch.int64, device=get_accelerator().get_current_device()
)
self._logger = get_dist_logger() self._logger = get_dist_logger()
def _set_grad_ptr(self): def _set_grad_ptr(self):

View File

@ -11,7 +11,7 @@ except:
import torch import torch
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.utils.device import get_current_device from colossalai.accelerator import get_accelerator
from .region import Region from .region import Region
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
@ -57,7 +57,10 @@ class Solver(ABC):
if memory_budget > 0: if memory_budget > 0:
self.memory_budget = memory_budget * self.error_factor self.memory_budget = memory_budget * self.error_factor
else: else:
self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor self.memory_budget = (
torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory
* self.error_factor
)
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
self.comp_power: float = self._extract_computing_power() self.comp_power: float = self._extract_computing_power()

View File

@ -5,8 +5,8 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils.device import autocast
from .mixed_precision_base import MixedPrecision from .mixed_precision_base import MixedPrecision
@ -89,7 +89,7 @@ class TorchAMPModule(ModelWrapper):
super().__init__(module) super().__init__(module)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
with autocast(): with get_accelerator().autocast():
return self.module(*args, **kwargs) return self.module(*args, **kwargs)

View File

@ -12,6 +12,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.accelerator import IS_NPU_AVAILABLE, get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import ( from colossalai.checkpoint_io.utils import (
get_model_base_filenames, get_model_base_filenames,
@ -24,8 +25,6 @@ from colossalai.checkpoint_io.utils import (
from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device
from colossalai.utils.device import IS_NPU_AVAILABLE
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats from colossalai.zero.gemini.memory_tracer import MemStats
@ -367,7 +366,7 @@ class GeminiPlugin(DPPluginBase):
assert placement_policy == "static", "NPU only supports static placement policy" assert placement_policy == "static", "NPU only supports static placement policy"
self.gemini_config = dict( self.gemini_config = dict(
chunk_config_dict=chunk_config_dict, chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()), chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()),
placement_policy=placement_policy, placement_policy=placement_policy,
enable_gradient_accumulation=enable_gradient_accumulation, enable_gradient_accumulation=enable_gradient_accumulation,
shard_param_frac=shard_param_frac, shard_param_frac=shard_param_frac,

View File

@ -18,6 +18,7 @@ from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
@ -29,7 +30,6 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.utils.device import get_current_device
from .pp_plugin_base import PipelinePluginBase from .pp_plugin_base import PipelinePluginBase
@ -82,7 +82,7 @@ class HybridParallelModule(ModelWrapper):
self.mixed_precision = torch.bfloat16 self.mixed_precision = torch.bfloat16
if self.mixed_precision is not None: if self.mixed_precision is not None:
module = module.to(self.mixed_precision) module = module.to(self.mixed_precision)
module = module.to(get_current_device()) module = module.to(get_accelerator().get_current_device())
# setting input type cast when using mixed precision # setting input type cast when using mixed precision
self.convert_fn = None self.convert_fn = None
@ -346,7 +346,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
if norm_type == inf: if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients) total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1: if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if self.pp_size > 1: if self.pp_size > 1:
@ -385,7 +387,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32) total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1: if self.tp_size > 1:
# compute norm in tp process group # compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
@ -543,7 +547,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
# so we need to calculate the norm of 'tp' and 'pp' gradients. # so we need to calculate the norm of 'tp' and 'pp' gradients.
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1: if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
@ -586,7 +592,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32) total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1: if self.tp_size > 1:
# compute norm in tp process group # compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
@ -798,7 +806,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# so we only need to calculate the norm 'tp' of 'pp' gradients. # so we only need to calculate the norm 'tp' of 'pp' gradients.
total_norm = super()._compute_grad_norm(gradients, norm_type) total_norm = super()._compute_grad_norm(gradients, norm_type)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if tp_size > 1: if tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
@ -837,7 +847,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32) total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if dp_size > 1: if dp_size > 1:
# compute norm in dp process group # compute norm in dp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg) dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)

View File

@ -12,6 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import ( from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames, get_optimizer_base_filenames,
@ -24,7 +25,6 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue, sharded_optimizer_loading_epilogue,
) )
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
@ -52,7 +52,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
if self.dtype is not None: if self.dtype is not None:
module = module.to(self.dtype) module = module.to(self.dtype)
module = module.to(get_current_device()) module = module.to(get_accelerator().get_current_device())
self.module = module self.module = module
self.convert_fn = None self.convert_fn = None
if self.dtype is not None: if self.dtype is not None:

View File

@ -6,12 +6,12 @@ import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Dict, Union
import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.context import Config from colossalai.context import Config
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed from colossalai.utils import set_seed
def launch( def launch(
@ -47,17 +47,18 @@ def launch(
if rank == 0: if rank == 0:
warnings.warn("`config` is deprecated and will be removed soon.") warnings.warn("`config` is deprecated and will be removed soon.")
if IS_NPU_AVAILABLE and backend == "nccl": cur_accelerator = get_accelerator()
backend = "hccl"
backend = cur_accelerator.communication_backend
# init default process group # init default process group
init_method = f"tcp://[{host}]:{port}" init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device # set cuda device
if torch.cuda.is_available() or IS_NPU_AVAILABLE:
# if local rank is not given, calculate automatically # if local rank is not given, calculate automatically
set_device(local_rank) if cur_accelerator.support_set_device:
cur_accelerator.set_device(local_rank)
set_seed(seed) set_seed(seed)

View File

@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from colossalai.utils.device import get_current_device from colossalai.accelerator import get_accelerator
class Unpad(torch.autograd.Function): class Unpad(torch.autograd.Function):
@ -70,7 +70,9 @@ class SeqLenInfo:
cu_seqlens: torch.Tensor = None cu_seqlens: torch.Tensor = None
@staticmethod @staticmethod
def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()): def materialize(
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
):
if attn_mask is not None: if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()

View File

@ -1,7 +1,7 @@
import torch import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
from colossalai.utils import get_current_device
from .bias_dropout_add import bias_dropout_add_fused_train from .bias_dropout_add import bias_dropout_add_fused_train
from .bias_gelu import bias_gelu_impl from .bias_gelu import bias_gelu_impl
@ -46,11 +46,13 @@ def warmup_jit_fusion(
): ):
"""Compile JIT functions before the main training steps""" """Compile JIT functions before the main training steps"""
embed = Embedding(vocab_size, hidden_size).to(get_current_device()) embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device()) linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())
linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device()) linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_accelerator().get_current_device())
x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device()) x = torch.randint(
vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_accelerator().get_current_device()
)
x = embed(x) x = embed(x)
y, y_bias = linear_1(x) y, y_bias = linear_1(x)
z, z_bias = linear_2(y) z, z_bias = linear_2(y)
@ -58,8 +60,8 @@ def warmup_jit_fusion(
# prop and recomputation # prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]): for bias_grad, input_grad in zip([True, True], [False, True]):
for _ in range(10): for _ in range(10):
bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device()) bias = torch.rand_like(y_bias, dtype=dtype, device=get_accelerator().get_current_device())
input_ = torch.rand_like(y, dtype=dtype, device=get_current_device()) input_ = torch.rand_like(y, dtype=dtype, device=get_accelerator().get_current_device())
bias.requires_grad, input_.requires_grad = bias_grad, input_grad bias.requires_grad, input_.requires_grad = bias_grad, input_grad
bias_gelu_impl(input_, bias) bias_gelu_impl(input_, bias)
@ -69,9 +71,9 @@ def warmup_jit_fusion(
# prop and recomputation # prop and recomputation
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
for _ in range(10): for _ in range(10):
input_ = torch.rand_like(z, dtype=dtype, device=get_current_device()) input_ = torch.rand_like(z, dtype=dtype, device=get_accelerator().get_current_device())
residual = torch.rand_like(x, dtype=dtype, device=get_current_device()) residual = torch.rand_like(x, dtype=dtype, device=get_accelerator().get_current_device())
bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device()) bias = torch.rand_like(z_bias, dtype=dtype, device=get_accelerator().get_current_device())
input_.requires_grad = input_grad input_.requires_grad = input_grad
bias.requires_grad = bias_grad bias.requires_grad = bias_grad
residual.requires_grad = residual_grad residual.requires_grad = residual_grad

View File

@ -1,18 +1,19 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from colossalai.utils.device import autocast
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.legacy.utils import clip_grad_norm_fp32 from colossalai.legacy.utils import clip_grad_norm_fp32
from ._grad_scaler import GradScaler from ._grad_scaler import GradScaler
autocast = get_accelerator().autocast
class TorchAMPOptimizer(OptimizerWrapper): class TorchAMPOptimizer(OptimizerWrapper):
"""A wrapper class which integrate Pytorch AMP with an optimizer """A wrapper class which integrate Pytorch AMP with an optimizer

View File

@ -8,9 +8,9 @@ from typing import List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
@ -43,12 +43,16 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
if isinstance(recv_shapes, torch.Size): if isinstance(recv_shapes, torch.Size):
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors) recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) buffer_recv = torch.empty(
recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype
)
return buffer_recv, recv_split return buffer_recv, recv_split
buffer_recv = [] buffer_recv = []
for recv_shape in recv_shapes: for recv_shape in recv_shapes:
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors) recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) tensor_recv = torch.empty(
recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype
)
buffer_recv.append(tensor_recv) buffer_recv.append(tensor_recv)
return buffer_recv, recv_split return buffer_recv, recv_split

View File

@ -3,9 +3,9 @@
import torch import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device, synchronize
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor: def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor:
@ -29,7 +29,7 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
current_rank = gpc.get_global_rank() current_rank = gpc.get_global_rank()
tensor_recv_prev = torch.empty( tensor_recv_prev = torch.empty(
buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype buffer_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=tensor_send_next.dtype
) )
# send to next rank # send to next rank
@ -52,6 +52,6 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
req.wait() req.wait()
# To protect against race condition when using batch_isend_irecv(). # To protect against race condition when using batch_isend_irecv().
synchronize() get_accelerator().synchronize()
return tensor_recv_prev return tensor_recv_prev

View File

@ -3,9 +3,9 @@ from typing import List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
TensorShape = Union[torch.Size, List[int], Tuple[int]] TensorShape = Union[torch.Size, List[int], Tuple[int]]
@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
if next_rank is None: if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()}
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
send_obj_nums = torch.tensor(1, **tensor_kwargs) send_obj_nums = torch.tensor(1, **tensor_kwargs)
dist.send(send_obj_nums, next_rank) dist.send(send_obj_nums, next_rank)
@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
if prev_rank is None: if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()}
recv_obj_nums = torch.empty((), **tensor_kwargs) recv_obj_nums = torch.empty((), **tensor_kwargs)
dist.recv(recv_obj_nums, prev_rank) dist.recv(recv_obj_nums, prev_rank)
if recv_obj_nums.item() == 1: if recv_obj_nums.item() == 1:

View File

@ -6,8 +6,8 @@ from typing import Callable, Iterable
import torch import torch
from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
class BaseSchedule(ABC): class BaseSchedule(ABC):
@ -29,12 +29,12 @@ class BaseSchedule(ABC):
def _move_tensor(element): def _move_tensor(element):
if torch.is_tensor(element): if torch.is_tensor(element):
if not element.is_cuda: if not element.is_cuda:
return element.to(get_current_device()).detach() return element.to(get_accelerator().get_current_device()).detach()
return element return element
def _move_to_device(self, data): def _move_to_device(self, data):
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
data = data.to(get_current_device()) data = data.to(get_accelerator().get_current_device())
elif isinstance(data, (list, tuple)): elif isinstance(data, (list, tuple)):
data_to_return = [] data_to_return = []
for element in data: for element in data:

View File

@ -7,12 +7,12 @@ from typing import Callable, List, Tuple, Union
import torch.cuda import torch.cuda
import colossalai.legacy.communication as comm import colossalai.legacy.communication as comm
from colossalai.accelerator import get_accelerator
from colossalai.legacy.amp.naive_amp import NaiveAMPModel from colossalai.legacy.amp.naive_amp import NaiveAMPModel
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.device import get_current_device
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
@ -352,7 +352,7 @@ class PipelineSchedule(BaseSchedule):
output_objs = [] output_objs = []
return_tensors = [] return_tensors = []
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else: else:
accum_loss = None accum_loss = None
# Used for tensor meta information communication # Used for tensor meta information communication
@ -584,7 +584,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if not forward_only: if not forward_only:
output_obj_grads = [[] for _ in range(len(model))] output_obj_grads = [[] for _ in range(len(model))]
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else: else:
accum_loss = None accum_loss = None

View File

@ -6,10 +6,10 @@ from typing import Iterable, Tuple
import torch.cuda import torch.cuda
import colossalai.legacy.communication.p2p_v2 as comm import colossalai.legacy.communication.p2p_v2 as comm
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.engine import Engine from colossalai.legacy.engine import Engine
from colossalai.utils.device import get_current_device
from ._pipeline_schedule import PipelineSchedule from ._pipeline_schedule import PipelineSchedule
@ -99,7 +99,7 @@ class PipelineScheduleV2(PipelineSchedule):
output_objs = [] output_objs = []
return_tensors = [] return_tensors = []
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else: else:
accum_loss = None accum_loss = None

View File

@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.accelerator import get_accelerator
from colossalai.context import Config, ConfigException from colossalai.context import Config, ConfigException
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.legacy.amp import AMP_TYPE, convert_to_amp from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
@ -34,7 +35,6 @@ from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence
from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2 from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
def get_default_parser(): def get_default_parser():
@ -309,9 +309,9 @@ def initialize(
else: else:
if isinstance(model, nn.Module): if isinstance(model, nn.Module):
# first sync model across dp ranks # first sync model across dp ranks
model.to(get_current_device()) model.to(get_accelerator().get_current_device())
elif isinstance(model, Callable): elif isinstance(model, Callable):
model = model().to(get_current_device()) model = model().to(get_accelerator().get_current_device())
# optimizer maybe a optimizer_cls # optimizer maybe a optimizer_cls
if isinstance(optimizer, Callable): if isinstance(optimizer, Callable):

View File

@ -3,8 +3,8 @@ from typing import Callable
from torch import dtype, nn from torch import dtype, nn
from colossalai.accelerator import get_accelerator
from colossalai.nn import init from colossalai.nn import init
from colossalai.utils import get_current_device
from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D
from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D
@ -83,7 +83,7 @@ class Embedding(ColossalaiModule):
embed = ( embed = (
nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs) nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs)
.to(dtype) .to(dtype)
.to(get_current_device()) .to(get_accelerator().get_current_device())
) )
weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
elif num_embeddings <= vocab_parallel_limit: elif num_embeddings <= vocab_parallel_limit:

View File

@ -1,6 +1,6 @@
from torch import nn from torch import nn
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
from ..parallel_1d import LayerNorm1D from ..parallel_1d import LayerNorm1D
from ..parallel_2d import LayerNorm2D from ..parallel_2d import LayerNorm2D
@ -36,7 +36,7 @@ class LayerNorm(ColossalaiModule):
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None: def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None:
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel is None: if tensor_parallel is None:
norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_accelerator().get_current_device())
else: else:
norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
super().__init__(norm) super().__init__(norm)

View File

@ -10,6 +10,7 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.kernel import LayerNorm from colossalai.kernel import LayerNorm
from colossalai.legacy.communication import broadcast from colossalai.legacy.communication import broadcast
from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.context import ParallelMode, seed
@ -22,7 +23,6 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule from ..colossalai_layer._utils import ColossalaiModule
@ -221,7 +221,7 @@ class Classifier1D(ParallelLayer):
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
if weight is not None: if weight is not None:
self.weight = weight self.weight = weight
self.has_weight = False self.has_weight = False
@ -357,7 +357,7 @@ class VocabParallelClassifier1D(ParallelLayer):
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
if weight is not None: if weight is not None:
self.weight = weight self.weight = weight
self.has_weight = False self.has_weight = False
@ -499,7 +499,7 @@ class Linear1D_Col(ParallelLayer):
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
if bias: if bias:
@ -638,7 +638,7 @@ class Linear1D_Row(ParallelLayer):
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
if self.stream_chunk_num > 1: if self.stream_chunk_num > 1:
@ -802,7 +802,9 @@ class Embedding1D(ParallelLayer):
self.embed_kwargs = kwargs self.embed_kwargs = kwargs
self.weight = Parameter( self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) torch.empty(
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.reset_parameters(weight_initializer) self.reset_parameters(weight_initializer)
@ -912,7 +914,11 @@ class VocabParallelEmbedding1D(ParallelLayer):
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter( self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype) torch.empty(
(self.num_embeddings_per_partition, self.embed_dim),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
) )
self.reset_parameters(weight_initializer) self.reset_parameters(weight_initializer)

View File

@ -5,10 +5,10 @@ import torch.distributed as dist
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
def matmul_2d( def matmul_2d(
@ -250,7 +250,7 @@ class Matmul_AB_2D(torch.autograd.Function):
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1]) C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor # use circular buffer to store the communication tensor
# 2 is enough for all cases # 2 is enough for all cases
@ -399,7 +399,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[0]) C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor # use circular buffer to store the communication tensor
# 2 is enough for all cases # 2 is enough for all cases
@ -556,7 +556,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[-1], B.shape[-1]) C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor # use circular buffer to store the communication tensor
# 2 is enough for all cases # 2 is enough for all cases

View File

@ -8,6 +8,7 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import broadcast from colossalai.legacy.communication import broadcast
from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.context import ParallelMode, seed
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
@ -18,7 +19,6 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
@ -82,7 +82,7 @@ class Linear2D(ParallelLayer):
self.hidden_size_per_partition = divide(self.out_features, self.summa_dim) self.hidden_size_per_partition = divide(self.out_features, self.summa_dim)
# create weight, shape: [k/q, h/q] # create weight, shape: [k/q, h/q]
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)
) )
@ -259,7 +259,7 @@ class LayerNorm2D(ParallelLayer):
self.partitioned_partition = divide(normalized_shape, self.summa_dim**2) self.partitioned_partition = divide(normalized_shape, self.summa_dim**2)
# create parameters # create parameters
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
if bias: if bias:
@ -438,18 +438,24 @@ class PatchEmbedding2D(ParallelLayer):
self.weight = Parameter( self.weight = Parameter(
torch.empty( torch.empty(
(self.embed_size_per_partition, in_chans, *self.patch_size), (self.embed_size_per_partition, in_chans, *self.patch_size),
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=dtype, dtype=dtype,
) )
) )
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) self.bias = Parameter(
torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
)
self.cls_token = Parameter( self.cls_token = Parameter(
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) torch.zeros(
(1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.pos_embed = Parameter( self.pos_embed = Parameter(
torch.zeros( torch.zeros(
(1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype (1, self.num_patches + 1, self.embed_size_per_partition),
device=get_accelerator().get_current_device(),
dtype=dtype,
) )
) )
@ -619,7 +625,9 @@ class Embedding2D(ParallelLayer):
self.embed_kwargs = kwargs self.embed_kwargs = kwargs
self.weight = Parameter( self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) torch.empty(
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.reset_parameters(weight_initializer) self.reset_parameters(weight_initializer)
@ -758,7 +766,7 @@ class VocabParallelEmbedding2D(ParallelLayer):
self.weight = Parameter( self.weight = Parameter(
torch.empty( torch.empty(
(self.num_embeddings_per_partition, self.embed_dim_per_partition), (self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=dtype, dtype=dtype,
) )
) )
@ -895,11 +903,18 @@ class Classifier2D(ParallelLayer):
self.has_weight = False self.has_weight = False
else: else:
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) torch.empty(
self.num_classes,
self.input_size_per_partition,
device=get_accelerator().get_current_device(),
dtype=dtype,
)
) )
self.has_weight = True self.has_weight = True
if bias: if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) self.bias = Parameter(
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
)
else: else:
self.bias = None self.bias = None
@ -1052,7 +1067,7 @@ class VocabParallelClassifier2D(ParallelLayer):
self.output_size_per_partition = divide(num_classes, self.summa_dim) self.output_size_per_partition = divide(num_classes, self.summa_dim)
# create weight, shape: [k/q, h/q] # create weight, shape: [k/q, h/q]
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
if weight is not None: if weight is not None:
self.weight = weight self.weight = weight
self.has_weight = False self.has_weight = False

View File

@ -5,10 +5,10 @@ import torch.distributed as dist
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
def get_parallel_group(parallel_mode: ParallelMode): def get_parallel_group(parallel_mode: ParallelMode):
@ -205,7 +205,7 @@ class Matmul_AB_2p5D(torch.autograd.Function):
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1]) C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor # use circular buffer to store the communication tensor
# 2 is enough for all cases # 2 is enough for all cases
@ -362,7 +362,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[0]) C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor # use circular buffer to store the communication tensor
# 2 is enough for all cases # 2 is enough for all cases
@ -527,7 +527,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[-1], B.shape[-1]) C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor # use circular buffer to store the communication tensor
# 2 is enough for all cases # 2 is enough for all cases
@ -661,7 +661,9 @@ class _Add_Bias_2p5D(torch.autograd.Function):
if row_rank == 0: if row_rank == 0:
bias_temp = bias.clone() bias_temp = bias.clone()
else: else:
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) bias_temp = torch.zeros(
output_size_per_partition, dtype=bias.dtype, device=get_accelerator().get_current_device()
)
src_rank = ( src_rank = (
col_rank col_rank
+ dep_rank * tesseract_dim**2 + dep_rank * tesseract_dim**2
@ -984,7 +986,7 @@ class SplitFirst(torch.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size,) + output_grad.shape[1:] grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_accelerator().get_current_device())
dist.all_gather( dist.all_gather(
list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode) list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode)
) )

View File

@ -8,6 +8,7 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import broadcast from colossalai.legacy.communication import broadcast
from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.context import ParallelMode, seed
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
@ -19,7 +20,6 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
@ -84,7 +84,7 @@ class Linear2p5D(ParallelLayer):
self.hidden_size_per_partition = divide(out_features, self.tesseract_dim) self.hidden_size_per_partition = divide(out_features, self.tesseract_dim)
# create weight, shape: [k/q, h/q] # create weight, shape: [k/q, h/q]
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)
) )
@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer):
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
# create parameters # create parameters
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
if bias: if bias:
@ -451,18 +451,24 @@ class PatchEmbedding2p5D(ParallelLayer):
self.weight = Parameter( self.weight = Parameter(
torch.empty( torch.empty(
(self.embed_size_per_partition, in_chans, *self.patch_size), (self.embed_size_per_partition, in_chans, *self.patch_size),
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=dtype, dtype=dtype,
) )
) )
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) self.bias = Parameter(
torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
)
self.cls_token = Parameter( self.cls_token = Parameter(
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) torch.zeros(
(1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.pos_embed = Parameter( self.pos_embed = Parameter(
torch.zeros( torch.zeros(
(1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype (1, self.num_patches + 1, self.embed_size_per_partition),
device=get_accelerator().get_current_device(),
dtype=dtype,
) )
) )
@ -632,7 +638,9 @@ class Embedding2p5D(ParallelLayer):
self.embed_kwargs = kwargs self.embed_kwargs = kwargs
self.weight = Parameter( self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) torch.empty(
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.reset_parameters(weight_initializer) self.reset_parameters(weight_initializer)
@ -772,7 +780,7 @@ class VocabParallelEmbedding2p5D(ParallelLayer):
self.weight = Parameter( self.weight = Parameter(
torch.empty( torch.empty(
(self.num_embeddings_per_partition, self.embed_dim_per_partition), (self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=dtype, dtype=dtype,
) )
) )
@ -910,11 +918,18 @@ class Classifier2p5D(ParallelLayer):
self.has_weight = False self.has_weight = False
else: else:
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) torch.empty(
self.num_classes,
self.input_size_per_partition,
device=get_accelerator().get_current_device(),
dtype=dtype,
)
) )
self.has_weight = True self.has_weight = True
if bias: if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) self.bias = Parameter(
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
)
else: else:
self.bias = None self.bias = None
@ -1068,7 +1083,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim) self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim)
# create weight, shape: [k/q, h/q] # create weight, shape: [k/q, h/q]
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
if weight is not None: if weight is not None:
self.weight = weight self.weight = weight
self.has_weight = False self.has_weight = False

View File

@ -8,6 +8,7 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import all_reduce, broadcast from colossalai.legacy.communication import all_reduce, broadcast
from colossalai.legacy.constants import ( from colossalai.legacy.constants import (
INPUT_GROUP_3D, INPUT_GROUP_3D,
@ -27,7 +28,6 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import ( from ._operation import (
@ -69,11 +69,13 @@ class LayerNorm3D(ParallelLayer):
self.normalized_shape_per_partition = divide(normalized_shape, self.depth) self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
self.weight = Parameter( self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) torch.ones(self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
) )
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) torch.zeros(
self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
)
) )
else: else:
self.bias = None self.bias = None
@ -202,13 +204,15 @@ class Linear3D(ParallelLayer):
torch.empty( torch.empty(
self.in_features_per_partition, self.in_features_per_partition,
self.out_features_per_partition, self.out_features_per_partition,
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=dtype, dtype=dtype,
) )
) )
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) torch.zeros(
self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
)
) )
else: else:
self.bias = None self.bias = None
@ -380,11 +384,18 @@ class Classifier3D(ParallelLayer):
self.has_weight = False self.has_weight = False
else: else:
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype) torch.empty(
self.num_classes,
self.in_features_per_partition,
device=get_accelerator().get_current_device(),
dtype=dtype,
)
) )
self.has_weight = True self.has_weight = True
if bias: if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) self.bias = Parameter(
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
)
else: else:
self.bias = None self.bias = None
@ -523,14 +534,16 @@ class VocabParallelClassifier3D(ParallelLayer):
torch.empty( torch.empty(
self.out_features_per_partition, self.out_features_per_partition,
self.in_features_per_partition, self.in_features_per_partition,
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=dtype, dtype=dtype,
) )
) )
self.has_weight = True self.has_weight = True
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) torch.zeros(
self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
)
) )
else: else:
self.bias = None self.bias = None
@ -705,16 +718,24 @@ class PatchEmbedding3D(ParallelLayer):
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty( torch.empty(
(embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype (embed_size_per_partition, in_chans, *self.patch_size),
device=get_accelerator().get_current_device(),
dtype=dtype,
) )
) )
self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype)) self.bias = nn.Parameter(
torch.empty(embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
)
self.cls_token = nn.Parameter( self.cls_token = nn.Parameter(
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) torch.zeros((1, 1, embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype)
) )
self.pos_embed = nn.Parameter( self.pos_embed = nn.Parameter(
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) torch.zeros(
(1, self.num_patches + 1, embed_size_per_partition),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
) )
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
@ -880,7 +901,9 @@ class Embedding3D(ParallelLayer):
self.embed_kwargs = kwargs self.embed_kwargs = kwargs
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) torch.empty(
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.reset_parameters(weight_initializer) self.reset_parameters(weight_initializer)
@ -1019,7 +1042,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
self.weight = Parameter( self.weight = Parameter(
torch.empty( torch.empty(
(self.num_embeddings_per_partition, self.embed_dim_per_partition), (self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=dtype, dtype=dtype,
) )
) )

View File

@ -5,11 +5,11 @@ import torch
from torch import distributed as dist from torch import distributed as dist
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import ring_forward from colossalai.legacy.communication import ring_forward
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range
from colossalai.utils import get_current_device
class RingQK(torch.autograd.Function): class RingQK(torch.autograd.Function):
@ -30,7 +30,7 @@ class RingQK(torch.autograd.Function):
sub_seq_length, sub_seq_length,
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
dtype=sub_q.dtype, dtype=sub_q.dtype,
device=get_current_device(), device=get_accelerator().get_current_device(),
) )
# compute local QK^T # compute local QK^T
@ -71,7 +71,7 @@ class RingQK(torch.autograd.Function):
grad_q = torch.zeros_like( grad_q = torch.zeros_like(
sub_q, sub_q,
dtype=sub_q.dtype, dtype=sub_q.dtype,
device=get_current_device(), device=get_accelerator().get_current_device(),
) )
# compute with local sub_k # compute with local sub_k
@ -105,7 +105,7 @@ class RingAV(torch.autograd.Function):
batch_size * num_attention_heads, batch_size * num_attention_heads,
sub_seq_length, sub_seq_length,
attention_head_size, attention_head_size,
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=attention_score.dtype, dtype=attention_score.dtype,
) )
@ -142,7 +142,9 @@ class RingAV(torch.autograd.Function):
grad_v /= local_world_size grad_v /= local_world_size
# calculate gradient for attention score # calculate gradient for attention score
grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device()) grad_attention_score = torch.zeros_like(
attention_scores, dtype=grad_output.dtype, device=get_accelerator().get_current_device()
)
# compute with local sub_k # compute with local sub_k
grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1))

View File

@ -7,10 +7,10 @@ from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context import seed from colossalai.legacy.context import seed
from colossalai.legacy.registry import LAYERS from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from ..utils import to_2tuple from ..utils import to_2tuple
@ -173,12 +173,18 @@ class VanillaPatchEmbedding(nn.Module):
self.flatten = flatten self.flatten = flatten
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype) torch.empty(
(embed_size, in_chans, *self.patch_size), device=get_accelerator().get_current_device(), dtype=dtype
)
)
self.bias = nn.Parameter(torch.empty(embed_size, device=get_accelerator().get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(
torch.zeros((1, 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype)
) )
self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter( self.pos_embed = nn.Parameter(
torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype) torch.zeros(
(1, self.num_patches + 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
@ -242,11 +248,15 @@ class VanillaClassifier(nn.Module):
self.has_weight = False self.has_weight = False
else: else:
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype) torch.empty(
self.num_classes, self.in_features, device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.has_weight = True self.has_weight = True
if bias: if bias:
self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) self.bias = nn.Parameter(
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
)
else: else:
self.bias = None self.bias = None
@ -287,7 +297,7 @@ class VanillaLayerNorm(nn.Module):
self.normalized_shape = (normalized_shape,) self.normalized_shape = (normalized_shape,)
self.variance_epsilon = eps self.variance_epsilon = eps
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs))
if bias: if bias:
@ -333,7 +343,7 @@ class VanillaLinear(nn.Module):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
if bias: if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))

View File

@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization
from colossalai.legacy.registry import LOSSES from colossalai.legacy.registry import LOSSES
from colossalai.utils import get_current_device
@LOSSES.register_module @LOSSES.register_module
@ -118,7 +118,7 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function):
grad_2d = grad_input.view(-1, partition_vocab_size) grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes. # Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
# Finally elementwise multiplication with the output gradients. # Finally elementwise multiplication with the output gradients.

View File

@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
from colossalai.legacy.registry import LOSSES from colossalai.legacy.registry import LOSSES
from colossalai.utils import get_current_device
@LOSSES.register_module @LOSSES.register_module
@ -112,7 +112,7 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function):
grad_2d = grad_input.view(-1, partition_vocab_size) grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes. # Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
# Finally elementwise multiplication with the output gradients. # Finally elementwise multiplication with the output gradients.

View File

@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.accelerator import get_accelerator
from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.legacy.registry import LOSSES from colossalai.legacy.registry import LOSSES
from colossalai.utils import get_current_device
@LOSSES.register_module @LOSSES.register_module
@ -80,7 +80,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
target_mask = (targets < vocab_start) | (targets > vocab_end) target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0 masked_target[target_mask] = 0
arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device()) arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_accelerator().get_current_device())
predicted_logits = logits[arange_1d, masked_target] predicted_logits = logits[arange_1d, masked_target]
predicted_logits = predicted_logits.clone().contiguous().view_as(targets) predicted_logits = predicted_logits.clone().contiguous().view_as(targets)
predicted_logits[target_mask] = 0.0 predicted_logits[target_mask] = 0.0
@ -110,7 +110,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
grad_2d = input_grad.view(-1, partition_vocab_size) grad_2d = input_grad.view(-1, partition_vocab_size)
# Add the gradient from matching classes. # Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
input_grad.mul_(output_grad.unsqueeze(dim=-1)) input_grad.mul_(output_grad.unsqueeze(dim=-1))

View File

@ -7,12 +7,12 @@ from typing import Callable
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import all_reduce from colossalai.legacy.communication import all_reduce
from colossalai.legacy.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.registry import HOOKS from colossalai.legacy.registry import HOOKS
from colossalai.legacy.utils import is_no_pp_or_last_stage from colossalai.legacy.utils import is_no_pp_or_last_stage
from colossalai.utils import get_current_device
from ._base_hook import BaseHook from ._base_hook import BaseHook
from ._commons_ import _format_number from ._commons_ import _format_number
@ -82,8 +82,8 @@ class LossMetric(Metric):
def __init__(self, epoch_only): def __init__(self, epoch_only):
super().__init__(epoch_only=epoch_only) super().__init__(epoch_only=epoch_only)
self.last_step_loss = torch.zeros(1, device=get_current_device()) self.last_step_loss = torch.zeros(1, device=get_accelerator().get_current_device())
self.accum_loss = torch.zeros(1, device=get_current_device()) self.accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
self.count = 0 self.count = 0
def reset(self) -> None: def reset(self) -> None:
@ -164,10 +164,10 @@ class AccuracyMetric(Metric):
def __init__(self, epoch_only: bool, accuracy_func: Callable): def __init__(self, epoch_only: bool, accuracy_func: Callable):
super().__init__(epoch_only=epoch_only) super().__init__(epoch_only=epoch_only)
self.acc = accuracy_func self.acc = accuracy_func
self.last_step_sum = torch.zeros(1, device=get_current_device()) self.last_step_sum = torch.zeros(1, device=get_accelerator().get_current_device())
self.last_step_correct = torch.zeros(1, device=get_current_device()) self.last_step_correct = torch.zeros(1, device=get_accelerator().get_current_device())
self.accumulated_sum = torch.zeros(1, device=get_current_device()) self.accumulated_sum = torch.zeros(1, device=get_accelerator().get_current_device())
self.accumulated_correct = torch.zeros(1, device=get_current_device()) self.accumulated_correct = torch.zeros(1, device=get_accelerator().get_current_device())
def reset(self) -> None: def reset(self) -> None:
self.last_step_sum.zero_() self.last_step_sum.zero_()
@ -320,10 +320,10 @@ class ThroughputMetric(Metric):
super().__init__(epoch_only=epoch_only) super().__init__(epoch_only=epoch_only)
self.ignored_steps = ignored_steps self.ignored_steps = ignored_steps
self.cur_steps = 0 self.cur_steps = 0
self.accumulated_num_samples = torch.zeros(1, device=get_current_device()) self.accumulated_num_samples = torch.zeros(1, device=get_accelerator().get_current_device())
self.accumulated_used_time = torch.zeros(1, device=get_current_device()) self.accumulated_used_time = torch.zeros(1, device=get_accelerator().get_current_device())
self.last_step_num_samples = torch.zeros(1, device=get_current_device()) self.last_step_num_samples = torch.zeros(1, device=get_accelerator().get_current_device())
self.last_step_used_time = torch.zeros(1, device=get_current_device()) self.last_step_used_time = torch.zeros(1, device=get_accelerator().get_current_device())
self._tflop_per_step = tflop_per_step self._tflop_per_step = tflop_per_step
self._use_local = use_local self._use_local = use_local

View File

@ -6,8 +6,8 @@ import weakref
import torch import torch
from torch.utils.checkpoint import check_backward_validity, detach_variable from torch.utils.checkpoint import check_backward_validity, detach_variable
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
from colossalai.utils.device import autocast, get_current_device
def copy_to_device(obj, device): def copy_to_device(obj, device):
@ -33,7 +33,7 @@ class CheckpointFunction(torch.autograd.Function):
check_backward_validity(args) check_backward_validity(args)
ctx.run_function = run_function ctx.run_function = run_function
ctx.activation_offload = activation_offload ctx.activation_offload = activation_offload
ctx.device = get_current_device() ctx.device = get_accelerator().get_current_device()
# preserve rng states # preserve rng states
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function):
inputs[idx] = tensors[i] inputs[idx] = tensors[i]
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
if ctx.had_autocast_in_fwd: if ctx.had_autocast_in_fwd:
with torch.enable_grad(), autocast(): with torch.enable_grad(), get_accelerator().autocast()():
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs)
else: else:
with torch.enable_grad(): with torch.enable_grad():
@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# rerun forward, the inner_pack will store all the activations in storage # rerun forward, the inner_pack will store all the activations in storage
if has_autocast_in_fwd: if has_autocast_in_fwd:
with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks( with torch.enable_grad(), get_accelerator().autocast()(), torch.autograd.graph.saved_tensors_hooks(
inner_pack, inner_unpack inner_pack, inner_unpack
): ):
_unused = function(*args) _unused = function(*args)
@ -245,7 +245,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# get device if we need to offload the activation # get device if we need to offload the activation
if activation_offload: if activation_offload:
device = get_current_device() device = get_accelerator().get_current_device()
# run function with pack and unpack as saved_tensors_hooks # run function with pack and unpack as saved_tensors_hooks
with torch.autograd.graph.saved_tensors_hooks(pack, unpack): with torch.autograd.graph.saved_tensors_hooks(pack, unpack):

View File

@ -6,9 +6,9 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from packaging import version from packaging import version
from colossalai.accelerator import get_accelerator
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
_GLOBAL_CUDA_MEM_FRACTION = 1.0 _GLOBAL_CUDA_MEM_FRACTION = 1.0
_GLOBAL_CPU_MEM_CAPACITY = -1 _GLOBAL_CPU_MEM_CAPACITY = -1
@ -112,7 +112,10 @@ def colo_device_memory_capacity(device: torch.device) -> int:
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory. # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node
if device.type == "cuda": if device.type == "cuda":
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION return (
torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory
* _GLOBAL_CUDA_MEM_FRACTION
)
def colo_device_memory_used(device: torch.device) -> int: def colo_device_memory_used(device: torch.device) -> int:
@ -153,7 +156,7 @@ def colo_set_process_memory_fraction(ratio: float) -> None:
return return
global _GLOBAL_CUDA_MEM_FRACTION global _GLOBAL_CUDA_MEM_FRACTION
_GLOBAL_CUDA_MEM_FRACTION = ratio _GLOBAL_CUDA_MEM_FRACTION = ratio
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device()) torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_accelerator().get_current_device())
def colo_set_cpu_memory_capacity(size: int) -> None: def colo_set_cpu_memory_capacity(size: int) -> None:

View File

@ -8,7 +8,7 @@ import torch.distributed as dist
from torch.autograd.profiler import profile from torch.autograd.profiler import profile
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time
@ -177,7 +177,7 @@ class CommProfiler(BaseProfiler):
assert current_comm_event is not None, "dist op has not been found" assert current_comm_event is not None, "dist op has not been found"
buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_accelerator().get_current_device())
torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) torch_all_reduce(buffer, op=ReduceOp.MIN, group=group)
current_comm_event.self_cuda_time = buffer.item() current_comm_event.self_cuda_time = buffer.item()

View File

@ -3,7 +3,7 @@ import types
from time import time from time import time
from typing import List from typing import List
from colossalai.utils.device import get_current_device from colossalai.accelerator import get_accelerator
from .stateful_tensor import StatefulTensor, TensorState from .stateful_tensor import StatefulTensor, TensorState
from .tensor_placement_policy import TensorPlacementPolicy from .tensor_placement_policy import TensorPlacementPolicy
@ -69,7 +69,7 @@ class StatefulTensorMgr(object):
# move COMPUTE tensors to CUDA # move COMPUTE tensors to CUDA
self._cpu_gpu_move_volume += cuda_demand self._cpu_gpu_move_volume += cuda_demand
for t in move_to_cuda_tensor_list: for t in move_to_cuda_tensor_list:
colo_model_data_tensor_move_inline(t, get_current_device()) colo_model_data_tensor_move_inline(t, get_accelerator().get_current_device())
@property @property
def cpu_gpu_move_volume(self): def cpu_gpu_move_volume(self):

View File

@ -5,8 +5,8 @@ from typing import List, Optional, Type
import torch import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.utils import get_current_device
from colossalai.zero.gemini.memory_tracer import MemStatsCollector from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from .stateful_tensor import StatefulTensor from .stateful_tensor import StatefulTensor
@ -38,7 +38,7 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy):
class CUDATensorPlacementPolicy(TensorPlacementPolicy): class CUDATensorPlacementPolicy(TensorPlacementPolicy):
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available" assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available"
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector) super().__init__(get_accelerator().get_current_device(), mem_stats_collector=mem_stats_collector)
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
return 0, 0 return 0, 0
@ -78,7 +78,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
int: the volume of memory that is evicted int: the volume of memory that is evicted
""" """
start = time() start = time()
cuda_capacity = colo_device_memory_capacity(get_current_device()) cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"] used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"]
if warmup: if warmup:
# We designate a part of CUDA memory for model data in warmup iterations. # We designate a part of CUDA memory for model data in warmup iterations.

View File

@ -4,8 +4,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch._utils import _flatten_dense_tensors as flatten from torch._utils import _flatten_dense_tensors as flatten
from colossalai.accelerator import get_accelerator
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
from .tensor_shard_strategy import TensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy
@ -30,9 +30,11 @@ class BucketTensorShardStrategy(TensorShardStrategy):
rank = dist.get_rank(process_group) rank = dist.get_rank(process_group)
for i in range(world_size): for i in range(world_size):
if i == rank: if i == rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) buffer_list.append(
flatten([t.payload for t in tensor_list]).cuda(get_accelerator().get_current_device())
)
else: else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_accelerator().get_current_device()))
dist.all_gather(buffer_list, buffer_list[rank], group=process_group) dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
# Move to target device before splitting buffer # Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth # Ensure we utilize maximum PCIE bandwidth

View File

@ -3,11 +3,11 @@ from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.legacy.zero.shard_utils import BaseShardStrategy
from colossalai.legacy.zero.shard_utils.commons import get_shard from colossalai.legacy.zero.shard_utils.commons import get_shard
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
class TensorShardStrategy(BaseShardStrategy): class TensorShardStrategy(BaseShardStrategy):
@ -34,9 +34,9 @@ class TensorShardStrategy(BaseShardStrategy):
if t.is_sharded: if t.is_sharded:
return return
if t.payload.device.type == "cuda": if t.payload.device.type == "cuda":
assert t.payload.device == get_current_device(), ( assert t.payload.device == get_accelerator().get_current_device(), (
f"shard tensor on cuda device index {t.payload.device.index}," f"shard tensor on cuda device index {t.payload.device.index},"
f" but current cuda device is {get_current_device()}" f" but current cuda device is {get_accelerator().get_current_device()}"
) )
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.payload_reset(sharded_payload) t.payload_reset(sharded_payload)
@ -50,7 +50,9 @@ class TensorShardStrategy(BaseShardStrategy):
world_size = dist.get_world_size(process_group) world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group) rank = dist.get_rank(process_group)
buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device()) buffer = torch.empty(
payload_numel * world_size, dtype=t.payload.dtype, device=get_accelerator().get_current_device()
)
buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0)) buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
buffer_list[rank].copy_(t.payload) buffer_list[rank].copy_(t.payload)

View File

@ -10,6 +10,7 @@ import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.legacy.utils.memory import colo_device_memory_capacity
@ -22,7 +23,7 @@ from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_move_to_c
from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.legacy.zero.shard_utils import BaseShardStrategy
from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import disposable, get_current_device from colossalai.utils import disposable
from colossalai.zero.gemini.memory_tracer import MemStatsCollector from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from ._utils import ( from ._utils import (
@ -212,8 +213,12 @@ class ShardedModelV2(nn.Module):
self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0]) self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0])
if gpc.get_global_rank() == 0: if gpc.get_global_rank() == 0:
with open(filename, "w+") as f: with open(filename, "w+") as f:
f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n") f.write(
f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n") f"cuda reserved {torch.cuda.memory_reserved(get_accelerator().get_current_device()) / 1e9} GB\n"
)
f.write(
f"cuda max allocated {torch.cuda.max_memory_allocated(get_accelerator().get_current_device()) / 1e9} GB\n"
)
f.write("CUDA model data (GB)\n") f.write("CUDA model data (GB)\n")
f.write("\n") f.write("\n")
f.write("CUDA non model data (GB)\n") f.write("CUDA non model data (GB)\n")
@ -266,7 +271,8 @@ class ShardedModelV2(nn.Module):
# model data is fixed in cuda during training. # model data is fixed in cuda during training.
# cuda margin space can be used to store OS. # cuda margin space can be used to store OS.
self._cuda_margin_space = ( self._cuda_margin_space = (
colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda colo_device_memory_capacity(get_accelerator().get_current_device())
- self._memstats_collector._memstats.max_overall_cuda
) )
@torch.no_grad() @torch.no_grad()

View File

@ -3,13 +3,13 @@ from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.registry import OPHOOKS from colossalai.legacy.registry import OPHOOKS
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
from colossalai.legacy.zero.gemini.stateful_tensor import TensorState from colossalai.legacy.zero.gemini.stateful_tensor import TensorState
from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.legacy.zero.shard_utils import BaseShardStrategy
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.zero.gemini.memory_tracer import MemStatsCollector from colossalai.zero.gemini.memory_tracer import MemStatsCollector
@ -33,7 +33,7 @@ class ZeroHook(BaseOpHook):
self.process_group = process_group self.process_group = process_group
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self.computing_device = get_current_device() self.computing_device = get_accelerator().get_current_device()
self._memstarts_collector = memstarts_collector self._memstarts_collector = memstarts_collector
self._stateful_tensor_mgr = stateful_tensor_mgr self._stateful_tensor_mgr = stateful_tensor_mgr

View File

@ -8,9 +8,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.accelerator import get_accelerator
from colossalai.moe._operation import moe_cumsum from colossalai.moe._operation import moe_cumsum
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.utils import get_current_device
class MoeRouter(nn.Module, ABC): class MoeRouter(nn.Module, ABC):
@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC):
drop_tks (bool, optional): Whether drops tokens in evaluation drop_tks (bool, optional): Whether drops tokens in evaluation
""" """
def __init__(self, def __init__(
self,
k_value: int, k_value: int,
capacity_factor_train: float, capacity_factor_train: float,
capacity_factor_eval: float, capacity_factor_eval: float,
min_capacity: int, min_capacity: int,
noisy_func: Optional[Callable] = None, noisy_func: Optional[Callable] = None,
drop_tks: bool = True, drop_tks: bool = True,
use_kernel: bool = False): use_kernel: bool = False,
):
super().__init__() super().__init__()
self.k_value = k_value self.k_value = k_value
self.capacity_factor_train = capacity_factor_train self.capacity_factor_train = capacity_factor_train
@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC):
if router_probs.dim() == expert_indices.dim() == 2: if router_probs.dim() == expert_indices.dim() == 2:
router_probs = router_probs.unsqueeze(0) router_probs = router_probs.unsqueeze(0)
expert_indices = expert_indices.unsqueeze(0) expert_indices = expert_indices.unsqueeze(0)
assert router_probs.dim() == expert_indices.dim() == 3, \ assert (
"router_probs must be 3D tensor and expert_indices must be 4D tensor" router_probs.dim() == expert_indices.dim() == 3
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_indices, num_experts) expert_mask = F.one_hot(expert_indices, num_experts)
@ -122,25 +125,29 @@ class Top1Router(MoeRouter):
drop_tks (bool, optional): Whether drops tokens in evaluation drop_tks (bool, optional): Whether drops tokens in evaluation
""" """
def __init__(self, def __init__(
self,
capacity_factor_train: float = 1.25, capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0, capacity_factor_eval: float = 2.0,
min_capacity: int = 4, min_capacity: int = 4,
select_policy: str = "first", select_policy: str = "first",
noisy_func: Optional[Callable] = None, noisy_func: Optional[Callable] = None,
drop_tks: bool = True): drop_tks: bool = True,
super().__init__(k_value=1, ):
super().__init__(
k_value=1,
capacity_factor_train=capacity_factor_train, capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval, capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity, min_capacity=min_capacity,
noisy_func=noisy_func, noisy_func=noisy_func,
drop_tks=drop_tks) drop_tks=drop_tks,
)
self.select_policy = select_policy self.select_policy = select_policy
assert select_policy in {"first", "random"} assert select_policy in {"first", "random"}
if select_policy == "random": if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform( self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_current_device()), low=torch.tensor(0.0, device=get_accelerator().get_current_device()),
high=torch.tensor(1.0, device=get_current_device()) high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
).rsample ).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
@ -216,18 +223,22 @@ class Top2Router(MoeRouter):
drop_tks (bool, optional): Whether drops tokens in evaluation. drop_tks (bool, optional): Whether drops tokens in evaluation.
""" """
def __init__(self, def __init__(
self,
capacity_factor_train: float = 1.25, capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0, capacity_factor_eval: float = 2.0,
min_capacity: int = 4, min_capacity: int = 4,
noisy_func: Optional[Callable] = None, noisy_func: Optional[Callable] = None,
drop_tks: bool = True): drop_tks: bool = True,
super().__init__(k_value=2, ):
super().__init__(
k_value=2,
capacity_factor_train=capacity_factor_train, capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval, capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity, min_capacity=min_capacity,
noisy_func=noisy_func, noisy_func=noisy_func,
drop_tks=drop_tks) drop_tks=drop_tks,
)
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
""" """
@ -255,7 +266,7 @@ class Top2Router(MoeRouter):
top2_idx = torch.argmax(logits_except1, dim=-1) top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = (mask1 + mask2) # loss: [s, e] cmask = mask1 + mask2 # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss # calculate loss
@ -336,15 +347,18 @@ class TopKRouter(MoeRouter):
oversubscribed / reach capacity. oversubscribed / reach capacity.
""" """
def __init__(self, def __init__(
self,
num_selected_experts: int, num_selected_experts: int,
capacity_factor_train: float = 1.25, capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0, capacity_factor_eval: float = 2.0,
min_capacity: int = 4, min_capacity: int = 4,
noisy_func: Optional[Callable] = None, noisy_func: Optional[Callable] = None,
drop_tks: bool = True): drop_tks: bool = True,
super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, ):
drop_tks) super().__init__(
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
)
def forward( def forward(
self, self,
@ -410,7 +424,7 @@ class TopKRouter(MoeRouter):
# The combine array will be used for combining expert outputs, scaled by the # The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity]. # expert_capacity].
combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask) combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
return combine_array, dispatch_mask return combine_array, dispatch_mask

View File

@ -7,13 +7,12 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.accelerator import get_accelerator
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
from colossalai.utils import get_current_device
class ForceFP32Parameter(torch.nn.Parameter): class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None): def half(self, memory_format=None):
return self.data.clone() return self.data.clone()
@ -30,8 +29,8 @@ class NormalNoiseGenerator:
def __init__(self, num_experts: int): def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal( self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()), loc=torch.tensor(0.0, device=get_accelerator().get_current_device()),
scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()), scale=torch.tensor(1.0 / num_experts**2, device=get_accelerator().get_current_device()),
).rsample ).rsample
def __call__(self, inputs: torch.Tensor): def __call__(self, inputs: torch.Tensor):
@ -52,8 +51,8 @@ class UniformNoiseGenerator:
def __init__(self, eps: float = 1e-2): def __init__(self, eps: float = 1e-2):
self.uniform = torch.distributions.uniform.Uniform( self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(1.0 - eps, device=get_current_device()), low=torch.tensor(1.0 - eps, device=get_accelerator().get_current_device()),
high=torch.tensor(1.0 + eps, device=get_current_device()), high=torch.tensor(1.0 + eps, device=get_accelerator().get_current_device()),
).rsample ).rsample
def __call__(self, inputs: torch.Tensor): def __call__(self, inputs: torch.Tensor):
@ -193,18 +192,13 @@ def create_ep_hierarchical_group(
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually." assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
nproc_per_node = int(nproc_per_node) nproc_per_node = int(nproc_per_node)
else: else:
assert dist.get_world_size() % nproc_per_node == 0, \ assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size."
"nproc_per_node should be a divisor of world_size."
num_node = dist.get_world_size() // nproc_per_node num_node = dist.get_world_size() // nproc_per_node
intra_src_rank = None intra_src_rank = None
ep_intra_node_group = None ep_intra_node_group = None
for i in range(num_node): for i in range(num_node):
ep_intra_ranks = [ ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks]
i * nproc_per_node + j
for j in range(nproc_per_node)
if j in ep_group_ranks
]
group = dist.new_group(ep_intra_ranks) group = dist.new_group(ep_intra_ranks)
if rank in ep_intra_ranks: if rank in ep_intra_ranks:
assert ep_intra_node_group is None assert ep_intra_node_group is None
@ -212,10 +206,7 @@ def create_ep_hierarchical_group(
intra_src_rank = ep_intra_ranks[0] intra_src_rank = ep_intra_ranks[0]
ep_inter_node_group = None ep_inter_node_group = None
ep_inter_ranks = [ ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)]
ep_group_ranks[0] + i * nproc_per_node
for i in range(num_node)
]
if len(ep_inter_ranks) > 1: if len(ep_inter_ranks) > 1:
group = dist.new_group(ep_inter_ranks) group = dist.new_group(ep_inter_ranks)
if rank in ep_inter_ranks: if rank in ep_inter_ranks:

View File

@ -7,10 +7,10 @@ import torch.cuda
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
from ._utils import get_batch_size, get_micro_batch, model_forward, to_device from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
from .base import PipelineSchedule from .base import PipelineSchedule
@ -86,7 +86,7 @@ class GenerateSchedule(PipelineSchedule):
""" """
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch) return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def _prepare_inputs_for_interval_stage(self): def _prepare_inputs_for_interval_stage(self):
""" """

View File

@ -6,10 +6,10 @@ import torch.cuda
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule from .base import PipelineSchedule
@ -56,7 +56,7 @@ class InterleavedSchedule(PipelineSchedule):
""" """
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch) return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int:
"""Helper method to get the model chunk ID given the iteration number. """Helper method to get the model chunk ID given the iteration number.
@ -292,7 +292,7 @@ class InterleavedSchedule(PipelineSchedule):
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
if return_loss and self.stage_manager.is_last_stage(): if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else: else:
accum_loss = None accum_loss = None

View File

@ -6,10 +6,10 @@ import torch.cuda
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
from ._utils import ( from ._utils import (
detach, detach,
@ -80,7 +80,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
""" """
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch) return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def recv_forward(self, prev_rank: int = None) -> Any: def recv_forward(self, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage. """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
@ -297,7 +297,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
if return_loss and self.stage_manager.is_last_stage(): if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else: else:
accum_loss = None accum_loss = None

View File

@ -7,7 +7,7 @@ from torch import nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup, get_world_size from torch.distributed import ProcessGroup, get_world_size
from colossalai.utils.device import get_current_device, get_rng_state, manual_seed, set_rng_state from colossalai.accelerator import get_accelerator
class SeqParallelUtils: class SeqParallelUtils:
@ -110,10 +110,10 @@ class Randomizer:
# 1. get the current rng state # 1. get the current rng state
# 2. set the seed and store the rng state # 2. set the seed and store the rng state
# 3. recover the original rng state # 3. recover the original rng state
device_original_rng_state = get_rng_state() device_original_rng_state = get_accelerator().get_rng_state()
manual_seed(seed) get_accelerator().manual_seed(seed)
self.device_rng_state = get_rng_state() self.device_rng_state = get_accelerator().get_rng_state()
set_rng_state(device_original_rng_state) get_accelerator().set_rng_state(device_original_rng_state)
# to the same for cpu rng state # to the same for cpu rng state
cpu_original_rng_state = torch.get_rng_state() cpu_original_rng_state = torch.get_rng_state()
@ -122,10 +122,10 @@ class Randomizer:
torch.set_rng_state(cpu_original_rng_state) torch.set_rng_state(cpu_original_rng_state)
def _set_device_rng_state(self, rng_state): def _set_device_rng_state(self, rng_state):
set_rng_state(rng_state) get_accelerator().set_rng_state(rng_state)
def _get_device_rng_state(self): def _get_device_rng_state(self):
current_state = get_rng_state() current_state = get_accelerator().get_rng_state()
return current_state return current_state
def _set_cpu_rng_state(self, rng_state): def _set_cpu_rng_state(self, rng_state):
@ -210,7 +210,7 @@ class Randomizer:
index = Randomizer.index() index = Randomizer.index()
if dist.is_initialized(): if dist.is_initialized():
# convert the index to tensor # convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device()) index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
# all gather the index # all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
@ -232,7 +232,7 @@ class Randomizer:
if dist.is_initialized(): if dist.is_initialized():
# convert the index to tensor # convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device()) index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
# all gather the index # all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]

View File

@ -9,7 +9,8 @@ from typing import Any, Callable, List
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from packaging import version from packaging import version
from colossalai.utils.device import empty_cache, reset_max_memory_allocated, reset_peak_memory_stats, synchronize, reset_max_memory_cached, device_count
from colossalai.accelerator import get_accelerator
def parameterize(argument: str, values: List[Any]) -> Callable: def parameterize(argument: str, values: List[Any]) -> Callable:
@ -199,7 +200,7 @@ def skip_if_not_enough_gpus(min_gpus: int):
def _wrap_func(f): def _wrap_func(f):
def _execute_by_gpu_num(*args, **kwargs): def _execute_by_gpu_num(*args, **kwargs):
num_avail_gpu = device_count() num_avail_gpu = get_accelerator().device_count()
if num_avail_gpu >= min_gpus: if num_avail_gpu >= min_gpus:
f(*args, **kwargs) f(*args, **kwargs)
@ -263,11 +264,11 @@ def clear_cache_before_run():
def _wrap_func(f): def _wrap_func(f):
def _clear_cache(*args, **kwargs): def _clear_cache(*args, **kwargs):
empty_cache() get_accelerator().empty_cache()
reset_peak_memory_stats() get_accelerator().reset_peak_memory_stats()
reset_max_memory_allocated() get_accelerator().reset_max_memory_allocated()
reset_max_memory_cached() get_accelerator().reset_max_memory_cached()
synchronize() get_accelerator().synchronize()
gc.collect() gc.collect()
f(*args, **kwargs) f(*args, **kwargs)

View File

@ -7,17 +7,12 @@ from .common import (
is_ddp_ignored, is_ddp_ignored,
set_seed, set_seed,
) )
from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
from .tensor_detector import TensorDetector from .tensor_detector import TensorDetector
from .timer import MultiTimer, Timer from .timer import MultiTimer, Timer
__all__ = [ __all__ = [
"conditional_context", "conditional_context",
"get_current_device",
"synchronize",
"empty_cache",
"set_to_cuda",
"Timer", "Timer",
"MultiTimer", "MultiTimer",
"multi_tensor_applier", "multi_tensor_applier",
@ -28,6 +23,4 @@ __all__ = [
"free_storage", "free_storage",
"set_seed", "set_seed",
"is_ddp_ignored", "is_ddp_ignored",
"set_device",
"IS_NPU_AVAILABLE",
] ]

View File

@ -1,223 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Dict, List, Optional, Tuple, Callable
import torch
import torch.distributed as dist
IS_NPU_AVAILABLE: bool = False
try:
import torch_npu # noqa
IS_NPU_AVAILABLE = torch.npu.is_available()
except ImportError:
pass
def set_to_cuda(models):
"""Send model to gpu.
:param models: nn.module or a list of module
"""
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(get_current_device())
else:
return models.to(get_current_device())
def get_current_device() -> torch.device:
"""
Returns currently selected device (gpu/cpu).
If cuda available, return gpu, otherwise return cpu.
"""
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif IS_NPU_AVAILABLE:
return torch.device(f"npu:{torch.npu.current_device()}")
else:
return torch.device("cpu")
def _dispatch_device_func(fn_name: str, *args, **kwargs):
if torch.cuda.is_available():
return getattr(torch.cuda, fn_name)(*args, **kwargs)
elif IS_NPU_AVAILABLE:
return getattr(torch.npu, fn_name)(*args, **kwargs)
else:
raise RuntimeError("No device available")
# device semantics
def can_device_access_peer(device, peer_device) -> bool:
return _dispatch_device_func("can_device_access_peer", device, peer_device)
def current_device() -> int:
return _dispatch_device_func("current_device")
def current_stream(device=None):
return _dispatch_device_func("current_stream", device)
def default_stream(device=None):
return _dispatch_device_func("default_stream", device)
def device_count() -> int:
return _dispatch_device_func("device_count")
def get_device_capability(device=None) -> Tuple[int, int]:
return _dispatch_device_func("get_device_capability", device)
def get_device_name(device=None) -> str:
return _dispatch_device_func("get_device_name", device)
def get_device_properties(device):
return _dispatch_device_func("get_device_properties", device)
def set_device(index: Optional[int] = None) -> None:
if index is None:
index = dist.get_rank() % device_count()
_dispatch_device_func("set_device", index)
def set_stream(stream_):
return _dispatch_device_func("set_stream", stream_)
def stream(stream_):
return _dispatch_device_func("stream", stream_)
def synchronize():
return _dispatch_device_func("synchronize")
def utilization(device=None) -> int:
return _dispatch_device_func("utilization", device)
# random number generator
def get_rng_state(device="cuda") -> torch.Tensor:
return _dispatch_device_func("get_rng_state", device)
def get_rng_state_all() -> List[torch.Tensor]:
return _dispatch_device_func("get_rng_state_all")
def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None:
return _dispatch_device_func("set_rng_state", new_state, device)
def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None:
return _dispatch_device_func("set_rng_state_all", new_states)
def manual_seed(seed: int) -> None:
return _dispatch_device_func("manual_seed", seed)
def manual_seed_all(seed: int) -> None:
return _dispatch_device_func("manual_seed_all", seed)
def seed() -> None:
return _dispatch_device_func("seed")
def seed_all() -> None:
return _dispatch_device_func("seed_all")
def initial_seed() -> int:
return _dispatch_device_func("initial_seed")
# streams and events
def Stream(device=None, priority=0, **kwargs):
return _dispatch_device_func("Stream", device, priority, **kwargs)
def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
return _dispatch_device_func("Event", enable_timing, blocking, interprocess)
# memory management
def empty_cache() -> None:
return _dispatch_device_func("empty_cache")
def memory_stats(device=None) -> Dict[str, Any]:
return _dispatch_device_func("memory_stats", device)
def memory_summary(device=None, abbreviated=False) -> str:
return _dispatch_device_func("memory_summary", device, abbreviated)
def memory_snapshot():
return _dispatch_device_func("memory_snapshot")
def memory_allocated(device=None) -> int:
return _dispatch_device_func("memory_allocated", device)
def max_memory_allocated(device=None) -> int:
return _dispatch_device_func("max_memory_allocated", device)
def reset_max_memory_allocated(device=None) -> None:
return _dispatch_device_func("reset_max_memory_allocated", device)
def reset_max_memory_cached(device=None) -> None:
return _dispatch_device_func("reset_max_memory_cached", device)
def memory_reserved(device=None) -> int:
return _dispatch_device_func("memory_reserved", device)
def max_memory_reserved(device=None) -> int:
return _dispatch_device_func("max_memory_reserved", device)
def set_per_process_memory_fraction(fraction: float, device=None) -> None:
return _dispatch_device_func("set_per_process_memory_fraction", fraction, device)
def reset_peak_memory_stats(device=None) -> None:
return _dispatch_device_func("reset_peak_memory_stats", device)
# amp
def autocast() -> Callable:
if torch.cuda.is_available():
return torch.cuda.amp.autocast()
elif IS_NPU_AVAILABLE:
return torch.npu.amp.autocast()
else:
raise RuntimeError("No device available")

View File

@ -3,7 +3,7 @@
import time import time
from typing import Tuple from typing import Tuple
from .device import synchronize from colossalai.accelerator import get_accelerator
class Timer: class Timer:
@ -21,13 +21,13 @@ class Timer:
@property @property
def current_time(self) -> float: def current_time(self) -> float:
synchronize() get_accelerator().synchronize()
return time.time() return time.time()
def start(self): def start(self):
"""Firstly synchronize cuda, reset the clock and then start the timer.""" """Firstly synchronize cuda, reset the clock and then start the timer."""
self._elapsed = 0 self._elapsed = 0
synchronize() get_accelerator().synchronize()
self._start_time = time.time() self._start_time = time.time()
self._started = True self._started = True
@ -44,7 +44,7 @@ class Timer:
Returns: Returns:
int: Start-stop interval. int: Start-stop interval.
""" """
synchronize() get_accelerator().synchronize()
end_time = time.time() end_time = time.time()
elapsed = end_time - self._start_time elapsed = end_time - self._start_time
if keep_in_history: if keep_in_history:

View File

@ -6,8 +6,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
from colossalai.utils.device import IS_NPU_AVAILABLE
class TensorState(Enum): class TensorState(Enum):
@ -107,7 +106,7 @@ class Chunk:
self.valid_end = self.shard_size self.valid_end = self.shard_size
self.dtype = dtype self.dtype = dtype
device = init_device or get_current_device() device = init_device or get_accelerator().get_current_device()
# chunk_temp is a global chunk, which only exists during building the chunks. # chunk_temp is a global chunk, which only exists during building the chunks.
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
@ -125,7 +124,7 @@ class Chunk:
# configure the init device of the shard # configure the init device of the shard
# no-offload default: fp16, fp32 -> CUDA # no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU # offload default: fp16, fp32 -> CPU
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device() self.shard_device = torch.device("cpu") if cpu_shard_init else get_accelerator().get_current_device()
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
self.shard_mem = self.chunk_mem // self.pg_size self.shard_mem = self.chunk_mem // self.pg_size
@ -192,10 +191,7 @@ class Chunk:
if self.chunk_temp is not None: if self.chunk_temp is not None:
return self.chunk_temp.device.type return self.chunk_temp.device.type
else: else:
if self.is_gathered or self.cuda_shard is not None: return get_accelerator().name
return "npu" if IS_NPU_AVAILABLE else "cuda"
else:
return "cpu"
@property @property
def payload(self) -> torch.Tensor: def payload(self) -> torch.Tensor:
@ -297,7 +293,7 @@ class Chunk:
self.valid_end = self.utilized_size - self.shard_begin self.valid_end = self.utilized_size - self.shard_begin
if self.chunk_temp.device.type == "cpu": if self.chunk_temp.device.type == "cpu":
self.cuda_global_chunk = self.chunk_temp.to(get_current_device()) self.cuda_global_chunk = self.chunk_temp.to(get_accelerator().get_current_device())
self.__update_tensors_ptr() self.__update_tensors_ptr()
else: else:
self.cuda_global_chunk = self.chunk_temp self.cuda_global_chunk = self.chunk_temp
@ -334,12 +330,12 @@ class Chunk:
return return
if device.type == "cuda" or device.type == "npu": if device.type == "cuda" or device.type == "npu":
assert device == get_current_device(), "can't move chunk to another device" assert device == get_accelerator().get_current_device(), "can't move chunk to another device"
if self.cuda_shard: if self.cuda_shard:
return return
self.cuda_shard = self.cpu_shard.to(get_current_device()) self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device())
if not self.pin_memory: if not self.pin_memory:
self.cpu_shard = None self.cpu_shard = None
@ -394,7 +390,9 @@ class Chunk:
if self.extra_dp_group is not None: if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group) dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
else: else:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) self.cuda_shard = torch.empty(
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
)
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
@ -533,7 +531,7 @@ class Chunk:
# only be called when optimizer state is in CPU memory # only be called when optimizer state is in CPU memory
# the grad and param should be in the same device # the grad and param should be in the same device
assert self.cuda_shard is None assert self.cuda_shard is None
temp = optim_chunk.cpu_shard.to(get_current_device()) temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device())
# avoid to transform FP32 in CPU # avoid to transform FP32 in CPU
self.cuda_shard = temp.to(self.dtype) self.cuda_shard = temp.to(self.dtype)
@ -631,7 +629,7 @@ class Chunk:
grad_chunk.valid_end = self.valid_end grad_chunk.valid_end = self.valid_end
if grad_chunk.chunk_temp.device.type == "cpu": if grad_chunk.chunk_temp.device.type == "cpu":
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device()) grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_accelerator().get_current_device())
else: else:
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp
grad_chunk.chunk_temp = None grad_chunk.chunk_temp = None

View File

@ -5,7 +5,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.utils import free_storage, get_current_device from colossalai.accelerator import get_accelerator
from colossalai.utils import free_storage
from .chunk import Chunk, ChunkFullError, TensorState from .chunk import Chunk, ChunkFullError, TensorState
@ -20,7 +21,7 @@ class ChunkManager:
""" """
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
self.device = init_device or get_current_device() self.device = init_device or get_accelerator().get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict() self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
self.kwargs_config = chunk_configuration self.kwargs_config = chunk_configuration
for k, v in self.kwargs_config.items(): for k, v in self.kwargs_config.items():
@ -107,7 +108,7 @@ class ChunkManager:
return return
self.__sub_memory_usage(chunk.memory_usage) self.__sub_memory_usage(chunk.memory_usage)
if chunk.device_type == "cpu": if chunk.device_type == "cpu":
chunk.shard_move(get_current_device()) chunk.shard_move(get_accelerator().get_current_device())
self.__add_accessed_chunk(chunk) self.__add_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage) self.__add_memory_usage(chunk.memory_usage)
@ -276,7 +277,10 @@ class ChunkManager:
accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size) accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size)
else: else:
accumulated_grad = ( accumulated_grad = (
chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size) chunk.grad_chunk.cpu_shard.to(get_accelerator().get_current_device())
.clone()
.detach()
.mul_(chunk.pg_size)
) )
accumulated_grad_gathered = False accumulated_grad_gathered = False

View File

@ -10,6 +10,7 @@ import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.distributed_c10d import _get_default_group
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import ModelWrapper from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor from colossalai.lazy import LazyTensor
@ -27,7 +28,7 @@ from colossalai.tensor.d_tensor import (
is_distributed_tensor, is_distributed_tensor,
) )
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook from .gemini_hook import GeminiZeROHook
@ -766,7 +767,7 @@ class GeminiDDP(ModelWrapper):
# move ignored parameters to CUDA # move ignored parameters to CUDA
if is_ddp_ignored(p): if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision) p.data = p.data.to(device=get_accelerator().get_current_device(), dtype=self.mixed_precision)
continue continue
# create a fp16 parameter # create a fp16 parameter
@ -815,7 +816,7 @@ class GeminiDDP(ModelWrapper):
for buffer in self.module.buffers(): for buffer in self.module.buffers():
if isinstance(buffer, LazyTensor): if isinstance(buffer, LazyTensor):
buffer.materialize() buffer.materialize()
buffer.data = buffer.to(get_current_device()) buffer.data = buffer.to(get_accelerator().get_current_device())
if torch.is_floating_point(buffer): if torch.is_floating_point(buffer):
buffer.data = buffer.to(self.mixed_precision) buffer.data = buffer.to(self.mixed_precision)

View File

@ -11,6 +11,7 @@ from torch.distributed import ProcessGroup
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
@ -26,7 +27,7 @@ from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor, is_customized_distributed_tensor,
is_distributed_tensor, is_distributed_tensor,
) )
from colossalai.utils import disposable, get_current_device, is_ddp_ignored from colossalai.utils import disposable, is_ddp_ignored
from .chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
from .gemini_ddp import GeminiDDP from .gemini_ddp import GeminiDDP
@ -233,7 +234,7 @@ class GeminiOptimizer(OptimizerWrapper):
grad_chunk.l2_norm = None # clear l2 norm grad_chunk.l2_norm = None # clear l2 norm
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) comm_buffer = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
for group, part_norm in group_to_norm.items(): for group, part_norm in group_to_norm.items():
comm_buffer.fill_(part_norm) comm_buffer.fill_(part_norm)
dist.all_reduce(comm_buffer, group=group) dist.all_reduce(comm_buffer, group=group)
@ -314,10 +315,10 @@ class GeminiOptimizer(OptimizerWrapper):
continue continue
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
self.chunk_manager.move_chunk(chunk32, get_current_device()) self.chunk_manager.move_chunk(chunk32, get_accelerator().get_current_device())
# stores grad now # stores grad now
self.chunk_manager.move_chunk(chunk16, get_current_device()) self.chunk_manager.move_chunk(chunk16, get_accelerator().get_current_device())
self.module.set_chunk_grad_device(chunk16, get_current_device()) self.module.set_chunk_grad_device(chunk16, get_accelerator().get_current_device())
fp32_params_used_cuda_margin_mem += chunk32.payload_mem fp32_params_used_cuda_margin_mem += chunk32.payload_mem
for group in self.param_groups: for group in self.param_groups:
@ -328,7 +329,7 @@ class GeminiOptimizer(OptimizerWrapper):
state = self.optim.state[fake_param] state = self.optim.state[fake_param]
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
state[k] = v.to(get_current_device()) state[k] = v.to(get_accelerator().get_current_device())
def _register_states_(self): def _register_states_(self):
for group in self.optim.param_groups: for group in self.optim.param_groups:
@ -551,7 +552,7 @@ class GeminiOptimizer(OptimizerWrapper):
self, self,
param_id: int, param_id: int,
state_names: list, state_names: list,
device: torch.device = get_current_device(), device: torch.device = get_accelerator().get_current_device(),
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
) -> torch.Tensor: ) -> torch.Tensor:
""" """

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
from colossalai.zero.gemini.chunk import ChunkManager from colossalai.zero.gemini.chunk import ChunkManager
from .memory_stats import MemStats from .memory_stats import MemStats
@ -33,4 +33,4 @@ class ChunkMemStatsCollector(MemStatsCollector):
def cuda_margin_mem(self) -> float: def cuda_margin_mem(self) -> float:
from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.legacy.utils.memory import colo_device_memory_capacity
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda return colo_device_memory_capacity(get_accelerator().get_current_device()) - self._memstats.max_overall_cuda

View File

@ -5,7 +5,7 @@ from time import sleep, time
import torch import torch
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
class MemoryMonitor: class MemoryMonitor:
@ -77,7 +77,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
super().__init__() super().__init__()
self.keep_measuring = False self.keep_measuring = False
current_device = get_current_device() current_device = get_accelerator().get_current_device()
def _set_cuda_device(): def _set_cuda_device():
torch.cuda.set_device(current_device) torch.cuda.set_device(current_device)
@ -116,7 +116,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
while self.keep_measuring: while self.keep_measuring:
max_usage = max( max_usage = max(
max_usage, max_usage,
colo_device_memory_used(get_current_device()), colo_device_memory_used(get_accelerator().get_current_device()),
) )
sleep(self.interval) sleep(self.interval)
return max_usage return max_usage

View File

@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
import torch import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.utils import get_current_device
from colossalai.zero.gemini.chunk import Chunk from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
@ -85,7 +85,7 @@ class StaticPlacementPolicy(PlacementPolicy):
# init offload optim settings # init offload optim settings
# keep gathered chunks are in CUDA # keep gathered chunks are in CUDA
if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem: if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
device = get_current_device() device = get_accelerator().get_current_device()
else: else:
device = torch.device("cpu") device = torch.device("cpu")
# real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here # real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
@ -140,7 +140,7 @@ class AutoPlacementPolicy(PlacementPolicy):
int: the volume of memory that is evicted int: the volume of memory that is evicted
""" """
start = time() start = time()
cuda_capacity = colo_device_memory_capacity(get_current_device()) cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
used_cuda_model_data = self.chunk_manager.total_mem["cuda"] used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
if warmup: if warmup:
# We designate a part of CUDA memory for model data in warmup iterations. # We designate a part of CUDA memory for model data in warmup iterations.
@ -194,7 +194,7 @@ class AutoPlacementPolicy(PlacementPolicy):
# init offload optim settings # init offload optim settings
# keep gathered chunks are in CUDA # keep gathered chunks are in CUDA
if chunk.keep_gathered: if chunk.keep_gathered:
grads_device_map[p] = get_current_device() grads_device_map[p] = get_accelerator().get_current_device()
else: else:
grads_device_map[p] = torch.device("cpu") grads_device_map[p] = torch.device("cpu")

View File

@ -6,7 +6,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
from .chunk import Chunk from .chunk import Chunk
@ -18,11 +18,11 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype):
if chunk.cuda_shard is not None: if chunk.cuda_shard is not None:
shard_temp = chunk.cuda_shard shard_temp = chunk.cuda_shard
else: else:
shard_temp = chunk.cpu_shard.to(get_current_device()) shard_temp = chunk.cpu_shard.to(get_accelerator().get_current_device())
shard_temp = shard_temp.to(dtype) shard_temp = shard_temp.to(dtype)
total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_current_device()) total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_accelerator().get_current_device())
gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)

View File

@ -12,7 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.optim import Optimizer from torch.optim import Optimizer
import colossalai.utils.device as device_utils from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_mixin import ( from colossalai.amp.naive_amp.mixed_precision_mixin import (
BF16MixedPrecisionMixin, BF16MixedPrecisionMixin,
FP16MixedPrecisionMixin, FP16MixedPrecisionMixin,
@ -22,9 +22,6 @@ from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.tensor.moe_tensor.api import is_moe_tensor
# from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, ParameterStore from .bookkeeping import BucketStore, GradientStore, ParameterStore
@ -183,7 +180,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# intialize communication stream for # intialize communication stream for
# communication-compuation overlapping # communication-compuation overlapping
if self._overlap_communication: if self._overlap_communication:
self._comm_stream = device_utils.Stream() self._comm_stream = get_accelerator().Stream()
# reduction hook is only used if overlapping communication # reduction hook is only used if overlapping communication
# or stage 2 is used # or stage 2 is used
@ -217,7 +214,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
return len(self._working_param_groups) return len(self._working_param_groups)
def _sanity_checks(self): def _sanity_checks(self):
assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required" assert get_accelerator().name in ["cuda", "npu"], "device is required"
for param_group in self.optim.param_groups: for param_group in self.optim.param_groups:
group_params = param_group["params"] group_params = param_group["params"]
for param in group_params: for param in group_params:
@ -228,7 +225,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def _create_master_param_current_rank(self, param_list): def _create_master_param_current_rank(self, param_list):
# split each param evenly by world size # split each param evenly by world size
params_current_rank = [] params_current_rank = []
device = "cpu" if self._cpu_offload else get_current_device() device = "cpu" if self._cpu_offload else get_accelerator().get_current_device()
for param in param_list: for param in param_list:
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
@ -340,11 +337,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if len(moe_grad_list) > 0: if len(moe_grad_list) > 0:
moe_flat_grads.record_stream(stream) moe_flat_grads.record_stream(stream)
# waiting for ops in the default stream finishing # waiting for ops in the default stream finishing
stream.wait_stream(device_utils.current_stream()) stream.wait_stream(get_accelerator().current_stream())
else: else:
stream = device_utils.current_stream() stream = get_accelerator().current_stream()
with device_utils.stream(stream): with get_accelerator().stream(stream):
group_id = self._bucket_store.current_group_id group_id = self._bucket_store.current_group_id
if self.moe_extra_dp_pg is None: if self.moe_extra_dp_pg is None:
@ -486,7 +483,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# clear reduced grads # clear reduced grads
if self._overlap_communication: if self._overlap_communication:
device_utils.synchronize() get_accelerator().synchronize()
self.zero_grad() self.zero_grad()
@ -505,7 +502,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# clear reduced grads # clear reduced grads
if self._overlap_communication: if self._overlap_communication:
device_utils.synchronize() get_accelerator().synchronize()
self.zero_grad() self.zero_grad()
@ -621,7 +618,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
release_param_grad(self._master_param_groups_of_current_rank[group_id]) release_param_grad(self._master_param_groups_of_current_rank[group_id])
# update working partition updated by the current rank # update working partition updated by the current rank
device = get_current_device() device = get_accelerator().get_current_device()
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"] master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param): for idx, splited_param in enumerate(master_working_param):
@ -661,7 +658,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
norm_type = float(norm_type) norm_type = float(norm_type)
if norm_type == inf: if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients) total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float) total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
total_norm = total_norm_cuda.item() total_norm = total_norm_cuda.item()
@ -673,7 +672,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# Sum across all model parallel GPUs. # Sum across all model parallel GPUs.
total_norm_exponentiated_cuda = torch.tensor( total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
) )
torch.distributed.all_reduce( torch.distributed.all_reduce(
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
@ -765,7 +764,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
Dict: the pytorch form state_dict Dict: the pytorch form state_dict
""" """
zero_state = dict() zero_state = dict()
device = get_current_device() device = get_accelerator().get_current_device()
for param, state in self.optim.state.items(): for param, state in self.optim.state.items():
zero_state[param] = copy.deepcopy(state) zero_state[param] = copy.deepcopy(state)
for k, v in state.items(): for k, v in state.items():
@ -827,7 +826,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
ret_block = dict() ret_block = dict()
ret_block_size = 0 ret_block_size = 0
device = get_current_device() device = get_accelerator().get_current_device()
local_states = self.optim.state_dict()["state"] local_states = self.optim.state_dict()["state"]
for param_idx, states in local_states.items(): for param_idx, states in local_states.items():
current_block_size = 0 current_block_size = 0

View File

@ -45,7 +45,6 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
``` ```
## Define Plugin ## Define Plugin
Create a `HybridParallelPlugin` object and specify the desired parallelism strategies to be used. In this example, both pipeline parallelism and ZeRO-1 are used simultaneously. Create a `HybridParallelPlugin` object and specify the desired parallelism strategies to be used. In this example, both pipeline parallelism and ZeRO-1 are used simultaneously.

View File

@ -43,7 +43,6 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
``` ```
### 定义plugin ### 定义plugin
定义一个[`HybridParallelPlugin`](../basics/booster_plugins.md)对象指定所需要使用的并行策略在该例子中同时使用了流水线并行和zero1. 定义一个[`HybridParallelPlugin`](../basics/booster_plugins.md)对象指定所需要使用的并行策略在该例子中同时使用了流水线并行和zero1.

View File

@ -16,10 +16,10 @@ from utils.global_vars import get_tensorboard_writer, get_timers, set_global_var
from utils.logger import Logger from utils.logger import Logger
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
@ -53,7 +53,7 @@ def main():
set_global_variables(launch_time, args.tensorboard_path) set_global_variables(launch_time, args.tensorboard_path)
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
get_current_device() get_accelerator().get_current_device()
# build model, optimizer and criterion # build model, optimizer and criterion
if args.distplan.startswith("CAI"): if args.distplan.startswith("CAI"):
@ -67,7 +67,10 @@ def main():
# build GPT model # build GPT model
with ColoInitContext( with ColoInitContext(
device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg device=get_accelerator().get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg,
): ):
config, model, numel = get_model(args, logger) config, model, numel = get_model(args, logger)
@ -78,7 +81,7 @@ def main():
elif args.distplan == "CAI_Gemini": elif args.distplan == "CAI_Gemini":
gemini_config = dict( gemini_config = dict(
strict_ddp_mode=args.tp_degree == 1, strict_ddp_mode=args.tp_degree == 1,
device=get_current_device(), device=get_accelerator().get_current_device(),
placement_policy=args.placement, placement_policy=args.placement,
pin_memory=True, pin_memory=True,
hidden_dim=model.config.hidden_size, hidden_dim=model.config.hidden_size,

View File

@ -20,11 +20,11 @@ from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig from transformers import AutoTokenizer, PretrainedConfig
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()
@ -386,7 +386,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir())) cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images: if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
@ -401,7 +401,7 @@ def main(args):
sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
pipeline.to(get_current_device()) pipeline.to(get_accelerator().get_current_device())
for example in tqdm( for example in tqdm(
sample_dataloader, sample_dataloader,
@ -578,8 +578,8 @@ def main(args):
# Move text_encode and vae to gpu. # Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision # For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required. # as these models are only used for inference, keeping weights in full precision is not required.
vae.to(get_current_device(), dtype=weight_dtype) vae.to(get_accelerator().get_current_device(), dtype=weight_dtype)
text_encoder.to(get_current_device(), dtype=weight_dtype) text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader)) num_update_steps_per_epoch = math.ceil(len(train_dataloader))
@ -613,7 +613,7 @@ def main(args):
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
# Move batch to gpu # Move batch to gpu
for key, value in batch.items(): for key, value in batch.items():
batch[key] = value.to(get_current_device(), non_blocking=True) batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True)
# Convert images to latent space # Convert images to latent space
optimizer.zero_grad() optimizer.zero_grad()

View File

@ -21,13 +21,13 @@ from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig from transformers import AutoTokenizer, PretrainedConfig
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()
@ -385,7 +385,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir())) cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images: if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
@ -400,7 +400,7 @@ def main(args):
sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
pipeline.to(get_current_device()) pipeline.to(get_accelerator().get_current_device())
for example in tqdm( for example in tqdm(
sample_dataloader, sample_dataloader,
@ -598,8 +598,8 @@ def main(args):
# Move text_encode and vae to gpu. # Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision # For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required. # as these models are only used for inference, keeping weights in full precision is not required.
vae.to(get_current_device(), dtype=weight_dtype) vae.to(get_accelerator().get_current_device(), dtype=weight_dtype)
text_encoder.to(get_current_device(), dtype=weight_dtype) text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader)) num_update_steps_per_epoch = math.ceil(len(train_dataloader))
@ -633,7 +633,7 @@ def main(args):
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
# Move batch to gpu # Move batch to gpu
for key, value in batch.items(): for key, value in batch.items():
batch[key] = value.to(get_current_device(), non_blocking=True) batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True)
# Convert images to latent space # Convert images to latent space
optimizer.zero_grad() optimizer.zero_grad()

View File

@ -13,12 +13,12 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
# ============================== # ==============================
# Prepare Hyperparameters # Prepare Hyperparameters
@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
@torch.no_grad() @torch.no_grad()
def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float:
model.eval() model.eval()
correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
for images, labels in test_dataloader: for images, labels in test_dataloader:
images = images.cuda() images = images.cuda()
labels = labels.cuda() labels = labels.cuda()

View File

@ -33,9 +33,10 @@ def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224
def colo_memory_cap(size_in_GB): def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device from colossalai.accelerator import get_accelerator
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction
cuda_capacity = colo_device_memory_capacity(get_current_device()) cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
if size_in_GB * (1024**3) < cuda_capacity: if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
print(f"Limiting GPU memory usage to {size_in_GB} GB") print(f"Limiting GPU memory usage to {size_in_GB} GB")

View File

@ -6,10 +6,9 @@ import torch.distributed as dist
import transformers import transformers
import colossalai import colossalai
import colossalai.utils.device as device_utils from colossalai.accelerator import get_accelerator
from colossalai.inference import InferenceEngine from colossalai.inference import InferenceEngine
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.utils.device import get_current_device
GIGABYTE = 1024**3 GIGABYTE = 1024**3
MEGABYTE = 1024 * 1024 MEGABYTE = 1024 * 1024
@ -52,7 +51,7 @@ CONFIG_MAP = {
def data_gen(batch_size: int = 4, seq_len: int = 512): def data_gen(batch_size: int = 4, seq_len: int = 512):
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device()) input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_accelerator().get_current_device())
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
data = dict(input_ids=input_ids, attention_mask=attention_mask) data = dict(input_ids=input_ids, attention_mask=attention_mask)
return data return data
@ -97,9 +96,9 @@ def print_details_info(outputs, model_config, args, whole_end2end):
msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n" msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n"
if torch.cuda.is_available(): if torch.cuda.is_available():
msg += f"-------Memory Summary Device:{device_utils.current_device()}-------\n" msg += f"-------Memory Summary Device:{get_accelerator().current_device()}-------\n"
msg += f"Max memory allocated: {device_utils.max_memory_allocated() / GIGABYTE:.2f} GB\n" msg += f"Max memory allocated: {get_accelerator().max_memory_allocated() / GIGABYTE:.2f} GB\n"
msg += f"Max memory reserved: {device_utils.max_memory_reserved() / GIGABYTE:.2f} GB\n" msg += f"Max memory reserved: {get_accelerator().max_memory_reserved() / GIGABYTE:.2f} GB\n"
print(msg) print(msg)

View File

@ -5,9 +5,9 @@ import torch.distributed as dist
from transformers import LlamaForCausalLM, LlamaTokenizer from transformers import LlamaForCausalLM, LlamaTokenizer
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.inference import InferenceEngine from colossalai.inference import InferenceEngine
from colossalai.testing import spawn from colossalai.testing import spawn
from colossalai.utils.device import get_current_device
INPUT_TEXTS = [ INPUT_TEXTS = [
"What is the longest river in the world?", "What is the longest river in the world?",
@ -57,7 +57,7 @@ def run_inference(args):
) )
inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True) inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True)
inputs = {k: v.to(get_current_device()) for k, v in inputs.items()} inputs = {k: v.to(get_accelerator().get_current_device()) for k, v in inputs.items()}
outputs = engine.generate(inputs) outputs = engine.generate(inputs)
if rank == 0: if rank == 0:

View File

@ -18,11 +18,11 @@ from transformers import (
) )
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
# ============================== # ==============================
# Prepare Hyperparameters # Prepare Hyperparameters
@ -59,7 +59,7 @@ def evaluate_model(
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
for batch in dataloader: for batch in dataloader:
batch = move_to_cuda(batch) batch = move_to_cuda(batch)
labels = batch["labels"] labels = batch["labels"]
@ -88,8 +88,10 @@ def evaluate_model(
object_list = [None, None] object_list = [None, None]
dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group)
metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) metric.add_batch(
accum_loss.add_(object_list[1].to(get_current_device())) predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels
)
accum_loss.add_(object_list[1].to(get_accelerator().get_current_device()))
else: else:
batch = move_to_cuda(batch) batch = move_to_cuda(batch)

View File

@ -7,13 +7,13 @@ from model_zoo import GPTLMLoss, get_gpt2_components
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.fx.profiler import parameter_size from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import spawn from colossalai.testing import spawn
from colossalai.utils import get_current_device
def parse_args(): def parse_args():
@ -41,7 +41,7 @@ def train_gpt(args):
64, 64,
8, 8,
), ),
device=get_current_device(), device=get_accelerator().get_current_device(),
) )
criterion = GPTLMLoss() criterion = GPTLMLoss()

View File

@ -12,12 +12,12 @@ from commons.utils import get_data, get_profile_context, get_tflops, get_time_st
from packaging import version from packaging import version
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
CAI_VERSION = colossalai.__version__ CAI_VERSION = colossalai.__version__
@ -141,7 +141,11 @@ def main():
criterion = GPTLMLoss() criterion = GPTLMLoss()
torch.manual_seed(123) torch.manual_seed(123)
if args.distplan.startswith("CAI"): if args.distplan.startswith("CAI"):
ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext() ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if args.distplan == "CAI_Gemini"
else nullcontext()
)
# build GPT model # build GPT model
with ctx: with ctx:
model = model_builder(args.model_type)(checkpoint=True) model = model_builder(args.model_type)(checkpoint=True)

View File

@ -13,11 +13,11 @@ from tqdm import tqdm
from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
# ============================== # ==============================
# Prepare Hyperparameters # Prepare Hyperparameters
@ -54,7 +54,7 @@ def evaluate_model(
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
for batch in dataloader: for batch in dataloader:
batch = move_to_cuda(batch) batch = move_to_cuda(batch)
labels = batch["labels"] labels = batch["labels"]
@ -83,8 +83,10 @@ def evaluate_model(
object_list = [None, None] object_list = [None, None]
dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group)
metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) metric.add_batch(
accum_loss.add_(object_list[1].to(get_current_device())) predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels
)
accum_loss.add_(object_list[1].to(get_accelerator().get_current_device()))
else: else:
batch = move_to_cuda(batch) batch = move_to_cuda(batch)

View File

@ -5,6 +5,7 @@ from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.context import ParallelMode, seed
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.base_layer import ParallelLayer from colossalai.legacy.nn.layer.base_layer import ParallelLayer
@ -12,7 +13,6 @@ from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_b
from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row
from colossalai.legacy.nn.layer.utils import divide from colossalai.legacy.nn.layer.utils import divide
from colossalai.legacy.registry import LAYERS, LOSSES from colossalai.legacy.registry import LAYERS, LOSSES
from colossalai.utils import get_current_device
class VocabParallelEmbedding(torch.nn.Module): class VocabParallelEmbedding(torch.nn.Module):
@ -96,7 +96,9 @@ class VocabParallelEmbedding(torch.nn.Module):
if position_ids is not None: if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1]) position_ids = position_ids.view(-1, input_shape[-1])
if position_ids is None: if position_ids is None:
position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device()) position_ids = torch.arange(
0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device()
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
@ -194,7 +196,7 @@ class VocabParallelEmbedding1D(torch.nn.Module):
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
# Allocate weights and initialize. # Allocate weights and initialize.
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs)) self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs))
init.uniform_(self.weight, -1, 1) init.uniform_(self.weight, -1, 1)
@ -439,7 +441,9 @@ class HiddenParallelEmbedding(torch.nn.Module):
if position_ids is not None: if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1]) position_ids = position_ids.view(-1, input_shape[-1])
if position_ids is None: if position_ids is None:
position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device()) position_ids = torch.arange(
0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device()
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
@ -532,7 +536,7 @@ class HiddenParallelEmbedding1D(torch.nn.Module):
self._weight = None self._weight = None
# Allocate weights and initialize. # Allocate weights and initialize.
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs)) self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs))
init.uniform_(self.weight, -1, 1) init.uniform_(self.weight, -1, 1)

View File

@ -13,13 +13,12 @@ from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaForCausalLM
import colossalai import colossalai
import colossalai.utils.device as device_utils from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
# ============================== # ==============================
# Constants # Constants
@ -166,7 +165,7 @@ def main():
# Initialize Model and Optimizer # Initialize Model and Optimizer
# ============================== # ==============================
init_ctx = ( init_ctx = (
LazyInitContext(default_device=get_current_device()) LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
else nullcontext() else nullcontext()
) )
@ -197,7 +196,9 @@ def main():
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master(
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
)
coordinator.print_on_master( coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
) )
@ -223,7 +224,7 @@ def main():
performance_evaluator.on_step_end(**batch) performance_evaluator.on_step_end(**batch)
performance_evaluator.on_fit_end() performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -8,7 +8,7 @@ from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.distributed_c10d import _get_default_group
from torch.utils.data import DataLoader, Dataset, DistributedSampler from torch.utils.data import DataLoader, Dataset, DistributedSampler
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
class StatefulDistributedSampler(DistributedSampler): class StatefulDistributedSampler(DistributedSampler):
@ -108,7 +108,9 @@ class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
self.num_samples = num_samples self.num_samples = num_samples
self.max_length = max_length self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) self.input_ids = torch.randint(
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
)
self.attention_mask = torch.ones_like(self.input_ids) self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self): def __len__(self):

View File

@ -21,13 +21,13 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.llama.tokenization_llama import LlamaTokenizer from transformers.models.llama.tokenization_llama import LlamaTokenizer
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def get_model_numel(model: nn.Module) -> int: def get_model_numel(model: nn.Module) -> int:
@ -191,7 +191,9 @@ def main():
config = LlamaConfig.from_pretrained(args.model_path) config = LlamaConfig.from_pretrained(args.model_path)
# use lazy init when using GeminiPlugin # use lazy init when using GeminiPlugin
init_ctx = ( init_ctx = (
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, GeminiPlugin)
else nullcontext()
) )
with init_ctx: with init_ctx:

View File

@ -5,9 +5,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor from torch import Tensor
import colossalai.utils.device as device_utils from colossalai.accelerator import get_accelerator
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.utils.device import get_current_device
def divide(x: float, y: float) -> float: def divide(x: float, y: float) -> float:
@ -22,7 +21,7 @@ def divide(x: float, y: float) -> float:
def all_reduce_mean(x: float, world_size: int) -> float: def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1: if world_size == 1:
return x return x
tensor = torch.tensor([x], device=get_current_device()) tensor = torch.tensor([x], device=get_accelerator().get_current_device())
dist.all_reduce(tensor) dist.all_reduce(tensor)
tensor = tensor / world_size tensor = tensor / world_size
return tensor.item() return tensor.item()
@ -86,13 +85,13 @@ class PerformanceEvaluator:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable: if self.disable:
return return
device_utils.synchronize() get_accelerator().synchronize()
self.timer.start() self.timer.start()
def on_step_end(self, input_ids: Tensor, **kwargs) -> None: def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable: if self.disable:
return return
device_utils.synchronize() get_accelerator().synchronize()
self.timer.end() self.timer.end()
batch_size, seq_len = input_ids.shape batch_size, seq_len = input_ids.shape

View File

@ -20,13 +20,13 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.llama.tokenization_llama import LlamaTokenizer from transformers.models.llama.tokenization_llama import LlamaTokenizer
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
MODEL_CONFIGS = { MODEL_CONFIGS = {
"7b": LlamaConfig(max_position_embeddings=4096), "7b": LlamaConfig(max_position_embeddings=4096),
@ -227,7 +227,9 @@ def main():
config = MODEL_CONFIGS[args.config] config = MODEL_CONFIGS[args.config]
# use lazy init when using GeminiPlugin # use lazy init when using GeminiPlugin
init_ctx = ( init_ctx = (
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, GeminiPlugin)
else nullcontext()
) )
with init_ctx: with init_ctx:

View File

@ -14,6 +14,7 @@ from transformers.models.llama import LlamaConfig
from utils import PerformanceEvaluator, get_model_numel from utils import PerformanceEvaluator, get_model_numel
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
@ -21,7 +22,6 @@ from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init from colossalai.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def move_to_cuda(batch, device): def move_to_cuda(batch, device):
@ -64,13 +64,15 @@ class RandomDataset(Dataset):
) )
self.input_ids.append(encode["input_ids"]) self.input_ids.append(encode["input_ids"])
self.attention_mask.append(encode["attention_mask"]) self.attention_mask.append(encode["attention_mask"])
self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device()) self.input_ids = torch.cat(self.input_ids, dim=0).to(get_accelerator().get_current_device())
self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device()) self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_accelerator().get_current_device())
repeat_times = num_samples // self.input_ids.shape[0] + 1 repeat_times = num_samples // self.input_ids.shape[0] + 1
self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples] self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples]
self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples] self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples]
else: else:
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) self.input_ids = torch.randint(
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
)
self.attention_mask = torch.ones_like(self.input_ids) self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self): def __len__(self):

View File

@ -15,6 +15,7 @@ from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig from transformers.models.llama import LlamaConfig
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
@ -22,7 +23,6 @@ from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init from colossalai.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def move_to_cuda(batch, device): def move_to_cuda(batch, device):
@ -61,7 +61,9 @@ class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
self.num_samples = num_samples self.num_samples = num_samples
self.max_length = max_length self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) self.input_ids = torch.randint(
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
)
self.attention_mask = torch.ones_like(self.input_ids) self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self): def __len__(self):

View File

@ -14,12 +14,12 @@ from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import HybridAdam from colossalai.nn import HybridAdam
from colossalai.utils import get_current_device
# constants # constants
@ -159,7 +159,11 @@ if args.distplan == "colossalai":
logger.info(f"plugin: {plugin}") logger.info(f"plugin: {plugin}")
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == "gemini" else nullcontext() ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if args.plugin == "gemini"
else nullcontext()
)
with ctx: with ctx:
model = PaLM(num_tokens=50304, dim=4096, depth=64) model = PaLM(num_tokens=50304, dim=4096, depth=64)

View File

@ -13,12 +13,12 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
# ============================== # ==============================
# Prepare Hyperparameters # Prepare Hyperparameters
@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
@torch.no_grad() @torch.no_grad()
def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float:
model.eval() model.eval()
correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
for images, labels in test_dataloader: for images, labels in test_dataloader:
images = images.cuda() images = images.cuda()
labels = labels.cuda() labels = labels.cuda()

View File

@ -13,13 +13,13 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.nn.lr_scheduler import LinearWarmupLR from colossalai.nn.lr_scheduler import LinearWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
# ============================== # ==============================
# Prepare Hyperparameters # Prepare Hyperparameters
@ -73,8 +73,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
@torch.no_grad() @torch.no_grad()
def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float:
model.eval() model.eval()
correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
for images, labels in test_dataloader: for images, labels in test_dataloader:
images = images.cuda() images = images.cuda()
labels = labels.cuda() labels = labels.cuda()

View File

@ -12,11 +12,11 @@ from tqdm import tqdm
from transformers import AutoConfig, BertForSequenceClassification, get_linear_schedule_with_warmup from transformers import AutoConfig, BertForSequenceClassification, get_linear_schedule_with_warmup
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
# ============================== # ==============================
# Prepare Hyperparameters # Prepare Hyperparameters
@ -45,7 +45,7 @@ def evaluate(
model.eval() model.eval()
def evaluate_subset(dataloader: DataLoader): def evaluate_subset(dataloader: DataLoader):
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
for batch in dataloader: for batch in dataloader:
batch = move_to_cuda(batch) batch = move_to_cuda(batch)
outputs = model(**batch) outputs = model(**batch)

View File

@ -51,13 +51,13 @@ from transformers import (
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.tensor import ProcessGroup from colossalai.legacy.tensor import ProcessGroup
from colossalai.legacy.utils import get_dataloader from colossalai.legacy.utils import get_dataloader
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.zero import GeminiOptimizer from colossalai.zero import GeminiOptimizer
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
@ -249,9 +249,9 @@ def parse_args():
def colo_memory_cap(size_in_GB): def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction
cuda_capacity = colo_device_memory_capacity(get_current_device()) cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
if size_in_GB * (1024**3) < cuda_capacity: if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
print("Using {} GB of GPU memory".format(size_in_GB)) print("Using {} GB of GPU memory".format(size_in_GB))
@ -265,7 +265,9 @@ class DummyDataloader:
self.vocab_size = vocab_size self.vocab_size = vocab_size
def generate(self): def generate(self):
input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=get_current_device()) input_ids = torch.randint(
0, self.vocab_size, (self.batch_size, self.seq_len), device=get_accelerator().get_current_device()
)
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids} return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids}
@ -390,7 +392,7 @@ def main():
if args.init_in_cpu: if args.init_in_cpu:
init_dev = torch.device("cpu") init_dev = torch.device("cpu")
else: else:
init_dev = get_current_device() init_dev = get_accelerator().get_current_device()
cai_version = colossalai.__version__ cai_version = colossalai.__version__
logger.info(f"using Colossal-AI version {cai_version}") logger.info(f"using Colossal-AI version {cai_version}")
@ -439,7 +441,9 @@ def main():
except ImportError: except ImportError:
# this works for unreleased main branch, and this may be released on 0.2.9 # this works for unreleased main branch, and this may be released on 0.2.9
from colossalai.zero import GeminiDDP from colossalai.zero import GeminiDDP
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) model = GeminiDDP(
model, device=get_accelerator().get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True
)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager from colossalai.gemini import ChunkManager, GeminiManager

Some files were not shown because too many files have changed in this diff Show More